diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index 5e162795..cd050856 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -100,22 +100,21 @@ type ChannelLink interface { // policy to govern if it an incoming HTLC should be forwarded or not. UpdateForwardingPolicy(ForwardingPolicy) - // HtlcSatifiesPolicy should return a nil error if the passed HTLC - // details satisfy the current forwarding policy fo the target link. - // Otherwise, a valid protocol failure message should be returned in - // order to signal to the source of the HTLC, the policy consistency - // issue. - HtlcSatifiesPolicy(payHash [32]byte, incomingAmt lnwire.MilliSatoshi, + // CheckHtlcForward should return a nil error if the passed HTLC details + // satisfy the current forwarding policy fo the target link. Otherwise, + // a valid protocol failure message should be returned in order to + // signal to the source of the HTLC, the policy consistency issue. + CheckHtlcForward(payHash [32]byte, incomingAmt lnwire.MilliSatoshi, amtToForward lnwire.MilliSatoshi, incomingTimeout, outgoingTimeout uint32, heightNow uint32) lnwire.FailureMessage - // HtlcSatifiesPolicyLocal should return a nil error if the passed HTLC - // details satisfy the current channel policy. Otherwise, a valid - // protocol failure message should be returned in order to signal the - // violation. This call is intended to be used for locally initiated - // payments for which there is no corresponding incoming htlc. - HtlcSatifiesPolicyLocal(payHash [32]byte, amt lnwire.MilliSatoshi, + // CheckHtlcTransit should return a nil error if the passed HTLC details + // satisfy the current channel policy. Otherwise, a valid protocol + // failure message should be returned in order to signal the violation. + // This call is intended to be used for locally initiated payments for + // which there is no corresponding incoming htlc. + CheckHtlcTransit(payHash [32]byte, amt lnwire.MilliSatoshi, timeout uint32, heightNow uint32) lnwire.FailureMessage // Bandwidth returns the amount of milli-satoshis which current link diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 3e356b35..c0b62633 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -605,6 +605,19 @@ func shouldAdjustCommitFee(netFee, chanFee lnwallet.SatPerKWeight) bool { } } +// createFailureWithUpdate retrieves this link's last channel update message and +// passes it into the callback. It expects a fully populated failure message. +func (l *channelLink) createFailureWithUpdate( + cb func(update *lnwire.ChannelUpdate) lnwire.FailureMessage) lnwire.FailureMessage { + + update, err := l.cfg.FetchLastChannelUpdate(l.ShortChanID()) + if err != nil { + return &lnwire.FailTemporaryNodeFailure{} + } + + return cb(update) +} + // syncChanState attempts to synchronize channel states with the remote party. // This method is to be called upon reconnection after the initial funding // flow. We'll compare out commitment chains with the remote party, and re-send @@ -1312,17 +1325,13 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) { reason lnwire.OpaqueReason ) - var failure lnwire.FailureMessage - update, err := l.cfg.FetchLastChannelUpdate( - l.ShortChanID(), + failure := l.createFailureWithUpdate( + func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + return lnwire.NewTemporaryChannelFailure( + upd, + ) + }, ) - if err != nil { - failure = &lnwire.FailTemporaryNodeFailure{} - } else { - failure = lnwire.NewTemporaryChannelFailure( - update, - ) - } // Encrypt the error back to the source unless // the payment was generated locally. @@ -2175,13 +2184,13 @@ func (l *channelLink) UpdateForwardingPolicy(newPolicy ForwardingPolicy) { l.cfg.FwrdingPolicy = newPolicy } -// HtlcSatifiesPolicy should return a nil error if the passed HTLC details -// satisfy the current forwarding policy fo the target link. Otherwise, a -// valid protocol failure message should be returned in order to signal to the -// source of the HTLC, the policy consistency issue. +// CheckHtlcForward should return a nil error if the passed HTLC details satisfy +// the current forwarding policy fo the target link. Otherwise, a valid +// protocol failure message should be returned in order to signal to the source +// of the HTLC, the policy consistency issue. // // NOTE: Part of the ChannelLink interface. -func (l *channelLink) HtlcSatifiesPolicy(payHash [32]byte, +func (l *channelLink) CheckHtlcForward(payHash [32]byte, incomingHtlcAmt, amtToForward lnwire.MilliSatoshi, incomingTimeout, outgoingTimeout uint32, heightNow uint32) lnwire.FailureMessage { @@ -2191,7 +2200,7 @@ func (l *channelLink) HtlcSatifiesPolicy(payHash [32]byte, l.RUnlock() // First check whether the outgoing htlc satisfies the channel policy. - err := l.htlcSatifiesPolicyOutgoing( + err := l.canSendHtlc( policy, payHash, amtToForward, outgoingTimeout, heightNow, ) if err != nil { @@ -2216,17 +2225,14 @@ func (l *channelLink) HtlcSatifiesPolicy(payHash [32]byte, // As part of the returned error, we'll send our latest routing // policy so the sending node obtains the most up to date data. - var failure lnwire.FailureMessage - update, err := l.cfg.FetchLastChannelUpdate(l.ShortChanID()) - if err != nil { - failure = &lnwire.FailTemporaryNodeFailure{} - } else { - failure = lnwire.NewFeeInsufficient( - amtToForward, *update, - ) - } - return failure + return l.createFailureWithUpdate( + func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + return lnwire.NewFeeInsufficient( + amtToForward, *upd, + ) + }, + ) } // Finally, we'll ensure that the time-lock on the outgoing HTLC meets @@ -2241,30 +2247,24 @@ func (l *channelLink) HtlcSatifiesPolicy(payHash [32]byte, // Grab the latest routing policy so the sending node is up to // date with our current policy. - var failure lnwire.FailureMessage - update, err := l.cfg.FetchLastChannelUpdate( - l.ShortChanID(), + return l.createFailureWithUpdate( + func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + return lnwire.NewIncorrectCltvExpiry( + incomingTimeout, *upd, + ) + }, ) - if err != nil { - failure = lnwire.NewTemporaryChannelFailure(update) - } else { - failure = lnwire.NewIncorrectCltvExpiry( - incomingTimeout, *update, - ) - } - - return failure } return nil } -// HtlcSatifiesPolicyLocal should return a nil error if the passed HTLC details -// satisfy the current channel policy. Otherwise, a valid protocol failure -// message should be returned in order to signal the violation. This call is -// intended to be used for locally initiated payments for which there is no -// corresponding incoming htlc. -func (l *channelLink) HtlcSatifiesPolicyLocal(payHash [32]byte, +// CheckHtlcTransit should return a nil error if the passed HTLC details satisfy the +// current channel policy. Otherwise, a valid protocol failure message should +// be returned in order to signal the violation. This call is intended to be +// used for locally initiated payments for which there is no corresponding +// incoming htlc. +func (l *channelLink) CheckHtlcTransit(payHash [32]byte, amt lnwire.MilliSatoshi, timeout uint32, heightNow uint32) lnwire.FailureMessage { @@ -2272,14 +2272,14 @@ func (l *channelLink) HtlcSatifiesPolicyLocal(payHash [32]byte, policy := l.cfg.FwrdingPolicy l.RUnlock() - return l.htlcSatifiesPolicyOutgoing( + return l.canSendHtlc( policy, payHash, amt, timeout, heightNow, ) } // htlcSatifiesPolicyOutgoing checks whether the given htlc parameters satisfy // the channel's amount and time lock constraints. -func (l *channelLink) htlcSatifiesPolicyOutgoing(policy ForwardingPolicy, +func (l *channelLink) canSendHtlc(policy ForwardingPolicy, payHash [32]byte, amt lnwire.MilliSatoshi, timeout uint32, heightNow uint32) lnwire.FailureMessage { @@ -2293,36 +2293,28 @@ func (l *channelLink) htlcSatifiesPolicyOutgoing(policy ForwardingPolicy, // As part of the returned error, we'll send our latest routing // policy so the sending node obtains the most up to date data. - var failure lnwire.FailureMessage - update, err := l.cfg.FetchLastChannelUpdate(l.ShortChanID()) - if err != nil { - failure = &lnwire.FailTemporaryNodeFailure{} - } else { - failure = lnwire.NewAmountBelowMinimum( - amt, *update, - ) - } - - return failure + return l.createFailureWithUpdate( + func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + return lnwire.NewAmountBelowMinimum( + amt, *upd, + ) + }, + ) } - // Next, ensure that the passed HTLC isn't too large. If so, we'll cancel - // the HTLC directly. + // Next, ensure that the passed HTLC isn't too large. If so, we'll + // cancel the HTLC directly. if policy.MaxHTLC != 0 && amt > policy.MaxHTLC { l.log.Errorf("outgoing htlc(%x) is too large: max_htlc=%v, "+ "htlc_value=%v", payHash[:], policy.MaxHTLC, amt) - // As part of the returned error, we'll send our latest routing policy - // so the sending node obtains the most up-to-date data. - var failure lnwire.FailureMessage - update, err := l.cfg.FetchLastChannelUpdate(l.ShortChanID()) - if err != nil { - failure = &lnwire.FailTemporaryNodeFailure{} - } else { - failure = lnwire.NewTemporaryChannelFailure(update) - } - - return failure + // As part of the returned error, we'll send our latest routing + // policy so the sending node obtains the most up-to-date data. + return l.createFailureWithUpdate( + func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + return lnwire.NewTemporaryChannelFailure(upd) + }, + ) } // We want to avoid offering an HTLC which will expire in the near @@ -2333,17 +2325,11 @@ func (l *channelLink) htlcSatifiesPolicyOutgoing(policy ForwardingPolicy, "outgoing_expiry=%v, best_height=%v", payHash[:], timeout, heightNow) - var failure lnwire.FailureMessage - update, err := l.cfg.FetchLastChannelUpdate( - l.ShortChanID(), + return l.createFailureWithUpdate( + func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + return lnwire.NewExpiryTooSoon(*upd) + }, ) - if err != nil { - failure = lnwire.NewTemporaryChannelFailure(update) - } else { - failure = lnwire.NewExpiryTooSoon(*update) - } - - return failure } // Check absolute max delta. @@ -2355,6 +2341,15 @@ func (l *channelLink) htlcSatifiesPolicyOutgoing(policy ForwardingPolicy, return &lnwire.FailExpiryTooFar{} } + // Check to see if there is enough balance in this channel. + if amt > l.Bandwidth() { + return l.createFailureWithUpdate( + func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + return lnwire.NewTemporaryChannelFailure(upd) + }, + ) + } + return nil } @@ -2764,17 +2759,13 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, l.log.Errorf("unable to encode the "+ "remaining route %v", err) - var failure lnwire.FailureMessage - update, err := l.cfg.FetchLastChannelUpdate( - l.ShortChanID(), + failure := l.createFailureWithUpdate( + func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + return lnwire.NewTemporaryChannelFailure( + upd, + ) + }, ) - if err != nil { - failure = &lnwire.FailTemporaryNodeFailure{} - } else { - failure = lnwire.NewTemporaryChannelFailure( - update, - ) - } l.sendHTLCError( pd.HtlcIndex, failure, obfuscator, pd.SourceRef, diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 60d05eca..86463ee4 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -5396,9 +5396,9 @@ func TestForwardingAsymmetricTimeLockPolicies(t *testing.T) { } } -// TestHtlcSatisfyPolicy tests that a link is properly enforcing the HTLC +// TestCheckHtlcForward tests that a link is properly enforcing the HTLC // forwarding policy. -func TestHtlcSatisfyPolicy(t *testing.T) { +func TestCheckHtlcForward(t *testing.T) { fetchLastChannelUpdate := func(lnwire.ShortChannelID) ( *lnwire.ChannelUpdate, error) { @@ -5406,6 +5406,15 @@ func TestHtlcSatisfyPolicy(t *testing.T) { return &lnwire.ChannelUpdate{}, nil } + testChannel, _, fCleanUp, err := createTestChannel( + alicePrivKey, bobPrivKey, 100000, 100000, + 1000, 1000, lnwire.ShortChannelID{}, + ) + if err != nil { + t.Fatal(err) + } + defer fCleanUp() + link := channelLink{ cfg: ChannelLinkConfig{ FwrdingPolicy: ForwardingPolicy{ @@ -5417,13 +5426,15 @@ func TestHtlcSatisfyPolicy(t *testing.T) { FetchLastChannelUpdate: fetchLastChannelUpdate, MaxOutgoingCltvExpiry: DefaultMaxOutgoingCltvExpiry, }, - log: log, + log: log, + channel: testChannel.channel, + overflowQueue: newPacketQueue(input.MaxHTLCNumber / 2), } var hash [32]byte t.Run("satisfied", func(t *testing.T) { - result := link.HtlcSatifiesPolicy(hash, 1500, 1000, + result := link.CheckHtlcForward(hash, 1500, 1000, 200, 150, 0) if result != nil { t.Fatalf("expected policy to be satisfied") @@ -5431,7 +5442,7 @@ func TestHtlcSatisfyPolicy(t *testing.T) { }) t.Run("below minhtlc", func(t *testing.T) { - result := link.HtlcSatifiesPolicy(hash, 100, 50, + result := link.CheckHtlcForward(hash, 100, 50, 200, 150, 0) if _, ok := result.(*lnwire.FailAmountBelowMinimum); !ok { t.Fatalf("expected FailAmountBelowMinimum failure code") @@ -5439,7 +5450,7 @@ func TestHtlcSatisfyPolicy(t *testing.T) { }) t.Run("above maxhtlc", func(t *testing.T) { - result := link.HtlcSatifiesPolicy(hash, 1500, 1200, + result := link.CheckHtlcForward(hash, 1500, 1200, 200, 150, 0) if _, ok := result.(*lnwire.FailTemporaryChannelFailure); !ok { t.Fatalf("expected FailTemporaryChannelFailure failure code") @@ -5447,7 +5458,7 @@ func TestHtlcSatisfyPolicy(t *testing.T) { }) t.Run("insufficient fee", func(t *testing.T) { - result := link.HtlcSatifiesPolicy(hash, 1005, 1000, + result := link.CheckHtlcForward(hash, 1005, 1000, 200, 150, 0) if _, ok := result.(*lnwire.FailFeeInsufficient); !ok { t.Fatalf("expected FailFeeInsufficient failure code") @@ -5455,7 +5466,7 @@ func TestHtlcSatisfyPolicy(t *testing.T) { }) t.Run("expiry too soon", func(t *testing.T) { - result := link.HtlcSatifiesPolicy(hash, 1500, 1000, + result := link.CheckHtlcForward(hash, 1500, 1000, 200, 150, 190) if _, ok := result.(*lnwire.FailExpiryTooSoon); !ok { t.Fatalf("expected FailExpiryTooSoon failure code") @@ -5463,7 +5474,7 @@ func TestHtlcSatisfyPolicy(t *testing.T) { }) t.Run("incorrect cltv expiry", func(t *testing.T) { - result := link.HtlcSatifiesPolicy(hash, 1500, 1000, + result := link.CheckHtlcForward(hash, 1500, 1000, 200, 190, 0) if _, ok := result.(*lnwire.FailIncorrectCltvExpiry); !ok { t.Fatalf("expected FailIncorrectCltvExpiry failure code") @@ -5473,7 +5484,7 @@ func TestHtlcSatisfyPolicy(t *testing.T) { t.Run("cltv expiry too far in the future", func(t *testing.T) { // Check that expiry isn't too far in the future. - result := link.HtlcSatifiesPolicy(hash, 1500, 1000, + result := link.CheckHtlcForward(hash, 1500, 1000, 10200, 10100, 0) if _, ok := result.(*lnwire.FailExpiryTooFar); !ok { t.Fatalf("expected FailExpiryTooFar failure code") diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 626ebe82..53d95720 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -334,6 +334,7 @@ var _ hop.Iterator = (*mockHopIterator)(nil) // encodes the failure and do not makes any onion obfuscation. type mockObfuscator struct { ogPacket *sphinx.OnionPacket + failure lnwire.FailureMessage } // NewMockObfuscator initializes a dummy mockObfuscator used for testing. @@ -366,6 +367,8 @@ func (o *mockObfuscator) Reextract( func (o *mockObfuscator) EncryptFirstHop(failure lnwire.FailureMessage) ( lnwire.OpaqueReason, error) { + o.failure = failure + var b bytes.Buffer if err := lnwire.EncodeFailure(&b, failure, 0); err != nil { return nil, err @@ -637,7 +640,9 @@ type mockChannelLink struct { htlcID uint64 - htlcSatifiesPolicyLocalResult lnwire.FailureMessage + checkHtlcTransitResult lnwire.FailureMessage + + checkHtlcForwardResult lnwire.FailureMessage } // completeCircuit is a helper method for adding the finalized payment circuit @@ -696,16 +701,17 @@ func (f *mockChannelLink) HandleChannelUpdate(lnwire.Message) { func (f *mockChannelLink) UpdateForwardingPolicy(_ ForwardingPolicy) { } -func (f *mockChannelLink) HtlcSatifiesPolicy([32]byte, lnwire.MilliSatoshi, +func (f *mockChannelLink) CheckHtlcForward([32]byte, lnwire.MilliSatoshi, lnwire.MilliSatoshi, uint32, uint32, uint32) lnwire.FailureMessage { - return nil + + return f.checkHtlcForwardResult } -func (f *mockChannelLink) HtlcSatifiesPolicyLocal(payHash [32]byte, +func (f *mockChannelLink) CheckHtlcTransit(payHash [32]byte, amt lnwire.MilliSatoshi, timeout uint32, heightNow uint32) lnwire.FailureMessage { - return f.htlcSatifiesPolicyLocalResult + return f.checkHtlcTransitResult } func (f *mockChannelLink) Stats() (uint64, lnwire.MilliSatoshi, lnwire.MilliSatoshi) { diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index e2c38d1b..b135e1fd 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -775,7 +775,7 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error { // Ensure that the htlc satisfies the outgoing channel policy. currentHeight := atomic.LoadUint32(&s.bestHeight) - htlcErr := link.HtlcSatifiesPolicyLocal( + htlcErr := link.CheckHtlcTransit( htlc.PaymentHash, htlc.Amount, htlc.Expiry, currentHeight, @@ -790,22 +790,6 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error { } } - if link.Bandwidth() < htlc.Amount { - err := fmt.Errorf("Link %v has insufficient capacity: "+ - "need %v, has %v", pkt.outgoingChanID, - htlc.Amount, link.Bandwidth()) - log.Error(err) - - // The update does not need to be populated as the error - // will be returned back to the router. - htlcErr := lnwire.NewTemporaryChannelFailure(nil) - return &ForwardingError{ - FailureSourceIdx: 0, - ExtraMsg: err.Error(), - FailureMessage: htlcErr, - } - } - return link.HandleSwitchPacket(pkt) } @@ -1034,69 +1018,38 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { // bandwidth. var destination ChannelLink for _, link := range interfaceLinks { + var failure lnwire.FailureMessage + // We'll skip any links that aren't yet eligible for // forwarding. - switch { - case !link.EligibleToForward(): - continue - - // If the link doesn't yet have a source chan ID, then - // we'll skip it as well. - case link.ShortChanID() == hop.Source: - continue + if !link.EligibleToForward() { + failure = &lnwire.FailUnknownNextPeer{} + } else { + // We'll ensure that the HTLC satisfies the + // current forwarding conditions of this target + // link. + currentHeight := atomic.LoadUint32(&s.bestHeight) + failure = link.CheckHtlcForward( + htlc.PaymentHash, packet.incomingAmount, + packet.amount, packet.incomingTimeout, + packet.outgoingTimeout, currentHeight, + ) } - // Before we check the link's bandwidth, we'll ensure - // that the HTLC satisfies the current forwarding - // policy of this target link. - currentHeight := atomic.LoadUint32(&s.bestHeight) - err := link.HtlcSatifiesPolicy( - htlc.PaymentHash, packet.incomingAmount, - packet.amount, packet.incomingTimeout, - packet.outgoingTimeout, currentHeight, - ) - if err != nil { - linkErrs[link.ShortChanID()] = err - continue - } - - if link.Bandwidth() >= htlc.Amount { + // Stop searching if this link can forward the htlc. + if failure == nil { destination = link - break } + + linkErrs[link.ShortChanID()] = failure } - switch { - // If the channel link we're attempting to forward the update - // over has insufficient capacity, and didn't violate any - // forwarding policies, then we'll cancel the htlc as the - // payment cannot succeed. - case destination == nil && len(linkErrs) == 0: - // If packet was forwarded from another channel link - // than we should notify this link that some error - // occurred. - var failure lnwire.FailureMessage - update, err := s.cfg.FetchLastChannelUpdate( - packet.outgoingChanID, - ) - if err != nil { - failure = &lnwire.FailTemporaryNodeFailure{} - } else { - failure = lnwire.NewTemporaryChannelFailure(update) - } - - addErr := fmt.Errorf("unable to find appropriate "+ - "channel link insufficient capacity, need "+ - "%v towards node=%x", htlc.Amount, targetPeerKey) - - return s.failAddPacket(packet, failure, addErr) - // If we had a forwarding failure due to the HTLC not // satisfying the current policy, then we'll send back an // error, but ensure we send back the error sourced at the // *target* link. - case destination == nil && len(linkErrs) != 0: + if destination == nil { // At this point, some or all of the links rejected the // HTLC so we couldn't forward it. So we'll try to look // up the error that came from the source. diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index b0a53eb3..28493c83 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -1287,10 +1287,67 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) { } } +type multiHopFwdTest struct { + name string + eligible1, eligible2 bool + failure1, failure2 lnwire.FailureMessage + expectedReply lnwire.FailCode +} + // TestSkipIneligibleLinksMultiHopForward tests that if a multi-hop HTLC comes // along, then we won't attempt to froward it down al ink that isn't yet able // to forward any HTLC's. func TestSkipIneligibleLinksMultiHopForward(t *testing.T) { + tests := []multiHopFwdTest{ + // None of the channels is eligible. + { + name: "not eligible", + expectedReply: lnwire.CodeUnknownNextPeer, + }, + + // Channel one has a policy failure and the other channel isn't + // available. + { + name: "policy fail", + eligible1: true, + failure1: lnwire.NewFinalIncorrectCltvExpiry(0), + expectedReply: lnwire.CodeFinalIncorrectCltvExpiry, + }, + + // The requested channel is not eligible, but the packet is + // forwarded through the other channel. + { + name: "non-strict success", + eligible2: true, + expectedReply: lnwire.CodeNone, + }, + + // The requested channel has insufficient bandwidth and the + // other channel's policy isn't satisfied. + { + name: "non-strict policy fail", + eligible1: true, + failure1: lnwire.NewTemporaryChannelFailure(nil), + eligible2: true, + failure2: lnwire.NewFinalIncorrectCltvExpiry(0), + expectedReply: lnwire.CodeTemporaryChannelFailure, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + testSkipIneligibleLinksMultiHopForward(t, &test) + }) + } +} + +// testSkipIneligibleLinksMultiHopForward tests that if a multi-hop HTLC comes +// along, then we won't attempt to froward it down al ink that isn't yet able +// to forward any HTLC's. +func testSkipIneligibleLinksMultiHopForward(t *testing.T, + testCase *multiHopFwdTest) { + t.Parallel() var packet *htlcPacket @@ -1313,22 +1370,32 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) { } defer s.Stop() - chanID1, chanID2, aliceChanID, bobChanID := genIDs() - + chanID1, aliceChanID := genID() aliceChannelLink := newMockChannelLink( s, chanID1, aliceChanID, alicePeer, true, ) // We'll create a link for Bob, but mark the link as unable to forward // any new outgoing HTLC's. - bobChannelLink := newMockChannelLink( - s, chanID2, bobChanID, bobPeer, false, + chanID2, bobChanID2 := genID() + bobChannelLink1 := newMockChannelLink( + s, chanID2, bobChanID2, bobPeer, testCase.eligible1, ) + bobChannelLink1.checkHtlcForwardResult = testCase.failure1 + + chanID3, bobChanID3 := genID() + bobChannelLink2 := newMockChannelLink( + s, chanID3, bobChanID3, bobPeer, testCase.eligible2, + ) + bobChannelLink2.checkHtlcForwardResult = testCase.failure2 if err := s.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) } - if err := s.AddLink(bobChannelLink); err != nil { + if err := s.AddLink(bobChannelLink1); err != nil { + t.Fatalf("unable to add bob link: %v", err) + } + if err := s.AddLink(bobChannelLink2); err != nil { t.Fatalf("unable to add bob link: %v", err) } @@ -1336,21 +1403,37 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) { // Alice. preimage := [sha256.Size]byte{1} rhash := fastsha256.Sum256(preimage[:]) + obfuscator := NewMockObfuscator() packet = &htlcPacket{ incomingChanID: aliceChannelLink.ShortChanID(), incomingHTLCID: 0, - outgoingChanID: bobChannelLink.ShortChanID(), + outgoingChanID: bobChannelLink1.ShortChanID(), htlc: &lnwire.UpdateAddHTLC{ PaymentHash: rhash, Amount: 1, }, - obfuscator: NewMockObfuscator(), + obfuscator: obfuscator, } // The request to forward should fail as err = s.forward(packet) - if err == nil { - t.Fatalf("forwarding should have failed due to inactive link") + + failure := obfuscator.(*mockObfuscator).failure + if testCase.expectedReply == lnwire.CodeNone { + if err != nil { + t.Fatalf("forwarding should have succeeded") + } + if failure != nil { + t.Fatalf("unexpected failure %T", failure) + } + } else { + if err == nil { + t.Fatalf("forwarding should have failed due to " + + "inactive link") + } + if failure.Code() != testCase.expectedReply { + t.Fatalf("unexpected failure %T", failure) + } } if s.circuits.NumOpen() != 0 { @@ -1399,7 +1482,7 @@ func testSkipLinkLocalForward(t *testing.T, eligible bool, aliceChannelLink := newMockChannelLink( s, chanID1, aliceChanID, alicePeer, eligible, ) - aliceChannelLink.htlcSatifiesPolicyLocalResult = policyResult + aliceChannelLink.checkHtlcTransitResult = policyResult if err := s.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) } diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index f2f86b75..3c11c3fd 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -92,27 +92,28 @@ var ( var idSeqNum uint64 -func genIDs() (lnwire.ChannelID, lnwire.ChannelID, lnwire.ShortChannelID, - lnwire.ShortChannelID) { - - id := atomic.AddUint64(&idSeqNum, 2) +// genID generates a unique tuple to identify a test channel. +func genID() (lnwire.ChannelID, lnwire.ShortChannelID) { + id := atomic.AddUint64(&idSeqNum, 1) var scratch [8]byte binary.BigEndian.PutUint64(scratch[:], id) hash1, _ := chainhash.NewHash(bytes.Repeat(scratch[:], 4)) - binary.BigEndian.PutUint64(scratch[:], id+1) - hash2, _ := chainhash.NewHash(bytes.Repeat(scratch[:], 4)) - chanPoint1 := wire.NewOutPoint(hash1, uint32(id)) - chanPoint2 := wire.NewOutPoint(hash2, uint32(id+1)) - chanID1 := lnwire.NewChanIDFromOutPoint(chanPoint1) - chanID2 := lnwire.NewChanIDFromOutPoint(chanPoint2) - aliceChanID := lnwire.NewShortChanIDFromInt(id) - bobChanID := lnwire.NewShortChanIDFromInt(id + 1) + + return chanID1, aliceChanID +} + +// genIDs generates ids for two test channels. +func genIDs() (lnwire.ChannelID, lnwire.ChannelID, lnwire.ShortChannelID, + lnwire.ShortChannelID) { + + chanID1, aliceChanID := genID() + chanID2, bobChanID := genID() return chanID1, chanID2, aliceChanID, bobChanID }