diff --git a/lntest/itest/lnd_forward_interceptor_test.go b/lntest/itest/lnd_forward_interceptor_test.go index 5f9b5d45..ba18eb2f 100644 --- a/lntest/itest/lnd_forward_interceptor_test.go +++ b/lntest/itest/lnd_forward_interceptor_test.go @@ -44,9 +44,20 @@ type interceptorTestCase struct { // 4. When Interceptor disconnects it resumes all held htlcs, which result in // valid payment (invoice is settled). func testForwardInterceptor(net *lntest.NetworkHarness, t *harnessTest) { - // initialize the test context with 3 connected nodes. - testContext := newInterceptorTestContext(t, net) - defer testContext.shutdownNodes() + // Initialize the test context with 3 connected nodes. + alice, err := net.NewNode("alice", nil) + require.NoError(t.t, err, "unable to create alice") + defer shutdownAndAssert(net, t, alice) + + bob, err := net.NewNode("bob", nil) + require.NoError(t.t, err, "unable to create bob") + defer shutdownAndAssert(net, t, alice) + + carol, err := net.NewNode("carol", nil) + require.NoError(t.t, err, "unable to create carol") + defer shutdownAndAssert(net, t, alice) + + testContext := newInterceptorTestContext(t, net, alice, bob, carol) const ( chanAmt = btcutil.Amount(300000) @@ -62,15 +73,10 @@ func testForwardInterceptor(net *lntest.NetworkHarness, t *harnessTest) { ctx := context.Background() ctxt, cancelInterceptor := context.WithTimeout(ctx, defaultTimeout) interceptor, err := testContext.bob.RouterClient.HtlcInterceptor(ctxt) - if err != nil { - t.Fatalf("failed to create HtlcInterceptor %v", err) - } + require.NoError(t.t, err, "failed to create HtlcInterceptor") // Prepare the test cases. - testCases, err := testContext.prepareTestCases() - if err != nil { - t.Fatalf("failed to prepare test cases") - } + testCases := testContext.prepareTestCases() // A channel for the interceptor go routine to send the requested packets. interceptedChan := make(chan *routerrpc.ForwardHtlcInterceptRequest, @@ -91,7 +97,7 @@ func testForwardInterceptor(net *lntest.NetworkHarness, t *harnessTest) { return } // Otherwise it an unexpected error, we fail the test. - t.t.Errorf("unexpected error in interceptor.Recv() %v", err) + require.NoError(t.t, err, "unexpected error in interceptor.Recv()") return } interceptedChan <- request @@ -114,26 +120,22 @@ func testForwardInterceptor(net *lntest.NetworkHarness, t *harnessTest) { return } if err != nil { - t.t.Errorf("failed to send payment %v", err) + require.NoError(t.t, err, "failed to send payment") } switch tc.interceptorAction { // For 'fail' interceptor action we make sure the payment failed. case routerrpc.ResolveHoldForwardAction_FAIL: - if attempt.Status != lnrpc.HTLCAttempt_FAILED { - t.t.Errorf("expected payment to fail, "+ - "instead got %v", attempt.Status) - } + require.Equal(t.t, lnrpc.HTLCAttempt_FAILED, + attempt.Status, "expected payment to fail") // For settle and resume we make sure the payment is successful. case routerrpc.ResolveHoldForwardAction_SETTLE: fallthrough case routerrpc.ResolveHoldForwardAction_RESUME: - if attempt.Status != lnrpc.HTLCAttempt_SUCCEEDED { - t.t.Errorf("expected payment to "+ - "succeed, instead got %v", attempt.Status) - } + require.Equal(t.t, lnrpc.HTLCAttempt_SUCCEEDED, + attempt.Status, "expected payment to succeed") } } }() @@ -185,9 +187,8 @@ func testForwardInterceptor(net *lntest.NetworkHarness, t *harnessTest) { // Alice's node. payments, err := testContext.alice.ListPayments(context.Background(), &lnrpc.ListPaymentsRequest{IncludeIncomplete: true}) - if err != nil { - t.Fatalf("failed to fetch payments") - } + require.NoError(t.t, err, "failed to fetch payment") + for _, testCase := range testCases { if testCase.shouldHold { hashStr := hex.EncodeToString(testCase.invoice.RHash) @@ -199,18 +200,14 @@ func testForwardInterceptor(net *lntest.NetworkHarness, t *harnessTest) { break } } - if foundPayment == nil { - t.Fatalf("expected to find pending payment for held"+ - "htlc %v", hashStr) - } - if foundPayment.ValueMsat != expectedAmt || - foundPayment.Status != lnrpc.Payment_IN_FLIGHT { - - t.Fatalf("expected to find in flight payment for"+ - "amount %v, %v", - testCase.invoice.ValueMsat, - foundPayment.Status) - } + require.NotNil(t.t, foundPayment, fmt.Sprintf("expected "+ + "to find pending payment for held htlc %v", + hashStr)) + require.Equal(t.t, lnrpc.Payment_IN_FLIGHT, + foundPayment.Status, "expected payment to be "+ + "in flight") + require.Equal(t.t, expectedAmt, foundPayment.ValueMsat, + "incorrect in flight amount") } } @@ -236,32 +233,26 @@ type interceptorTestContext struct { } func newInterceptorTestContext(t *harnessTest, - net *lntest.NetworkHarness) *interceptorTestContext { + net *lntest.NetworkHarness, + alice, bob, carol *lntest.HarnessNode) *interceptorTestContext { ctxb := context.Background() - // Create a three-node context consisting of Alice, Bob and Carol - carol, err := net.NewNode("carol", nil) - if err != nil { - t.Fatalf("unable to create carol: %v", err) - } - // Connect nodes - nodes := []*lntest.HarnessNode{net.Alice, net.Bob, carol} + nodes := []*lntest.HarnessNode{alice, bob, carol} for i := 0; i < len(nodes); i++ { for j := i + 1; j < len(nodes); j++ { ctxt, _ := context.WithTimeout(ctxb, defaultTimeout) - if err := net.EnsureConnected(ctxt, nodes[i], nodes[j]); err != nil { - t.Fatalf("unable to connect nodes: %v", err) - } + err := net.EnsureConnected(ctxt, nodes[i], nodes[j]) + require.NoError(t.t, err, "unable to connect nodes") } } ctx := interceptorTestContext{ t: t, net: net, - alice: net.Alice, - bob: net.Bob, + alice: alice, + bob: bob, carol: carol, nodes: nodes, } @@ -274,9 +265,7 @@ func newInterceptorTestContext(t *harnessTest, // 2. resumed htlc. // 3. settling htlc externally. // 4. held htlc that is resumed later. -func (c *interceptorTestContext) prepareTestCases() ( - []*interceptorTestCase, error) { - +func (c *interceptorTestContext) prepareTestCases() []*interceptorTestCase { cases := []*interceptorTestCase{ {amountMsat: 1000, shouldHold: false, interceptorAction: routerrpc.ResolveHoldForwardAction_FAIL}, @@ -292,15 +281,12 @@ func (c *interceptorTestContext) prepareTestCases() ( addResponse, err := c.carol.AddInvoice(context.Background(), &lnrpc.Invoice{ ValueMsat: t.amountMsat, }) - if err != nil { - return nil, fmt.Errorf("unable to add invoice: %v", err) - } + require.NoError(c.t.t, err, "unable to add invoice") + invoice, err := c.carol.LookupInvoice(context.Background(), &lnrpc.PaymentHash{ RHashStr: hex.EncodeToString(addResponse.RHash), }) - if err != nil { - return nil, fmt.Errorf("unable to add invoice: %v", err) - } + require.NoError(c.t.t, err, "unable to find invoice") // We'll need to also decode the returned invoice so we can // grab the payment address which is now required for ALL @@ -308,13 +294,12 @@ func (c *interceptorTestContext) prepareTestCases() ( payReq, err := c.carol.DecodePayReq(context.Background(), &lnrpc.PayReqString{ PayReq: invoice.PaymentRequest, }) - if err != nil { - return nil, fmt.Errorf("unable to decode invoice: %v", err) - } + require.NoError(c.t.t, err, "unable to decode invoice") + t.invoice = invoice t.payAddr = payReq.PaymentAddr } - return cases, nil + return cases } func (c *interceptorTestContext) openChannel(from, to *lntest.HarnessNode, @@ -324,9 +309,7 @@ func (c *interceptorTestContext) openChannel(from, to *lntest.HarnessNode, ctxt, _ := context.WithTimeout(ctxb, defaultTimeout) err := c.net.SendCoins(ctxt, btcutil.SatoshiPerBitcoin, from) - if err != nil { - c.t.Fatalf("unable to send coins : %v", err) - } + require.NoError(c.t.t, err, "unable to send coins") ctxt, _ = context.WithTimeout(ctxb, channelOpenTimeout) chanPoint := openChannelAndAssert( @@ -352,10 +335,6 @@ func (c *interceptorTestContext) closeChannels() { } } -func (c *interceptorTestContext) shutdownNodes() { - shutdownAndAssert(c.net, c.t, c.carol) -} - func (c *interceptorTestContext) waitForChannels() { ctxb := context.Background() @@ -363,9 +342,8 @@ func (c *interceptorTestContext) waitForChannels() { for _, chanPoint := range c.networkChans { for _, node := range c.nodes { txid, err := lnd.GetChanPointFundingTxid(chanPoint) - if err != nil { - c.t.Fatalf("unable to get txid: %v", err) - } + require.NoError(c.t.t, err, "unable to get txid") + point := wire.OutPoint{ Hash: *txid, Index: chanPoint.OutputIndex, @@ -373,11 +351,9 @@ func (c *interceptorTestContext) waitForChannels() { ctxt, _ := context.WithTimeout(ctxb, defaultTimeout) err = node.WaitForNetworkChannelOpen(ctxt, chanPoint) - if err != nil { - c.t.Fatalf("(%d): timeout waiting for "+ - "channel(%s) open: %v", - node.NodeID, point, err) - } + require.NoError(c.t.t, err, fmt.Sprintf("(%d): timeout "+ + "waiting for channel(%s) open", node.NodeID, + point)) } } }