diff --git a/sweep/sweeper.go b/sweep/sweeper.go index c6839d44..46e0ba69 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -144,6 +144,7 @@ type pendingInputs = map[wire.OutPoint]*pendingInput // inputCluster is a helper struct to gather a set of pending inputs that should // be swept with the specified fee rate. type inputCluster struct { + lockTime *uint32 sweepFeeRate chainfee.SatPerKWeight inputs pendingInputs } @@ -833,6 +834,99 @@ func (s *UtxoSweeper) clusterBySweepFeeRate(inputs pendingInputs) []inputCluster return inputClusters } +// zipClusters merges pairwise clusters from as and bs such that cluster a from +// as is merged with a cluster from bs that has at least the fee rate of a. +// This to ensure we don't delay confirmation by decreasing the fee rate (the +// lock time inputs are typically second level HTLC transactions, that are time +// sensitive). +func zipClusters(as, bs []inputCluster) []inputCluster { + // Sort the clusters by decreasing fee rates. + sort.Slice(as, func(i, j int) bool { + return as[i].sweepFeeRate > + as[j].sweepFeeRate + }) + sort.Slice(bs, func(i, j int) bool { + return bs[i].sweepFeeRate > + bs[j].sweepFeeRate + }) + + var ( + finalClusters []inputCluster + j int + ) + + // Go through each cluster in as, and merge with the next one from bs + // if it has at least the fee rate needed. + for i := range as { + a := as[i] + + switch { + + // If the fee rate for the next one from bs is at least a's, we + // merge. + case j < len(bs) && bs[j].sweepFeeRate >= a.sweepFeeRate: + merged := mergeClusters(a, bs[j]) + finalClusters = append(finalClusters, merged...) + + // Increment j for the next round. + j++ + + // We did not merge, meaning all the remining clusters from bs + // have lower fee rate. Instead we add a directly to the final + // clusters. + default: + finalClusters = append(finalClusters, a) + } + } + + // Add any remaining clusters from bs. + for ; j < len(bs); j++ { + b := bs[j] + finalClusters = append(finalClusters, b) + } + + return finalClusters +} + +// mergeClusters attempts to merge cluster a and b if they are compatible. The +// new cluster will have the locktime set if a or b had a locktime set, and a +// sweep fee rate that is the maximum of a and b's. If the two clusters are not +// compatible, they will be returned unchanged. +func mergeClusters(a, b inputCluster) []inputCluster { + newCluster := inputCluster{} + + switch { + + // Incompatible locktimes, return the sets without merging them. + case a.lockTime != nil && b.lockTime != nil && *a.lockTime != *b.lockTime: + return []inputCluster{a, b} + + case a.lockTime != nil: + newCluster.lockTime = a.lockTime + + case b.lockTime != nil: + newCluster.lockTime = b.lockTime + } + + if a.sweepFeeRate > b.sweepFeeRate { + newCluster.sweepFeeRate = a.sweepFeeRate + } else { + newCluster.sweepFeeRate = b.sweepFeeRate + } + + newCluster.inputs = make(pendingInputs) + + for op, in := range a.inputs { + newCluster.inputs[op] = in + } + + for op, in := range b.inputs { + newCluster.inputs[op] = in + } + + return []inputCluster{newCluster} +} + // scheduleSweep starts the sweep timer to create an opportunity for more inputs // to be added. func (s *UtxoSweeper) scheduleSweep(currentHeight int32) error { diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index 2f71c35b..d8c62001 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -2,6 +2,7 @@ package sweep import ( "os" + "reflect" "runtime/debug" "runtime/pprof" "testing" @@ -11,6 +12,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" + "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/build" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -1376,3 +1378,221 @@ func TestCpfp(t *testing.T) { ctx.finish(1) } + +var ( + testInputsA = pendingInputs{ + wire.OutPoint{Hash: chainhash.Hash{}, Index: 0}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 1}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 2}: &pendingInput{}, + } + + testInputsB = pendingInputs{ + wire.OutPoint{Hash: chainhash.Hash{}, Index: 10}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 11}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 12}: &pendingInput{}, + } + + testInputsC = pendingInputs{ + wire.OutPoint{Hash: chainhash.Hash{}, Index: 0}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 1}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 2}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 10}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 11}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 12}: &pendingInput{}, + } +) + +// TestMergeClusters check that we properly can merge clusters together, +// according to their required locktime. +func TestMergeClusters(t *testing.T) { + t.Parallel() + + lockTime1 := uint32(100) + lockTime2 := uint32(200) + + testCases := []struct { + name string + a inputCluster + b inputCluster + res []inputCluster + }{ + { + name: "max fee rate", + a: inputCluster{ + sweepFeeRate: 5000, + inputs: testInputsA, + }, + b: inputCluster{ + sweepFeeRate: 7000, + inputs: testInputsB, + }, + res: []inputCluster{ + { + sweepFeeRate: 7000, + inputs: testInputsC, + }, + }, + }, + { + name: "same locktime", + a: inputCluster{ + lockTime: &lockTime1, + sweepFeeRate: 5000, + inputs: testInputsA, + }, + b: inputCluster{ + lockTime: &lockTime1, + sweepFeeRate: 7000, + inputs: testInputsB, + }, + res: []inputCluster{ + { + lockTime: &lockTime1, + sweepFeeRate: 7000, + inputs: testInputsC, + }, + }, + }, + { + name: "diff locktime", + a: inputCluster{ + lockTime: &lockTime1, + sweepFeeRate: 5000, + inputs: testInputsA, + }, + b: inputCluster{ + lockTime: &lockTime2, + sweepFeeRate: 7000, + inputs: testInputsB, + }, + res: []inputCluster{ + { + lockTime: &lockTime1, + sweepFeeRate: 5000, + inputs: testInputsA, + }, + { + lockTime: &lockTime2, + sweepFeeRate: 7000, + inputs: testInputsB, + }, + }, + }, + } + + for _, test := range testCases { + merged := mergeClusters(test.a, test.b) + if !reflect.DeepEqual(merged, test.res) { + t.Fatalf("[%s] unexpected result: %v", + test.name, spew.Sdump(merged)) + } + } +} + +// TestZipClusters tests that we can merge lists of inputs clusters correctly. +func TestZipClusters(t *testing.T) { + t.Parallel() + + createCluster := func(inp pendingInputs, f chainfee.SatPerKWeight) inputCluster { + return inputCluster{ + sweepFeeRate: f, + inputs: inp, + } + } + + testCases := []struct { + name string + as []inputCluster + bs []inputCluster + res []inputCluster + }{ + { + name: "merge A into B", + as: []inputCluster{ + createCluster(testInputsA, 5000), + }, + bs: []inputCluster{ + createCluster(testInputsB, 7000), + }, + res: []inputCluster{ + createCluster(testInputsC, 7000), + }, + }, + { + name: "A can't merge with B", + as: []inputCluster{ + createCluster(testInputsA, 7000), + }, + bs: []inputCluster{ + createCluster(testInputsB, 5000), + }, + res: []inputCluster{ + createCluster(testInputsA, 7000), + createCluster(testInputsB, 5000), + }, + }, + { + name: "empty bs", + as: []inputCluster{ + createCluster(testInputsA, 7000), + }, + bs: []inputCluster{}, + res: []inputCluster{ + createCluster(testInputsA, 7000), + }, + }, + { + name: "empty as", + as: []inputCluster{}, + bs: []inputCluster{ + createCluster(testInputsB, 5000), + }, + res: []inputCluster{ + createCluster(testInputsB, 5000), + }, + }, + + { + name: "zip 3xA into 3xB", + as: []inputCluster{ + createCluster(testInputsA, 5000), + createCluster(testInputsA, 5000), + createCluster(testInputsA, 5000), + }, + bs: []inputCluster{ + createCluster(testInputsB, 7000), + createCluster(testInputsB, 7000), + createCluster(testInputsB, 7000), + }, + res: []inputCluster{ + createCluster(testInputsC, 7000), + createCluster(testInputsC, 7000), + createCluster(testInputsC, 7000), + }, + }, + { + name: "zip A into 3xB", + as: []inputCluster{ + createCluster(testInputsA, 2500), + }, + bs: []inputCluster{ + createCluster(testInputsB, 3000), + createCluster(testInputsB, 2000), + createCluster(testInputsB, 1000), + }, + res: []inputCluster{ + createCluster(testInputsC, 3000), + createCluster(testInputsB, 2000), + createCluster(testInputsB, 1000), + }, + }, + } + + for _, test := range testCases { + zipped := zipClusters(test.as, test.bs) + if !reflect.DeepEqual(zipped, test.res) { + t.Fatalf("[%s] unexpected result: %v", + test.name, spew.Sdump(zipped)) + } + } +}