diff --git a/contractcourt/chain_watcher_test.go b/contractcourt/chain_watcher_test.go index 50e73ec5..b6e5d8e1 100644 --- a/contractcourt/chain_watcher_test.go +++ b/contractcourt/chain_watcher_test.go @@ -204,9 +204,32 @@ type dlpTestCase struct { NumUpdates uint8 } +// executeStateTransitions execute the given number of state transitions. +// Copies of Alice's channel state before each transition (including initial +// state) are returned. func executeStateTransitions(t *testing.T, htlcAmount lnwire.MilliSatoshi, aliceChannel, bobChannel *lnwallet.LightningChannel, - numUpdates uint8) error { + numUpdates uint8) ([]*channeldb.OpenChannel, func(), error) { + + // We'll make a copy of the channel state before each transition. + var ( + chanStates []*channeldb.OpenChannel + cleanupFuncs []func() + ) + + cleanAll := func() { + for _, f := range cleanupFuncs { + f() + } + } + + state, f, err := copyChannelState(aliceChannel.State()) + if err != nil { + return nil, nil, err + } + + chanStates = append(chanStates, state) + cleanupFuncs = append(cleanupFuncs, f) for i := 0; i < int(numUpdates); i++ { addFakeHTLC( @@ -215,11 +238,21 @@ func executeStateTransitions(t *testing.T, htlcAmount lnwire.MilliSatoshi, err := lnwallet.ForceStateTransition(aliceChannel, bobChannel) if err != nil { - return err + cleanAll() + return nil, nil, err } + + state, f, err := copyChannelState(aliceChannel.State()) + if err != nil { + cleanAll() + return nil, nil, err + } + + chanStates = append(chanStates, state) + cleanupFuncs = append(cleanupFuncs, f) } - return nil + return chanStates, cleanAll, nil } // TestChainWatcherDataLossProtect tests that if we've lost data (and are @@ -250,6 +283,24 @@ func TestChainWatcherDataLossProtect(t *testing.T) { } defer cleanUp() + // Based on the number of random updates for this state, make a + // new HTLC to add to the commitment, and then lock in a state + // transition. + const htlcAmt = 1000 + states, cleanStates, err := executeStateTransitions( + t, htlcAmt, aliceChannel, bobChannel, + testCase.BroadcastStateNum, + ) + if err != nil { + t.Errorf("unable to trigger state "+ + "transition: %v", err) + return false + } + defer cleanStates() + + // We'll use the state this test case wants Alice to start at. + aliceChanState := states[testCase.NumUpdates] + // With the channels created, we'll now create a chain watcher // instance which will be watching for any closes of Alice's // channel. @@ -259,7 +310,7 @@ func TestChainWatcherDataLossProtect(t *testing.T) { ConfChan: make(chan *chainntnfs.TxConfirmation), } aliceChainWatcher, err := newChainWatcher(chainWatcherConfig{ - chanState: aliceChannel.State(), + chanState: aliceChanState, notifier: aliceNotifier, signer: aliceChannel.Signer, extractStateNumHint: func(*wire.MsgTx, @@ -279,19 +330,6 @@ func TestChainWatcherDataLossProtect(t *testing.T) { } defer aliceChainWatcher.Stop() - // Based on the number of random updates for this state, make a - // new HTLC to add to the commitment, and then lock in a state - // transition. - const htlcAmt = 1000 - err = executeStateTransitions( - t, htlcAmt, aliceChannel, bobChannel, testCase.NumUpdates, - ) - if err != nil { - t.Errorf("unable to trigger state "+ - "transition: %v", err) - return false - } - // We'll request a new channel event subscription from Alice's // chain watcher so we can be notified of our fake close below. chanEvents := aliceChainWatcher.SubscribeChannelEvents() @@ -299,7 +337,7 @@ func TestChainWatcherDataLossProtect(t *testing.T) { // Otherwise, we'll feed in this new state number as a response // to the query, and insert the expected DLP commit point. dlpPoint := aliceChannel.State().RemoteCurrentRevocation - err = aliceChannel.State().MarkDataLoss(dlpPoint) + err = aliceChanState.MarkDataLoss(dlpPoint) if err != nil { t.Errorf("unable to insert dlp point: %v", err) return false @@ -421,6 +459,24 @@ func TestChainWatcherLocalForceCloseDetect(t *testing.T) { } defer cleanUp() + // We'll execute a number of state transitions based on the + // randomly selected number from testing/quick. We do this to + // get more coverage of various state hint encodings beyond 0 + // and 1. + const htlcAmt = 1000 + states, cleanStates, err := executeStateTransitions( + t, htlcAmt, aliceChannel, bobChannel, numUpdates, + ) + if err != nil { + t.Errorf("unable to trigger state "+ + "transition: %v", err) + return false + } + defer cleanStates() + + // We'll use the state this test case wants Alice to start at. + aliceChanState := states[numUpdates] + // With the channels created, we'll now create a chain watcher // instance which will be watching for any closes of Alice's // channel. @@ -430,7 +486,7 @@ func TestChainWatcherLocalForceCloseDetect(t *testing.T) { ConfChan: make(chan *chainntnfs.TxConfirmation), } aliceChainWatcher, err := newChainWatcher(chainWatcherConfig{ - chanState: aliceChannel.State(), + chanState: aliceChanState, notifier: aliceNotifier, signer: aliceChannel.Signer, extractStateNumHint: lnwallet.GetStateNumHint, @@ -443,20 +499,6 @@ func TestChainWatcherLocalForceCloseDetect(t *testing.T) { } defer aliceChainWatcher.Stop() - // We'll execute a number of state transitions based on the - // randomly selected number from testing/quick. We do this to - // get more coverage of various state hint encodings beyond 0 - // and 1. - const htlcAmt = 1000 - err = executeStateTransitions( - t, htlcAmt, aliceChannel, bobChannel, numUpdates, - ) - if err != nil { - t.Errorf("unable to trigger state "+ - "transition: %v", err) - return false - } - // We'll request a new channel event subscription from Alice's // chain watcher so we can be notified of our fake close below. chanEvents := aliceChainWatcher.SubscribeChannelEvents() diff --git a/contractcourt/utils_test.go b/contractcourt/utils_test.go index 2bf81b41..11f23d8c 100644 --- a/contractcourt/utils_test.go +++ b/contractcourt/utils_test.go @@ -1,10 +1,16 @@ package contractcourt import ( + "fmt" + "io" + "io/ioutil" "os" + "path/filepath" "runtime/pprof" "testing" "time" + + "github.com/lightningnetwork/lnd/channeldb" ) // timeout implements a test level timeout. @@ -24,3 +30,69 @@ func timeout(t *testing.T) func() { close(done) } } + +func copyFile(dest, src string) error { + s, err := os.Open(src) + if err != nil { + return err + } + defer s.Close() + + d, err := os.Create(dest) + if err != nil { + return err + } + + if _, err := io.Copy(d, s); err != nil { + d.Close() + return err + } + + return d.Close() +} + +// copyChannelState copies the OpenChannel state by copying the database and +// creating a new struct from it. The copied state and a cleanup function are +// returned. +func copyChannelState(state *channeldb.OpenChannel) ( + *channeldb.OpenChannel, func(), error) { + + // Make a copy of the DB. + dbFile := filepath.Join(state.Db.Path(), "channel.db") + tempDbPath, err := ioutil.TempDir("", "past-state") + if err != nil { + return nil, nil, err + } + + cleanup := func() { + os.RemoveAll(tempDbPath) + } + + tempDbFile := filepath.Join(tempDbPath, "channel.db") + err = copyFile(tempDbFile, dbFile) + if err != nil { + cleanup() + return nil, nil, err + } + + newDb, err := channeldb.Open(tempDbPath) + if err != nil { + cleanup() + return nil, nil, err + } + + chans, err := newDb.FetchAllChannels() + if err != nil { + cleanup() + return nil, nil, err + } + + // We only support DBs with a single channel, for now. + if len(chans) != 1 { + cleanup() + return nil, nil, fmt.Errorf("found %d chans in the db", + len(chans)) + } + + return chans[0], cleanup, nil +}