diff --git a/lnwallet/chanfunding/psbt_assembler.go b/lnwallet/chanfunding/psbt_assembler.go index 8b53d0fa..a40bae3c 100644 --- a/lnwallet/chanfunding/psbt_assembler.go +++ b/lnwallet/chanfunding/psbt_assembler.go @@ -1,7 +1,6 @@ package chanfunding import ( - "bytes" "crypto/sha256" "errors" "fmt" @@ -226,7 +225,7 @@ func (i *PsbtIntent) Verify(packet *psbt.Packet) error { outputSum := int64(0) for _, out := range packet.UnsignedTx.TxOut { outputSum += out.Value - if txOutsEqual(out, expectedOutput) { + if psbt.TxOutsEqual(out, expectedOutput) { outputFound = true } } @@ -241,7 +240,7 @@ func (i *PsbtIntent) Verify(packet *psbt.Packet) error { if len(packet.UnsignedTx.TxIn) == 0 { return fmt.Errorf("PSBT has no inputs") } - sum, err := sumUtxoInputValues(packet) + sum, err := psbt.SumUtxoInputValues(packet) if err != nil { return fmt.Errorf("error determining input sum: %v", err) } @@ -305,11 +304,13 @@ func (i *PsbtIntent) FinalizeRawTX(rawTx *wire.MsgTx) error { if i.PendingPsbt == nil { return fmt.Errorf("PSBT was not verified first") } - err := verifyOutputsEqual(rawTx.TxOut, i.PendingPsbt.UnsignedTx.TxOut) + err := psbt.VerifyOutputsEqual( + rawTx.TxOut, i.PendingPsbt.UnsignedTx.TxOut, + ) if err != nil { return fmt.Errorf("outputs differ from verified PSBT: %v", err) } - err = verifyInputPrevOutpointsEqual( + err = psbt.VerifyInputPrevOutpointsEqual( rawTx.TxIn, i.PendingPsbt.UnsignedTx.TxIn, ) if err != nil { @@ -472,82 +473,6 @@ func (p *PsbtAssembler) ShouldPublishFundingTx() bool { // ConditionalPublishAssembler interface. var _ ConditionalPublishAssembler = (*PsbtAssembler)(nil) -// sumUtxoInputValues tries to extract the sum of all inputs specified in the -// UTXO fields of the PSBT. An error is returned if an input is specified that -// does not contain any UTXO information. -func sumUtxoInputValues(packet *psbt.Packet) (int64, error) { - // We take the TX ins of the unsigned TX as the truth for how many - // inputs there should be, as the fields in the extra data part of the - // PSBT can be empty. - if len(packet.UnsignedTx.TxIn) != len(packet.Inputs) { - return 0, fmt.Errorf("TX input length doesn't match PSBT " + - "input length") - } - inputSum := int64(0) - for idx, in := range packet.Inputs { - switch { - case in.WitnessUtxo != nil: - // Witness UTXOs only need to reference the TxOut. - inputSum += in.WitnessUtxo.Value - - case in.NonWitnessUtxo != nil: - // Non-witness UTXOs reference to the whole transaction - // the UTXO resides in. - utxOuts := in.NonWitnessUtxo.TxOut - txIn := packet.UnsignedTx.TxIn[idx] - inputSum += utxOuts[txIn.PreviousOutPoint.Index].Value - - default: - return 0, fmt.Errorf("input %d has no UTXO information", - idx) - } - } - return inputSum, nil -} - -// txOutsEqual returns true if two transaction outputs are equal. -func txOutsEqual(out1, out2 *wire.TxOut) bool { - if out1 == nil || out2 == nil { - return out1 == out2 - } - return out1.Value == out2.Value && - bytes.Equal(out1.PkScript, out2.PkScript) -} - -// verifyOutputsEqual verifies that the two slices of transaction outputs are -// deep equal to each other. We do the length check and manual loop to provide -// better error messages to the user than just returning "not equal". -func verifyOutputsEqual(outs1, outs2 []*wire.TxOut) error { - if len(outs1) != len(outs2) { - return fmt.Errorf("number of outputs are different") - } - for idx, out := range outs1 { - // There is a byte slice in the output so we can't use the - // equality operator. - if !txOutsEqual(out, outs2[idx]) { - return fmt.Errorf("output %d is different", idx) - } - } - return nil -} - -// verifyInputPrevOutpointsEqual verifies that the previous outpoints of the -// two slices of transaction inputs are deep equal to each other. We do the -// length check and manual loop to provide better error messages to the user -// than just returning "not equal". -func verifyInputPrevOutpointsEqual(ins1, ins2 []*wire.TxIn) error { - if len(ins1) != len(ins2) { - return fmt.Errorf("number of inputs are different") - } - for idx, in := range ins1 { - if in.PreviousOutPoint != ins2[idx].PreviousOutPoint { - return fmt.Errorf("previous outpoint of input %d is "+ - "different", idx) - } - } - return nil -} - // verifyInputsSigned verifies that the given list of inputs is non-empty and // that all the inputs either contain a script signature or a witness stack. func verifyInputsSigned(ins []*wire.TxIn) error {