diff --git a/breacharbiter_test.go b/breacharbiter_test.go index 74722beb..e8c6afbb 100644 --- a/breacharbiter_test.go +++ b/breacharbiter_test.go @@ -2,7 +2,6 @@ package main import ( "bytes" - "errors" "fmt" "io/ioutil" "os" @@ -10,6 +9,7 @@ import ( "sync" "testing" + "github.com/btcsuite/btclog" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnwallet" "github.com/roasbeef/btcd/btcec" @@ -192,7 +192,8 @@ var ( }, } - retributions = []retributionInfo{ + retributionMap = make(map[wire.OutPoint]retributionInfo) + retributions = []retributionInfo{ { commitHash: [chainhash.HashSize]byte{ 0xb7, 0x94, 0x38, 0x5f, 0x2d, 0x1e, 0xf7, 0xab, @@ -227,6 +228,83 @@ var ( } ) +func init() { + // Ensure that breached outputs are initialized before starting tests. + if err := initBreachedOutputs(); err != nil { + panic(err) + } + + // Populate a retribution map to for convenience, to allow lookups by + // channel point. + for i := range retributions { + retInfo := &retributions[i] + retInfo.remoteIdentity = *breachedOutputs[i].signDescriptor.PubKey + retributionMap[retInfo.chanPoint] = *retInfo + } +} + +// FailingRetributionStore wraps a RetributionStore and supports controlled +// restarts of the persistent instance. This allows us to test (1) that no +// modifications to the entries are made between calls or through side effects, +// and (2) that the database is actually being persisted between actions. +type FailingRetributionStore interface { + RetributionStore + + Restart() +} + +// failingRetributionStore is a concrete implementation of a +// FailingRetributionStore. It wraps an underlying RetributionStore and is +// parameterized entirely by a restart function, which is intended to simulate a +// full stop/start of the store. +type failingRetributionStore struct { + mu sync.Mutex + + rs RetributionStore + + restart func() RetributionStore +} + +// newFailingRetributionStore creates a new failing retribution store. The given +// restart closure should ensure that it is reloading its contents from the +// persistent source. +func newFailingRetributionStore( + restart func() RetributionStore) *failingRetributionStore { + + return &failingRetributionStore{ + mu: sync.Mutex{}, + rs: restart(), + restart: restart, + } +} + +func (frs *failingRetributionStore) Restart() { + frs.mu.Lock() + frs.rs = frs.restart() + frs.mu.Unlock() +} + +func (frs *failingRetributionStore) Add(retInfo *retributionInfo) error { + frs.mu.Lock() + defer frs.mu.Unlock() + + return frs.rs.Add(retInfo) +} + +func (frs *failingRetributionStore) Remove(key *wire.OutPoint) error { + frs.mu.Lock() + defer frs.mu.Unlock() + + return frs.rs.Remove(key) +} + +func (frs *failingRetributionStore) ForAll(cb func(*retributionInfo) error) error { + frs.mu.Lock() + defer frs.mu.Unlock() + + return frs.rs.ForAll(cb) +} + // Parse the pubkeys in the breached outputs. func initBreachedOutputs() error { for i := range breachedOutputs { @@ -236,7 +314,8 @@ func initBreachedOutputs() error { sd := &breachSignDescs[i] pubkey, err := btcec.ParsePubKey(breachKeys[i], btcec.S256()) if err != nil { - return fmt.Errorf("unable to parse pubkey: %v", breachKeys[i]) + return fmt.Errorf("unable to parse pubkey: %v", + breachKeys[i]) } sd.PubKey = pubkey bo.signDescriptor = *sd @@ -247,26 +326,25 @@ func initBreachedOutputs() error { // Test that breachedOutput Encode/Decode works. func TestBreachedOutputSerialization(t *testing.T) { - if err := initBreachedOutputs(); err != nil { - t.Fatalf("unable to init breached outputs: %v", err) - } - for i := 0; i < len(breachedOutputs); i++ { bo := &breachedOutputs[i] var buf bytes.Buffer if err := bo.Encode(&buf); err != nil { - t.Fatalf("unable to serialize breached output [%v]: %v", i, err) + t.Fatalf("unable to serialize breached output [%v]: %v", + i, err) } desBo := &breachedOutput{} if err := desBo.Decode(&buf); err != nil { - t.Fatalf("unable to deserialize breached output [%v]: %v", i, err) + t.Fatalf("unable to deserialize "+ + "breached output [%v]: %v", i, err) } if !reflect.DeepEqual(bo, desBo) { - t.Fatalf("original and deserialized breached outputs not equal:\n"+ + t.Fatalf("original and deserialized "+ + "breached outputs not equal:\n"+ "original : %+v\n"+ "deserialized : %+v\n", bo, desBo) @@ -276,32 +354,25 @@ func TestBreachedOutputSerialization(t *testing.T) { // Test that retribution Encode/Decode works. func TestRetributionSerialization(t *testing.T) { - if err := initBreachedOutputs(); err != nil { - t.Fatalf("unable to init breached outputs: %v", err) - } - for i := 0; i < len(retributions); i++ { ret := &retributions[i] - remoteIdentity, err := btcec.ParsePubKey(breachKeys[i], btcec.S256()) - if err != nil { - t.Fatalf("unable to parse public key [%v]: %v", i, err) - } - ret.remoteIdentity = *remoteIdentity - var buf bytes.Buffer if err := ret.Encode(&buf); err != nil { - t.Fatalf("unable to serialize retribution [%v]: %v", i, err) + t.Fatalf("unable to serialize retribution [%v]: %v", + i, err) } desRet := &retributionInfo{} if err := desRet.Decode(&buf); err != nil { - t.Fatalf("unable to deserialize retribution [%v]: %v", i, err) + t.Fatalf("unable to deserialize retribution [%v]: %v", + i, err) } if !reflect.DeepEqual(ret, desRet) { - t.Fatalf("original and deserialized retribution infos not equal:\n"+ + t.Fatalf("original and deserialized "+ + "retribution infos not equal:\n"+ "original : %+v\n"+ "deserialized : %+v\n", ret, desRet) @@ -311,6 +382,8 @@ func TestRetributionSerialization(t *testing.T) { // copyRetInfo creates a complete copy of the given retributionInfo. func copyRetInfo(retInfo *retributionInfo) *retributionInfo { + nHtlcs := len(retInfo.htlcOutputs) + ret := &retributionInfo{ commitHash: retInfo.commitHash, chanPoint: retInfo.chanPoint, @@ -319,8 +392,8 @@ func copyRetInfo(retInfo *retributionInfo) *retributionInfo { settledBalance: retInfo.settledBalance, selfOutput: retInfo.selfOutput, revokedOutput: retInfo.revokedOutput, - htlcOutputs: make([]*breachedOutput, len(retInfo.htlcOutputs)), - doneChan: make(chan struct{}), + htlcOutputs: make([]*breachedOutput, nHtlcs), + doneChan: retInfo.doneChan, } for i, htlco := range retInfo.htlcOutputs { @@ -374,11 +447,47 @@ func (rs *mockRetributionStore) ForAll(cb func(*retributionInfo) error) error { return nil } +var retributionStoreTestSuite = []struct { + name string + test func(FailingRetributionStore, *testing.T) +}{ + { + "Initialization", + testRetributionStoreInit, + }, + { + "Add/Remove", + testRetributionStoreAddRemove, + }, + { + "Persistence", + testRetributionStorePersistence, + }, + { + "Overwrite", + testRetributionStoreOverwrite, + }, + { + "RemoveEmpty", + testRetributionStoreRemoveEmpty, + }, +} + // TestMockRetributionStore instantiates a mockRetributionStore and tests its // behavior using the general RetributionStore test suite. func TestMockRetributionStore(t *testing.T) { - mrs := newMockRetributionStore() - testRetributionStore(mrs, t) + for _, test := range retributionStoreTestSuite { + t.Run( + "mockRetributionStore."+test.name, + func(tt *testing.T) { + mrs := newMockRetributionStore() + frs := newFailingRetributionStore( + func() RetributionStore { return mrs }, + ) + test.test(frs, tt) + }, + ) + } } // TestChannelDBRetributionStore instantiates a retributionStore backed by a @@ -389,10 +498,14 @@ func TestChannelDBRetributionStore(t *testing.T) { // this test. tempDirName, err := ioutil.TempDir("", "channeldb") if err != nil { - t.Fatalf("unable to initialize temp directory for channeldb: %v", err) + t.Fatalf("unable to initialize temp "+ + "directory for channeldb: %v", err) } defer os.RemoveAll(tempDirName) + // Disable logging to prevent panics bc. of global state + channeldb.UseLogger(btclog.Disabled) + // Next, create channeldb for the first time. db, err := channeldb.Open(tempDirName) if err != nil { @@ -400,12 +513,40 @@ func TestChannelDBRetributionStore(t *testing.T) { } defer db.Close() - // Finally, instantiate retribution store and execute RetributionStore test - // suite. - rs := newRetributionStore(db) - testRetributionStore(rs, t) + restartDb := func() RetributionStore { + // Close and reopen channeldb + if err = db.Close(); err != nil { + t.Fatalf("unalbe to close channeldb during restart: %v", + err) + } + db, err = channeldb.Open(tempDirName) + if err != nil { + t.Fatalf("unable to open channeldb: %v", err) + } + + return newRetributionStore(db) + } + + // Finally, instantiate retribution store and execute RetributionStore + // test suite. + for _, test := range retributionStoreTestSuite { + t.Run( + "channeldbDBRetributionStore."+test.name, + func(tt *testing.T) { + if err = db.Wipe(); err != nil { + t.Fatalf("unable to wipe channeldb: %v", + err) + } + + frs := newFailingRetributionStore(restartDb) + test.test(frs, tt) + }, + ) + } } +// countRetributions uses a retribution store's ForAll to count the number of +// elements emitted from the store. func countRetributions(t *testing.T, rs RetributionStore) int { count := 0 err := rs.ForAll(func(_ *retributionInfo) error { @@ -418,61 +559,224 @@ func countRetributions(t *testing.T, rs RetributionStore) int { return count } -// Test that the retribution persistence layer works. -func testRetributionStore(rs RetributionStore, t *testing.T) { - if err := initBreachedOutputs(); err != nil { - t.Fatalf("unable to init breached outputs: %v", err) - } - +// testRetributionStore executes a generic test suite for any concrete +// implementation of the RetributionStore interface. +func testRetributionStoreAddRemove(frs FailingRetributionStore, t *testing.T) { // Make sure that a new retribution store is actually emtpy. - if count := countRetributions(t, rs); count != 0 { + if count := countRetributions(t, frs); count != 0 { t.Fatalf("expected 0 retributions, found %v", count) } - // Add first retribution state to the store. - if err := rs.Add(&retributions[0]); err != nil { - t.Fatalf("unable to add to retribution store: %v", err) - } - // Ensure that the retribution store has one retribution. - if count := countRetributions(t, rs); count != 1 { + // Add all retributions, check that ForAll returns the correct + // information, and then remove all retributions. + testRetributionStoreAdds(frs, t, false) + testRetributionStoreForAll(frs, t, false) + testRetributionStoreRemoves(frs, t, false) +} + +func testRetributionStorePersistence(frs FailingRetributionStore, t *testing.T) { + // Make sure that a new retribution store is still emtpy after failing + // right off the bat. + frs.Restart() + if count := countRetributions(t, frs); count != 0 { t.Fatalf("expected 1 retributions, found %v", count) } - // Add second retribution state to the store. - if err := rs.Add(&retributions[1]); err != nil { - t.Fatalf("unable to add to retribution store: %v", err) - } - // There should be 2 retributions in the store. - if count := countRetributions(t, rs); count != 2 { - t.Fatalf("expected 2 retributions, found %v", count) - } + // Insert all retributions into the database, restarting and checking + // between subsequent calls to test that each intermediate additions are + // persisted. + testRetributionStoreAdds(frs, t, true) - // Retrieving the retribution states from the store should yield the same - // values as the originals. - rs.ForAll(func(ret *retributionInfo) error { - equal0 := reflect.DeepEqual(ret, &retributions[0]) - equal1 := reflect.DeepEqual(ret, &retributions[1]) - if !equal0 || !equal1 { - return errors.New("unexpected retribution retrieved from db") - } - return nil - }) + // After all retributions have been inserted, verify that the store + // emits a distinct set of retributions that are equivalent to the test + // vector. + testRetributionStoreForAll(frs, t, true) - // Remove the retribution states. - if err := rs.Remove(&retributions[0].chanPoint); err != nil { - t.Fatalf("unable to remove from retribution store: %v", err) - } - // Ensure that the retribution store has one retribution. - if count := countRetributions(t, rs); count != 1 { - t.Fatalf("expected 1 retributions, found %v", count) - } + // Remove all retributions from the database, restarting and checking + // between subsequent calls to test that each intermediate removals are + // persisted. + testRetributionStoreRemoves(frs, t, true) +} - if err := rs.Remove(&retributions[1].chanPoint); err != nil { - t.Fatalf("unable to remove from retribution store: %v", err) - } - - // Ensure that the retribution store is empty again. - if count := countRetributions(t, rs); count != 0 { +func testRetributionStoreInit(frs FailingRetributionStore, t *testing.T) { + // Make sure that a new retribution store starts empty. + if count := countRetributions(t, frs); count != 0 { t.Fatalf("expected 0 retributions, found %v", count) } } + +func testRetributionStoreRemoveEmpty(frs FailingRetributionStore, t *testing.T) { + testRetributionStoreRemoves(frs, t, false) +} + +func testRetributionStoreOverwrite(frs FailingRetributionStore, t *testing.T) { + // Initially, add all retributions to store. + testRetributionStoreAdds(frs, t, false) + + // Overwrite the initial entries again. + for i, retInfo := range retributions { + if err := frs.Add(&retInfo); err != nil { + t.Fatalf("unable to add to retribution %v to store: %v", + i, err) + } + } + + // Check that retribution store still has 2 entries. + if count := countRetributions(t, frs); count != 2 { + t.Fatalf("expected 2 retributions, found %v", count) + } +} + +func testRetributionStoreAdds( + frs FailingRetributionStore, + t *testing.T, + failing bool) { + + // Iterate over retributions, adding each from the store. If we are + // testing the store under failures, we restart the store and verify + // that the contents are the same. + for i, retInfo := range retributions { + // Snapshot number of entires before and after the addition. + nbefore := countRetributions(t, frs) + if err := frs.Add(&retInfo); err != nil { + t.Fatalf("unable to add to retribution %v to store: %v", + i, err) + } + nafter := countRetributions(t, frs) + + // Check that only one retribution was added. + if nafter-nbefore != 1 { + t.Fatalf("expected %v retributions, found %v", + nbefore+1, nafter) + } + + if failing { + frs.Restart() + + // Check that retribution store has persisted addition + // after restarting. + nrestart := countRetributions(t, frs) + if nrestart-nbefore != 1 { + t.Fatalf("expected %v retributions, found %v", + nbefore+1, nrestart) + } + } + } +} + +func testRetributionStoreRemoves( + frs FailingRetributionStore, + t *testing.T, + failing bool) { + + // Iterate over retributions, removing each from the store. If we are + // testing the store under failures, we restart the store and verify + // that the contents are the same. + for i, retInfo := range retributions { + // Snapshot number of entires before and after the removal. + nbefore := countRetributions(t, frs) + if err := frs.Remove(&retInfo.chanPoint); err != nil { + t.Fatalf("unable to remove to retribution %v "+ + "from store: %v", i, err) + } + nafter := countRetributions(t, frs) + + // If the store is empty, increment nbefore to simulate the + // removal of one element. + if nbefore == 0 { + nbefore++ + } + + // Check that only one retribution was removed. + if nbefore-nafter != 1 { + t.Fatalf("expected %v retributions, found %v", + nbefore-1, nafter) + } + + if failing { + frs.Restart() + + // Check that retribution store has persisted removal + // after restarting. + nrestart := countRetributions(t, frs) + if nbefore-nrestart != 1 { + t.Fatalf("expected %v retributions, found %v", + nbefore-1, nrestart) + } + } + } +} + +func testRetributionStoreForAll( + frs FailingRetributionStore, + t *testing.T, + failing bool) { + + // nrets is the number of retributions in the test vector + nrets := len(retributions) + + // isRestart indicates whether or not the database has been restarted. + // When testing for failures, this allows the test case to make a second + // attempt without causing a subsequent restart on the second pass. + var isRestart bool + +restartCheck: + // Construct a set of all channel points presented by the store. Entires + // are only be added to the set if their corresponding retribution + // infromation matches the test vector. + var foundSet = make(map[wire.OutPoint]struct{}) + + // Iterate through the stored retributions, checking to see if we have + // an equivalent retribution in the test vector. This will return an + // error unless all persisted retributions exist in the test vector. + if err := frs.ForAll(func(ret *retributionInfo) error { + // Fetch the retribution information from the test vector. If + // the entry does not exist, the test returns an error. + if exRetInfo, ok := retributionMap[ret.chanPoint]; ok { + // Compare the presented retribution information with + // the expected value, fail if they are inconsistent. + if !reflect.DeepEqual(ret, &exRetInfo) { + return fmt.Errorf("unexpected retribution "+ + "retrieved from db --\n"+ + "want: %#v\ngot: %#v", exRetInfo, ret, + ) + } + + // Retribution information from database matches the + // test vector, record the channel point in the found + // map. + foundSet[ret.chanPoint] = struct{}{} + + } else { + return fmt.Errorf("unkwown retribution "+ + "retrieved from db: %v", ret) + } + + return nil + }); err != nil { + t.Fatalf("failed to iterate over persistent retributions: %v", + err) + } + + // Check that retribution store emits nrets entires + if count := countRetributions(t, frs); count != nrets { + t.Fatalf("expected %v retributions, found %v", nrets, count) + } + + // Confirm that all of the retributions emitted from the iteration + // correspond to unique channel points. + nunique := len(foundSet) + if nunique != nrets { + t.Fatalf("expected %v unique retributions, only found %v", + nrets, nunique) + } + + // If in failure mode on only on first pass, restart the database and + // rexecute the test. + if failing && !isRestart { + frs.Restart() + isRestart = true + + goto restartCheck + } +}