sweep: add mergeClusters

This commit is contained in:
Johan T. Halseth 2020-11-06 14:34:49 +01:00
parent 128087044f
commit fa4fd02cf1
No known key found for this signature in database
GPG Key ID: 15BAADA29DA20D26
2 changed files with 314 additions and 0 deletions

@ -144,6 +144,7 @@ type pendingInputs = map[wire.OutPoint]*pendingInput
// inputCluster is a helper struct to gather a set of pending inputs that should // inputCluster is a helper struct to gather a set of pending inputs that should
// be swept with the specified fee rate. // be swept with the specified fee rate.
type inputCluster struct { type inputCluster struct {
lockTime *uint32
sweepFeeRate chainfee.SatPerKWeight sweepFeeRate chainfee.SatPerKWeight
inputs pendingInputs inputs pendingInputs
} }
@ -833,6 +834,99 @@ func (s *UtxoSweeper) clusterBySweepFeeRate(inputs pendingInputs) []inputCluster
return inputClusters return inputClusters
} }
// zipClusters merges pairwise clusters from as and bs such that cluster a from
// as is merged with a cluster from bs that has at least the fee rate of a.
// This to ensure we don't delay confirmation by decreasing the fee rate (the
// lock time inputs are typically second level HTLC transactions, that are time
// sensitive).
func zipClusters(as, bs []inputCluster) []inputCluster {
// Sort the clusters by decreasing fee rates.
sort.Slice(as, func(i, j int) bool {
return as[i].sweepFeeRate >
as[j].sweepFeeRate
})
sort.Slice(bs, func(i, j int) bool {
return bs[i].sweepFeeRate >
bs[j].sweepFeeRate
})
var (
finalClusters []inputCluster
j int
)
// Go through each cluster in as, and merge with the next one from bs
// if it has at least the fee rate needed.
for i := range as {
a := as[i]
switch {
// If the fee rate for the next one from bs is at least a's, we
// merge.
case j < len(bs) && bs[j].sweepFeeRate >= a.sweepFeeRate:
merged := mergeClusters(a, bs[j])
finalClusters = append(finalClusters, merged...)
// Increment j for the next round.
j++
// We did not merge, meaning all the remining clusters from bs
// have lower fee rate. Instead we add a directly to the final
// clusters.
default:
finalClusters = append(finalClusters, a)
}
}
// Add any remaining clusters from bs.
for ; j < len(bs); j++ {
b := bs[j]
finalClusters = append(finalClusters, b)
}
return finalClusters
}
// mergeClusters attempts to merge cluster a and b if they are compatible. The
// new cluster will have the locktime set if a or b had a locktime set, and a
// sweep fee rate that is the maximum of a and b's. If the two clusters are not
// compatible, they will be returned unchanged.
func mergeClusters(a, b inputCluster) []inputCluster {
newCluster := inputCluster{}
switch {
// Incompatible locktimes, return the sets without merging them.
case a.lockTime != nil && b.lockTime != nil && *a.lockTime != *b.lockTime:
return []inputCluster{a, b}
case a.lockTime != nil:
newCluster.lockTime = a.lockTime
case b.lockTime != nil:
newCluster.lockTime = b.lockTime
}
if a.sweepFeeRate > b.sweepFeeRate {
newCluster.sweepFeeRate = a.sweepFeeRate
} else {
newCluster.sweepFeeRate = b.sweepFeeRate
}
newCluster.inputs = make(pendingInputs)
for op, in := range a.inputs {
newCluster.inputs[op] = in
}
for op, in := range b.inputs {
newCluster.inputs[op] = in
}
return []inputCluster{newCluster}
}
// scheduleSweep starts the sweep timer to create an opportunity for more inputs // scheduleSweep starts the sweep timer to create an opportunity for more inputs
// to be added. // to be added.
func (s *UtxoSweeper) scheduleSweep(currentHeight int32) error { func (s *UtxoSweeper) scheduleSweep(currentHeight int32) error {

@ -2,6 +2,7 @@ package sweep
import ( import (
"os" "os"
"reflect"
"runtime/debug" "runtime/debug"
"runtime/pprof" "runtime/pprof"
"testing" "testing"
@ -11,6 +12,7 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/build" "github.com/lightningnetwork/lnd/build"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
@ -1376,3 +1378,221 @@ func TestCpfp(t *testing.T) {
ctx.finish(1) ctx.finish(1)
} }
var (
testInputsA = pendingInputs{
wire.OutPoint{Hash: chainhash.Hash{}, Index: 0}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 1}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 2}: &pendingInput{},
}
testInputsB = pendingInputs{
wire.OutPoint{Hash: chainhash.Hash{}, Index: 10}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 11}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 12}: &pendingInput{},
}
testInputsC = pendingInputs{
wire.OutPoint{Hash: chainhash.Hash{}, Index: 0}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 1}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 2}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 10}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 11}: &pendingInput{},
wire.OutPoint{Hash: chainhash.Hash{}, Index: 12}: &pendingInput{},
}
)
// TestMergeClusters check that we properly can merge clusters together,
// according to their required locktime.
func TestMergeClusters(t *testing.T) {
t.Parallel()
lockTime1 := uint32(100)
lockTime2 := uint32(200)
testCases := []struct {
name string
a inputCluster
b inputCluster
res []inputCluster
}{
{
name: "max fee rate",
a: inputCluster{
sweepFeeRate: 5000,
inputs: testInputsA,
},
b: inputCluster{
sweepFeeRate: 7000,
inputs: testInputsB,
},
res: []inputCluster{
{
sweepFeeRate: 7000,
inputs: testInputsC,
},
},
},
{
name: "same locktime",
a: inputCluster{
lockTime: &lockTime1,
sweepFeeRate: 5000,
inputs: testInputsA,
},
b: inputCluster{
lockTime: &lockTime1,
sweepFeeRate: 7000,
inputs: testInputsB,
},
res: []inputCluster{
{
lockTime: &lockTime1,
sweepFeeRate: 7000,
inputs: testInputsC,
},
},
},
{
name: "diff locktime",
a: inputCluster{
lockTime: &lockTime1,
sweepFeeRate: 5000,
inputs: testInputsA,
},
b: inputCluster{
lockTime: &lockTime2,
sweepFeeRate: 7000,
inputs: testInputsB,
},
res: []inputCluster{
{
lockTime: &lockTime1,
sweepFeeRate: 5000,
inputs: testInputsA,
},
{
lockTime: &lockTime2,
sweepFeeRate: 7000,
inputs: testInputsB,
},
},
},
}
for _, test := range testCases {
merged := mergeClusters(test.a, test.b)
if !reflect.DeepEqual(merged, test.res) {
t.Fatalf("[%s] unexpected result: %v",
test.name, spew.Sdump(merged))
}
}
}
// TestZipClusters tests that we can merge lists of inputs clusters correctly.
func TestZipClusters(t *testing.T) {
t.Parallel()
createCluster := func(inp pendingInputs, f chainfee.SatPerKWeight) inputCluster {
return inputCluster{
sweepFeeRate: f,
inputs: inp,
}
}
testCases := []struct {
name string
as []inputCluster
bs []inputCluster
res []inputCluster
}{
{
name: "merge A into B",
as: []inputCluster{
createCluster(testInputsA, 5000),
},
bs: []inputCluster{
createCluster(testInputsB, 7000),
},
res: []inputCluster{
createCluster(testInputsC, 7000),
},
},
{
name: "A can't merge with B",
as: []inputCluster{
createCluster(testInputsA, 7000),
},
bs: []inputCluster{
createCluster(testInputsB, 5000),
},
res: []inputCluster{
createCluster(testInputsA, 7000),
createCluster(testInputsB, 5000),
},
},
{
name: "empty bs",
as: []inputCluster{
createCluster(testInputsA, 7000),
},
bs: []inputCluster{},
res: []inputCluster{
createCluster(testInputsA, 7000),
},
},
{
name: "empty as",
as: []inputCluster{},
bs: []inputCluster{
createCluster(testInputsB, 5000),
},
res: []inputCluster{
createCluster(testInputsB, 5000),
},
},
{
name: "zip 3xA into 3xB",
as: []inputCluster{
createCluster(testInputsA, 5000),
createCluster(testInputsA, 5000),
createCluster(testInputsA, 5000),
},
bs: []inputCluster{
createCluster(testInputsB, 7000),
createCluster(testInputsB, 7000),
createCluster(testInputsB, 7000),
},
res: []inputCluster{
createCluster(testInputsC, 7000),
createCluster(testInputsC, 7000),
createCluster(testInputsC, 7000),
},
},
{
name: "zip A into 3xB",
as: []inputCluster{
createCluster(testInputsA, 2500),
},
bs: []inputCluster{
createCluster(testInputsB, 3000),
createCluster(testInputsB, 2000),
createCluster(testInputsB, 1000),
},
res: []inputCluster{
createCluster(testInputsC, 3000),
createCluster(testInputsB, 2000),
createCluster(testInputsB, 1000),
},
},
}
for _, test := range testCases {
zipped := zipClusters(test.as, test.bs)
if !reflect.DeepEqual(zipped, test.res) {
t.Fatalf("[%s] unexpected result: %v",
test.name, spew.Sdump(zipped))
}
}
}