From 450da3d2f4f346090c4b1be0110fc47a0c9b21c8 Mon Sep 17 00:00:00 2001 From: "Johan T. Halseth" Date: Wed, 18 Nov 2020 22:45:35 +0100 Subject: [PATCH] contractcourt/chain_watcher test: do proper state rollback The tests didn't really roll back the channel state, so we would only rely on the state number to determine whether we had lost state. Now we properly roll back the channel to a previous state, in preparation for upcoming changes. --- contractcourt/chain_watcher_test.go | 108 +++++++++++++++++++--------- contractcourt/utils_test.go | 72 +++++++++++++++++++ 2 files changed, 147 insertions(+), 33 deletions(-) diff --git a/contractcourt/chain_watcher_test.go b/contractcourt/chain_watcher_test.go index 50e73ec5..b6e5d8e1 100644 --- a/contractcourt/chain_watcher_test.go +++ b/contractcourt/chain_watcher_test.go @@ -204,9 +204,32 @@ type dlpTestCase struct { NumUpdates uint8 } +// executeStateTransitions execute the given number of state transitions. +// Copies of Alice's channel state before each transition (including initial +// state) are returned. func executeStateTransitions(t *testing.T, htlcAmount lnwire.MilliSatoshi, aliceChannel, bobChannel *lnwallet.LightningChannel, - numUpdates uint8) error { + numUpdates uint8) ([]*channeldb.OpenChannel, func(), error) { + + // We'll make a copy of the channel state before each transition. + var ( + chanStates []*channeldb.OpenChannel + cleanupFuncs []func() + ) + + cleanAll := func() { + for _, f := range cleanupFuncs { + f() + } + } + + state, f, err := copyChannelState(aliceChannel.State()) + if err != nil { + return nil, nil, err + } + + chanStates = append(chanStates, state) + cleanupFuncs = append(cleanupFuncs, f) for i := 0; i < int(numUpdates); i++ { addFakeHTLC( @@ -215,11 +238,21 @@ func executeStateTransitions(t *testing.T, htlcAmount lnwire.MilliSatoshi, err := lnwallet.ForceStateTransition(aliceChannel, bobChannel) if err != nil { - return err + cleanAll() + return nil, nil, err } + + state, f, err := copyChannelState(aliceChannel.State()) + if err != nil { + cleanAll() + return nil, nil, err + } + + chanStates = append(chanStates, state) + cleanupFuncs = append(cleanupFuncs, f) } - return nil + return chanStates, cleanAll, nil } // TestChainWatcherDataLossProtect tests that if we've lost data (and are @@ -250,6 +283,24 @@ func TestChainWatcherDataLossProtect(t *testing.T) { } defer cleanUp() + // Based on the number of random updates for this state, make a + // new HTLC to add to the commitment, and then lock in a state + // transition. + const htlcAmt = 1000 + states, cleanStates, err := executeStateTransitions( + t, htlcAmt, aliceChannel, bobChannel, + testCase.BroadcastStateNum, + ) + if err != nil { + t.Errorf("unable to trigger state "+ + "transition: %v", err) + return false + } + defer cleanStates() + + // We'll use the state this test case wants Alice to start at. + aliceChanState := states[testCase.NumUpdates] + // With the channels created, we'll now create a chain watcher // instance which will be watching for any closes of Alice's // channel. @@ -259,7 +310,7 @@ func TestChainWatcherDataLossProtect(t *testing.T) { ConfChan: make(chan *chainntnfs.TxConfirmation), } aliceChainWatcher, err := newChainWatcher(chainWatcherConfig{ - chanState: aliceChannel.State(), + chanState: aliceChanState, notifier: aliceNotifier, signer: aliceChannel.Signer, extractStateNumHint: func(*wire.MsgTx, @@ -279,19 +330,6 @@ func TestChainWatcherDataLossProtect(t *testing.T) { } defer aliceChainWatcher.Stop() - // Based on the number of random updates for this state, make a - // new HTLC to add to the commitment, and then lock in a state - // transition. - const htlcAmt = 1000 - err = executeStateTransitions( - t, htlcAmt, aliceChannel, bobChannel, testCase.NumUpdates, - ) - if err != nil { - t.Errorf("unable to trigger state "+ - "transition: %v", err) - return false - } - // We'll request a new channel event subscription from Alice's // chain watcher so we can be notified of our fake close below. chanEvents := aliceChainWatcher.SubscribeChannelEvents() @@ -299,7 +337,7 @@ func TestChainWatcherDataLossProtect(t *testing.T) { // Otherwise, we'll feed in this new state number as a response // to the query, and insert the expected DLP commit point. dlpPoint := aliceChannel.State().RemoteCurrentRevocation - err = aliceChannel.State().MarkDataLoss(dlpPoint) + err = aliceChanState.MarkDataLoss(dlpPoint) if err != nil { t.Errorf("unable to insert dlp point: %v", err) return false @@ -421,6 +459,24 @@ func TestChainWatcherLocalForceCloseDetect(t *testing.T) { } defer cleanUp() + // We'll execute a number of state transitions based on the + // randomly selected number from testing/quick. We do this to + // get more coverage of various state hint encodings beyond 0 + // and 1. + const htlcAmt = 1000 + states, cleanStates, err := executeStateTransitions( + t, htlcAmt, aliceChannel, bobChannel, numUpdates, + ) + if err != nil { + t.Errorf("unable to trigger state "+ + "transition: %v", err) + return false + } + defer cleanStates() + + // We'll use the state this test case wants Alice to start at. + aliceChanState := states[numUpdates] + // With the channels created, we'll now create a chain watcher // instance which will be watching for any closes of Alice's // channel. @@ -430,7 +486,7 @@ func TestChainWatcherLocalForceCloseDetect(t *testing.T) { ConfChan: make(chan *chainntnfs.TxConfirmation), } aliceChainWatcher, err := newChainWatcher(chainWatcherConfig{ - chanState: aliceChannel.State(), + chanState: aliceChanState, notifier: aliceNotifier, signer: aliceChannel.Signer, extractStateNumHint: lnwallet.GetStateNumHint, @@ -443,20 +499,6 @@ func TestChainWatcherLocalForceCloseDetect(t *testing.T) { } defer aliceChainWatcher.Stop() - // We'll execute a number of state transitions based on the - // randomly selected number from testing/quick. We do this to - // get more coverage of various state hint encodings beyond 0 - // and 1. - const htlcAmt = 1000 - err = executeStateTransitions( - t, htlcAmt, aliceChannel, bobChannel, numUpdates, - ) - if err != nil { - t.Errorf("unable to trigger state "+ - "transition: %v", err) - return false - } - // We'll request a new channel event subscription from Alice's // chain watcher so we can be notified of our fake close below. chanEvents := aliceChainWatcher.SubscribeChannelEvents() diff --git a/contractcourt/utils_test.go b/contractcourt/utils_test.go index 2bf81b41..11f23d8c 100644 --- a/contractcourt/utils_test.go +++ b/contractcourt/utils_test.go @@ -1,10 +1,16 @@ package contractcourt import ( + "fmt" + "io" + "io/ioutil" "os" + "path/filepath" "runtime/pprof" "testing" "time" + + "github.com/lightningnetwork/lnd/channeldb" ) // timeout implements a test level timeout. @@ -24,3 +30,69 @@ func timeout(t *testing.T) func() { close(done) } } + +func copyFile(dest, src string) error { + s, err := os.Open(src) + if err != nil { + return err + } + defer s.Close() + + d, err := os.Create(dest) + if err != nil { + return err + } + + if _, err := io.Copy(d, s); err != nil { + d.Close() + return err + } + + return d.Close() +} + +// copyChannelState copies the OpenChannel state by copying the database and +// creating a new struct from it. The copied state and a cleanup function are +// returned. +func copyChannelState(state *channeldb.OpenChannel) ( + *channeldb.OpenChannel, func(), error) { + + // Make a copy of the DB. + dbFile := filepath.Join(state.Db.Path(), "channel.db") + tempDbPath, err := ioutil.TempDir("", "past-state") + if err != nil { + return nil, nil, err + } + + cleanup := func() { + os.RemoveAll(tempDbPath) + } + + tempDbFile := filepath.Join(tempDbPath, "channel.db") + err = copyFile(tempDbFile, dbFile) + if err != nil { + cleanup() + return nil, nil, err + } + + newDb, err := channeldb.Open(tempDbPath) + if err != nil { + cleanup() + return nil, nil, err + } + + chans, err := newDb.FetchAllChannels() + if err != nil { + cleanup() + return nil, nil, err + } + + // We only support DBs with a single channel, for now. + if len(chans) != 1 { + cleanup() + return nil, nil, fmt.Errorf("found %d chans in the db", + len(chans)) + } + + return chans[0], cleanup, nil +}