Merge pull request #4838 from halseth/sweeper-input-script-ordering

sweep/txgenerator: fix input witness ordering
This commit is contained in:
Olaoluwa Osuntokun 2020-12-08 20:03:05 -08:00 committed by GitHub
commit c95c423703
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 159 additions and 14 deletions

@ -10,6 +10,7 @@ import (
"github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
@ -1617,6 +1618,60 @@ func (i *testInput) RequiredTxOut() *wire.TxOut {
return i.reqTxOut return i.reqTxOut
} }
// CraftInputScript is a custom sign method for the testInput type that will
// encode the spending outpoint and the tx input index as part of the returned
// witness.
func (i *testInput) CraftInputScript(_ input.Signer, txn *wire.MsgTx,
hashCache *txscript.TxSigHashes, txinIdx int) (*input.Script, error) {
// We'll encode the outpoint in the witness, so we can assert that the
// expected input was signed at the correct index.
op := i.OutPoint()
return &input.Script{
Witness: [][]byte{
// We encode the hash of the outpoint...
op.Hash[:],
// ..the outpoint index...
{byte(op.Index)},
// ..and finally the tx input index.
{byte(txinIdx)},
},
}, nil
}
// assertSignedIndex goes through all inputs to the tx and checks that all
// testInputs have witnesses corresponding to the outpoints they are spending,
// and are signed at the correct tx input index. All found testInputs are
// returned such that we can sum up and sanity check that all testInputs were
// part of the sweep.
func assertSignedIndex(t *testing.T, tx *wire.MsgTx,
testInputs map[wire.OutPoint]*testInput) map[wire.OutPoint]struct{} {
found := make(map[wire.OutPoint]struct{})
for idx, txIn := range tx.TxIn {
op := txIn.PreviousOutPoint
// Not a testInput, it won't have the test encoding we require
// to check outpoint and index.
if _, ok := testInputs[op]; !ok {
continue
}
if _, ok := found[op]; ok {
t.Fatalf("input already used")
}
// Check it was signes spending the correct outpoint, and at
// the expected tx input index.
require.Equal(t, txIn.Witness[0], op.Hash[:])
require.Equal(t, txIn.Witness[1], []byte{byte(op.Index)})
require.Equal(t, txIn.Witness[2], []byte{byte(idx)})
found[op] = struct{}{}
}
return found
}
// TestLockTimes checks that the sweeper properly groups inputs requiring the // TestLockTimes checks that the sweeper properly groups inputs requiring the
// same locktime together into sweep transactions. // same locktime together into sweep transactions.
func TestLockTimes(t *testing.T) { func TestLockTimes(t *testing.T) {
@ -1824,6 +1879,66 @@ func TestRequiredTxOuts(t *testing.T) {
) )
}, },
}, },
{
// Two inputs, where the first one required no tx out.
name: "two inputs, one with required tx out",
inputs: []*testInput{
{
// We add a normal, non-requiredTxOut
// input. We use test input 10, to make
// sure this has a higher yield than
// the other input, and will be
// attempted added first to the sweep
// tx.
BaseInput: inputs[10],
},
{
// The second input requires a TxOut.
BaseInput: inputs[0],
reqTxOut: &wire.TxOut{
PkScript: []byte("aaa"),
Value: inputs[0].SignDesc().Output.Value,
},
},
},
// We expect the inputs to have been reordered.
assertSweeps: func(t *testing.T,
_ map[wire.OutPoint]*testInput,
txs []*wire.MsgTx) {
require.Equal(t, 1, len(txs))
tx := txs[0]
require.Equal(t, 2, len(tx.TxIn))
require.Equal(t, 2, len(tx.TxOut))
// The required TxOut should be the first one.
out := tx.TxOut[0]
require.Equal(t, []byte("aaa"), out.PkScript)
require.Equal(
t, inputs[0].SignDesc().Output.Value,
out.Value,
)
// The first input should be the one having the
// required TxOut.
require.Len(t, tx.TxIn, 2)
require.Equal(
t, inputs[0].OutPoint(),
&tx.TxIn[0].PreviousOutPoint,
)
// Second one is the one without a required tx
// out.
require.Equal(
t, inputs[10].OutPoint(),
&tx.TxIn[1].PreviousOutPoint,
)
},
},
{ {
// An input committing to an output of equal value, just // An input committing to an output of equal value, just
// add input to pay fees. // add input to pay fees.
@ -2076,6 +2191,30 @@ func TestRequiredTxOuts(t *testing.T) {
// Assert the transactions are what we expect. // Assert the transactions are what we expect.
testCase.assertSweeps(t, inputs, sweeps) testCase.assertSweeps(t, inputs, sweeps)
// Finally we assert that all our test inputs were part
// of the sweeps, and that they were signed correctly.
sweptInputs := make(map[wire.OutPoint]struct{})
for _, sweep := range sweeps {
swept := assertSignedIndex(t, sweep, inputs)
for op := range swept {
if _, ok := sweptInputs[op]; ok {
t.Fatalf("outpoint %v part of "+
"previous sweep", op)
}
sweptInputs[op] = struct{}{}
}
}
require.Equal(t, len(inputs), len(sweptInputs))
for op := range sweptInputs {
_, ok := inputs[op]
if !ok {
t.Fatalf("swept input %v not part of "+
"test inputs", op)
}
}
}) })
} }
} }

