diff --git a/input/script_utils_test.go b/input/script_utils_test.go index 1c6c5bc9..b716d4db 100644 --- a/input/script_utils_test.go +++ b/input/script_utils_test.go @@ -15,6 +15,73 @@ import ( "github.com/lightningnetwork/lnd/keychain" ) +// assertEngineExecution executes the VM returned by the newEngine closure, +// asserting the result matches the validity expectation. In the case where it +// doesn't match the expectation, it executes the script step-by-step and +// prints debug information to stdout. +func assertEngineExecution(t *testing.T, testNum int, valid bool, + newEngine func() (*txscript.Engine, error)) { + t.Helper() + + // Get a new VM to execute. + vm, err := newEngine() + if err != nil { + t.Fatalf("unable to create engine: %v", err) + } + + // Execute the VM, only go on to the step-by-step execution if + // it doesn't validate as expected. + vmErr := vm.Execute() + if valid == (vmErr == nil) { + return + } + + // Now that the execution didn't match what we expected, fetch a new VM + // to step through. + vm, err = newEngine() + if err != nil { + t.Fatalf("unable to create engine: %v", err) + } + + // This buffer will trace execution of the Script, dumping out + // to stdout. + var debugBuf bytes.Buffer + + done := false + for !done { + dis, err := vm.DisasmPC() + if err != nil { + t.Fatalf("stepping (%v)\n", err) + } + debugBuf.WriteString(fmt.Sprintf("stepping %v\n", dis)) + + done, err = vm.Step() + if err != nil && valid { + fmt.Println(debugBuf.String()) + t.Fatalf("spend test case #%v failed, spend "+ + "should be valid: %v", testNum, err) + } else if err == nil && !valid && done { + fmt.Println(debugBuf.String()) + t.Fatalf("spend test case #%v succeed, spend "+ + "should be invalid: %v", testNum, err) + } + + debugBuf.WriteString(fmt.Sprintf("Stack: %v", vm.GetStack())) + debugBuf.WriteString(fmt.Sprintf("AltStack: %v", vm.GetAltStack())) + } + + // If we get to this point the unexpected case was not reached + // during step execution, which happens for some checks, like + // the clean-stack rule. + validity := "invalid" + if valid { + validity = "valid" + } + + fmt.Println(debugBuf.String()) + t.Fatalf("%v spend test case #%v execution ended with: %v", validity, testNum, vmErr) +} + // TestRevocationKeyDerivation tests that given a public key, and a revocation // hash, the homomorphic revocation public and private key derivation work // properly. @@ -308,39 +375,13 @@ func TestHTLCSenderSpendValidation(t *testing.T) { for i, testCase := range testCases { sweepTx.TxIn[0].Witness = testCase.witness() - vm, err := txscript.NewEngine(htlcPkScript, - sweepTx, 0, txscript.StandardVerifyFlags, nil, - nil, int64(paymentAmt)) - if err != nil { - t.Fatalf("unable to create engine: %v", err) + newEngine := func() (*txscript.Engine, error) { + return txscript.NewEngine(htlcPkScript, + sweepTx, 0, txscript.StandardVerifyFlags, nil, + nil, int64(paymentAmt)) } - // This buffer will trace execution of the Script, only dumping - // out to stdout in the case that a test fails. - var debugBuf bytes.Buffer - - done := false - for !done { - dis, err := vm.DisasmPC() - if err != nil { - t.Fatalf("stepping (%v)\n", err) - } - debugBuf.WriteString(fmt.Sprintf("stepping %v\n", dis)) - - done, err = vm.Step() - if err != nil && testCase.valid { - fmt.Println(debugBuf.String()) - t.Fatalf("spend test case #%v failed, spend "+ - "should be valid: %v", i, err) - } else if err == nil && !testCase.valid && done { - fmt.Println(debugBuf.String()) - t.Fatalf("spend test case #%v succeed, spend "+ - "should be invalid: %v", i, err) - } - - debugBuf.WriteString(fmt.Sprintf("Stack: %v", vm.GetStack())) - debugBuf.WriteString(fmt.Sprintf("AltStack: %v", vm.GetAltStack())) - } + assertEngineExecution(t, i, testCase.valid, newEngine) } } @@ -581,37 +622,13 @@ func TestHTLCReceiverSpendValidation(t *testing.T) { for i, testCase := range testCases { sweepTx.TxIn[0].Witness = testCase.witness() - vm, err := txscript.NewEngine(htlcPkScript, - sweepTx, 0, txscript.StandardVerifyFlags, nil, - nil, int64(paymentAmt)) - if err != nil { - t.Fatalf("unable to create engine: %v", err) + newEngine := func() (*txscript.Engine, error) { + return txscript.NewEngine(htlcPkScript, + sweepTx, 0, txscript.StandardVerifyFlags, nil, + nil, int64(paymentAmt)) } - // This buffer will trace execution of the Script, only dumping - // out to stdout in the case that a test fails. - var debugBuf bytes.Buffer - - done := false - for !done { - dis, err := vm.DisasmPC() - if err != nil { - t.Fatalf("stepping (%v)\n", err) - } - debugBuf.WriteString(fmt.Sprintf("stepping %v\n", dis)) - - done, err = vm.Step() - if err != nil && testCase.valid { - fmt.Println(debugBuf.String()) - t.Fatalf("spend test case #%v failed, spend should be valid: %v", i, err) - } else if err == nil && !testCase.valid && done { - fmt.Println(debugBuf.String()) - t.Fatalf("spend test case #%v succeed, spend should be invalid: %v", i, err) - } - - debugBuf.WriteString(fmt.Sprintf("Stack: %v", vm.GetStack())) - debugBuf.WriteString(fmt.Sprintf("AltStack: %v", vm.GetAltStack())) - } + assertEngineExecution(t, i, testCase.valid, newEngine) } } @@ -811,39 +828,13 @@ func TestSecondLevelHtlcSpends(t *testing.T) { for i, testCase := range testCases { sweepTx.TxIn[0].Witness = testCase.witness() - vm, err := txscript.NewEngine(htlcPkScript, - sweepTx, 0, txscript.StandardVerifyFlags, nil, - nil, int64(htlcAmt)) - if err != nil { - t.Fatalf("unable to create engine: %v", err) + newEngine := func() (*txscript.Engine, error) { + return txscript.NewEngine(htlcPkScript, + sweepTx, 0, txscript.StandardVerifyFlags, nil, + nil, int64(htlcAmt)) } - // This buffer will trace execution of the Script, only dumping - // out to stdout in the case that a test fails. - var debugBuf bytes.Buffer - - done := false - for !done { - dis, err := vm.DisasmPC() - if err != nil { - t.Fatalf("stepping (%v)\n", err) - } - debugBuf.WriteString(fmt.Sprintf("stepping %v\n", dis)) - - done, err = vm.Step() - if err != nil && testCase.valid { - fmt.Println(debugBuf.String()) - t.Fatalf("spend test case #%v failed, spend "+ - "should be valid: %v", i, err) - } else if err == nil && !testCase.valid && done { - fmt.Println(debugBuf.String()) - t.Fatalf("spend test case #%v succeed, spend "+ - "should be invalid: %v", i, err) - } - - debugBuf.WriteString(fmt.Sprintf("Stack: %v", vm.GetStack())) - debugBuf.WriteString(fmt.Sprintf("AltStack: %v", vm.GetAltStack())) - } + assertEngineExecution(t, i, testCase.valid, newEngine) } }