sweep+input: add RequiredLockTime to inputs

This commit is contained in:
Johan T. Halseth 2020-11-06 19:35:01 +01:00
parent fa4fd02cf1
commit efd6bc9501
No known key found for this signature in database
GPG Key ID: 15BAADA29DA20D26
5 changed files with 251 additions and 17 deletions

@ -866,6 +866,12 @@ func (bo *breachedOutput) OutPoint() *wire.OutPoint {
return &bo.outpoint return &bo.outpoint
} }
// RequiredLockTime returns whether this input commits to a tx locktime that
// must be used in the transaction including it.
func (bo *breachedOutput) RequiredLockTime() (uint32, bool) {
return 0, false
}
// WitnessType returns the type of witness that must be generated to spend the // WitnessType returns the type of witness that must be generated to spend the
// breached output. // breached output.
func (bo *breachedOutput) WitnessType() input.WitnessType { func (bo *breachedOutput) WitnessType() input.WitnessType {

@ -15,6 +15,10 @@ type Input interface {
// construct the corresponding transaction input. // construct the corresponding transaction input.
OutPoint() *wire.OutPoint OutPoint() *wire.OutPoint
// RequiredLockTime returns whether this input commits to a tx locktime
// that must be used in the transaction including it.
RequiredLockTime() (uint32, bool)
// WitnessType returns an enum specifying the type of witness that must // WitnessType returns an enum specifying the type of witness that must
// be generated in order to spend this output. // be generated in order to spend this output.
WitnessType() WitnessType WitnessType() WitnessType
@ -75,6 +79,13 @@ func (i *inputKit) OutPoint() *wire.OutPoint {
return &i.outpoint return &i.outpoint
} }
// RequiredLockTime returns whether this input commits to a tx locktime that
// must be used in the transaction including it. This will be false for the
// base input type since we can re-sign for any lock time.
func (i *inputKit) RequiredLockTime() (uint32, bool) {
return 0, false
}
// WitnessType returns the type of witness that must be generated to spend the // WitnessType returns the type of witness that must be generated to spend the
// breached output. // breached output.
func (i *inputKit) WitnessType() WitnessType { func (i *inputKit) WitnessType() WitnessType {

@ -752,12 +752,86 @@ func (s *UtxoSweeper) bucketForFeeRate(
} }
// createInputClusters creates a list of input clusters from the set of pending // createInputClusters creates a list of input clusters from the set of pending
// inputs known by the UtxoSweeper. // inputs known by the UtxoSweeper. It clusters inputs by
// 1) Required tx locktime
// 2) Similar fee rates
func (s *UtxoSweeper) createInputClusters() []inputCluster { func (s *UtxoSweeper) createInputClusters() []inputCluster {
inputs := s.pendingInputs inputs := s.pendingInputs
feeClusters := s.clusterBySweepFeeRate(inputs) // We start by getting the inputs clusters by locktime. Since the
return feeClusters // inputs commit to the locktime, they can only be clustered together
// if the locktime is equal.
lockTimeClusters, nonLockTimeInputs := s.clusterByLockTime(inputs)
// Cluster the the remaining inputs by sweep fee rate.
feeClusters := s.clusterBySweepFeeRate(nonLockTimeInputs)
// Since the inputs that we clustered by fee rate don't commit to a
// specific locktime, we can try to merge a locktime cluster with a fee
// cluster.
return zipClusters(lockTimeClusters, feeClusters)
}
// clusterByLockTime takes the given set of pending inputs and clusters those
// with equal locktime together. Each cluster contains a sweep fee rate, which
// is determined by calculating the average fee rate of all inputs within that
// cluster. In addition to the created clusters, inputs that did not specify a
// required lock time are returned.
func (s *UtxoSweeper) clusterByLockTime(inputs pendingInputs) ([]inputCluster,
pendingInputs) {
locktimes := make(map[uint32]pendingInputs)
inputFeeRates := make(map[wire.OutPoint]chainfee.SatPerKWeight)
rem := make(pendingInputs)
// Go through all inputs and check if they require a certain locktime.
for op, input := range inputs {
lt, ok := input.RequiredLockTime()
if !ok {
rem[op] = input
continue
}
// Check if we already have inputs with this locktime.
p, ok := locktimes[lt]
if !ok {
p = make(pendingInputs)
}
p[op] = input
locktimes[lt] = p
// We also get the preferred fee rate for this input.
feeRate, err := s.feeRateForPreference(input.params.Fee)
if err != nil {
log.Warnf("Skipping input %v: %v", op, err)
continue
}
input.lastFeeRate = feeRate
inputFeeRates[op] = feeRate
}
// We'll then determine the sweep fee rate for each set of inputs by
// calculating the average fee rate of the inputs within each set.
inputClusters := make([]inputCluster, 0, len(locktimes))
for lt, inputs := range locktimes {
lt := lt
var sweepFeeRate chainfee.SatPerKWeight
for op := range inputs {
sweepFeeRate += inputFeeRates[op]
}
sweepFeeRate /= chainfee.SatPerKWeight(len(inputs))
inputClusters = append(inputClusters, inputCluster{
lockTime: &lt,
sweepFeeRate: sweepFeeRate,
inputs: inputs,
})
}
return inputClusters, rem
} }
// clusterBySweepFeeRate takes the set of pending inputs within the UtxoSweeper // clusterBySweepFeeRate takes the set of pending inputs within the UtxoSweeper

@ -90,7 +90,7 @@ func createTestInput(value int64, witnessType input.WitnessType) input.BaseInput
func init() { func init() {
// Create a set of test spendable inputs. // Create a set of test spendable inputs.
for i := 0; i < 5; i++ { for i := 0; i < 20; i++ {
input := createTestInput(int64(10000+i*500), input := createTestInput(int64(10000+i*500),
input.CommitmentTimeLock) input.CommitmentTimeLock)
@ -1596,3 +1596,128 @@ func TestZipClusters(t *testing.T) {
} }
} }
} }
type testInput struct {
*input.BaseInput
locktime *uint32
}
func (i *testInput) RequiredLockTime() (uint32, bool) {
if i.locktime != nil {
return *i.locktime, true
}
return 0, false
}
// TestLockTimes checks that the sweeper properly groups inputs requiring the
// same locktime together into sweep transactions.
func TestLockTimes(t *testing.T) {
ctx := createSweeperTestContext(t)
// We increase the number of max inputs to a tx so that won't
// impact our test.
ctx.sweeper.cfg.MaxInputsPerTx = 100
// We will set up the lock times in such a way that we expect the
// sweeper to divide the inputs into 4 diffeerent transactions.
const numSweeps = 4
// Sweep 8 inputs, using 4 different lock times.
var (
results []chan Result
inputs = make(map[wire.OutPoint]input.Input)
)
for i := 0; i < numSweeps*2; i++ {
lt := uint32(10 + (i % numSweeps))
inp := &testInput{
BaseInput: spendableInputs[i],
locktime: &lt,
}
result, err := ctx.sweeper.SweepInput(
inp, Params{
Fee: FeePreference{ConfTarget: 6},
},
)
if err != nil {
t.Fatal(err)
}
results = append(results, result)
op := inp.OutPoint()
inputs[*op] = inp
}
// We also add 3 regular inputs that don't require any specific lock
// time.
for i := 0; i < 3; i++ {
inp := spendableInputs[i+numSweeps*2]
result, err := ctx.sweeper.SweepInput(
inp, Params{
Fee: FeePreference{ConfTarget: 6},
},
)
if err != nil {
t.Fatal(err)
}
results = append(results, result)
op := inp.OutPoint()
inputs[*op] = inp
}
// We expect all inputs to be published in separate transactions, even
// though they share the same fee preference.
ctx.tick()
// Check the sweeps transactions, ensuring all inputs are there, and
// all the locktimes are satisfied.
for i := 0; i < numSweeps; i++ {
sweepTx := ctx.receiveTx()
if len(sweepTx.TxOut) != 1 {
t.Fatal("expected a single tx out in the sweep tx")
}
for _, txIn := range sweepTx.TxIn {
op := txIn.PreviousOutPoint
inp, ok := inputs[op]
if !ok {
t.Fatalf("Unexpected outpoint: %v", op)
}
delete(inputs, op)
// If this input had a required locktime, ensure the tx
// has that set correctly.
lt, ok := inp.RequiredLockTime()
if !ok {
continue
}
if lt != sweepTx.LockTime {
t.Fatalf("Input required locktime %v, sweep "+
"tx had locktime %v", lt, sweepTx.LockTime)
}
}
}
// The should be no inputs not foud in any of the sweeps.
if len(inputs) != 0 {
t.Fatalf("had unsweeped inputs")
}
// Mine the first sweeps
ctx.backend.mine()
// Results should all come back.
for i := range results {
result := <-results[i]
if result.Err != nil {
t.Fatal("expected input to be swept")
}
}
}

@ -138,33 +138,51 @@ func createSweepTx(inputs []input.Input, outputPkScript []byte,
txFee := estimator.fee() txFee := estimator.fee()
// Sum up the total value contained in the inputs. // Create the sweep transaction that we will be building. We use
// version 2 as it is required for CSV.
sweepTx := wire.NewMsgTx(2)
// Track whether any of the inputs require a certain locktime.
locktime := int32(-1)
// Sum up the total value contained in the inputs, and add all inputs
// to the sweep transaction. Ensure that for each csvInput, we set the
// sequence number properly.
var totalSum btcutil.Amount var totalSum btcutil.Amount
for _, o := range inputs { for _, o := range inputs {
sweepTx.AddTxIn(&wire.TxIn{
PreviousOutPoint: *o.OutPoint(),
Sequence: o.BlocksToMaturity(),
})
if lt, ok := o.RequiredLockTime(); ok {
// If another input commits to a different locktime,
// they cannot be combined in the same transcation.
if locktime != -1 && locktime != int32(lt) {
return nil, fmt.Errorf("incompatible locktime")
}
locktime = int32(lt)
}
totalSum += btcutil.Amount(o.SignDesc().Output.Value) totalSum += btcutil.Amount(o.SignDesc().Output.Value)
} }
// Sweep as much possible, after subtracting txn fees. // Sweep as much possible, after subtracting txn fees.
sweepAmt := int64(totalSum - txFee) sweepAmt := int64(totalSum - txFee)
// Create the sweep transaction that we will be building. We use // The txn will sweep the amount after fees to the pkscript generated
// version 2 as it is required for CSV. The txn will sweep the amount // above.
// after fees to the pkscript generated above.
sweepTx := wire.NewMsgTx(2)
sweepTx.AddTxOut(&wire.TxOut{ sweepTx.AddTxOut(&wire.TxOut{
PkScript: outputPkScript, PkScript: outputPkScript,
Value: sweepAmt, Value: sweepAmt,
}) })
// We'll default to using the current block height as locktime, if none
// of the inputs commits to a different locktime.
sweepTx.LockTime = currentBlockHeight sweepTx.LockTime = currentBlockHeight
if locktime != -1 {
// Add all inputs to the sweep transaction. Ensure that for each sweepTx.LockTime = uint32(locktime)
// csvInput, we set the sequence number properly.
for _, input := range inputs {
sweepTx.AddTxIn(&wire.TxIn{
PreviousOutPoint: *input.OutPoint(),
Sequence: input.BlocksToMaturity(),
})
} }
// Before signing the transaction, check to ensure that it meets some // Before signing the transaction, check to ensure that it meets some