Merge pull request #3547 from joostjager/non-strict-error

htlcswitch: fix non-strict forwarding failures
This commit is contained in:
Joost Jager 2019-10-23 11:42:41 +02:00 committed by GitHub
commit 3cfd1ebb03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 250 additions and 206 deletions

@ -100,22 +100,21 @@ type ChannelLink interface {
// policy to govern if it an incoming HTLC should be forwarded or not. // policy to govern if it an incoming HTLC should be forwarded or not.
UpdateForwardingPolicy(ForwardingPolicy) UpdateForwardingPolicy(ForwardingPolicy)
// HtlcSatifiesPolicy should return a nil error if the passed HTLC // CheckHtlcForward should return a nil error if the passed HTLC details
// details satisfy the current forwarding policy fo the target link. // satisfy the current forwarding policy fo the target link. Otherwise,
// Otherwise, a valid protocol failure message should be returned in // a valid protocol failure message should be returned in order to
// order to signal to the source of the HTLC, the policy consistency // signal to the source of the HTLC, the policy consistency issue.
// issue. CheckHtlcForward(payHash [32]byte, incomingAmt lnwire.MilliSatoshi,
HtlcSatifiesPolicy(payHash [32]byte, incomingAmt lnwire.MilliSatoshi,
amtToForward lnwire.MilliSatoshi, amtToForward lnwire.MilliSatoshi,
incomingTimeout, outgoingTimeout uint32, incomingTimeout, outgoingTimeout uint32,
heightNow uint32) lnwire.FailureMessage heightNow uint32) lnwire.FailureMessage
// HtlcSatifiesPolicyLocal should return a nil error if the passed HTLC // CheckHtlcTransit should return a nil error if the passed HTLC details
// details satisfy the current channel policy. Otherwise, a valid // satisfy the current channel policy. Otherwise, a valid protocol
// protocol failure message should be returned in order to signal the // failure message should be returned in order to signal the violation.
// violation. This call is intended to be used for locally initiated // This call is intended to be used for locally initiated payments for
// payments for which there is no corresponding incoming htlc. // which there is no corresponding incoming htlc.
HtlcSatifiesPolicyLocal(payHash [32]byte, amt lnwire.MilliSatoshi, CheckHtlcTransit(payHash [32]byte, amt lnwire.MilliSatoshi,
timeout uint32, heightNow uint32) lnwire.FailureMessage timeout uint32, heightNow uint32) lnwire.FailureMessage
// Bandwidth returns the amount of milli-satoshis which current link // Bandwidth returns the amount of milli-satoshis which current link