@ -137,28 +137,35 @@ func createSweepTx(inputs []input.Input, outputPkScript []byte,
dustLimit btcutil.Amount, signer input.Signer) (*wire.MsgTx, error) { dustLimit btcutil.Amount, signer input.Signer) (*wire.MsgTx, error) {
inputs, estimator := getWeightEstimate(inputs, feePerKw) inputs, estimator := getWeightEstimate(inputs, feePerKw)
txFee := estimator.fee() txFee := estimator.fee()
// Create the sweep transaction that we will be building. We use var (
// version 2 as it is required for CSV. // Create the sweep transaction that we will be building. We
sweepTx := wire.NewMsgTx(2) // use version 2 as it is required for CSV.
sweepTx = wire.NewMsgTx(2)
// Track whether any of the inputs require a certain locktime. // Track whether any of the inputs require a certain locktime.
locktime := int32(-1) locktime = int32(-1)
// We keep track of total input amount, and required output
// amount to use for calculating the change amount below.
totalInput btcutil.Amount
requiredOutput btcutil.Amount
// We'll add the inputs as we go so we know the final ordering
// of inputs to sign.
idxs []input.Input
)
// We start by adding all inputs that commit to an output. We do this // We start by adding all inputs that commit to an output. We do this
// since the input and output index must stay the same for the // since the input and output index must stay the same for the
// signatures to be valid. // signatures to be valid.
var (
totalInput btcutil.Amount
requiredOutput btcutil.Amount
)
for _, o := range inputs { for _, o := range inputs {
if o.RequiredTxOut() == nil { if o.RequiredTxOut() == nil {
continue continue
} }
idxs = append(idxs, o)
sweepTx.AddTxIn(&wire.TxIn{ sweepTx.AddTxIn(&wire.TxIn{
PreviousOutPoint: *o.OutPoint(), PreviousOutPoint: *o.OutPoint(),
Sequence: o.BlocksToMaturity(), Sequence: o.BlocksToMaturity(),
@ -186,6 +193,7 @@ func createSweepTx(inputs []input.Input, outputPkScript []byte,
continue continue
} }
idxs = append(idxs, o)
sweepTx.AddTxIn(&wire.TxIn{ sweepTx.AddTxIn(&wire.TxIn{
PreviousOutPoint: *o.OutPoint(), PreviousOutPoint: *o.OutPoint(),
Sequence: o.BlocksToMaturity(), Sequence: o.BlocksToMaturity(),
@ -255,10 +263,8 @@ func createSweepTx(inputs []input.Input, outputPkScript []byte,
return nil return nil
} }
// Finally we'll attach a valid input script to each csv and cltv input for idx, inp := range idxs {
// within the sweeping transaction. if err := addInputScript(idx, inp); err != nil {
for i, input := range inputs {
if err := addInputScript(i, input); err != nil {
return nil, err return nil, err
} }
} }