contractcourt/chain_watcher test: do proper state rollback

The tests didn't really roll back the channel state, so we would only
rely on the state number to determine whether we had lost state. Now we
properly roll back the channel to a previous state, in preparation for
upcoming changes.
This commit is contained in:
Johan T. Halseth 2020-11-18 22:45:35 +01:00
parent 18f79e20d5
commit 450da3d2f4
No known key found for this signature in database
GPG Key ID: 15BAADA29DA20D26
2 changed files with 147 additions and 33 deletions

@ -204,9 +204,32 @@ type dlpTestCase struct {
NumUpdates uint8 NumUpdates uint8
} }
// executeStateTransitions execute the given number of state transitions.
// Copies of Alice's channel state before each transition (including initial
// state) are returned.
func executeStateTransitions(t *testing.T, htlcAmount lnwire.MilliSatoshi, func executeStateTransitions(t *testing.T, htlcAmount lnwire.MilliSatoshi,
aliceChannel, bobChannel *lnwallet.LightningChannel, aliceChannel, bobChannel *lnwallet.LightningChannel,
numUpdates uint8) error { numUpdates uint8) ([]*channeldb.OpenChannel, func(), error) {
// We'll make a copy of the channel state before each transition.
var (
chanStates []*channeldb.OpenChannel
cleanupFuncs []func()
)
cleanAll := func() {
for _, f := range cleanupFuncs {
f()
}
}
state, f, err := copyChannelState(aliceChannel.State())
if err != nil {
return nil, nil, err
}
chanStates = append(chanStates, state)
cleanupFuncs = append(cleanupFuncs, f)
for i := 0; i < int(numUpdates); i++ { for i := 0; i < int(numUpdates); i++ {
addFakeHTLC( addFakeHTLC(
@ -215,11 +238,21 @@ func executeStateTransitions(t *testing.T, htlcAmount lnwire.MilliSatoshi,
err := lnwallet.ForceStateTransition(aliceChannel, bobChannel) err := lnwallet.ForceStateTransition(aliceChannel, bobChannel)
if err != nil { if err != nil {
return err cleanAll()
return nil, nil, err
} }
state, f, err := copyChannelState(aliceChannel.State())
if err != nil {
cleanAll()
return nil, nil, err
}
chanStates = append(chanStates, state)
cleanupFuncs = append(cleanupFuncs, f)
} }
return nil return chanStates, cleanAll, nil
} }
// TestChainWatcherDataLossProtect tests that if we've lost data (and are // TestChainWatcherDataLossProtect tests that if we've lost data (and are
@ -250,6 +283,24 @@ func TestChainWatcherDataLossProtect(t *testing.T) {
} }
defer cleanUp() defer cleanUp()
// Based on the number of random updates for this state, make a
// new HTLC to add to the commitment, and then lock in a state
// transition.
const htlcAmt = 1000
states, cleanStates, err := executeStateTransitions(
t, htlcAmt, aliceChannel, bobChannel,
testCase.BroadcastStateNum,
)
if err != nil {
t.Errorf("unable to trigger state "+
"transition: %v", err)
return false
}
defer cleanStates()
// We'll use the state this test case wants Alice to start at.
aliceChanState := states[testCase.NumUpdates]
// With the channels created, we'll now create a chain watcher // With the channels created, we'll now create a chain watcher
// instance which will be watching for any closes of Alice's // instance which will be watching for any closes of Alice's
// channel. // channel.
@ -259,7 +310,7 @@ func TestChainWatcherDataLossProtect(t *testing.T) {
ConfChan: make(chan *chainntnfs.TxConfirmation), ConfChan: make(chan *chainntnfs.TxConfirmation),
} }
aliceChainWatcher, err := newChainWatcher(chainWatcherConfig{ aliceChainWatcher, err := newChainWatcher(chainWatcherConfig{
chanState: aliceChannel.State(), chanState: aliceChanState,
notifier: aliceNotifier, notifier: aliceNotifier,
signer: aliceChannel.Signer, signer: aliceChannel.Signer,
extractStateNumHint: func(*wire.MsgTx, extractStateNumHint: func(*wire.MsgTx,
@ -279,19 +330,6 @@ func TestChainWatcherDataLossProtect(t *testing.T) {
} }
defer aliceChainWatcher.Stop() defer aliceChainWatcher.Stop()
// Based on the number of random updates for this state, make a
// new HTLC to add to the commitment, and then lock in a state
// transition.
const htlcAmt = 1000
err = executeStateTransitions(
t, htlcAmt, aliceChannel, bobChannel, testCase.NumUpdates,
)
if err != nil {
t.Errorf("unable to trigger state "+
"transition: %v", err)
return false
}
// We'll request a new channel event subscription from Alice's // We'll request a new channel event subscription from Alice's
// chain watcher so we can be notified of our fake close below. // chain watcher so we can be notified of our fake close below.
chanEvents := aliceChainWatcher.SubscribeChannelEvents() chanEvents := aliceChainWatcher.SubscribeChannelEvents()
@ -299,7 +337,7 @@ func TestChainWatcherDataLossProtect(t *testing.T) {
// Otherwise, we'll feed in this new state number as a response // Otherwise, we'll feed in this new state number as a response
// to the query, and insert the expected DLP commit point. // to the query, and insert the expected DLP commit point.
dlpPoint := aliceChannel.State().RemoteCurrentRevocation dlpPoint := aliceChannel.State().RemoteCurrentRevocation
err = aliceChannel.State().MarkDataLoss(dlpPoint) err = aliceChanState.MarkDataLoss(dlpPoint)
if err != nil { if err != nil {
t.Errorf("unable to insert dlp point: %v", err) t.Errorf("unable to insert dlp point: %v", err)
return false return false
@ -421,6 +459,24 @@ func TestChainWatcherLocalForceCloseDetect(t *testing.T) {
} }
defer cleanUp() defer cleanUp()
// We'll execute a number of state transitions based on the
// randomly selected number from testing/quick. We do this to
// get more coverage of various state hint encodings beyond 0
// and 1.
const htlcAmt = 1000
states, cleanStates, err := executeStateTransitions(
t, htlcAmt, aliceChannel, bobChannel, numUpdates,
)
if err != nil {
t.Errorf("unable to trigger state "+
"transition: %v", err)
return false
}
defer cleanStates()
// We'll use the state this test case wants Alice to start at.
aliceChanState := states[numUpdates]
// With the channels created, we'll now create a chain watcher // With the channels created, we'll now create a chain watcher
// instance which will be watching for any closes of Alice's // instance which will be watching for any closes of Alice's
// channel. // channel.
@ -430,7 +486,7 @@ func TestChainWatcherLocalForceCloseDetect(t *testing.T) {
ConfChan: make(chan *chainntnfs.TxConfirmation), ConfChan: make(chan *chainntnfs.TxConfirmation),
} }
aliceChainWatcher, err := newChainWatcher(chainWatcherConfig{ aliceChainWatcher, err := newChainWatcher(chainWatcherConfig{
chanState: aliceChannel.State(), chanState: aliceChanState,
notifier: aliceNotifier, notifier: aliceNotifier,
signer: aliceChannel.Signer, signer: aliceChannel.Signer,
extractStateNumHint: lnwallet.GetStateNumHint, extractStateNumHint: lnwallet.GetStateNumHint,
@ -443,20 +499,6 @@ func TestChainWatcherLocalForceCloseDetect(t *testing.T) {
} }
defer aliceChainWatcher.Stop() defer aliceChainWatcher.Stop()
// We'll execute a number of state transitions based on the
// randomly selected number from testing/quick. We do this to
// get more coverage of various state hint encodings beyond 0
// and 1.
const htlcAmt = 1000
err = executeStateTransitions(
t, htlcAmt, aliceChannel, bobChannel, numUpdates,
)
if err != nil {
t.Errorf("unable to trigger state "+
"transition: %v", err)
return false
}
// We'll request a new channel event subscription from Alice's // We'll request a new channel event subscription from Alice's
// chain watcher so we can be notified of our fake close below. // chain watcher so we can be notified of our fake close below.
chanEvents := aliceChainWatcher.SubscribeChannelEvents() chanEvents := aliceChainWatcher.SubscribeChannelEvents()

@ -1,10 +1,16 @@
package contractcourt package contractcourt
import ( import (
"fmt"
"io"
"io/ioutil"
"os" "os"
"path/filepath"
"runtime/pprof" "runtime/pprof"
"testing" "testing"
"time" "time"
"github.com/lightningnetwork/lnd/channeldb"
) )
// timeout implements a test level timeout. // timeout implements a test level timeout.
@ -24,3 +30,69 @@ func timeout(t *testing.T) func() {
close(done) close(done)
} }
} }
func copyFile(dest, src string) error {
s, err := os.Open(src)
if err != nil {
return err
}
defer s.Close()
d, err := os.Create(dest)
if err != nil {
return err
}
if _, err := io.Copy(d, s); err != nil {
d.Close()
return err
}
return d.Close()
}
// copyChannelState copies the OpenChannel state by copying the database and
// creating a new struct from it. The copied state and a cleanup function are
// returned.
func copyChannelState(state *channeldb.OpenChannel) (
*channeldb.OpenChannel, func(), error) {
// Make a copy of the DB.
dbFile := filepath.Join(state.Db.Path(), "channel.db")
tempDbPath, err := ioutil.TempDir("", "past-state")
if err != nil {
return nil, nil, err
}
cleanup := func() {
os.RemoveAll(tempDbPath)
}
tempDbFile := filepath.Join(tempDbPath, "channel.db")
err = copyFile(tempDbFile, dbFile)
if err != nil {
cleanup()
return nil, nil, err
}
newDb, err := channeldb.Open(tempDbPath)
if err != nil {
cleanup()
return nil, nil, err
}
chans, err := newDb.FetchAllChannels()
if err != nil {
cleanup()
return nil, nil, err
}
// We only support DBs with a single channel, for now.
if len(chans) != 1 {
cleanup()
return nil, nil, fmt.Errorf("found %d chans in the db",
len(chans))
}
return chans[0], cleanup, nil
}