diff --git a/breacharbiter.go b/breacharbiter.go index ab1a8bd5..334e72f7 100644 --- a/breacharbiter.go +++ b/breacharbiter.go @@ -866,6 +866,12 @@ func (bo *breachedOutput) OutPoint() *wire.OutPoint { return &bo.outpoint } +// RequiredLockTime returns whether this input commits to a tx locktime that +// must be used in the transaction including it. +func (bo *breachedOutput) RequiredLockTime() (uint32, bool) { + return 0, false +} + // WitnessType returns the type of witness that must be generated to spend the // breached output. func (bo *breachedOutput) WitnessType() input.WitnessType { diff --git a/input/input.go b/input/input.go index 2e3a71c0..70b6ba93 100644 --- a/input/input.go +++ b/input/input.go @@ -15,6 +15,10 @@ type Input interface { // construct the corresponding transaction input. OutPoint() *wire.OutPoint + // RequiredLockTime returns whether this input commits to a tx locktime + // that must be used in the transaction including it. + RequiredLockTime() (uint32, bool) + // WitnessType returns an enum specifying the type of witness that must // be generated in order to spend this output. WitnessType() WitnessType @@ -75,6 +79,13 @@ func (i *inputKit) OutPoint() *wire.OutPoint { return &i.outpoint } +// RequiredLockTime returns whether this input commits to a tx locktime that +// must be used in the transaction including it. This will be false for the +// base input type since we can re-sign for any lock time. +func (i *inputKit) RequiredLockTime() (uint32, bool) { + return 0, false +} + // WitnessType returns the type of witness that must be generated to spend the // breached output. func (i *inputKit) WitnessType() WitnessType { diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 46e0ba69..09eb5a02 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -752,12 +752,86 @@ func (s *UtxoSweeper) bucketForFeeRate( } // createInputClusters creates a list of input clusters from the set of pending -// inputs known by the UtxoSweeper. +// inputs known by the UtxoSweeper. It clusters inputs by +// 1) Required tx locktime +// 2) Similar fee rates func (s *UtxoSweeper) createInputClusters() []inputCluster { inputs := s.pendingInputs - feeClusters := s.clusterBySweepFeeRate(inputs) - return feeClusters + // We start by getting the inputs clusters by locktime. Since the + // inputs commit to the locktime, they can only be clustered together + // if the locktime is equal. + lockTimeClusters, nonLockTimeInputs := s.clusterByLockTime(inputs) + + // Cluster the the remaining inputs by sweep fee rate. + feeClusters := s.clusterBySweepFeeRate(nonLockTimeInputs) + + // Since the inputs that we clustered by fee rate don't commit to a + // specific locktime, we can try to merge a locktime cluster with a fee + // cluster. + return zipClusters(lockTimeClusters, feeClusters) +} + +// clusterByLockTime takes the given set of pending inputs and clusters those +// with equal locktime together. Each cluster contains a sweep fee rate, which +// is determined by calculating the average fee rate of all inputs within that +// cluster. In addition to the created clusters, inputs that did not specify a +// required lock time are returned. +func (s *UtxoSweeper) clusterByLockTime(inputs pendingInputs) ([]inputCluster, + pendingInputs) { + + locktimes := make(map[uint32]pendingInputs) + inputFeeRates := make(map[wire.OutPoint]chainfee.SatPerKWeight) + rem := make(pendingInputs) + + // Go through all inputs and check if they require a certain locktime. + for op, input := range inputs { + lt, ok := input.RequiredLockTime() + if !ok { + rem[op] = input + continue + } + + // Check if we already have inputs with this locktime. + p, ok := locktimes[lt] + if !ok { + p = make(pendingInputs) + } + + p[op] = input + locktimes[lt] = p + + // We also get the preferred fee rate for this input. + feeRate, err := s.feeRateForPreference(input.params.Fee) + if err != nil { + log.Warnf("Skipping input %v: %v", op, err) + continue + } + + input.lastFeeRate = feeRate + inputFeeRates[op] = feeRate + } + + // We'll then determine the sweep fee rate for each set of inputs by + // calculating the average fee rate of the inputs within each set. + inputClusters := make([]inputCluster, 0, len(locktimes)) + for lt, inputs := range locktimes { + lt := lt + + var sweepFeeRate chainfee.SatPerKWeight + for op := range inputs { + sweepFeeRate += inputFeeRates[op] + } + + sweepFeeRate /= chainfee.SatPerKWeight(len(inputs)) + inputClusters = append(inputClusters, inputCluster{ + lockTime: <, + sweepFeeRate: sweepFeeRate, + inputs: inputs, + }) + } + + return inputClusters, rem } // clusterBySweepFeeRate takes the set of pending inputs within the UtxoSweeper diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index d8c62001..77e9d2b5 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -90,7 +90,7 @@ func createTestInput(value int64, witnessType input.WitnessType) input.BaseInput func init() { // Create a set of test spendable inputs. - for i := 0; i < 5; i++ { + for i := 0; i < 20; i++ { input := createTestInput(int64(10000+i*500), input.CommitmentTimeLock) @@ -1596,3 +1596,128 @@ func TestZipClusters(t *testing.T) { } } } + +type testInput struct { + *input.BaseInput + + locktime *uint32 +} + +func (i *testInput) RequiredLockTime() (uint32, bool) { + if i.locktime != nil { + return *i.locktime, true + } + + return 0, false +} + +// TestLockTimes checks that the sweeper properly groups inputs requiring the +// same locktime together into sweep transactions. +func TestLockTimes(t *testing.T) { + ctx := createSweeperTestContext(t) + + // We increase the number of max inputs to a tx so that won't + // impact our test. + ctx.sweeper.cfg.MaxInputsPerTx = 100 + + // We will set up the lock times in such a way that we expect the + // sweeper to divide the inputs into 4 diffeerent transactions. + const numSweeps = 4 + + // Sweep 8 inputs, using 4 different lock times. + var ( + results []chan Result + inputs = make(map[wire.OutPoint]input.Input) + ) + for i := 0; i < numSweeps*2; i++ { + lt := uint32(10 + (i % numSweeps)) + inp := &testInput{ + BaseInput: spendableInputs[i], + locktime: <, + } + + result, err := ctx.sweeper.SweepInput( + inp, Params{ + Fee: FeePreference{ConfTarget: 6}, + }, + ) + if err != nil { + t.Fatal(err) + } + results = append(results, result) + + op := inp.OutPoint() + inputs[*op] = inp + } + + // We also add 3 regular inputs that don't require any specific lock + // time. + for i := 0; i < 3; i++ { + inp := spendableInputs[i+numSweeps*2] + result, err := ctx.sweeper.SweepInput( + inp, Params{ + Fee: FeePreference{ConfTarget: 6}, + }, + ) + if err != nil { + t.Fatal(err) + } + + results = append(results, result) + + op := inp.OutPoint() + inputs[*op] = inp + } + + // We expect all inputs to be published in separate transactions, even + // though they share the same fee preference. + ctx.tick() + + // Check the sweeps transactions, ensuring all inputs are there, and + // all the locktimes are satisfied. + for i := 0; i < numSweeps; i++ { + sweepTx := ctx.receiveTx() + if len(sweepTx.TxOut) != 1 { + t.Fatal("expected a single tx out in the sweep tx") + } + + for _, txIn := range sweepTx.TxIn { + op := txIn.PreviousOutPoint + inp, ok := inputs[op] + if !ok { + t.Fatalf("Unexpected outpoint: %v", op) + } + + delete(inputs, op) + + // If this input had a required locktime, ensure the tx + // has that set correctly. + lt, ok := inp.RequiredLockTime() + if !ok { + continue + } + + if lt != sweepTx.LockTime { + t.Fatalf("Input required locktime %v, sweep "+ + "tx had locktime %v", lt, sweepTx.LockTime) + } + + } + } + + // The should be no inputs not foud in any of the sweeps. + if len(inputs) != 0 { + t.Fatalf("had unsweeped inputs") + } + + // Mine the first sweeps + ctx.backend.mine() + + // Results should all come back. + for i := range results { + result := <-results[i] + if result.Err != nil { + t.Fatal("expected input to be swept") + } + } +} diff --git a/sweep/txgenerator.go b/sweep/txgenerator.go index f023c317..d0c71c49 100644 --- a/sweep/txgenerator.go +++ b/sweep/txgenerator.go @@ -138,33 +138,51 @@ func createSweepTx(inputs []input.Input, outputPkScript []byte, txFee := estimator.fee() - // Sum up the total value contained in the inputs. + // Create the sweep transaction that we will be building. We use + // version 2 as it is required for CSV. + sweepTx := wire.NewMsgTx(2) + + // Track whether any of the inputs require a certain locktime. + locktime := int32(-1) + + // Sum up the total value contained in the inputs, and add all inputs + // to the sweep transaction. Ensure that for each csvInput, we set the + // sequence number properly. var totalSum btcutil.Amount for _, o := range inputs { + sweepTx.AddTxIn(&wire.TxIn{ + PreviousOutPoint: *o.OutPoint(), + Sequence: o.BlocksToMaturity(), + }) + + if lt, ok := o.RequiredLockTime(); ok { + // If another input commits to a different locktime, + // they cannot be combined in the same transcation. + if locktime != -1 && locktime != int32(lt) { + return nil, fmt.Errorf("incompatible locktime") + } + + locktime = int32(lt) + } + totalSum += btcutil.Amount(o.SignDesc().Output.Value) } // Sweep as much possible, after subtracting txn fees. sweepAmt := int64(totalSum - txFee) - // Create the sweep transaction that we will be building. We use - // version 2 as it is required for CSV. The txn will sweep the amount - // after fees to the pkscript generated above. - sweepTx := wire.NewMsgTx(2) + // The txn will sweep the amount after fees to the pkscript generated + // above. sweepTx.AddTxOut(&wire.TxOut{ PkScript: outputPkScript, Value: sweepAmt, }) + // We'll default to using the current block height as locktime, if none + // of the inputs commits to a different locktime. sweepTx.LockTime = currentBlockHeight - - // Add all inputs to the sweep transaction. Ensure that for each - // csvInput, we set the sequence number properly. - for _, input := range inputs { - sweepTx.AddTxIn(&wire.TxIn{ - PreviousOutPoint: *input.OutPoint(), - Sequence: input.BlocksToMaturity(), - }) + if locktime != -1 { + sweepTx.LockTime = uint32(locktime) } // Before signing the transaction, check to ensure that it meets some