contractcourt/chain_watcher test: do proper state rollback
The tests didn't really roll back the channel state, so we would only rely on the state number to determine whether we had lost state. Now we properly roll back the channel to a previous state, in preparation for upcoming changes.
This commit is contained in:
parent
18f79e20d5
commit
450da3d2f4
@ -204,9 +204,32 @@ type dlpTestCase struct {
|
||||
NumUpdates uint8
|
||||
}
|
||||
|
||||
// executeStateTransitions execute the given number of state transitions.
|
||||
// Copies of Alice's channel state before each transition (including initial
|
||||
// state) are returned.
|
||||
func executeStateTransitions(t *testing.T, htlcAmount lnwire.MilliSatoshi,
|
||||
aliceChannel, bobChannel *lnwallet.LightningChannel,
|
||||
numUpdates uint8) error {
|
||||
numUpdates uint8) ([]*channeldb.OpenChannel, func(), error) {
|
||||
|
||||
// We'll make a copy of the channel state before each transition.
|
||||
var (
|
||||
chanStates []*channeldb.OpenChannel
|
||||
cleanupFuncs []func()
|
||||
)
|
||||
|
||||
cleanAll := func() {
|
||||
for _, f := range cleanupFuncs {
|
||||
f()
|
||||
}
|
||||
}
|
||||
|
||||
state, f, err := copyChannelState(aliceChannel.State())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
chanStates = append(chanStates, state)
|
||||
cleanupFuncs = append(cleanupFuncs, f)
|
||||
|
||||
for i := 0; i < int(numUpdates); i++ {
|
||||
addFakeHTLC(
|
||||
@ -215,11 +238,21 @@ func executeStateTransitions(t *testing.T, htlcAmount lnwire.MilliSatoshi,
|
||||
|
||||
err := lnwallet.ForceStateTransition(aliceChannel, bobChannel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cleanAll()
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return nil
|
||||
state, f, err := copyChannelState(aliceChannel.State())
|
||||
if err != nil {
|
||||
cleanAll()
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
chanStates = append(chanStates, state)
|
||||
cleanupFuncs = append(cleanupFuncs, f)
|
||||
}
|
||||
|
||||
return chanStates, cleanAll, nil
|
||||
}
|
||||
|
||||
// TestChainWatcherDataLossProtect tests that if we've lost data (and are
|
||||
@ -250,6 +283,24 @@ func TestChainWatcherDataLossProtect(t *testing.T) {
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
// Based on the number of random updates for this state, make a
|
||||
// new HTLC to add to the commitment, and then lock in a state
|
||||
// transition.
|
||||
const htlcAmt = 1000
|
||||
states, cleanStates, err := executeStateTransitions(
|
||||
t, htlcAmt, aliceChannel, bobChannel,
|
||||
testCase.BroadcastStateNum,
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("unable to trigger state "+
|
||||
"transition: %v", err)
|
||||
return false
|
||||
}
|
||||
defer cleanStates()
|
||||
|
||||
// We'll use the state this test case wants Alice to start at.
|
||||
aliceChanState := states[testCase.NumUpdates]
|
||||
|
||||
// With the channels created, we'll now create a chain watcher
|
||||
// instance which will be watching for any closes of Alice's
|
||||
// channel.
|
||||
@ -259,7 +310,7 @@ func TestChainWatcherDataLossProtect(t *testing.T) {
|
||||
ConfChan: make(chan *chainntnfs.TxConfirmation),
|
||||
}
|
||||
aliceChainWatcher, err := newChainWatcher(chainWatcherConfig{
|
||||
chanState: aliceChannel.State(),
|
||||
chanState: aliceChanState,
|
||||
notifier: aliceNotifier,
|
||||
signer: aliceChannel.Signer,
|
||||
extractStateNumHint: func(*wire.MsgTx,
|
||||
@ -279,19 +330,6 @@ func TestChainWatcherDataLossProtect(t *testing.T) {
|
||||
}
|
||||
defer aliceChainWatcher.Stop()
|
||||
|
||||
// Based on the number of random updates for this state, make a
|
||||
// new HTLC to add to the commitment, and then lock in a state
|
||||
// transition.
|
||||
const htlcAmt = 1000
|
||||
err = executeStateTransitions(
|
||||
t, htlcAmt, aliceChannel, bobChannel, testCase.NumUpdates,
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("unable to trigger state "+
|
||||
"transition: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
// We'll request a new channel event subscription from Alice's
|
||||
// chain watcher so we can be notified of our fake close below.
|
||||
chanEvents := aliceChainWatcher.SubscribeChannelEvents()
|
||||
@ -299,7 +337,7 @@ func TestChainWatcherDataLossProtect(t *testing.T) {
|
||||
// Otherwise, we'll feed in this new state number as a response
|
||||
// to the query, and insert the expected DLP commit point.
|
||||
dlpPoint := aliceChannel.State().RemoteCurrentRevocation
|
||||
err = aliceChannel.State().MarkDataLoss(dlpPoint)
|
||||
err = aliceChanState.MarkDataLoss(dlpPoint)
|
||||
if err != nil {
|
||||
t.Errorf("unable to insert dlp point: %v", err)
|
||||
return false
|
||||
@ -421,6 +459,24 @@ func TestChainWatcherLocalForceCloseDetect(t *testing.T) {
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
// We'll execute a number of state transitions based on the
|
||||
// randomly selected number from testing/quick. We do this to
|
||||
// get more coverage of various state hint encodings beyond 0
|
||||
// and 1.
|
||||
const htlcAmt = 1000
|
||||
states, cleanStates, err := executeStateTransitions(
|
||||
t, htlcAmt, aliceChannel, bobChannel, numUpdates,
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("unable to trigger state "+
|
||||
"transition: %v", err)
|
||||
return false
|
||||
}
|
||||
defer cleanStates()
|
||||
|
||||
// We'll use the state this test case wants Alice to start at.
|
||||
aliceChanState := states[numUpdates]
|
||||
|
||||
// With the channels created, we'll now create a chain watcher
|
||||
// instance which will be watching for any closes of Alice's
|
||||
// channel.
|
||||
@ -430,7 +486,7 @@ func TestChainWatcherLocalForceCloseDetect(t *testing.T) {
|
||||
ConfChan: make(chan *chainntnfs.TxConfirmation),
|
||||
}
|
||||
aliceChainWatcher, err := newChainWatcher(chainWatcherConfig{
|
||||
chanState: aliceChannel.State(),
|
||||
chanState: aliceChanState,
|
||||
notifier: aliceNotifier,
|
||||
signer: aliceChannel.Signer,
|
||||
extractStateNumHint: lnwallet.GetStateNumHint,
|
||||
@ -443,20 +499,6 @@ func TestChainWatcherLocalForceCloseDetect(t *testing.T) {
|
||||
}
|
||||
defer aliceChainWatcher.Stop()
|
||||
|
||||
// We'll execute a number of state transitions based on the
|
||||
// randomly selected number from testing/quick. We do this to
|
||||
// get more coverage of various state hint encodings beyond 0
|
||||
// and 1.
|
||||
const htlcAmt = 1000
|
||||
err = executeStateTransitions(
|
||||
t, htlcAmt, aliceChannel, bobChannel, numUpdates,
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("unable to trigger state "+
|
||||
"transition: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
// We'll request a new channel event subscription from Alice's
|
||||
// chain watcher so we can be notified of our fake close below.
|
||||
chanEvents := aliceChainWatcher.SubscribeChannelEvents()
|
||||
|
@ -1,10 +1,16 @@
|
||||
package contractcourt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/pprof"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
)
|
||||
|
||||
// timeout implements a test level timeout.
|
||||
@ -24,3 +30,69 @@ func timeout(t *testing.T) func() {
|
||||
close(done)
|
||||
}
|
||||
}
|
||||
|
||||
func copyFile(dest, src string) error {
|
||||
s, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
d, err := os.Create(dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := io.Copy(d, s); err != nil {
|
||||
d.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return d.Close()
|
||||
}
|
||||
|
||||
// copyChannelState copies the OpenChannel state by copying the database and
|
||||
// creating a new struct from it. The copied state and a cleanup function are
|
||||
// returned.
|
||||
func copyChannelState(state *channeldb.OpenChannel) (
|
||||
*channeldb.OpenChannel, func(), error) {
|
||||
|
||||
// Make a copy of the DB.
|
||||
dbFile := filepath.Join(state.Db.Path(), "channel.db")
|
||||
tempDbPath, err := ioutil.TempDir("", "past-state")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
os.RemoveAll(tempDbPath)
|
||||
}
|
||||
|
||||
tempDbFile := filepath.Join(tempDbPath, "channel.db")
|
||||
err = copyFile(tempDbFile, dbFile)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
newDb, err := channeldb.Open(tempDbPath)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
chans, err := newDb.FetchAllChannels()
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// We only support DBs with a single channel, for now.
|
||||
if len(chans) != 1 {
|
||||
cleanup()
|
||||
return nil, nil, fmt.Errorf("found %d chans in the db",
|
||||
len(chans))
|
||||
}
|
||||
|
||||
return chans[0], cleanup, nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user