diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index f6689919..a9dc2bc5 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -107,6 +107,14 @@ type ChannelLink interface { incomingTimeout, outgoingTimeout uint32, heightNow uint32) lnwire.FailureMessage + // HtlcSatifiesPolicyLocal should return a nil error if the passed HTLC + // details satisfy the current channel policy. Otherwise, a valid + // protocol failure message should be returned in order to signal the + // violation. This call is intended to be used for locally initiated + // payments for which there is no corresponding incoming htlc. + HtlcSatifiesPolicyLocal(payHash [32]byte, amt lnwire.MilliSatoshi, + timeout uint32, heightNow uint32) lnwire.FailureMessage + // Bandwidth returns the amount of milli-satoshis which current link // might pass through channel link. The value returned from this method // represents the up to date available flow through the channel. This diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 75a5204e..481c90c0 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -2093,76 +2093,12 @@ func (l *channelLink) HtlcSatifiesPolicy(payHash [32]byte, policy := l.cfg.FwrdingPolicy l.RUnlock() - // As our first sanity check, we'll ensure that the passed HTLC isn't - // too small for the next hop. If so, then we'll cancel the HTLC - // directly. - if amtToForward < policy.MinHTLC { - l.errorf("outgoing htlc(%x) is too small: min_htlc=%v, "+ - "htlc_value=%v", payHash[:], policy.MinHTLC, - amtToForward) - - // As part of the returned error, we'll send our latest routing - // policy so the sending node obtains the most up to date data. - var failure lnwire.FailureMessage - update, err := l.cfg.FetchLastChannelUpdate(l.ShortChanID()) - if err != nil { - failure = &lnwire.FailTemporaryNodeFailure{} - } else { - failure = lnwire.NewAmountBelowMinimum( - amtToForward, *update, - ) - } - - return failure - } - - // Next, ensure that the passed HTLC isn't too large. If so, we'll cancel - // the HTLC directly. - if policy.MaxHTLC != 0 && amtToForward > policy.MaxHTLC { - l.errorf("outgoing htlc(%x) is too large: max_htlc=%v, "+ - "htlc_value=%v", payHash[:], policy.MaxHTLC, amtToForward) - - // As part of the returned error, we'll send our latest routing policy - // so the sending node obtains the most up-to-date data. - var failure lnwire.FailureMessage - update, err := l.cfg.FetchLastChannelUpdate(l.ShortChanID()) - if err != nil { - failure = &lnwire.FailTemporaryNodeFailure{} - } else { - failure = lnwire.NewTemporaryChannelFailure(update) - } - - return failure - } - - // We want to avoid offering an HTLC which will expire in the near - // future, so we'll reject an HTLC if the outgoing expiration time is - // too close to the current height. - if outgoingTimeout <= heightNow+l.cfg.OutgoingCltvRejectDelta { - l.errorf("htlc(%x) has an expiry that's too soon: "+ - "outgoing_expiry=%v, best_height=%v", payHash[:], - outgoingTimeout, heightNow) - - var failure lnwire.FailureMessage - update, err := l.cfg.FetchLastChannelUpdate( - l.ShortChanID(), - ) - if err != nil { - failure = lnwire.NewTemporaryChannelFailure(update) - } else { - failure = lnwire.NewExpiryTooSoon(*update) - } - - return failure - } - - // Check absolute max delta. - if outgoingTimeout > maxCltvExpiry+heightNow { - l.errorf("outgoing htlc(%x) has a time lock too far in the "+ - "future: got %v, but maximum is %v", payHash[:], - outgoingTimeout-heightNow, maxCltvExpiry) - - return &lnwire.FailExpiryTooFar{} + // First check whether the outgoing htlc satisfies the channel policy. + err := l.htlcSatifiesPolicyOutgoing( + policy, payHash, amtToForward, outgoingTimeout, heightNow, + ) + if err != nil { + return err } // Next, using the amount of the incoming HTLC, we'll calculate the @@ -2225,6 +2161,105 @@ func (l *channelLink) HtlcSatifiesPolicy(payHash [32]byte, return nil } +// HtlcSatifiesPolicyLocal should return a nil error if the passed HTLC details +// satisfy the current channel policy. Otherwise, a valid protocol failure +// message should be returned in order to signal the violation. This call is +// intended to be used for locally initiated payments for which there is no +// corresponding incoming htlc. +func (l *channelLink) HtlcSatifiesPolicyLocal(payHash [32]byte, + amt lnwire.MilliSatoshi, timeout uint32, + heightNow uint32) lnwire.FailureMessage { + + l.RLock() + policy := l.cfg.FwrdingPolicy + l.RUnlock() + + return l.htlcSatifiesPolicyOutgoing( + policy, payHash, amt, timeout, heightNow, + ) +} + +// htlcSatifiesPolicyOutgoing checks whether the given htlc parameters satisfy +// the channel's amount and time lock constraints. +func (l *channelLink) htlcSatifiesPolicyOutgoing(policy ForwardingPolicy, + payHash [32]byte, amt lnwire.MilliSatoshi, timeout uint32, + heightNow uint32) lnwire.FailureMessage { + + // As our first sanity check, we'll ensure that the passed HTLC isn't + // too small for the next hop. If so, then we'll cancel the HTLC + // directly. + if amt < policy.MinHTLC { + l.errorf("outgoing htlc(%x) is too small: min_htlc=%v, "+ + "htlc_value=%v", payHash[:], policy.MinHTLC, + amt) + + // As part of the returned error, we'll send our latest routing + // policy so the sending node obtains the most up to date data. + var failure lnwire.FailureMessage + update, err := l.cfg.FetchLastChannelUpdate(l.ShortChanID()) + if err != nil { + failure = &lnwire.FailTemporaryNodeFailure{} + } else { + failure = lnwire.NewAmountBelowMinimum( + amt, *update, + ) + } + + return failure + } + + // Next, ensure that the passed HTLC isn't too large. If so, we'll cancel + // the HTLC directly. + if policy.MaxHTLC != 0 && amt > policy.MaxHTLC { + l.errorf("outgoing htlc(%x) is too large: max_htlc=%v, "+ + "htlc_value=%v", payHash[:], policy.MaxHTLC, amt) + + // As part of the returned error, we'll send our latest routing policy + // so the sending node obtains the most up-to-date data. + var failure lnwire.FailureMessage + update, err := l.cfg.FetchLastChannelUpdate(l.ShortChanID()) + if err != nil { + failure = &lnwire.FailTemporaryNodeFailure{} + } else { + failure = lnwire.NewTemporaryChannelFailure(update) + } + + return failure + } + + // We want to avoid offering an HTLC which will expire in the near + // future, so we'll reject an HTLC if the outgoing expiration time is + // too close to the current height. + if timeout <= heightNow+l.cfg.OutgoingCltvRejectDelta { + l.errorf("htlc(%x) has an expiry that's too soon: "+ + "outgoing_expiry=%v, best_height=%v", payHash[:], + timeout, heightNow) + + var failure lnwire.FailureMessage + update, err := l.cfg.FetchLastChannelUpdate( + l.ShortChanID(), + ) + if err != nil { + failure = lnwire.NewTemporaryChannelFailure(update) + } else { + failure = lnwire.NewExpiryTooSoon(*update) + } + + return failure + } + + // Check absolute max delta. + if timeout > maxCltvExpiry+heightNow { + l.errorf("outgoing htlc(%x) has a time lock too far in the "+ + "future: got %v, but maximum is %v", payHash[:], + timeout-heightNow, maxCltvExpiry) + + return &lnwire.FailExpiryTooFar{} + } + + return nil +} + // Stats returns the statistics of channel link. // // NOTE: Part of the ChannelLink interface. diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 8483c09e..7aaf1da3 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -1365,13 +1365,12 @@ func TestChannelLinkExpiryTooSoonExitNode(t *testing.T) { amount := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin) - // We'll craft an HTLC packet, but set the final hop CLTV to 3 blocks - // after the current true height. This is less or equal to the expiry - // grace delta of 3, so we expect the incoming htlc to be failed by the + // We'll craft an HTLC packet, but set the final hop CLTV to 5 blocks + // after the current true height. This is less than the test invoice + // cltv delta of 6, so we expect the incoming htlc to be failed by the // exit hop. - lastHopDelta := n.firstBobChannelLink.cfg.FwrdingPolicy.TimeLockDelta htlcAmt, totalTimelock, hops := generateHops(amount, - startingHeight+3-lastHopDelta, n.firstBobChannelLink) + startingHeight-1, n.firstBobChannelLink) // Now we'll send out the payment from Alice to Bob. firstHop := n.firstBobChannelLink.ShortChanID() diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index a4d8f1bb..bf481f75 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -612,6 +612,8 @@ type mockChannelLink struct { eligible bool htlcID uint64 + + htlcSatifiesPolicyLocalResult lnwire.FailureMessage } // completeCircuit is a helper method for adding the finalized payment circuit @@ -675,6 +677,13 @@ func (f *mockChannelLink) HtlcSatifiesPolicy([32]byte, lnwire.MilliSatoshi, return nil } +func (f *mockChannelLink) HtlcSatifiesPolicyLocal(payHash [32]byte, + amt lnwire.MilliSatoshi, timeout uint32, + heightNow uint32) lnwire.FailureMessage { + + return f.htlcSatifiesPolicyLocalResult +} + func (f *mockChannelLink) Stats() (uint64, lnwire.MilliSatoshi, lnwire.MilliSatoshi) { return 0, 0, 0 } @@ -728,7 +737,7 @@ func newDB() (*channeldb.DB, func(), error) { return cdb, cleanUp, nil } -const testInvoiceCltvExpiry = 4 +const testInvoiceCltvExpiry = 6 type mockInvoiceRegistry struct { settleChan chan lntypes.Hash diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index f48cbdc9..66beeed0 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -799,6 +799,23 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error { } } + // Ensure that the htlc satisfies the outgoing channel policy. + currentHeight := atomic.LoadUint32(&s.bestHeight) + htlcErr := link.HtlcSatifiesPolicyLocal( + htlc.PaymentHash, + htlc.Amount, + htlc.Expiry, currentHeight, + ) + if htlcErr != nil { + log.Errorf("Link %v policy for local forward not "+ + "satisfied", pkt.outgoingChanID) + + return &ForwardingError{ + ErrorSource: s.cfg.SelfKey, + FailureMessage: htlcErr, + } + } + if link.Bandwidth() < htlc.Amount { err := fmt.Errorf("Link %v has insufficient capacity: "+ "need %v, has %v", pkt.outgoingChanID, diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 59b0e4ed..a2360b0d 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -1363,6 +1363,21 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) { func TestSkipIneligibleLinksLocalForward(t *testing.T) { t.Parallel() + testSkipLinkLocalForward(t, false, nil) +} + +// TestSkipPolicyUnsatisfiedLinkLocalForward ensures that the switch will not +// attempt to send locally initiated HTLCs that would violate the channel policy +// down a link. +func TestSkipPolicyUnsatisfiedLinkLocalForward(t *testing.T) { + t.Parallel() + + testSkipLinkLocalForward(t, true, lnwire.NewTemporaryChannelFailure(nil)) +} + +func testSkipLinkLocalForward(t *testing.T, eligible bool, + policyResult lnwire.FailureMessage) { + // We'll create a single link for this test, marking it as being unable // to forward form the get go. alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) @@ -1382,8 +1397,9 @@ func TestSkipIneligibleLinksLocalForward(t *testing.T) { chanID1, _, aliceChanID, _ := genIDs() aliceChannelLink := newMockChannelLink( - s, chanID1, aliceChanID, alicePeer, false, + s, chanID1, aliceChanID, alicePeer, eligible, ) + aliceChannelLink.htlcSatifiesPolicyLocalResult = policyResult if err := s.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) } diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index 89762d75..7826021b 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -1052,7 +1052,7 @@ func (h *hopNetwork) createChannelLink(server, peer *mockServer, MinFeeUpdateTimeout: minFeeUpdateTimeout, MaxFeeUpdateTimeout: maxFeeUpdateTimeout, OnChannelFailure: func(lnwire.ChannelID, lnwire.ShortChannelID, LinkFailureError) {}, - FinalCltvRejectDelta: 3, + FinalCltvRejectDelta: 5, OutgoingCltvRejectDelta: 3, }, channel,