watchtower/wtclient: parameterize backup task with channel type
This commit is contained in:
parent
3856acce50
commit
dd325f04d2
@ -183,7 +183,8 @@ type TowerClient interface {
|
||||
// abide by the negotiated policy. If the channel we're trying to back
|
||||
// up doesn't have a tweak for the remote party's output, then
|
||||
// isTweakless should be true.
|
||||
BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution, bool) error
|
||||
BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution,
|
||||
channeldb.ChannelType) error
|
||||
}
|
||||
|
||||
// InterceptableHtlcForwarder is the interface to set the interceptor
|
||||
|
@ -1855,7 +1855,7 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) {
|
||||
chanType := l.channel.State().ChanType
|
||||
chanID := l.ChanID()
|
||||
err = l.cfg.TowerClient.BackupState(
|
||||
&chanID, breachInfo, chanType.IsTweakless(),
|
||||
&chanID, breachInfo, chanType,
|
||||
)
|
||||
if err != nil {
|
||||
l.fail(LinkFailureError{code: ErrInternalError},
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/btcsuite/btcutil"
|
||||
"github.com/btcsuite/btcutil/txsort"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
"github.com/lightningnetwork/lnd/lnwallet"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
@ -54,7 +55,7 @@ type backupTask struct {
|
||||
// variables.
|
||||
func newBackupTask(chanID *lnwire.ChannelID,
|
||||
breachInfo *lnwallet.BreachRetribution,
|
||||
sweepPkScript []byte, isTweakless bool) *backupTask {
|
||||
sweepPkScript []byte, chanType channeldb.ChannelType) *backupTask {
|
||||
|
||||
// Parse the non-dust outputs from the breach transaction,
|
||||
// simultaneously computing the total amount contained in the inputs
|
||||
@ -85,9 +86,12 @@ func newBackupTask(chanID *lnwire.ChannelID,
|
||||
totalAmt += breachInfo.RemoteOutputSignDesc.Output.Value
|
||||
}
|
||||
if breachInfo.LocalOutputSignDesc != nil {
|
||||
witnessType := input.CommitmentNoDelay
|
||||
if isTweakless {
|
||||
var witnessType input.WitnessType
|
||||
switch {
|
||||
case chanType.IsTweakless():
|
||||
witnessType = input.CommitSpendNoDelayTweakless
|
||||
default:
|
||||
witnessType = input.CommitmentNoDelay
|
||||
}
|
||||
|
||||
toRemoteInput = input.NewBaseInput(
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/btcsuite/btcutil"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
"github.com/lightningnetwork/lnd/lnwallet"
|
||||
@ -74,7 +75,7 @@ type backupTaskTest struct {
|
||||
bindErr error
|
||||
expSweepScript []byte
|
||||
signer input.Signer
|
||||
tweakless bool
|
||||
chanType channeldb.ChannelType
|
||||
}
|
||||
|
||||
// genTaskTest creates a instance of a backupTaskTest using the passed
|
||||
@ -92,7 +93,7 @@ func genTaskTest(
|
||||
expSweepAmt int64,
|
||||
expRewardAmt int64,
|
||||
bindErr error,
|
||||
tweakless bool) backupTaskTest {
|
||||
chanType channeldb.ChannelType) backupTaskTest {
|
||||
|
||||
// Parse the key pairs for all keys used in the test.
|
||||
revSK, revPK := btcec.PrivKeyFromBytes(
|
||||
@ -192,9 +193,12 @@ func genTaskTest(
|
||||
Index: index,
|
||||
}
|
||||
|
||||
witnessType := input.CommitmentNoDelay
|
||||
if tweakless {
|
||||
var witnessType input.WitnessType
|
||||
switch {
|
||||
case chanType.IsTweakless():
|
||||
witnessType = input.CommitSpendNoDelayTweakless
|
||||
default:
|
||||
witnessType = input.CommitmentNoDelay
|
||||
}
|
||||
|
||||
toRemoteInput = input.NewBaseInput(
|
||||
@ -227,7 +231,7 @@ func genTaskTest(
|
||||
bindErr: bindErr,
|
||||
expSweepScript: makeAddrSlice(22),
|
||||
signer: signer,
|
||||
tweakless: tweakless,
|
||||
chanType: chanType,
|
||||
}
|
||||
}
|
||||
|
||||
@ -253,8 +257,13 @@ var (
|
||||
func TestBackupTask(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chanTypes := []channeldb.ChannelType{
|
||||
channeldb.SingleFunderBit,
|
||||
channeldb.SingleFunderTweaklessBit,
|
||||
}
|
||||
|
||||
var backupTaskTests []backupTaskTest
|
||||
for _, tweakless := range []bool{true, false} {
|
||||
for _, chanType := range chanTypes {
|
||||
backupTaskTests = append(backupTaskTests, []backupTaskTest{
|
||||
genTaskTest(
|
||||
"commit no-reward, both outputs",
|
||||
@ -267,7 +276,7 @@ func TestBackupTask(t *testing.T) {
|
||||
299241, // expSweepAmt
|
||||
0, // expRewardAmt
|
||||
nil, // bindErr
|
||||
tweakless,
|
||||
chanType,
|
||||
),
|
||||
genTaskTest(
|
||||
"commit no-reward, to-local output only",
|
||||
@ -280,7 +289,7 @@ func TestBackupTask(t *testing.T) {
|
||||
199514, // expSweepAmt
|
||||
0, // expRewardAmt
|
||||
nil, // bindErr
|
||||
tweakless,
|
||||
chanType,
|
||||
),
|
||||
genTaskTest(
|
||||
"commit no-reward, to-remote output only",
|
||||
@ -293,7 +302,7 @@ func TestBackupTask(t *testing.T) {
|
||||
99561, // expSweepAmt
|
||||
0, // expRewardAmt
|
||||
nil, // bindErr
|
||||
tweakless,
|
||||
chanType,
|
||||
),
|
||||
genTaskTest(
|
||||
"commit no-reward, to-remote output only, creates dust",
|
||||
@ -306,7 +315,7 @@ func TestBackupTask(t *testing.T) {
|
||||
0, // expSweepAmt
|
||||
0, // expRewardAmt
|
||||
wtpolicy.ErrCreatesDust, // bindErr
|
||||
tweakless,
|
||||
chanType,
|
||||
),
|
||||
genTaskTest(
|
||||
"commit no-reward, no outputs, fee rate exceeds inputs",
|
||||
@ -319,7 +328,7 @@ func TestBackupTask(t *testing.T) {
|
||||
0, // expSweepAmt
|
||||
0, // expRewardAmt
|
||||
wtpolicy.ErrFeeExceedsInputs, // bindErr
|
||||
tweakless,
|
||||
chanType,
|
||||
),
|
||||
genTaskTest(
|
||||
"commit no-reward, no outputs, fee rate of 0 creates dust",
|
||||
@ -332,7 +341,7 @@ func TestBackupTask(t *testing.T) {
|
||||
0, // expSweepAmt
|
||||
0, // expRewardAmt
|
||||
wtpolicy.ErrCreatesDust, // bindErr
|
||||
tweakless,
|
||||
chanType,
|
||||
),
|
||||
genTaskTest(
|
||||
"commit reward, both outputs",
|
||||
@ -345,7 +354,7 @@ func TestBackupTask(t *testing.T) {
|
||||
296117, // expSweepAmt
|
||||
3000, // expRewardAmt
|
||||
nil, // bindErr
|
||||
tweakless,
|
||||
chanType,
|
||||
),
|
||||
genTaskTest(
|
||||
"commit reward, to-local output only",
|
||||
@ -358,7 +367,7 @@ func TestBackupTask(t *testing.T) {
|
||||
197390, // expSweepAmt
|
||||
2000, // expRewardAmt
|
||||
nil, // bindErr
|
||||
tweakless,
|
||||
chanType,
|
||||
),
|
||||
genTaskTest(
|
||||
"commit reward, to-remote output only",
|
||||
@ -371,7 +380,7 @@ func TestBackupTask(t *testing.T) {
|
||||
98437, // expSweepAmt
|
||||
1000, // expRewardAmt
|
||||
nil, // bindErr
|
||||
tweakless,
|
||||
chanType,
|
||||
),
|
||||
genTaskTest(
|
||||
"commit reward, to-remote output only, creates dust",
|
||||
@ -384,7 +393,7 @@ func TestBackupTask(t *testing.T) {
|
||||
0, // expSweepAmt
|
||||
0, // expRewardAmt
|
||||
wtpolicy.ErrCreatesDust, // bindErr
|
||||
tweakless,
|
||||
chanType,
|
||||
),
|
||||
genTaskTest(
|
||||
"commit reward, no outputs, fee rate exceeds inputs",
|
||||
@ -397,7 +406,7 @@ func TestBackupTask(t *testing.T) {
|
||||
0, // expSweepAmt
|
||||
0, // expRewardAmt
|
||||
wtpolicy.ErrFeeExceedsInputs, // bindErr
|
||||
tweakless,
|
||||
chanType,
|
||||
),
|
||||
genTaskTest(
|
||||
"commit reward, no outputs, fee rate of 0 creates dust",
|
||||
@ -410,7 +419,7 @@ func TestBackupTask(t *testing.T) {
|
||||
0, // expSweepAmt
|
||||
0, // expRewardAmt
|
||||
wtpolicy.ErrCreatesDust, // bindErr
|
||||
tweakless,
|
||||
chanType,
|
||||
),
|
||||
}...)
|
||||
}
|
||||
@ -430,7 +439,7 @@ func testBackupTask(t *testing.T, test backupTaskTest) {
|
||||
// Create a new backupTask from the channel id and breach info.
|
||||
task := newBackupTask(
|
||||
&test.chanID, test.breachInfo, test.expSweepScript,
|
||||
test.tweakless,
|
||||
test.chanType,
|
||||
)
|
||||
|
||||
// Assert that all parameters set during initialization are properly
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
"github.com/lightningnetwork/lnd/lnwallet"
|
||||
@ -103,7 +104,8 @@ type Client interface {
|
||||
// negotiated policy. If the channel we're trying to back up doesn't
|
||||
// have a tweak for the remote party's output, then isTweakless should
|
||||
// be true.
|
||||
BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution, bool) error
|
||||
BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution,
|
||||
channeldb.ChannelType) error
|
||||
|
||||
// Start initializes the watchtower client, allowing it process requests
|
||||
// to backup revoked channel states.
|
||||
@ -592,7 +594,8 @@ func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error {
|
||||
// - breached outputs contain too little value to sweep at the target sweep fee
|
||||
// rate.
|
||||
func (c *TowerClient) BackupState(chanID *lnwire.ChannelID,
|
||||
breachInfo *lnwallet.BreachRetribution, isTweakless bool) error {
|
||||
breachInfo *lnwallet.BreachRetribution,
|
||||
chanType channeldb.ChannelType) error {
|
||||
|
||||
// Retrieve the cached sweep pkscript used for this channel.
|
||||
c.backupMu.Lock()
|
||||
@ -618,7 +621,7 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID,
|
||||
c.backupMu.Unlock()
|
||||
|
||||
task := newBackupTask(
|
||||
chanID, breachInfo, summary.SweepPkScript, isTweakless,
|
||||
chanID, breachInfo, summary.SweepPkScript, chanType,
|
||||
)
|
||||
|
||||
return c.pipeline.QueueBackupTask(task)
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"github.com/btcsuite/btcd/txscript"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/btcsuite/btcutil"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
"github.com/lightningnetwork/lnd/lnwallet"
|
||||
@ -632,7 +633,7 @@ func (h *testHarness) backupState(id, i uint64, expErr error) {
|
||||
_, retribution := h.channel(id).getState(i)
|
||||
|
||||
chanID := chanIDFromInt(id)
|
||||
err := h.client.BackupState(&chanID, retribution, false)
|
||||
err := h.client.BackupState(&chanID, retribution, channeldb.SingleFunderBit)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("back error mismatch, want: %v, got: %v",
|
||||
expErr, err)
|
||||
|
Loading…
Reference in New Issue
Block a user