watchtower/wtclient: parameterize backup task with channel type

This commit is contained in:
Conner Fromknecht 2020-11-25 15:04:12 -08:00
parent 3856acce50
commit dd325f04d2
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7
6 changed files with 46 additions and 28 deletions

@ -183,7 +183,8 @@ type TowerClient interface {
// abide by the negotiated policy. If the channel we're trying to back // abide by the negotiated policy. If the channel we're trying to back
// up doesn't have a tweak for the remote party's output, then // up doesn't have a tweak for the remote party's output, then
// isTweakless should be true. // isTweakless should be true.
BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution, bool) error BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution,
channeldb.ChannelType) error
} }
// InterceptableHtlcForwarder is the interface to set the interceptor // InterceptableHtlcForwarder is the interface to set the interceptor

@ -1855,7 +1855,7 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) {
chanType := l.channel.State().ChanType chanType := l.channel.State().ChanType
chanID := l.ChanID() chanID := l.ChanID()
err = l.cfg.TowerClient.BackupState( err = l.cfg.TowerClient.BackupState(
&chanID, breachInfo, chanType.IsTweakless(), &chanID, breachInfo, chanType,
) )
if err != nil { if err != nil {
l.fail(LinkFailureError{code: ErrInternalError}, l.fail(LinkFailureError{code: ErrInternalError},

@ -7,6 +7,7 @@ import (
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/btcsuite/btcutil/txsort" "github.com/btcsuite/btcutil/txsort"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
@ -54,7 +55,7 @@ type backupTask struct {
// variables. // variables.
func newBackupTask(chanID *lnwire.ChannelID, func newBackupTask(chanID *lnwire.ChannelID,
breachInfo *lnwallet.BreachRetribution, breachInfo *lnwallet.BreachRetribution,
sweepPkScript []byte, isTweakless bool) *backupTask { sweepPkScript []byte, chanType channeldb.ChannelType) *backupTask {
// Parse the non-dust outputs from the breach transaction, // Parse the non-dust outputs from the breach transaction,
// simultaneously computing the total amount contained in the inputs // simultaneously computing the total amount contained in the inputs
@ -85,9 +86,12 @@ func newBackupTask(chanID *lnwire.ChannelID,
totalAmt += breachInfo.RemoteOutputSignDesc.Output.Value totalAmt += breachInfo.RemoteOutputSignDesc.Output.Value
} }
if breachInfo.LocalOutputSignDesc != nil { if breachInfo.LocalOutputSignDesc != nil {
witnessType := input.CommitmentNoDelay var witnessType input.WitnessType
if isTweakless { switch {
case chanType.IsTweakless():
witnessType = input.CommitSpendNoDelayTweakless witnessType = input.CommitSpendNoDelayTweakless
default:
witnessType = input.CommitmentNoDelay
} }
toRemoteInput = input.NewBaseInput( toRemoteInput = input.NewBaseInput(

@ -13,6 +13,7 @@ import (
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
@ -74,7 +75,7 @@ type backupTaskTest struct {
bindErr error bindErr error
expSweepScript []byte expSweepScript []byte
signer input.Signer signer input.Signer
tweakless bool chanType channeldb.ChannelType
} }
// genTaskTest creates a instance of a backupTaskTest using the passed // genTaskTest creates a instance of a backupTaskTest using the passed
@ -92,7 +93,7 @@ func genTaskTest(
expSweepAmt int64, expSweepAmt int64,
expRewardAmt int64, expRewardAmt int64,
bindErr error, bindErr error,
tweakless bool) backupTaskTest { chanType channeldb.ChannelType) backupTaskTest {
// Parse the key pairs for all keys used in the test. // Parse the key pairs for all keys used in the test.
revSK, revPK := btcec.PrivKeyFromBytes( revSK, revPK := btcec.PrivKeyFromBytes(
@ -192,9 +193,12 @@ func genTaskTest(
Index: index, Index: index,
} }
witnessType := input.CommitmentNoDelay var witnessType input.WitnessType
if tweakless { switch {
case chanType.IsTweakless():
witnessType = input.CommitSpendNoDelayTweakless witnessType = input.CommitSpendNoDelayTweakless
default:
witnessType = input.CommitmentNoDelay
} }
toRemoteInput = input.NewBaseInput( toRemoteInput = input.NewBaseInput(
@ -227,7 +231,7 @@ func genTaskTest(
bindErr: bindErr, bindErr: bindErr,
expSweepScript: makeAddrSlice(22), expSweepScript: makeAddrSlice(22),
signer: signer, signer: signer,
tweakless: tweakless, chanType: chanType,
} }
} }
@ -253,8 +257,13 @@ var (
func TestBackupTask(t *testing.T) { func TestBackupTask(t *testing.T) {
t.Parallel() t.Parallel()
chanTypes := []channeldb.ChannelType{
channeldb.SingleFunderBit,
channeldb.SingleFunderTweaklessBit,
}
var backupTaskTests []backupTaskTest var backupTaskTests []backupTaskTest
for _, tweakless := range []bool{true, false} { for _, chanType := range chanTypes {
backupTaskTests = append(backupTaskTests, []backupTaskTest{ backupTaskTests = append(backupTaskTests, []backupTaskTest{
genTaskTest( genTaskTest(
"commit no-reward, both outputs", "commit no-reward, both outputs",
@ -267,7 +276,7 @@ func TestBackupTask(t *testing.T) {
299241, // expSweepAmt 299241, // expSweepAmt
0, // expRewardAmt 0, // expRewardAmt
nil, // bindErr nil, // bindErr
tweakless, chanType,
), ),
genTaskTest( genTaskTest(
"commit no-reward, to-local output only", "commit no-reward, to-local output only",
@ -280,7 +289,7 @@ func TestBackupTask(t *testing.T) {
199514, // expSweepAmt 199514, // expSweepAmt
0, // expRewardAmt 0, // expRewardAmt
nil, // bindErr nil, // bindErr
tweakless, chanType,
), ),
genTaskTest( genTaskTest(
"commit no-reward, to-remote output only", "commit no-reward, to-remote output only",
@ -293,7 +302,7 @@ func TestBackupTask(t *testing.T) {
99561, // expSweepAmt 99561, // expSweepAmt
0, // expRewardAmt 0, // expRewardAmt
nil, // bindErr nil, // bindErr
tweakless, chanType,
), ),
genTaskTest( genTaskTest(
"commit no-reward, to-remote output only, creates dust", "commit no-reward, to-remote output only, creates dust",
@ -306,7 +315,7 @@ func TestBackupTask(t *testing.T) {
0, // expSweepAmt 0, // expSweepAmt
0, // expRewardAmt 0, // expRewardAmt
wtpolicy.ErrCreatesDust, // bindErr wtpolicy.ErrCreatesDust, // bindErr
tweakless, chanType,
), ),
genTaskTest( genTaskTest(
"commit no-reward, no outputs, fee rate exceeds inputs", "commit no-reward, no outputs, fee rate exceeds inputs",
@ -319,7 +328,7 @@ func TestBackupTask(t *testing.T) {
0, // expSweepAmt 0, // expSweepAmt
0, // expRewardAmt 0, // expRewardAmt
wtpolicy.ErrFeeExceedsInputs, // bindErr wtpolicy.ErrFeeExceedsInputs, // bindErr
tweakless, chanType,
), ),
genTaskTest( genTaskTest(
"commit no-reward, no outputs, fee rate of 0 creates dust", "commit no-reward, no outputs, fee rate of 0 creates dust",
@ -332,7 +341,7 @@ func TestBackupTask(t *testing.T) {
0, // expSweepAmt 0, // expSweepAmt
0, // expRewardAmt 0, // expRewardAmt
wtpolicy.ErrCreatesDust, // bindErr wtpolicy.ErrCreatesDust, // bindErr
tweakless, chanType,
), ),
genTaskTest( genTaskTest(
"commit reward, both outputs", "commit reward, both outputs",
@ -345,7 +354,7 @@ func TestBackupTask(t *testing.T) {
296117, // expSweepAmt 296117, // expSweepAmt
3000, // expRewardAmt 3000, // expRewardAmt
nil, // bindErr nil, // bindErr
tweakless, chanType,
), ),
genTaskTest( genTaskTest(
"commit reward, to-local output only", "commit reward, to-local output only",
@ -358,7 +367,7 @@ func TestBackupTask(t *testing.T) {
197390, // expSweepAmt 197390, // expSweepAmt
2000, // expRewardAmt 2000, // expRewardAmt
nil, // bindErr nil, // bindErr
tweakless, chanType,
), ),
genTaskTest( genTaskTest(
"commit reward, to-remote output only", "commit reward, to-remote output only",
@ -371,7 +380,7 @@ func TestBackupTask(t *testing.T) {
98437, // expSweepAmt 98437, // expSweepAmt
1000, // expRewardAmt 1000, // expRewardAmt
nil, // bindErr nil, // bindErr
tweakless, chanType,
), ),
genTaskTest( genTaskTest(
"commit reward, to-remote output only, creates dust", "commit reward, to-remote output only, creates dust",
@ -384,7 +393,7 @@ func TestBackupTask(t *testing.T) {
0, // expSweepAmt 0, // expSweepAmt
0, // expRewardAmt 0, // expRewardAmt
wtpolicy.ErrCreatesDust, // bindErr wtpolicy.ErrCreatesDust, // bindErr
tweakless, chanType,
), ),
genTaskTest( genTaskTest(
"commit reward, no outputs, fee rate exceeds inputs", "commit reward, no outputs, fee rate exceeds inputs",
@ -397,7 +406,7 @@ func TestBackupTask(t *testing.T) {
0, // expSweepAmt 0, // expSweepAmt
0, // expRewardAmt 0, // expRewardAmt
wtpolicy.ErrFeeExceedsInputs, // bindErr wtpolicy.ErrFeeExceedsInputs, // bindErr
tweakless, chanType,
), ),
genTaskTest( genTaskTest(
"commit reward, no outputs, fee rate of 0 creates dust", "commit reward, no outputs, fee rate of 0 creates dust",
@ -410,7 +419,7 @@ func TestBackupTask(t *testing.T) {
0, // expSweepAmt 0, // expSweepAmt
0, // expRewardAmt 0, // expRewardAmt
wtpolicy.ErrCreatesDust, // bindErr wtpolicy.ErrCreatesDust, // bindErr
tweakless, chanType,
), ),
}...) }...)
} }
@ -430,7 +439,7 @@ func testBackupTask(t *testing.T, test backupTaskTest) {
// Create a new backupTask from the channel id and breach info. // Create a new backupTask from the channel id and breach info.
task := newBackupTask( task := newBackupTask(
&test.chanID, test.breachInfo, test.expSweepScript, &test.chanID, test.breachInfo, test.expSweepScript,
test.tweakless, test.chanType,
) )
// Assert that all parameters set during initialization are properly // Assert that all parameters set during initialization are properly

@ -10,6 +10,7 @@ import (
"github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
@ -103,7 +104,8 @@ type Client interface {
// negotiated policy. If the channel we're trying to back up doesn't // negotiated policy. If the channel we're trying to back up doesn't
// have a tweak for the remote party's output, then isTweakless should // have a tweak for the remote party's output, then isTweakless should
// be true. // be true.
BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution, bool) error BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution,
channeldb.ChannelType) error
// Start initializes the watchtower client, allowing it process requests // Start initializes the watchtower client, allowing it process requests
// to backup revoked channel states. // to backup revoked channel states.
@ -592,7 +594,8 @@ func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error {
// - breached outputs contain too little value to sweep at the target sweep fee // - breached outputs contain too little value to sweep at the target sweep fee
// rate. // rate.
func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, func (c *TowerClient) BackupState(chanID *lnwire.ChannelID,
breachInfo *lnwallet.BreachRetribution, isTweakless bool) error { breachInfo *lnwallet.BreachRetribution,
chanType channeldb.ChannelType) error {
// Retrieve the cached sweep pkscript used for this channel. // Retrieve the cached sweep pkscript used for this channel.
c.backupMu.Lock() c.backupMu.Lock()
@ -618,7 +621,7 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID,
c.backupMu.Unlock() c.backupMu.Unlock()
task := newBackupTask( task := newBackupTask(
chanID, breachInfo, summary.SweepPkScript, isTweakless, chanID, breachInfo, summary.SweepPkScript, chanType,
) )
return c.pipeline.QueueBackupTask(task) return c.pipeline.QueueBackupTask(task)

@ -12,6 +12,7 @@ import (
"github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
@ -632,7 +633,7 @@ func (h *testHarness) backupState(id, i uint64, expErr error) {
_, retribution := h.channel(id).getState(i) _, retribution := h.channel(id).getState(i)
chanID := chanIDFromInt(id) chanID := chanIDFromInt(id)
err := h.client.BackupState(&chanID, retribution, false) err := h.client.BackupState(&chanID, retribution, channeldb.SingleFunderBit)
if err != expErr { if err != expErr {
h.t.Fatalf("back error mismatch, want: %v, got: %v", h.t.Fatalf("back error mismatch, want: %v, got: %v",
expErr, err) expErr, err)