diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 9339b295..96e989cc 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -1287,10 +1287,69 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) { } } +type multiHopFwdTest struct { + name string + eligible1, eligible2 bool + failure1, failure2 lnwire.FailureMessage + expectedReply lnwire.FailCode +} + // TestSkipIneligibleLinksMultiHopForward tests that if a multi-hop HTLC comes // along, then we won't attempt to froward it down al ink that isn't yet able // to forward any HTLC's. func TestSkipIneligibleLinksMultiHopForward(t *testing.T) { + tests := []multiHopFwdTest{ + // None of the channels is eligible or has enough bandwidth. + { + name: "not eligible", + expectedReply: lnwire.CodeTemporaryChannelFailure, + }, + + // Channel one has a policy failure and the other channel isn't + // available. + { + name: "policy fail", + eligible1: true, + failure1: lnwire.NewFinalIncorrectCltvExpiry(0), + expectedReply: lnwire.CodeFinalIncorrectCltvExpiry, + }, + + // The requested channel is not eligible or has insufficient + // bandwidth, but the packet is forwarded through the other + // channel. + { + name: "non-strict success", + eligible2: true, + expectedReply: lnwire.CodeNone, + }, + + // The requested channel is not eligible or has insufficient + // bandwidth and the other channel's policy isn't satisfied. + // + // NOTE: We expect a temporary channel failure here, but don't + // receive it! + { + name: "non-strict policy fail", + eligible2: true, + failure2: lnwire.NewFinalIncorrectCltvExpiry(0), + expectedReply: lnwire.CodeUnknownNextPeer, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + testSkipIneligibleLinksMultiHopForward(t, &test) + }) + } +} + +// testSkipIneligibleLinksMultiHopForward tests that if a multi-hop HTLC comes +// along, then we won't attempt to froward it down al ink that isn't yet able +// to forward any HTLC's. +func testSkipIneligibleLinksMultiHopForward(t *testing.T, + testCase *multiHopFwdTest) { + t.Parallel() var packet *htlcPacket @@ -1313,22 +1372,32 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) { } defer s.Stop() - chanID1, chanID2, aliceChanID, bobChanID := genIDs() - + chanID1, aliceChanID := genID() aliceChannelLink := newMockChannelLink( s, chanID1, aliceChanID, alicePeer, true, ) // We'll create a link for Bob, but mark the link as unable to forward // any new outgoing HTLC's. - bobChannelLink := newMockChannelLink( - s, chanID2, bobChanID, bobPeer, false, + chanID2, bobChanID2 := genID() + bobChannelLink1 := newMockChannelLink( + s, chanID2, bobChanID2, bobPeer, testCase.eligible1, ) + bobChannelLink1.checkHtlcForwardResult = testCase.failure1 + + chanID3, bobChanID3 := genID() + bobChannelLink2 := newMockChannelLink( + s, chanID3, bobChanID3, bobPeer, testCase.eligible2, + ) + bobChannelLink2.checkHtlcForwardResult = testCase.failure2 if err := s.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) } - if err := s.AddLink(bobChannelLink); err != nil { + if err := s.AddLink(bobChannelLink1); err != nil { + t.Fatalf("unable to add bob link: %v", err) + } + if err := s.AddLink(bobChannelLink2); err != nil { t.Fatalf("unable to add bob link: %v", err) } @@ -1340,7 +1409,7 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) { packet = &htlcPacket{ incomingChanID: aliceChannelLink.ShortChanID(), incomingHTLCID: 0, - outgoingChanID: bobChannelLink.ShortChanID(), + outgoingChanID: bobChannelLink1.ShortChanID(), htlc: &lnwire.UpdateAddHTLC{ PaymentHash: rhash, Amount: 1, @@ -1350,13 +1419,23 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) { // The request to forward should fail as err = s.forward(packet) - if err == nil { - t.Fatalf("forwarding should have failed due to inactive link") - } failure := obfuscator.(*mockObfuscator).failure - if _, ok := failure.(*lnwire.FailTemporaryChannelFailure); !ok { - t.Fatalf("unexpected failure %T", failure) + if testCase.expectedReply == lnwire.CodeNone { + if err != nil { + t.Fatalf("forwarding should have succeeded") + } + if failure != nil { + t.Fatalf("unexpected failure %T", failure) + } + } else { + if err == nil { + t.Fatalf("forwarding should have failed due to " + + "inactive link") + } + if failure.Code() != testCase.expectedReply { + t.Fatalf("unexpected failure %T", failure) + } } if s.circuits.NumOpen() != 0 {