From dd325f04d2b3db4a27989b7a5755d1126ef0a025 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Wed, 25 Nov 2020 15:04:12 -0800 Subject: [PATCH] watchtower/wtclient: parameterize backup task with channel type --- htlcswitch/interfaces.go | 3 +- htlcswitch/link.go | 2 +- watchtower/wtclient/backup_task.go | 10 ++-- .../wtclient/backup_task_internal_test.go | 47 +++++++++++-------- watchtower/wtclient/client.go | 9 ++-- watchtower/wtclient/client_test.go | 3 +- 6 files changed, 46 insertions(+), 28 deletions(-) diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index c881341f..e121e061 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -183,7 +183,8 @@ type TowerClient interface { // abide by the negotiated policy. If the channel we're trying to back // up doesn't have a tweak for the remote party's output, then // isTweakless should be true. - BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution, bool) error + BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution, + channeldb.ChannelType) error } // InterceptableHtlcForwarder is the interface to set the interceptor diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 70fff3db..3b0e923c 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -1855,7 +1855,7 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) { chanType := l.channel.State().ChanType chanID := l.ChanID() err = l.cfg.TowerClient.BackupState( - &chanID, breachInfo, chanType.IsTweakless(), + &chanID, breachInfo, chanType, ) if err != nil { l.fail(LinkFailureError{code: ErrInternalError}, diff --git a/watchtower/wtclient/backup_task.go b/watchtower/wtclient/backup_task.go index 9994f7d0..5f4385d7 100644 --- a/watchtower/wtclient/backup_task.go +++ b/watchtower/wtclient/backup_task.go @@ -7,6 +7,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil/txsort" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -54,7 +55,7 @@ type backupTask struct { // variables. func newBackupTask(chanID *lnwire.ChannelID, breachInfo *lnwallet.BreachRetribution, - sweepPkScript []byte, isTweakless bool) *backupTask { + sweepPkScript []byte, chanType channeldb.ChannelType) *backupTask { // Parse the non-dust outputs from the breach transaction, // simultaneously computing the total amount contained in the inputs @@ -85,9 +86,12 @@ func newBackupTask(chanID *lnwire.ChannelID, totalAmt += breachInfo.RemoteOutputSignDesc.Output.Value } if breachInfo.LocalOutputSignDesc != nil { - witnessType := input.CommitmentNoDelay - if isTweakless { + var witnessType input.WitnessType + switch { + case chanType.IsTweakless(): witnessType = input.CommitSpendNoDelayTweakless + default: + witnessType = input.CommitmentNoDelay } toRemoteInput = input.NewBaseInput( diff --git a/watchtower/wtclient/backup_task_internal_test.go b/watchtower/wtclient/backup_task_internal_test.go index 3c8ed17f..4841ca02 100644 --- a/watchtower/wtclient/backup_task_internal_test.go +++ b/watchtower/wtclient/backup_task_internal_test.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" @@ -74,7 +75,7 @@ type backupTaskTest struct { bindErr error expSweepScript []byte signer input.Signer - tweakless bool + chanType channeldb.ChannelType } // genTaskTest creates a instance of a backupTaskTest using the passed @@ -92,7 +93,7 @@ func genTaskTest( expSweepAmt int64, expRewardAmt int64, bindErr error, - tweakless bool) backupTaskTest { + chanType channeldb.ChannelType) backupTaskTest { // Parse the key pairs for all keys used in the test. revSK, revPK := btcec.PrivKeyFromBytes( @@ -192,9 +193,12 @@ func genTaskTest( Index: index, } - witnessType := input.CommitmentNoDelay - if tweakless { + var witnessType input.WitnessType + switch { + case chanType.IsTweakless(): witnessType = input.CommitSpendNoDelayTweakless + default: + witnessType = input.CommitmentNoDelay } toRemoteInput = input.NewBaseInput( @@ -227,7 +231,7 @@ func genTaskTest( bindErr: bindErr, expSweepScript: makeAddrSlice(22), signer: signer, - tweakless: tweakless, + chanType: chanType, } } @@ -253,8 +257,13 @@ var ( func TestBackupTask(t *testing.T) { t.Parallel() + chanTypes := []channeldb.ChannelType{ + channeldb.SingleFunderBit, + channeldb.SingleFunderTweaklessBit, + } + var backupTaskTests []backupTaskTest - for _, tweakless := range []bool{true, false} { + for _, chanType := range chanTypes { backupTaskTests = append(backupTaskTests, []backupTaskTest{ genTaskTest( "commit no-reward, both outputs", @@ -267,7 +276,7 @@ func TestBackupTask(t *testing.T) { 299241, // expSweepAmt 0, // expRewardAmt nil, // bindErr - tweakless, + chanType, ), genTaskTest( "commit no-reward, to-local output only", @@ -280,7 +289,7 @@ func TestBackupTask(t *testing.T) { 199514, // expSweepAmt 0, // expRewardAmt nil, // bindErr - tweakless, + chanType, ), genTaskTest( "commit no-reward, to-remote output only", @@ -293,7 +302,7 @@ func TestBackupTask(t *testing.T) { 99561, // expSweepAmt 0, // expRewardAmt nil, // bindErr - tweakless, + chanType, ), genTaskTest( "commit no-reward, to-remote output only, creates dust", @@ -306,7 +315,7 @@ func TestBackupTask(t *testing.T) { 0, // expSweepAmt 0, // expRewardAmt wtpolicy.ErrCreatesDust, // bindErr - tweakless, + chanType, ), genTaskTest( "commit no-reward, no outputs, fee rate exceeds inputs", @@ -319,7 +328,7 @@ func TestBackupTask(t *testing.T) { 0, // expSweepAmt 0, // expRewardAmt wtpolicy.ErrFeeExceedsInputs, // bindErr - tweakless, + chanType, ), genTaskTest( "commit no-reward, no outputs, fee rate of 0 creates dust", @@ -332,7 +341,7 @@ func TestBackupTask(t *testing.T) { 0, // expSweepAmt 0, // expRewardAmt wtpolicy.ErrCreatesDust, // bindErr - tweakless, + chanType, ), genTaskTest( "commit reward, both outputs", @@ -345,7 +354,7 @@ func TestBackupTask(t *testing.T) { 296117, // expSweepAmt 3000, // expRewardAmt nil, // bindErr - tweakless, + chanType, ), genTaskTest( "commit reward, to-local output only", @@ -358,7 +367,7 @@ func TestBackupTask(t *testing.T) { 197390, // expSweepAmt 2000, // expRewardAmt nil, // bindErr - tweakless, + chanType, ), genTaskTest( "commit reward, to-remote output only", @@ -371,7 +380,7 @@ func TestBackupTask(t *testing.T) { 98437, // expSweepAmt 1000, // expRewardAmt nil, // bindErr - tweakless, + chanType, ), genTaskTest( "commit reward, to-remote output only, creates dust", @@ -384,7 +393,7 @@ func TestBackupTask(t *testing.T) { 0, // expSweepAmt 0, // expRewardAmt wtpolicy.ErrCreatesDust, // bindErr - tweakless, + chanType, ), genTaskTest( "commit reward, no outputs, fee rate exceeds inputs", @@ -397,7 +406,7 @@ func TestBackupTask(t *testing.T) { 0, // expSweepAmt 0, // expRewardAmt wtpolicy.ErrFeeExceedsInputs, // bindErr - tweakless, + chanType, ), genTaskTest( "commit reward, no outputs, fee rate of 0 creates dust", @@ -410,7 +419,7 @@ func TestBackupTask(t *testing.T) { 0, // expSweepAmt 0, // expRewardAmt wtpolicy.ErrCreatesDust, // bindErr - tweakless, + chanType, ), }...) } @@ -430,7 +439,7 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // Create a new backupTask from the channel id and breach info. task := newBackupTask( &test.chanID, test.breachInfo, test.expSweepScript, - test.tweakless, + test.chanType, ) // Assert that all parameters set during initialization are properly diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 71d767ef..d0cb4283 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" @@ -103,7 +104,8 @@ type Client interface { // negotiated policy. If the channel we're trying to back up doesn't // have a tweak for the remote party's output, then isTweakless should // be true. - BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution, bool) error + BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution, + channeldb.ChannelType) error // Start initializes the watchtower client, allowing it process requests // to backup revoked channel states. @@ -592,7 +594,8 @@ func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error { // - breached outputs contain too little value to sweep at the target sweep fee // rate. func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, - breachInfo *lnwallet.BreachRetribution, isTweakless bool) error { + breachInfo *lnwallet.BreachRetribution, + chanType channeldb.ChannelType) error { // Retrieve the cached sweep pkscript used for this channel. c.backupMu.Lock() @@ -618,7 +621,7 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, c.backupMu.Unlock() task := newBackupTask( - chanID, breachInfo, summary.SweepPkScript, isTweakless, + chanID, breachInfo, summary.SweepPkScript, chanType, ) return c.pipeline.QueueBackupTask(task) diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 166ab228..a341b849 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -12,6 +12,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" @@ -632,7 +633,7 @@ func (h *testHarness) backupState(id, i uint64, expErr error) { _, retribution := h.channel(id).getState(i) chanID := chanIDFromInt(id) - err := h.client.BackupState(&chanID, retribution, false) + err := h.client.BackupState(&chanID, retribution, channeldb.SingleFunderBit) if err != expErr { h.t.Fatalf("back error mismatch, want: %v, got: %v", expErr, err)