@ -605,6 +605,19 @@ func shouldAdjustCommitFee(netFee, chanFee lnwallet.SatPerKWeight) bool {
} }
} }
// createFailureWithUpdate retrieves this link's last channel update message and
// passes it into the callback. It expects a fully populated failure message.
func (l *channelLink) createFailureWithUpdate(
cb func(update *lnwire.ChannelUpdate) lnwire.FailureMessage) lnwire.FailureMessage {
update, err := l.cfg.FetchLastChannelUpdate(l.ShortChanID())
if err != nil {
return &lnwire.FailTemporaryNodeFailure{}
}
return cb(update)
}
// syncChanState attempts to synchronize channel states with the remote party. // syncChanState attempts to synchronize channel states with the remote party.
// This method is to be called upon reconnection after the initial funding // This method is to be called upon reconnection after the initial funding
// flow. We'll compare out commitment chains with the remote party, and re-send // flow. We'll compare out commitment chains with the remote party, and re-send
@ -1312,17 +1325,13 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) {
reason lnwire.OpaqueReason reason lnwire.OpaqueReason
) )
var failure lnwire.FailureMessage failure := l.createFailureWithUpdate(
update, err := l.cfg.FetchLastChannelUpdate( func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage {
l.ShortChanID(), return lnwire.NewTemporaryChannelFailure(
upd,
)
},
) )
if err != nil {
failure = &lnwire.FailTemporaryNodeFailure{}
} else {
failure = lnwire.NewTemporaryChannelFailure(
update,
)
}
// Encrypt the error back to the source unless // Encrypt the error back to the source unless
// the payment was generated locally. // the payment was generated locally.
@ -2175,13 +2184,13 @@ func (l *channelLink) UpdateForwardingPolicy(newPolicy ForwardingPolicy) {
l.cfg.FwrdingPolicy = newPolicy l.cfg.FwrdingPolicy = newPolicy
} }
// HtlcSatifiesPolicy should return a nil error if the passed HTLC details // CheckHtlcForward should return a nil error if the passed HTLC details satisfy
// satisfy the current forwarding policy fo the target link. Otherwise, a // the current forwarding policy fo the target link. Otherwise, a valid
// valid protocol failure message should be returned in order to signal to the // protocol failure message should be returned in order to signal to the source
// source of the HTLC, the policy consistency issue. // of the HTLC, the policy consistency issue.
// //
// NOTE: Part of the ChannelLink interface. // NOTE: Part of the ChannelLink interface.
func (l *channelLink) HtlcSatifiesPolicy(payHash [32]byte, func (l *channelLink) CheckHtlcForward(payHash [32]byte,
incomingHtlcAmt, amtToForward lnwire.MilliSatoshi, incomingHtlcAmt, amtToForward lnwire.MilliSatoshi,
incomingTimeout, outgoingTimeout uint32, incomingTimeout, outgoingTimeout uint32,
heightNow uint32) lnwire.FailureMessage { heightNow uint32) lnwire.FailureMessage {
@ -2191,7 +2200,7 @@ func (l *channelLink) HtlcSatifiesPolicy(payHash [32]byte,
l.RUnlock() l.RUnlock()
// First check whether the outgoing htlc satisfies the channel policy. // First check whether the outgoing htlc satisfies the channel policy.
err := l.htlcSatifiesPolicyOutgoing( err := l.canSendHtlc(
policy, payHash, amtToForward, outgoingTimeout, heightNow, policy, payHash, amtToForward, outgoingTimeout, heightNow,
) )
if err != nil { if err != nil {
@ -2216,17 +2225,14 @@ func (l *channelLink) HtlcSatifiesPolicy(payHash [32]byte,
// As part of the returned error, we'll send our latest routing // As part of the returned error, we'll send our latest routing
// policy so the sending node obtains the most up to date data. // policy so the sending node obtains the most up to date data.
var failure lnwire.FailureMessage
update, err := l.cfg.FetchLastChannelUpdate(l.ShortChanID())
if err != nil {
failure = &lnwire.FailTemporaryNodeFailure{}
} else {
failure = lnwire.NewFeeInsufficient(
amtToForward, *update,
)
}
return failure return l.createFailureWithUpdate(
func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage {
return lnwire.NewFeeInsufficient(
amtToForward, *upd,
)
},
)
} }
// Finally, we'll ensure that the time-lock on the outgoing HTLC meets // Finally, we'll ensure that the time-lock on the outgoing HTLC meets
@ -2241,30 +2247,24 @@ func (l *channelLink) HtlcSatifiesPolicy(payHash [32]byte,
// Grab the latest routing policy so the sending node is up to // Grab the latest routing policy so the sending node is up to
// date with our current policy. // date with our current policy.
var failure lnwire.FailureMessage return l.createFailureWithUpdate(
update, err := l.cfg.FetchLastChannelUpdate( func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage {
l.ShortChanID(), return lnwire.NewIncorrectCltvExpiry(
incomingTimeout, *upd,
)
},
) )
if err != nil {
failure = lnwire.NewTemporaryChannelFailure(update)
} else {
failure = lnwire.NewIncorrectCltvExpiry(
incomingTimeout, *update,
)
}
return failure
} }
return nil return nil
} }
// HtlcSatifiesPolicyLocal should return a nil error if the passed HTLC details // CheckHtlcTransit should return a nil error if the passed HTLC details satisfy the
// satisfy the current channel policy. Otherwise, a valid protocol failure // current channel policy. Otherwise, a valid protocol failure message should
// message should be returned in order to signal the violation. This call is // be returned in order to signal the violation. This call is intended to be
// intended to be used for locally initiated payments for which there is no // used for locally initiated payments for which there is no corresponding
// corresponding incoming htlc. // incoming htlc.
func (l *channelLink) HtlcSatifiesPolicyLocal(payHash [32]byte, func (l *channelLink) CheckHtlcTransit(payHash [32]byte,
amt lnwire.MilliSatoshi, timeout uint32, amt lnwire.MilliSatoshi, timeout uint32,
heightNow uint32) lnwire.FailureMessage { heightNow uint32) lnwire.FailureMessage {
@ -2272,14 +2272,14 @@ func (l *channelLink) HtlcSatifiesPolicyLocal(payHash [32]byte,
policy := l.cfg.FwrdingPolicy policy := l.cfg.FwrdingPolicy
l.RUnlock() l.RUnlock()
return l.htlcSatifiesPolicyOutgoing( return l.canSendHtlc(
policy, payHash, amt, timeout, heightNow, policy, payHash, amt, timeout, heightNow,
) )
} }
// htlcSatifiesPolicyOutgoing checks whether the given htlc parameters satisfy // htlcSatifiesPolicyOutgoing checks whether the given htlc parameters satisfy
// the channel's amount and time lock constraints. // the channel's amount and time lock constraints.
func (l *channelLink) htlcSatifiesPolicyOutgoing(policy ForwardingPolicy, func (l *channelLink) canSendHtlc(policy ForwardingPolicy,
payHash [32]byte, amt lnwire.MilliSatoshi, timeout uint32, payHash [32]byte, amt lnwire.MilliSatoshi, timeout uint32,
heightNow uint32) lnwire.FailureMessage { heightNow uint32) lnwire.FailureMessage {
@ -2293,36 +2293,28 @@ func (l *channelLink) htlcSatifiesPolicyOutgoing(policy ForwardingPolicy,
// As part of the returned error, we'll send our latest routing // As part of the returned error, we'll send our latest routing
// policy so the sending node obtains the most up to date data. // policy so the sending node obtains the most up to date data.
var failure lnwire.FailureMessage return l.createFailureWithUpdate(
update, err := l.cfg.FetchLastChannelUpdate(l.ShortChanID()) func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage {
if err != nil { return lnwire.NewAmountBelowMinimum(
failure = &lnwire.FailTemporaryNodeFailure{} amt, *upd,
} else { )
failure = lnwire.NewAmountBelowMinimum( },
amt, *update, )
)
}
return failure
} }
// Next, ensure that the passed HTLC isn't too large. If so, we'll cancel // Next, ensure that the passed HTLC isn't too large. If so, we'll
// the HTLC directly. // cancel the HTLC directly.
if policy.MaxHTLC != 0 && amt > policy.MaxHTLC { if policy.MaxHTLC != 0 && amt > policy.MaxHTLC {
l.log.Errorf("outgoing htlc(%x) is too large: max_htlc=%v, "+ l.log.Errorf("outgoing htlc(%x) is too large: max_htlc=%v, "+
"htlc_value=%v", payHash[:], policy.MaxHTLC, amt) "htlc_value=%v", payHash[:], policy.MaxHTLC, amt)
// As part of the returned error, we'll send our latest routing policy // As part of the returned error, we'll send our latest routing
// so the sending node obtains the most up-to-date data. // policy so the sending node obtains the most up-to-date data.
var failure lnwire.FailureMessage return l.createFailureWithUpdate(
update, err := l.cfg.FetchLastChannelUpdate(l.ShortChanID()) func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage {
if err != nil { return lnwire.NewTemporaryChannelFailure(upd)
failure = &lnwire.FailTemporaryNodeFailure{} },
} else { )
failure = lnwire.NewTemporaryChannelFailure(update)
}
return failure
} }
// We want to avoid offering an HTLC which will expire in the near // We want to avoid offering an HTLC which will expire in the near
@ -2333,17 +2325,11 @@ func (l *channelLink) htlcSatifiesPolicyOutgoing(policy ForwardingPolicy,
"outgoing_expiry=%v, best_height=%v", payHash[:], "outgoing_expiry=%v, best_height=%v", payHash[:],
timeout, heightNow) timeout, heightNow)
var failure lnwire.FailureMessage return l.createFailureWithUpdate(
update, err := l.cfg.FetchLastChannelUpdate( func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage {
l.ShortChanID(), return lnwire.NewExpiryTooSoon(*upd)
},
) )
if err != nil {
failure = lnwire.NewTemporaryChannelFailure(update)
} else {
failure = lnwire.NewExpiryTooSoon(*update)
}
return failure
} }
// Check absolute max delta. // Check absolute max delta.
@ -2355,6 +2341,15 @@ func (l *channelLink) htlcSatifiesPolicyOutgoing(policy ForwardingPolicy,
return &lnwire.FailExpiryTooFar{} return &lnwire.FailExpiryTooFar{}
} }
// Check to see if there is enough balance in this channel.
if amt > l.Bandwidth() {
return l.createFailureWithUpdate(
func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage {
return lnwire.NewTemporaryChannelFailure(upd)
},
)
}
return nil return nil
} }
@ -2764,17 +2759,13 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
l.log.Errorf("unable to encode the "+ l.log.Errorf("unable to encode the "+
"remaining route %v", err) "remaining route %v", err)
var failure lnwire.FailureMessage failure := l.createFailureWithUpdate(
update, err := l.cfg.FetchLastChannelUpdate( func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage {
l.ShortChanID(), return lnwire.NewTemporaryChannelFailure(
upd,
)
},
) )
if err != nil {
failure = &lnwire.FailTemporaryNodeFailure{}
} else {
failure = lnwire.NewTemporaryChannelFailure(
update,
)
}
l.sendHTLCError( l.sendHTLCError(
pd.HtlcIndex, failure, obfuscator, pd.SourceRef, pd.HtlcIndex, failure, obfuscator, pd.SourceRef,

@ -5396,9 +5396,9 @@ func TestForwardingAsymmetricTimeLockPolicies(t *testing.T) {
} }
} }
// TestHtlcSatisfyPolicy tests that a link is properly enforcing the HTLC // TestCheckHtlcForward tests that a link is properly enforcing the HTLC
// forwarding policy. // forwarding policy.
func TestHtlcSatisfyPolicy(t *testing.T) { func TestCheckHtlcForward(t *testing.T) {
fetchLastChannelUpdate := func(lnwire.ShortChannelID) ( fetchLastChannelUpdate := func(lnwire.ShortChannelID) (
*lnwire.ChannelUpdate, error) { *lnwire.ChannelUpdate, error) {
@ -5406,6 +5406,15 @@ func TestHtlcSatisfyPolicy(t *testing.T) {
return &lnwire.ChannelUpdate{}, nil return &lnwire.ChannelUpdate{}, nil
} }
testChannel, _, fCleanUp, err := createTestChannel(
alicePrivKey, bobPrivKey, 100000, 100000,
1000, 1000, lnwire.ShortChannelID{},
)
if err != nil {
t.Fatal(err)
}
defer fCleanUp()
link := channelLink{ link := channelLink{
cfg: ChannelLinkConfig{ cfg: ChannelLinkConfig{
FwrdingPolicy: ForwardingPolicy{ FwrdingPolicy: ForwardingPolicy{
@ -5417,13 +5426,15 @@ func TestHtlcSatisfyPolicy(t *testing.T) {
FetchLastChannelUpdate: fetchLastChannelUpdate, FetchLastChannelUpdate: fetchLastChannelUpdate,
MaxOutgoingCltvExpiry: DefaultMaxOutgoingCltvExpiry, MaxOutgoingCltvExpiry: DefaultMaxOutgoingCltvExpiry,
}, },
log: log, log: log,
channel: testChannel.channel,
overflowQueue: newPacketQueue(input.MaxHTLCNumber / 2),
} }
var hash [32]byte var hash [32]byte
t.Run("satisfied", func(t *testing.T) { t.Run("satisfied", func(t *testing.T) {
result := link.HtlcSatifiesPolicy(hash, 1500, 1000, result := link.CheckHtlcForward(hash, 1500, 1000,
200, 150, 0) 200, 150, 0)
if result != nil { if result != nil {
t.Fatalf("expected policy to be satisfied") t.Fatalf("expected policy to be satisfied")
@ -5431,7 +5442,7 @@ func TestHtlcSatisfyPolicy(t *testing.T) {
}) })
t.Run("below minhtlc", func(t *testing.T) { t.Run("below minhtlc", func(t *testing.T) {
result := link.HtlcSatifiesPolicy(hash, 100, 50, result := link.CheckHtlcForward(hash, 100, 50,
200, 150, 0) 200, 150, 0)
if _, ok := result.(*lnwire.FailAmountBelowMinimum); !ok { if _, ok := result.(*lnwire.FailAmountBelowMinimum); !ok {
t.Fatalf("expected FailAmountBelowMinimum failure code") t.Fatalf("expected FailAmountBelowMinimum failure code")
@ -5439,7 +5450,7 @@ func TestHtlcSatisfyPolicy(t *testing.T) {
}) })
t.Run("above maxhtlc", func(t *testing.T) { t.Run("above maxhtlc", func(t *testing.T) {
result := link.HtlcSatifiesPolicy(hash, 1500, 1200, result := link.CheckHtlcForward(hash, 1500, 1200,
200, 150, 0) 200, 150, 0)
if _, ok := result.(*lnwire.FailTemporaryChannelFailure); !ok { if _, ok := result.(*lnwire.FailTemporaryChannelFailure); !ok {
t.Fatalf("expected FailTemporaryChannelFailure failure code") t.Fatalf("expected FailTemporaryChannelFailure failure code")
@ -5447,7 +5458,7 @@ func TestHtlcSatisfyPolicy(t *testing.T) {
}) })
t.Run("insufficient fee", func(t *testing.T) { t.Run("insufficient fee", func(t *testing.T) {
result := link.HtlcSatifiesPolicy(hash, 1005, 1000, result := link.CheckHtlcForward(hash, 1005, 1000,
200, 150, 0) 200, 150, 0)
if _, ok := result.(*lnwire.FailFeeInsufficient); !ok { if _, ok := result.(*lnwire.FailFeeInsufficient); !ok {
t.Fatalf("expected FailFeeInsufficient failure code") t.Fatalf("expected FailFeeInsufficient failure code")
@ -5455,7 +5466,7 @@ func TestHtlcSatisfyPolicy(t *testing.T) {
}) })
t.Run("expiry too soon", func(t *testing.T) { t.Run("expiry too soon", func(t *testing.T) {
result := link.HtlcSatifiesPolicy(hash, 1500, 1000, result := link.CheckHtlcForward(hash, 1500, 1000,
200, 150, 190) 200, 150, 190)
if _, ok := result.(*lnwire.FailExpiryTooSoon); !ok { if _, ok := result.(*lnwire.FailExpiryTooSoon); !ok {
t.Fatalf("expected FailExpiryTooSoon failure code") t.Fatalf("expected FailExpiryTooSoon failure code")
@ -5463,7 +5474,7 @@ func TestHtlcSatisfyPolicy(t *testing.T) {
}) })
t.Run("incorrect cltv expiry", func(t *testing.T) { t.Run("incorrect cltv expiry", func(t *testing.T) {
result := link.HtlcSatifiesPolicy(hash, 1500, 1000, result := link.CheckHtlcForward(hash, 1500, 1000,
200, 190, 0) 200, 190, 0)
if _, ok := result.(*lnwire.FailIncorrectCltvExpiry); !ok { if _, ok := result.(*lnwire.FailIncorrectCltvExpiry); !ok {
t.Fatalf("expected FailIncorrectCltvExpiry failure code") t.Fatalf("expected FailIncorrectCltvExpiry failure code")
@ -5473,7 +5484,7 @@ func TestHtlcSatisfyPolicy(t *testing.T) {
t.Run("cltv expiry too far in the future", func(t *testing.T) { t.Run("cltv expiry too far in the future", func(t *testing.T) {
// Check that expiry isn't too far in the future. // Check that expiry isn't too far in the future.
result := link.HtlcSatifiesPolicy(hash, 1500, 1000, result := link.CheckHtlcForward(hash, 1500, 1000,
10200, 10100, 0) 10200, 10100, 0)
if _, ok := result.(*lnwire.FailExpiryTooFar); !ok { if _, ok := result.(*lnwire.FailExpiryTooFar); !ok {
t.Fatalf("expected FailExpiryTooFar failure code") t.Fatalf("expected FailExpiryTooFar failure code")

@ -334,6 +334,7 @@ var _ hop.Iterator = (*mockHopIterator)(nil)
// encodes the failure and do not makes any onion obfuscation. // encodes the failure and do not makes any onion obfuscation.
type mockObfuscator struct { type mockObfuscator struct {
ogPacket *sphinx.OnionPacket ogPacket *sphinx.OnionPacket
failure lnwire.FailureMessage
} }
// NewMockObfuscator initializes a dummy mockObfuscator used for testing. // NewMockObfuscator initializes a dummy mockObfuscator used for testing.
@ -366,6 +367,8 @@ func (o *mockObfuscator) Reextract(
func (o *mockObfuscator) EncryptFirstHop(failure lnwire.FailureMessage) ( func (o *mockObfuscator) EncryptFirstHop(failure lnwire.FailureMessage) (
lnwire.OpaqueReason, error) { lnwire.OpaqueReason, error) {
o.failure = failure
var b bytes.Buffer var b bytes.Buffer
if err := lnwire.EncodeFailure(&b, failure, 0); err != nil { if err := lnwire.EncodeFailure(&b, failure, 0); err != nil {
return nil, err return nil, err
@ -637,7 +640,9 @@ type mockChannelLink struct {
htlcID uint64 htlcID uint64
htlcSatifiesPolicyLocalResult lnwire.FailureMessage checkHtlcTransitResult lnwire.FailureMessage
checkHtlcForwardResult lnwire.FailureMessage
} }
// completeCircuit is a helper method for adding the finalized payment circuit // completeCircuit is a helper method for adding the finalized payment circuit
@ -696,16 +701,17 @@ func (f *mockChannelLink) HandleChannelUpdate(lnwire.Message) {
func (f *mockChannelLink) UpdateForwardingPolicy(_ ForwardingPolicy) { func (f *mockChannelLink) UpdateForwardingPolicy(_ ForwardingPolicy) {
} }
func (f *mockChannelLink) HtlcSatifiesPolicy([32]byte, lnwire.MilliSatoshi, func (f *mockChannelLink) CheckHtlcForward([32]byte, lnwire.MilliSatoshi,
lnwire.MilliSatoshi, uint32, uint32, uint32) lnwire.FailureMessage { lnwire.MilliSatoshi, uint32, uint32, uint32) lnwire.FailureMessage {
return nil
return f.checkHtlcForwardResult
} }
func (f *mockChannelLink) HtlcSatifiesPolicyLocal(payHash [32]byte, func (f *mockChannelLink) CheckHtlcTransit(payHash [32]byte,
amt lnwire.MilliSatoshi, timeout uint32, amt lnwire.MilliSatoshi, timeout uint32,
heightNow uint32) lnwire.FailureMessage { heightNow uint32) lnwire.FailureMessage {
return f.htlcSatifiesPolicyLocalResult return f.checkHtlcTransitResult
} }
func (f *mockChannelLink) Stats() (uint64, lnwire.MilliSatoshi, lnwire.MilliSatoshi) { func (f *mockChannelLink) Stats() (uint64, lnwire.MilliSatoshi, lnwire.MilliSatoshi) {

@ -775,7 +775,7 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error {
// Ensure that the htlc satisfies the outgoing channel policy. // Ensure that the htlc satisfies the outgoing channel policy.
currentHeight := atomic.LoadUint32(&s.bestHeight) currentHeight := atomic.LoadUint32(&s.bestHeight)
htlcErr := link.HtlcSatifiesPolicyLocal( htlcErr := link.CheckHtlcTransit(
htlc.PaymentHash, htlc.PaymentHash,
htlc.Amount, htlc.Amount,
htlc.Expiry, currentHeight, htlc.Expiry, currentHeight,
@ -790,22 +790,6 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error {
} }
} }
if link.Bandwidth() < htlc.Amount {
err := fmt.Errorf("Link %v has insufficient capacity: "+
"need %v, has %v", pkt.outgoingChanID,
htlc.Amount, link.Bandwidth())
log.Error(err)
// The update does not need to be populated as the error
// will be returned back to the router.
htlcErr := lnwire.NewTemporaryChannelFailure(nil)
return &ForwardingError{
FailureSourceIdx: 0,
ExtraMsg: err.Error(),
FailureMessage: htlcErr,
}
}
return link.HandleSwitchPacket(pkt) return link.HandleSwitchPacket(pkt)
} }
@ -1034,69 +1018,38 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
// bandwidth. // bandwidth.
var destination ChannelLink var destination ChannelLink
for _, link := range interfaceLinks { for _, link := range interfaceLinks {
var failure lnwire.FailureMessage
// We'll skip any links that aren't yet eligible for // We'll skip any links that aren't yet eligible for
// forwarding. // forwarding.
switch { if !link.EligibleToForward() {
case !link.EligibleToForward(): failure = &lnwire.FailUnknownNextPeer{}
continue } else {
// We'll ensure that the HTLC satisfies the
// If the link doesn't yet have a source chan ID, then // current forwarding conditions of this target
// we'll skip it as well. // link.
case link.ShortChanID() == hop.Source: currentHeight := atomic.LoadUint32(&s.bestHeight)
continue failure = link.CheckHtlcForward(
htlc.PaymentHash, packet.incomingAmount,
packet.amount, packet.incomingTimeout,
packet.outgoingTimeout, currentHeight,
)
} }
// Before we check the link's bandwidth, we'll ensure // Stop searching if this link can forward the htlc.
// that the HTLC satisfies the current forwarding if failure == nil {
// policy of this target link.
currentHeight := atomic.LoadUint32(&s.bestHeight)
err := link.HtlcSatifiesPolicy(
htlc.PaymentHash, packet.incomingAmount,
packet.amount, packet.incomingTimeout,
packet.outgoingTimeout, currentHeight,
)
if err != nil {
linkErrs[link.ShortChanID()] = err
continue
}
if link.Bandwidth() >= htlc.Amount {
destination = link destination = link
break break
} }
linkErrs[link.ShortChanID()] = failure
} }
switch {
// If the channel link we're attempting to forward the update
// over has insufficient capacity, and didn't violate any
// forwarding policies, then we'll cancel the htlc as the
// payment cannot succeed.
case destination == nil && len(linkErrs) == 0:
// If packet was forwarded from another channel link
// than we should notify this link that some error
// occurred.
var failure lnwire.FailureMessage
update, err := s.cfg.FetchLastChannelUpdate(
packet.outgoingChanID,
)
if err != nil {
failure = &lnwire.FailTemporaryNodeFailure{}
} else {
failure = lnwire.NewTemporaryChannelFailure(update)
}
addErr := fmt.Errorf("unable to find appropriate "+
"channel link insufficient capacity, need "+
"%v towards node=%x", htlc.Amount, targetPeerKey)
return s.failAddPacket(packet, failure, addErr)
// If we had a forwarding failure due to the HTLC not // If we had a forwarding failure due to the HTLC not
// satisfying the current policy, then we'll send back an // satisfying the current policy, then we'll send back an
// error, but ensure we send back the error sourced at the // error, but ensure we send back the error sourced at the
// *target* link. // *target* link.
case destination == nil && len(linkErrs) != 0: if destination == nil {
// At this point, some or all of the links rejected the // At this point, some or all of the links rejected the
// HTLC so we couldn't forward it. So we'll try to look // HTLC so we couldn't forward it. So we'll try to look
// up the error that came from the source. // up the error that came from the source.

@ -1287,10 +1287,67 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) {
} }
} }
type multiHopFwdTest struct {
name string
eligible1, eligible2 bool
failure1, failure2 lnwire.FailureMessage
expectedReply lnwire.FailCode
}
// TestSkipIneligibleLinksMultiHopForward tests that if a multi-hop HTLC comes // TestSkipIneligibleLinksMultiHopForward tests that if a multi-hop HTLC comes
// along, then we won't attempt to froward it down al ink that isn't yet able // along, then we won't attempt to froward it down al ink that isn't yet able
// to forward any HTLC's. // to forward any HTLC's.
func TestSkipIneligibleLinksMultiHopForward(t *testing.T) { func TestSkipIneligibleLinksMultiHopForward(t *testing.T) {
tests := []multiHopFwdTest{
// None of the channels is eligible.
{
name: "not eligible",
expectedReply: lnwire.CodeUnknownNextPeer,
},
// Channel one has a policy failure and the other channel isn't
// available.
{
name: "policy fail",
eligible1: true,
failure1: lnwire.NewFinalIncorrectCltvExpiry(0),
expectedReply: lnwire.CodeFinalIncorrectCltvExpiry,
},
// The requested channel is not eligible, but the packet is
// forwarded through the other channel.
{
name: "non-strict success",
eligible2: true,
expectedReply: lnwire.CodeNone,
},
// The requested channel has insufficient bandwidth and the
// other channel's policy isn't satisfied.
{
name: "non-strict policy fail",
eligible1: true,
failure1: lnwire.NewTemporaryChannelFailure(nil),
eligible2: true,
failure2: lnwire.NewFinalIncorrectCltvExpiry(0),
expectedReply: lnwire.CodeTemporaryChannelFailure,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
testSkipIneligibleLinksMultiHopForward(t, &test)
})
}
}
// testSkipIneligibleLinksMultiHopForward tests that if a multi-hop HTLC comes
// along, then we won't attempt to froward it down al ink that isn't yet able
// to forward any HTLC's.
func testSkipIneligibleLinksMultiHopForward(t *testing.T,
testCase *multiHopFwdTest) {
t.Parallel() t.Parallel()
var packet *htlcPacket var packet *htlcPacket
@ -1313,22 +1370,32 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) {
} }
defer s.Stop() defer s.Stop()
chanID1, chanID2, aliceChanID, bobChanID := genIDs() chanID1, aliceChanID := genID()
aliceChannelLink := newMockChannelLink( aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true, s, chanID1, aliceChanID, alicePeer, true,
) )
// We'll create a link for Bob, but mark the link as unable to forward // We'll create a link for Bob, but mark the link as unable to forward
// any new outgoing HTLC's. // any new outgoing HTLC's.
bobChannelLink := newMockChannelLink( chanID2, bobChanID2 := genID()
s, chanID2, bobChanID, bobPeer, false, bobChannelLink1 := newMockChannelLink(
s, chanID2, bobChanID2, bobPeer, testCase.eligible1,
) )
bobChannelLink1.checkHtlcForwardResult = testCase.failure1
chanID3, bobChanID3 := genID()
bobChannelLink2 := newMockChannelLink(
s, chanID3, bobChanID3, bobPeer, testCase.eligible2,
)
bobChannelLink2.checkHtlcForwardResult = testCase.failure2
if err := s.AddLink(aliceChannelLink); err != nil { if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err) t.Fatalf("unable to add alice link: %v", err)
} }
if err := s.AddLink(bobChannelLink); err != nil { if err := s.AddLink(bobChannelLink1); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if err := s.AddLink(bobChannelLink2); err != nil {
t.Fatalf("unable to add bob link: %v", err) t.Fatalf("unable to add bob link: %v", err)
} }
@ -1336,21 +1403,37 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) {
// Alice. // Alice.
preimage := [sha256.Size]byte{1} preimage := [sha256.Size]byte{1}
rhash := fastsha256.Sum256(preimage[:]) rhash := fastsha256.Sum256(preimage[:])
obfuscator := NewMockObfuscator()
packet = &htlcPacket{ packet = &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(), incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0, incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(), outgoingChanID: bobChannelLink1.ShortChanID(),
htlc: &lnwire.UpdateAddHTLC{ htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash, PaymentHash: rhash,
Amount: 1, Amount: 1,
}, },
obfuscator: NewMockObfuscator(), obfuscator: obfuscator,
} }
// The request to forward should fail as // The request to forward should fail as
err = s.forward(packet) err = s.forward(packet)
if err == nil {
t.Fatalf("forwarding should have failed due to inactive link") failure := obfuscator.(*mockObfuscator).failure
if testCase.expectedReply == lnwire.CodeNone {
if err != nil {
t.Fatalf("forwarding should have succeeded")
}
if failure != nil {
t.Fatalf("unexpected failure %T", failure)
}
} else {
if err == nil {
t.Fatalf("forwarding should have failed due to " +
"inactive link")
}
if failure.Code() != testCase.expectedReply {
t.Fatalf("unexpected failure %T", failure)
}
} }
if s.circuits.NumOpen() != 0 { if s.circuits.NumOpen() != 0 {
@ -1399,7 +1482,7 @@ func testSkipLinkLocalForward(t *testing.T, eligible bool,
aliceChannelLink := newMockChannelLink( aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, eligible, s, chanID1, aliceChanID, alicePeer, eligible,
) )
aliceChannelLink.htlcSatifiesPolicyLocalResult = policyResult aliceChannelLink.checkHtlcTransitResult = policyResult
if err := s.AddLink(aliceChannelLink); err != nil { if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err) t.Fatalf("unable to add alice link: %v", err)
} }

@ -92,27 +92,28 @@ var (
var idSeqNum uint64 var idSeqNum uint64
func genIDs() (lnwire.ChannelID, lnwire.ChannelID, lnwire.ShortChannelID, // genID generates a unique tuple to identify a test channel.
lnwire.ShortChannelID) { func genID() (lnwire.ChannelID, lnwire.ShortChannelID) {
id := atomic.AddUint64(&idSeqNum, 1)
id := atomic.AddUint64(&idSeqNum, 2)
var scratch [8]byte var scratch [8]byte
binary.BigEndian.PutUint64(scratch[:], id) binary.BigEndian.PutUint64(scratch[:], id)
hash1, _ := chainhash.NewHash(bytes.Repeat(scratch[:], 4)) hash1, _ := chainhash.NewHash(bytes.Repeat(scratch[:], 4))
binary.BigEndian.PutUint64(scratch[:], id+1)
hash2, _ := chainhash.NewHash(bytes.Repeat(scratch[:], 4))
chanPoint1 := wire.NewOutPoint(hash1, uint32(id)) chanPoint1 := wire.NewOutPoint(hash1, uint32(id))
chanPoint2 := wire.NewOutPoint(hash2, uint32(id+1))
chanID1 := lnwire.NewChanIDFromOutPoint(chanPoint1) chanID1 := lnwire.NewChanIDFromOutPoint(chanPoint1)
chanID2 := lnwire.NewChanIDFromOutPoint(chanPoint2)
aliceChanID := lnwire.NewShortChanIDFromInt(id) aliceChanID := lnwire.NewShortChanIDFromInt(id)
bobChanID := lnwire.NewShortChanIDFromInt(id + 1)
return chanID1, aliceChanID
}
// genIDs generates ids for two test channels.
func genIDs() (lnwire.ChannelID, lnwire.ChannelID, lnwire.ShortChannelID,
lnwire.ShortChannelID) {
chanID1, aliceChanID := genID()
chanID2, bobChanID := genID()
return chanID1, chanID2, aliceChanID, bobChanID return chanID1, chanID2, aliceChanID, bobChanID
} }