routing/router_test: add TestSendToRouteMultiShardSend

This commit is contained in:
Johan T. Halseth 2020-04-01 00:13:27 +02:00
parent 864e64e725
commit 95c5a123c8
No known key found for this signature in database
GPG Key ID: 15BAADA29DA20D26
2 changed files with 145 additions and 4 deletions

@ -15,6 +15,8 @@ import (
type mockPaymentAttemptDispatcher struct {
onPayment func(firstHop lnwire.ShortChannelID) ([32]byte, error)
results map[uint64]*htlcswitch.PaymentResult
sync.Mutex
}
var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil)
@ -27,10 +29,6 @@ func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID,
return nil
}
if m.results == nil {
m.results = make(map[uint64]*htlcswitch.PaymentResult)
}
var result *htlcswitch.PaymentResult
preimage, err := m.onPayment(firstHop)
if err != nil {
@ -45,7 +43,13 @@ func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID,
result = &htlcswitch.PaymentResult{Preimage: preimage}
}
m.Lock()
if m.results == nil {
m.results = make(map[uint64]*htlcswitch.PaymentResult)
}
m.results[pid] = result
m.Unlock()
return nil
}
@ -55,7 +59,11 @@ func (m *mockPaymentAttemptDispatcher) GetPaymentResult(paymentID uint64,
<-chan *htlcswitch.PaymentResult, error) {
c := make(chan *htlcswitch.PaymentResult, 1)
m.Lock()
res, ok := m.results[paymentID]
m.Unlock()
if !ok {
return nil, htlcswitch.ErrPaymentIDNotFound
}

@ -21,6 +21,7 @@ import (
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/zpay32"
)
@ -2725,6 +2726,138 @@ func TestSendToRouteStructuredError(t *testing.T) {
}
}
// TestSendToRouteMultiShardSend checks that a 3-shard payment can be executed
// using SendToRoute.
func TestSendToRouteMultiShardSend(t *testing.T) {
t.Parallel()
ctx, cleanup, err := createTestCtxSingleNode(0)
if err != nil {
t.Fatal(err)
}
defer cleanup()
const numShards = 3
const payAmt = lnwire.MilliSatoshi(numShards * 10000)
node, err := createTestNode()
if err != nil {
t.Fatal(err)
}
// Create a simple 1-hop route that we will use for all three shards.
hops := []*route.Hop{
{
ChannelID: 1,
PubKeyBytes: node.PubKeyBytes,
AmtToForward: payAmt / numShards,
MPP: record.NewMPP(payAmt, [32]byte{}),
},
}
sourceNode, err := ctx.graph.SourceNode()
if err != nil {
t.Fatal(err)
}
rt, err := route.NewRouteFromHops(
payAmt, 100, sourceNode.PubKeyBytes, hops,
)
if err != nil {
t.Fatalf("unable to create route: %v", err)
}
// The first shard we send we'll fail immediately, to check that we are
// still allowed to retry with other shards after a failed one.
ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcher).setPaymentResult(
func(firstHop lnwire.ShortChannelID) ([32]byte, error) {
return [32]byte{}, htlcswitch.NewForwardingError(
&lnwire.FailFeeInsufficient{
Update: lnwire.ChannelUpdate{},
}, 1,
)
})
// The payment parameter is mostly redundant in SendToRoute. Can be left
// empty for this test.
var payment lntypes.Hash
// Send the shard using the created route, and expect an error to be
// returned.
_, err = ctx.router.SendToRoute(payment, rt)
if err == nil {
t.Fatalf("expected forwarding error")
}
// Now we'll modify the SendToSwitch method again to wait until all
// three shards are initiated before returning a result. We do this by
// signalling when the method has been called, and then stop to wait
// for the test to deliver the final result on the channel below.
waitForResultSignal := make(chan struct{}, numShards)
results := make(chan lntypes.Preimage, numShards)
ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcher).setPaymentResult(
func(firstHop lnwire.ShortChannelID) ([32]byte, error) {
// Signal that the shard has been initiated and is
// waiting for a result.
waitForResultSignal <- struct{}{}
// Wait for a result before returning it.
res, ok := <-results
if !ok {
return [32]byte{}, fmt.Errorf("failure")
}
return res, nil
})
// Launch three shards by calling SendToRoute in three goroutines,
// returning their final error on the channel.
errChan := make(chan error)
successes := make(chan lntypes.Preimage)
for i := 0; i < numShards; i++ {
go func() {
preimg, err := ctx.router.SendToRoute(payment, rt)
if err != nil {
errChan <- err
return
}
successes <- preimg
}()
}
// Wait for all shards to signal they have been initiated.
for i := 0; i < numShards; i++ {
select {
case <-waitForResultSignal:
case <-time.After(5 * time.Second):
t.Fatalf("not waiting for results")
}
}
// Deliver a dummy preimage to all the shard handlers.
preimage := lntypes.Preimage{}
preimage[4] = 42
for i := 0; i < numShards; i++ {
results <- preimage
}
// Finally expect all shards to return with the above preimage.
for i := 0; i < numShards; i++ {
select {
case p := <-successes:
if p != preimage {
t.Fatalf("preimage mismatch")
}
case err := <-errChan:
t.Fatalf("unexpected error from SendToRoute: %v", err)
case <-time.After(5 * time.Second):
t.Fatalf("result not received")
}
}
}
// TestSendToRouteMaxHops asserts that SendToRoute fails when using a route that
// exceeds the maximum number of hops.
func TestSendToRouteMaxHops(t *testing.T) {