From 95c5a123c8661918723566740e50de573b9d0942 Mon Sep 17 00:00:00 2001 From: "Johan T. Halseth" Date: Wed, 1 Apr 2020 00:13:27 +0200 Subject: [PATCH] routing/router_test: add TestSendToRouteMultiShardSend --- routing/mock_test.go | 16 +++-- routing/router_test.go | 133 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+), 4 deletions(-) diff --git a/routing/mock_test.go b/routing/mock_test.go index 36523a97..f47a9420 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -15,6 +15,8 @@ import ( type mockPaymentAttemptDispatcher struct { onPayment func(firstHop lnwire.ShortChannelID) ([32]byte, error) results map[uint64]*htlcswitch.PaymentResult + + sync.Mutex } var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil) @@ -27,10 +29,6 @@ func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID, return nil } - if m.results == nil { - m.results = make(map[uint64]*htlcswitch.PaymentResult) - } - var result *htlcswitch.PaymentResult preimage, err := m.onPayment(firstHop) if err != nil { @@ -45,7 +43,13 @@ func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID, result = &htlcswitch.PaymentResult{Preimage: preimage} } + m.Lock() + if m.results == nil { + m.results = make(map[uint64]*htlcswitch.PaymentResult) + } + m.results[pid] = result + m.Unlock() return nil } @@ -55,7 +59,11 @@ func (m *mockPaymentAttemptDispatcher) GetPaymentResult(paymentID uint64, <-chan *htlcswitch.PaymentResult, error) { c := make(chan *htlcswitch.PaymentResult, 1) + + m.Lock() res, ok := m.results[paymentID] + m.Unlock() + if !ok { return nil, htlcswitch.ErrPaymentIDNotFound } diff --git a/routing/router_test.go b/routing/router_test.go index 0bb90d71..7ac7527c 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -21,6 +21,7 @@ import ( "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/zpay32" ) @@ -2725,6 +2726,138 @@ func TestSendToRouteStructuredError(t *testing.T) { } } +// TestSendToRouteMultiShardSend checks that a 3-shard payment can be executed +// using SendToRoute. +func TestSendToRouteMultiShardSend(t *testing.T) { + t.Parallel() + + ctx, cleanup, err := createTestCtxSingleNode(0) + if err != nil { + t.Fatal(err) + } + defer cleanup() + + const numShards = 3 + const payAmt = lnwire.MilliSatoshi(numShards * 10000) + node, err := createTestNode() + if err != nil { + t.Fatal(err) + } + + // Create a simple 1-hop route that we will use for all three shards. + hops := []*route.Hop{ + { + ChannelID: 1, + PubKeyBytes: node.PubKeyBytes, + AmtToForward: payAmt / numShards, + MPP: record.NewMPP(payAmt, [32]byte{}), + }, + } + + sourceNode, err := ctx.graph.SourceNode() + if err != nil { + t.Fatal(err) + } + + rt, err := route.NewRouteFromHops( + payAmt, 100, sourceNode.PubKeyBytes, hops, + ) + if err != nil { + t.Fatalf("unable to create route: %v", err) + } + + // The first shard we send we'll fail immediately, to check that we are + // still allowed to retry with other shards after a failed one. + ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcher).setPaymentResult( + func(firstHop lnwire.ShortChannelID) ([32]byte, error) { + return [32]byte{}, htlcswitch.NewForwardingError( + &lnwire.FailFeeInsufficient{ + Update: lnwire.ChannelUpdate{}, + }, 1, + ) + }) + + // The payment parameter is mostly redundant in SendToRoute. Can be left + // empty for this test. + var payment lntypes.Hash + + // Send the shard using the created route, and expect an error to be + // returned. + _, err = ctx.router.SendToRoute(payment, rt) + if err == nil { + t.Fatalf("expected forwarding error") + } + + // Now we'll modify the SendToSwitch method again to wait until all + // three shards are initiated before returning a result. We do this by + // signalling when the method has been called, and then stop to wait + // for the test to deliver the final result on the channel below. + waitForResultSignal := make(chan struct{}, numShards) + results := make(chan lntypes.Preimage, numShards) + + ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcher).setPaymentResult( + func(firstHop lnwire.ShortChannelID) ([32]byte, error) { + + // Signal that the shard has been initiated and is + // waiting for a result. + waitForResultSignal <- struct{}{} + + // Wait for a result before returning it. + res, ok := <-results + if !ok { + return [32]byte{}, fmt.Errorf("failure") + } + return res, nil + }) + + // Launch three shards by calling SendToRoute in three goroutines, + // returning their final error on the channel. + errChan := make(chan error) + successes := make(chan lntypes.Preimage) + + for i := 0; i < numShards; i++ { + go func() { + preimg, err := ctx.router.SendToRoute(payment, rt) + if err != nil { + errChan <- err + return + } + + successes <- preimg + }() + } + + // Wait for all shards to signal they have been initiated. + for i := 0; i < numShards; i++ { + select { + case <-waitForResultSignal: + case <-time.After(5 * time.Second): + t.Fatalf("not waiting for results") + } + } + + // Deliver a dummy preimage to all the shard handlers. + preimage := lntypes.Preimage{} + preimage[4] = 42 + for i := 0; i < numShards; i++ { + results <- preimage + } + + // Finally expect all shards to return with the above preimage. + for i := 0; i < numShards; i++ { + select { + case p := <-successes: + if p != preimage { + t.Fatalf("preimage mismatch") + } + case err := <-errChan: + t.Fatalf("unexpected error from SendToRoute: %v", err) + case <-time.After(5 * time.Second): + t.Fatalf("result not received") + } + } +} + // TestSendToRouteMaxHops asserts that SendToRoute fails when using a route that // exceeds the maximum number of hops. func TestSendToRouteMaxHops(t *testing.T) {