Merge pull request #4823 from cfromknecht/fwd-interceptor-fixes

itest: fwd interceptor fixes
This commit is contained in:
Conner Fromknecht 2020-12-04 16:36:14 -08:00 committed by GitHub
commit 125dbbf0da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -44,9 +44,20 @@ type interceptorTestCase struct {
// 4. When Interceptor disconnects it resumes all held htlcs, which result in // 4. When Interceptor disconnects it resumes all held htlcs, which result in
// valid payment (invoice is settled). // valid payment (invoice is settled).
func testForwardInterceptor(net *lntest.NetworkHarness, t *harnessTest) { func testForwardInterceptor(net *lntest.NetworkHarness, t *harnessTest) {
// initialize the test context with 3 connected nodes. // Initialize the test context with 3 connected nodes.
testContext := newInterceptorTestContext(t, net) alice, err := net.NewNode("alice", nil)
defer testContext.shutdownNodes() require.NoError(t.t, err, "unable to create alice")
defer shutdownAndAssert(net, t, alice)
bob, err := net.NewNode("bob", nil)
require.NoError(t.t, err, "unable to create bob")
defer shutdownAndAssert(net, t, alice)
carol, err := net.NewNode("carol", nil)
require.NoError(t.t, err, "unable to create carol")
defer shutdownAndAssert(net, t, alice)
testContext := newInterceptorTestContext(t, net, alice, bob, carol)
const ( const (
chanAmt = btcutil.Amount(300000) chanAmt = btcutil.Amount(300000)
@ -62,15 +73,10 @@ func testForwardInterceptor(net *lntest.NetworkHarness, t *harnessTest) {
ctx := context.Background() ctx := context.Background()
ctxt, cancelInterceptor := context.WithTimeout(ctx, defaultTimeout) ctxt, cancelInterceptor := context.WithTimeout(ctx, defaultTimeout)
interceptor, err := testContext.bob.RouterClient.HtlcInterceptor(ctxt) interceptor, err := testContext.bob.RouterClient.HtlcInterceptor(ctxt)
if err != nil { require.NoError(t.t, err, "failed to create HtlcInterceptor")
t.Fatalf("failed to create HtlcInterceptor %v", err)
}
// Prepare the test cases. // Prepare the test cases.
testCases, err := testContext.prepareTestCases() testCases := testContext.prepareTestCases()
if err != nil {
t.Fatalf("failed to prepare test cases")
}
// A channel for the interceptor go routine to send the requested packets. // A channel for the interceptor go routine to send the requested packets.
interceptedChan := make(chan *routerrpc.ForwardHtlcInterceptRequest, interceptedChan := make(chan *routerrpc.ForwardHtlcInterceptRequest,
@ -91,7 +97,7 @@ func testForwardInterceptor(net *lntest.NetworkHarness, t *harnessTest) {
return return
} }
// Otherwise it an unexpected error, we fail the test. // Otherwise it an unexpected error, we fail the test.
t.t.Errorf("unexpected error in interceptor.Recv() %v", err) require.NoError(t.t, err, "unexpected error in interceptor.Recv()")
return return
} }
interceptedChan <- request interceptedChan <- request
@ -114,26 +120,22 @@ func testForwardInterceptor(net *lntest.NetworkHarness, t *harnessTest) {
return return
} }
if err != nil { if err != nil {
t.t.Errorf("failed to send payment %v", err) require.NoError(t.t, err, "failed to send payment")
} }
switch tc.interceptorAction { switch tc.interceptorAction {
// For 'fail' interceptor action we make sure the payment failed. // For 'fail' interceptor action we make sure the payment failed.
case routerrpc.ResolveHoldForwardAction_FAIL: case routerrpc.ResolveHoldForwardAction_FAIL:
if attempt.Status != lnrpc.HTLCAttempt_FAILED { require.Equal(t.t, lnrpc.HTLCAttempt_FAILED,
t.t.Errorf("expected payment to fail, "+ attempt.Status, "expected payment to fail")
"instead got %v", attempt.Status)
}
// For settle and resume we make sure the payment is successful. // For settle and resume we make sure the payment is successful.
case routerrpc.ResolveHoldForwardAction_SETTLE: case routerrpc.ResolveHoldForwardAction_SETTLE:
fallthrough fallthrough
case routerrpc.ResolveHoldForwardAction_RESUME: case routerrpc.ResolveHoldForwardAction_RESUME:
if attempt.Status != lnrpc.HTLCAttempt_SUCCEEDED { require.Equal(t.t, lnrpc.HTLCAttempt_SUCCEEDED,
t.t.Errorf("expected payment to "+ attempt.Status, "expected payment to succeed")
"succeed, instead got %v", attempt.Status)
}
} }
} }
}() }()
@ -185,9 +187,8 @@ func testForwardInterceptor(net *lntest.NetworkHarness, t *harnessTest) {
// Alice's node. // Alice's node.
payments, err := testContext.alice.ListPayments(context.Background(), payments, err := testContext.alice.ListPayments(context.Background(),
&lnrpc.ListPaymentsRequest{IncludeIncomplete: true}) &lnrpc.ListPaymentsRequest{IncludeIncomplete: true})
if err != nil { require.NoError(t.t, err, "failed to fetch payment")
t.Fatalf("failed to fetch payments")
}
for _, testCase := range testCases { for _, testCase := range testCases {
if testCase.shouldHold { if testCase.shouldHold {
hashStr := hex.EncodeToString(testCase.invoice.RHash) hashStr := hex.EncodeToString(testCase.invoice.RHash)
@ -199,18 +200,14 @@ func testForwardInterceptor(net *lntest.NetworkHarness, t *harnessTest) {
break break
} }
} }
if foundPayment == nil { require.NotNil(t.t, foundPayment, fmt.Sprintf("expected "+
t.Fatalf("expected to find pending payment for held"+ "to find pending payment for held htlc %v",
"htlc %v", hashStr) hashStr))
} require.Equal(t.t, lnrpc.Payment_IN_FLIGHT,
if foundPayment.ValueMsat != expectedAmt || foundPayment.Status, "expected payment to be "+
foundPayment.Status != lnrpc.Payment_IN_FLIGHT { "in flight")
require.Equal(t.t, expectedAmt, foundPayment.ValueMsat,
t.Fatalf("expected to find in flight payment for"+ "incorrect in flight amount")
"amount %v, %v",
testCase.invoice.ValueMsat,
foundPayment.Status)
}
} }
} }
@ -236,32 +233,26 @@ type interceptorTestContext struct {
} }
func newInterceptorTestContext(t *harnessTest, func newInterceptorTestContext(t *harnessTest,
net *lntest.NetworkHarness) *interceptorTestContext { net *lntest.NetworkHarness,
alice, bob, carol *lntest.HarnessNode) *interceptorTestContext {
ctxb := context.Background() ctxb := context.Background()
// Create a three-node context consisting of Alice, Bob and Carol
carol, err := net.NewNode("carol", nil)
if err != nil {
t.Fatalf("unable to create carol: %v", err)
}
// Connect nodes // Connect nodes
nodes := []*lntest.HarnessNode{net.Alice, net.Bob, carol} nodes := []*lntest.HarnessNode{alice, bob, carol}
for i := 0; i < len(nodes); i++ { for i := 0; i < len(nodes); i++ {
for j := i + 1; j < len(nodes); j++ { for j := i + 1; j < len(nodes); j++ {
ctxt, _ := context.WithTimeout(ctxb, defaultTimeout) ctxt, _ := context.WithTimeout(ctxb, defaultTimeout)
if err := net.EnsureConnected(ctxt, nodes[i], nodes[j]); err != nil { err := net.EnsureConnected(ctxt, nodes[i], nodes[j])
t.Fatalf("unable to connect nodes: %v", err) require.NoError(t.t, err, "unable to connect nodes")
}
} }
} }
ctx := interceptorTestContext{ ctx := interceptorTestContext{
t: t, t: t,
net: net, net: net,
alice: net.Alice, alice: alice,
bob: net.Bob, bob: bob,
carol: carol, carol: carol,
nodes: nodes, nodes: nodes,
} }
@ -274,9 +265,7 @@ func newInterceptorTestContext(t *harnessTest,
// 2. resumed htlc. // 2. resumed htlc.
// 3. settling htlc externally. // 3. settling htlc externally.
// 4. held htlc that is resumed later. // 4. held htlc that is resumed later.
func (c *interceptorTestContext) prepareTestCases() ( func (c *interceptorTestContext) prepareTestCases() []*interceptorTestCase {
[]*interceptorTestCase, error) {
cases := []*interceptorTestCase{ cases := []*interceptorTestCase{
{amountMsat: 1000, shouldHold: false, {amountMsat: 1000, shouldHold: false,
interceptorAction: routerrpc.ResolveHoldForwardAction_FAIL}, interceptorAction: routerrpc.ResolveHoldForwardAction_FAIL},
@ -292,15 +281,12 @@ func (c *interceptorTestContext) prepareTestCases() (
addResponse, err := c.carol.AddInvoice(context.Background(), &lnrpc.Invoice{ addResponse, err := c.carol.AddInvoice(context.Background(), &lnrpc.Invoice{
ValueMsat: t.amountMsat, ValueMsat: t.amountMsat,
}) })
if err != nil { require.NoError(c.t.t, err, "unable to add invoice")
return nil, fmt.Errorf("unable to add invoice: %v", err)
}
invoice, err := c.carol.LookupInvoice(context.Background(), &lnrpc.PaymentHash{ invoice, err := c.carol.LookupInvoice(context.Background(), &lnrpc.PaymentHash{
RHashStr: hex.EncodeToString(addResponse.RHash), RHashStr: hex.EncodeToString(addResponse.RHash),
}) })
if err != nil { require.NoError(c.t.t, err, "unable to find invoice")
return nil, fmt.Errorf("unable to add invoice: %v", err)
}
// We'll need to also decode the returned invoice so we can // We'll need to also decode the returned invoice so we can
// grab the payment address which is now required for ALL // grab the payment address which is now required for ALL
@ -308,13 +294,12 @@ func (c *interceptorTestContext) prepareTestCases() (
payReq, err := c.carol.DecodePayReq(context.Background(), &lnrpc.PayReqString{ payReq, err := c.carol.DecodePayReq(context.Background(), &lnrpc.PayReqString{
PayReq: invoice.PaymentRequest, PayReq: invoice.PaymentRequest,
}) })
if err != nil { require.NoError(c.t.t, err, "unable to decode invoice")
return nil, fmt.Errorf("unable to decode invoice: %v", err)
}
t.invoice = invoice t.invoice = invoice
t.payAddr = payReq.PaymentAddr t.payAddr = payReq.PaymentAddr
} }
return cases, nil return cases
} }
func (c *interceptorTestContext) openChannel(from, to *lntest.HarnessNode, func (c *interceptorTestContext) openChannel(from, to *lntest.HarnessNode,
@ -324,9 +309,7 @@ func (c *interceptorTestContext) openChannel(from, to *lntest.HarnessNode,
ctxt, _ := context.WithTimeout(ctxb, defaultTimeout) ctxt, _ := context.WithTimeout(ctxb, defaultTimeout)
err := c.net.SendCoins(ctxt, btcutil.SatoshiPerBitcoin, from) err := c.net.SendCoins(ctxt, btcutil.SatoshiPerBitcoin, from)
if err != nil { require.NoError(c.t.t, err, "unable to send coins")
c.t.Fatalf("unable to send coins : %v", err)
}
ctxt, _ = context.WithTimeout(ctxb, channelOpenTimeout) ctxt, _ = context.WithTimeout(ctxb, channelOpenTimeout)
chanPoint := openChannelAndAssert( chanPoint := openChannelAndAssert(
@ -352,10 +335,6 @@ func (c *interceptorTestContext) closeChannels() {
} }
} }
func (c *interceptorTestContext) shutdownNodes() {
shutdownAndAssert(c.net, c.t, c.carol)
}
func (c *interceptorTestContext) waitForChannels() { func (c *interceptorTestContext) waitForChannels() {
ctxb := context.Background() ctxb := context.Background()
@ -363,9 +342,8 @@ func (c *interceptorTestContext) waitForChannels() {
for _, chanPoint := range c.networkChans { for _, chanPoint := range c.networkChans {
for _, node := range c.nodes { for _, node := range c.nodes {
txid, err := lnd.GetChanPointFundingTxid(chanPoint) txid, err := lnd.GetChanPointFundingTxid(chanPoint)
if err != nil { require.NoError(c.t.t, err, "unable to get txid")
c.t.Fatalf("unable to get txid: %v", err)
}
point := wire.OutPoint{ point := wire.OutPoint{
Hash: *txid, Hash: *txid,
Index: chanPoint.OutputIndex, Index: chanPoint.OutputIndex,
@ -373,11 +351,9 @@ func (c *interceptorTestContext) waitForChannels() {
ctxt, _ := context.WithTimeout(ctxb, defaultTimeout) ctxt, _ := context.WithTimeout(ctxb, defaultTimeout)
err = node.WaitForNetworkChannelOpen(ctxt, chanPoint) err = node.WaitForNetworkChannelOpen(ctxt, chanPoint)
if err != nil { require.NoError(c.t.t, err, fmt.Sprintf("(%d): timeout "+
c.t.Fatalf("(%d): timeout waiting for "+ "waiting for channel(%s) open", node.NodeID,
"channel(%s) open: %v", point))
node.NodeID, point, err)
}
} }
} }
} }