contractcourt: revamp HTLC success unit test

We add checkpoint assertions and resume the resolver from every
checkpoint to ensure it can handle restarts.
This commit is contained in:
Johan T. Halseth 2020-12-09 12:24:03 +01:00
parent 7142a302c9
commit d02b486195
No known key found for this signature in database
GPG Key ID: 15BAADA29DA20D26
2 changed files with 232 additions and 74 deletions

@ -103,17 +103,19 @@ func (i *commitSweepResolverTestContext) waitForResult() {
} }
type mockSweeper struct { type mockSweeper struct {
sweptInputs chan input.Input sweptInputs chan input.Input
updatedInputs chan wire.OutPoint updatedInputs chan wire.OutPoint
sweepTx *wire.MsgTx sweepTx *wire.MsgTx
sweepErr error sweepErr error
createSweepTxChan chan *wire.MsgTx
} }
func newMockSweeper() *mockSweeper { func newMockSweeper() *mockSweeper {
return &mockSweeper{ return &mockSweeper{
sweptInputs: make(chan input.Input), sweptInputs: make(chan input.Input),
updatedInputs: make(chan wire.OutPoint), updatedInputs: make(chan wire.OutPoint),
sweepTx: &wire.MsgTx{}, sweepTx: &wire.MsgTx{},
createSweepTxChan: make(chan *wire.MsgTx),
} }
} }
@ -133,7 +135,9 @@ func (s *mockSweeper) SweepInput(input input.Input, params sweep.Params) (
func (s *mockSweeper) CreateSweepTx(inputs []input.Input, feePref sweep.FeePreference, func (s *mockSweeper) CreateSweepTx(inputs []input.Input, feePref sweep.FeePreference,
currentBlockHeight uint32) (*wire.MsgTx, error) { currentBlockHeight uint32) (*wire.MsgTx, error) {
return nil, nil // We will wait for the test to supply the sweep tx to return.
sweepTx := <-s.createSweepTxChan
return sweepTx, nil
} }
func (s *mockSweeper) RelayFeePerKW() chainfee.SatPerKWeight { func (s *mockSweeper) RelayFeePerKW() chainfee.SatPerKWeight {

@ -1,10 +1,14 @@
package contractcourt package contractcourt
import ( import (
"bytes"
"io"
"reflect"
"testing" "testing"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/channeldb/kvdb"
@ -22,11 +26,11 @@ type htlcSuccessResolverTestContext struct {
t *testing.T t *testing.T
} }
func newHtlcSuccessResolverTextContext(t *testing.T) *htlcSuccessResolverTestContext { func newHtlcSuccessResolverTextContext(t *testing.T, checkpoint io.Reader) *htlcSuccessResolverTestContext {
notifier := &mock.ChainNotifier{ notifier := &mock.ChainNotifier{
EpochChan: make(chan *chainntnfs.BlockEpoch), EpochChan: make(chan *chainntnfs.BlockEpoch, 1),
SpendChan: make(chan *chainntnfs.SpendDetail), SpendChan: make(chan *chainntnfs.SpendDetail, 1),
ConfChan: make(chan *chainntnfs.TxConfirmation), ConfChan: make(chan *chainntnfs.TxConfirmation, 1),
} }
checkPointChan := make(chan struct{}, 1) checkPointChan := make(chan struct{}, 1)
@ -42,6 +46,11 @@ func newHtlcSuccessResolverTextContext(t *testing.T) *htlcSuccessResolverTestCon
PublishTx: func(_ *wire.MsgTx, _ string) error { PublishTx: func(_ *wire.MsgTx, _ string) error {
return nil return nil
}, },
Sweeper: newMockSweeper(),
IncubateOutputs: func(wire.OutPoint, *lnwallet.OutgoingHtlcResolution,
*lnwallet.IncomingHtlcResolution, uint32) error {
return nil
},
}, },
PutResolverReport: func(_ kvdb.RwTx, PutResolverReport: func(_ kvdb.RwTx,
report *channeldb.ResolverReport) error { report *channeldb.ResolverReport) error {
@ -59,15 +68,27 @@ func newHtlcSuccessResolverTextContext(t *testing.T) *htlcSuccessResolverTestCon
return nil return nil
}, },
} }
htlc := channeldb.HTLC{
RHash: testResHash,
OnionBlob: testOnionBlob,
Amt: testHtlcAmt,
}
if checkpoint != nil {
var err error
testCtx.resolver, err = newSuccessResolverFromReader(checkpoint, cfg)
if err != nil {
t.Fatal(err)
}
testCtx.resolver = &htlcSuccessResolver{ testCtx.resolver.Supplement(htlc)
contractResolverKit: *newContractResolverKit(cfg),
htlcResolution: lnwallet.IncomingHtlcResolution{}, } else {
htlc: channeldb.HTLC{
RHash: testResHash, testCtx.resolver = &htlcSuccessResolver{
OnionBlob: testOnionBlob, contractResolverKit: *newContractResolverKit(cfg),
Amt: testHtlcAmt, htlcResolution: lnwallet.IncomingHtlcResolution{},
}, htlc: htlc,
}
} }
return testCtx return testCtx
@ -98,8 +119,9 @@ func (i *htlcSuccessResolverTestContext) waitForResult() {
} }
} }
// TestSingleStageSuccess tests successful sweep of a single stage htlc claim. // TestHtlcSuccessSingleStage tests successful sweep of a single stage htlc
func TestSingleStageSuccess(t *testing.T) { // claim.
func TestHtlcSuccessSingleStage(t *testing.T) {
htlcOutpoint := wire.OutPoint{Index: 3} htlcOutpoint := wire.OutPoint{Index: 3}
sweepTx := &wire.MsgTx{ sweepTx := &wire.MsgTx{
@ -114,15 +136,6 @@ func TestSingleStageSuccess(t *testing.T) {
ClaimOutpoint: htlcOutpoint, ClaimOutpoint: htlcOutpoint,
} }
// We send a confirmation for our sweep tx to indicate that our sweep
// succeeded.
resolve := func(ctx *htlcSuccessResolverTestContext) {
ctx.notifier.ConfChan <- &chainntnfs.TxConfirmation{
Tx: ctx.resolver.sweepTx,
BlockHeight: testInitialBlockHeight - 1,
}
}
sweepTxid := sweepTx.TxHash() sweepTxid := sweepTx.TxHash()
claim := &channeldb.ResolverReport{ claim := &channeldb.ResolverReport{
OutPoint: htlcOutpoint, OutPoint: htlcOutpoint,
@ -131,14 +144,45 @@ func TestSingleStageSuccess(t *testing.T) {
ResolverOutcome: channeldb.ResolverOutcomeClaimed, ResolverOutcome: channeldb.ResolverOutcomeClaimed,
SpendTxID: &sweepTxid, SpendTxID: &sweepTxid,
} }
checkpoints := []checkpoint{
{
// We send a confirmation for our sweep tx to indicate
// that our sweep succeeded.
preCheckpoint: func(ctx *htlcSuccessResolverTestContext,
_ bool) error {
// The resolver will create and publish a sweep
// tx.
ctx.resolver.Sweeper.(*mockSweeper).
createSweepTxChan <- sweepTx
// Confirm the sweep, which should resolve it.
ctx.notifier.ConfChan <- &chainntnfs.TxConfirmation{
Tx: sweepTx,
BlockHeight: testInitialBlockHeight - 1,
}
return nil
},
// After the sweep has confirmed, we expect the
// checkpoint to be resolved, and with the above
// report.
resolved: true,
reports: []*channeldb.ResolverReport{
claim,
},
},
}
testHtlcSuccess( testHtlcSuccess(
t, singleStageResolution, resolve, sweepTx, claim, t, singleStageResolution, checkpoints,
) )
} }
// TestSecondStageResolution tests successful sweep of a second stage htlc // TestSecondStageResolution tests successful sweep of a second stage htlc
// claim. // claim, going through the Nursery.
func TestSecondStageResolution(t *testing.T) { func TestHtlcSuccessSecondStageResolution(t *testing.T) {
commitOutpoint := wire.OutPoint{Index: 2} commitOutpoint := wire.OutPoint{Index: 2}
htlcOutpoint := wire.OutPoint{Index: 3} htlcOutpoint := wire.OutPoint{Index: 3}
@ -158,20 +202,17 @@ func TestSecondStageResolution(t *testing.T) {
PreviousOutPoint: commitOutpoint, PreviousOutPoint: commitOutpoint,
}, },
}, },
TxOut: []*wire.TxOut{}, TxOut: []*wire.TxOut{
{
Value: 111,
PkScript: []byte{0xaa, 0xaa},
},
},
}, },
ClaimOutpoint: htlcOutpoint, ClaimOutpoint: htlcOutpoint,
SweepSignDesc: testSignDesc, SweepSignDesc: testSignDesc,
} }
// We send a spend notification for our output to resolve our htlc.
resolve := func(ctx *htlcSuccessResolverTestContext) {
ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{
SpendingTx: sweepTx,
SpenderTxHash: &sweepHash,
}
}
successTx := twoStageResolution.SignedSuccessTx.TxHash() successTx := twoStageResolution.SignedSuccessTx.TxHash()
firstStage := &channeldb.ResolverReport{ firstStage := &channeldb.ResolverReport{
OutPoint: commitOutpoint, OutPoint: commitOutpoint,
@ -189,54 +230,167 @@ func TestSecondStageResolution(t *testing.T) {
SpendTxID: &sweepHash, SpendTxID: &sweepHash,
} }
checkpoints := []checkpoint{
{
// The resolver will send the output to the Nursery.
incubating: true,
},
{
// It will then wait for the Nursery to spend the
// output. We send a spend notification for our output
// to resolve our htlc.
preCheckpoint: func(ctx *htlcSuccessResolverTestContext,
_ bool) error {
ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{
SpendingTx: sweepTx,
SpenderTxHash: &sweepHash,
}
return nil
},
incubating: true,
resolved: true,
reports: []*channeldb.ResolverReport{
secondStage,
firstStage,
},
},
}
testHtlcSuccess( testHtlcSuccess(
t, twoStageResolution, resolve, sweepTx, secondStage, firstStage, t, twoStageResolution, checkpoints,
) )
} }
// testHtlcSuccess tests resolution of a success resolver. It takes a resolve // checkpoint holds expected data we expect the resolver to checkpoint itself
// function which triggers resolution and the sweeptxid that will resolve it. // to the DB next.
type checkpoint struct {
// preCheckpoint is a method that will be called before we reach the
// checkpoint, to carry out any needed operations to drive the resolver
// in this stage.
preCheckpoint func(*htlcSuccessResolverTestContext, bool) error
// data we expect the resolver to be checkpointed with next.
incubating bool
resolved bool
reports []*channeldb.ResolverReport
}
// testHtlcSuccess tests resolution of a success resolver. It takes a a list of
// checkpoints that it expects the resolver to go through. And will run the
// resolver all the way through these checkpoints, and also attempt to resume
// the resolver from every checkpoint.
func testHtlcSuccess(t *testing.T, resolution lnwallet.IncomingHtlcResolution, func testHtlcSuccess(t *testing.T, resolution lnwallet.IncomingHtlcResolution,
resolve func(*htlcSuccessResolverTestContext), checkpoints []checkpoint) {
sweepTx *wire.MsgTx, reports ...*channeldb.ResolverReport) {
defer timeout(t)() defer timeout(t)()
ctx := newHtlcSuccessResolverTextContext(t) // We first run the resolver from start to finish, ensuring it gets
// checkpointed at every expected stage. We store the checkpointed data
// Replace our checkpoint with one which will push reports into a // for the next portion of the test.
// channel for us to consume. We replace this function on the resolver ctx := newHtlcSuccessResolverTextContext(t, nil)
// itself because it is created by the test context.
reportChan := make(chan *channeldb.ResolverReport)
ctx.resolver.Checkpoint = func(_ ContractResolver,
reports ...*channeldb.ResolverReport) error {
// Send all of our reports into the channel.
for _, report := range reports {
reportChan <- report
}
return nil
}
ctx.resolver.htlcResolution = resolution ctx.resolver.htlcResolution = resolution
// We set the sweepTx to be non-nil and mark the output as already checkpointedState := runFromCheckpoint(t, ctx, checkpoints)
// incubating so that we do not need to set test values for crafting
// our own sweep transaction. // Now, from every checkpoint created, we re-create the resolver, and
ctx.resolver.sweepTx = sweepTx // run the test from that checkpoint.
ctx.resolver.outputIncubating = true for i := range checkpointedState {
cp := bytes.NewReader(checkpointedState[i])
ctx := newHtlcSuccessResolverTextContext(t, cp)
ctx.resolver.htlcResolution = resolution
// Run from the given checkpoint, ensuring we'll hit the rest.
_ = runFromCheckpoint(t, ctx, checkpoints[i+1:])
}
}
// runFromCheckpoint executes the Resolve method on the success resolver, and
// asserts that it checkpoints itself according to the expected checkpoints.
func runFromCheckpoint(t *testing.T, ctx *htlcSuccessResolverTestContext,
expectedCheckpoints []checkpoint) [][]byte {
defer timeout(t)()
var checkpointedState [][]byte
// Replace our checkpoint method with one which we'll use to assert the
// checkpointed state and reports are equal to what we expect.
nextCheckpoint := 0
checkpointChan := make(chan struct{})
ctx.resolver.Checkpoint = func(resolver ContractResolver,
reports ...*channeldb.ResolverReport) error {
if nextCheckpoint >= len(expectedCheckpoints) {
t.Fatal("did not expect more checkpoints")
}
h := resolver.(*htlcSuccessResolver)
cp := expectedCheckpoints[nextCheckpoint]
if h.resolved != cp.resolved {
t.Fatalf("expected checkpoint to be resolve=%v, had %v",
cp.resolved, h.resolved)
}
if !reflect.DeepEqual(h.outputIncubating, cp.incubating) {
t.Fatalf("expected checkpoint to be have "+
"incubating=%v, had %v", cp.incubating,
h.outputIncubating)
}
// Check we go the expected reports.
if len(reports) != len(cp.reports) {
t.Fatalf("unexpected number of reports. Expected %v "+
"got %v", len(cp.reports), len(reports))
}
for i, report := range reports {
if !reflect.DeepEqual(report, cp.reports[i]) {
t.Fatalf("expected: %v, got: %v",
spew.Sdump(cp.reports[i]),
spew.Sdump(report))
}
}
// Finally encode the resolver, and store it for later use.
b := bytes.Buffer{}
if err := resolver.Encode(&b); err != nil {
t.Fatal(err)
}
checkpointedState = append(checkpointedState, b.Bytes())
nextCheckpoint++
checkpointChan <- struct{}{}
return nil
}
// Start the htlc success resolver. // Start the htlc success resolver.
ctx.resolve() ctx.resolve()
// Trigger and event that will resolve our test context. // Go through our list of expected checkpoints, so we can run the
resolve(ctx) // preCheckpoint logic if needed.
resumed := true
for i, cp := range expectedCheckpoints {
if cp.preCheckpoint != nil {
if err := cp.preCheckpoint(ctx, resumed); err != nil {
t.Fatalf("failure at stage %d: %v", i, err)
}
for _, report := range reports { }
assertResolverReport(t, reportChan, report) resumed = false
// Wait for the resolver to have checkpointed its state.
<-checkpointChan
} }
// Wait for the resolver to fully complete. // Wait for the resolver to fully complete.
ctx.waitForResult() ctx.waitForResult()
if nextCheckpoint < len(expectedCheckpoints) {
t.Fatalf("not all checkpoints hit")
}
return checkpointedState
} }