diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index deea56f4..54d622d5 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/davecgh/go-spew/spew" @@ -1617,6 +1618,60 @@ func (i *testInput) RequiredTxOut() *wire.TxOut { return i.reqTxOut } +// CraftInputScript is a custom sign method for the testInput type that will +// encode the spending outpoint and the tx input index as part of the returned +// witness. +func (i *testInput) CraftInputScript(_ input.Signer, txn *wire.MsgTx, + hashCache *txscript.TxSigHashes, txinIdx int) (*input.Script, error) { + + // We'll encode the outpoint in the witness, so we can assert that the + // expected input was signed at the correct index. + op := i.OutPoint() + return &input.Script{ + Witness: [][]byte{ + // We encode the hash of the outpoint... + op.Hash[:], + // ..the outpoint index... + {byte(op.Index)}, + // ..and finally the tx input index. + {byte(txinIdx)}, + }, + }, nil +} + +// assertSignedIndex goes through all inputs to the tx and checks that all +// testInputs have witnesses corresponding to the outpoints they are spending, +// and are signed at the correct tx input index. All found testInputs are +// returned such that we can sum up and sanity check that all testInputs were +// part of the sweep. +func assertSignedIndex(t *testing.T, tx *wire.MsgTx, + testInputs map[wire.OutPoint]*testInput) map[wire.OutPoint]struct{} { + + found := make(map[wire.OutPoint]struct{}) + for idx, txIn := range tx.TxIn { + op := txIn.PreviousOutPoint + + // Not a testInput, it won't have the test encoding we require + // to check outpoint and index. + if _, ok := testInputs[op]; !ok { + continue + } + + if _, ok := found[op]; ok { + t.Fatalf("input already used") + } + + // Check it was signes spending the correct outpoint, and at + // the expected tx input index. + require.Equal(t, txIn.Witness[0], op.Hash[:]) + require.Equal(t, txIn.Witness[1], []byte{byte(op.Index)}) + require.Equal(t, txIn.Witness[2], []byte{byte(idx)}) + found[op] = struct{}{} + } + + return found +} + // TestLockTimes checks that the sweeper properly groups inputs requiring the // same locktime together into sweep transactions. func TestLockTimes(t *testing.T) { @@ -1824,6 +1879,66 @@ func TestRequiredTxOuts(t *testing.T) { ) }, }, + { + // Two inputs, where the first one required no tx out. + name: "two inputs, one with required tx out", + inputs: []*testInput{ + { + + // We add a normal, non-requiredTxOut + // input. We use test input 10, to make + // sure this has a higher yield than + // the other input, and will be + // attempted added first to the sweep + // tx. + BaseInput: inputs[10], + }, + { + // The second input requires a TxOut. + BaseInput: inputs[0], + reqTxOut: &wire.TxOut{ + PkScript: []byte("aaa"), + Value: inputs[0].SignDesc().Output.Value, + }, + }, + }, + + // We expect the inputs to have been reordered. + assertSweeps: func(t *testing.T, + _ map[wire.OutPoint]*testInput, + txs []*wire.MsgTx) { + + require.Equal(t, 1, len(txs)) + + tx := txs[0] + require.Equal(t, 2, len(tx.TxIn)) + require.Equal(t, 2, len(tx.TxOut)) + + // The required TxOut should be the first one. + out := tx.TxOut[0] + require.Equal(t, []byte("aaa"), out.PkScript) + require.Equal( + t, inputs[0].SignDesc().Output.Value, + out.Value, + ) + + // The first input should be the one having the + // required TxOut. + require.Len(t, tx.TxIn, 2) + require.Equal( + t, inputs[0].OutPoint(), + &tx.TxIn[0].PreviousOutPoint, + ) + + // Second one is the one without a required tx + // out. + require.Equal( + t, inputs[10].OutPoint(), + &tx.TxIn[1].PreviousOutPoint, + ) + }, + }, + { // An input committing to an output of equal value, just // add input to pay fees. @@ -2076,6 +2191,30 @@ func TestRequiredTxOuts(t *testing.T) { // Assert the transactions are what we expect. testCase.assertSweeps(t, inputs, sweeps) + + // Finally we assert that all our test inputs were part + // of the sweeps, and that they were signed correctly. + sweptInputs := make(map[wire.OutPoint]struct{}) + for _, sweep := range sweeps { + swept := assertSignedIndex(t, sweep, inputs) + for op := range swept { + if _, ok := sweptInputs[op]; ok { + t.Fatalf("outpoint %v part of "+ + "previous sweep", op) + } + + sweptInputs[op] = struct{}{} + } + } + + require.Equal(t, len(inputs), len(sweptInputs)) + for op := range sweptInputs { + _, ok := inputs[op] + if !ok { + t.Fatalf("swept input %v not part of "+ + "test inputs", op) + } + } }) } }