diff --git a/routing/mock_test.go b/routing/mock_test.go index 8186902f..57586805 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -4,6 +4,7 @@ import ( "fmt" "sync" + "github.com/btcsuite/btcd/btcec" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/htlcswitch" @@ -167,6 +168,18 @@ func (m *mockPaymentSession) RequestRoute(_, _ lnwire.MilliSatoshi, return r, nil } +func (m *mockPaymentSession) UpdateAdditionalEdge(_ *lnwire.ChannelUpdate, + _ *btcec.PublicKey, _ *channeldb.ChannelEdgePolicy) bool { + + return false +} + +func (m *mockPaymentSession) GetAdditionalEdgePolicy(_ *btcec.PublicKey, + _ uint64) *channeldb.ChannelEdgePolicy { + + return nil +} + type mockPayer struct { sendResult chan error paymentResult chan *htlcswitch.PaymentResult diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index b55766e3..155b0f37 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -91,6 +91,7 @@ func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) { shardTracker: p.shardTracker, shardErrors: make(chan error), quit: make(chan struct{}), + paySession: p.paySession, } // When the payment lifecycle loop exits, we make sure to signal any @@ -305,6 +306,7 @@ type shardHandler struct { identifier lntypes.Hash router *ChannelRouter shardTracker shards.ShardTracker + paySession PaymentSession // shardErrors is a channel where errors collected by calling // collectResultAsync will be delivered. These results are meant to be @@ -855,12 +857,42 @@ func (p *shardHandler) handleFailureMessage(rt *route.Route, return err } - // Apply channel update. + var ( + isAdditionalEdge bool + policy *channeldb.ChannelEdgePolicy + ) + + // Before we apply the channel update, we need to decide whether the + // update is for additional (ephemeral) edge or normal edge stored in + // db. + // + // Note: the p.paySession might be nil here if it's called inside + // SendToRoute where there's no payment lifecycle. + if p.paySession != nil { + policy = p.paySession.GetAdditionalEdgePolicy( + errSource, update.ShortChannelID.ToUint64(), + ) + if policy != nil { + isAdditionalEdge = true + } + } + + // Apply channel update to additional edge policy. + if isAdditionalEdge { + if !p.paySession.UpdateAdditionalEdge( + update, errSource, policy) { + + log.Debugf("Invalid channel update received: node=%v", + errVertex) + } + return nil + } + + // Apply channel update to the channel edge policy in our db. if !p.router.applyChannelUpdate(update, errSource) { log.Debugf("Invalid channel update received: node=%v", errVertex) } - return nil } diff --git a/routing/payment_session.go b/routing/payment_session.go index d2c02255..22e88090 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -138,6 +138,19 @@ type PaymentSession interface { // during path finding. RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, activeShards, height uint32) (*route.Route, error) + + // UpdateAdditionalEdge takes an additional channel edge policy + // (private channels) and applies the update from the message. Returns + // a boolean to indicate whether the update has been applied without + // error. + UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, pubKey *btcec.PublicKey, + policy *channeldb.ChannelEdgePolicy) bool + + // GetAdditionalEdgePolicy uses the public key and channel ID to query + // the ephemeral channel edge policy for additional edges. Returns a nil + // if nothing found. + GetAdditionalEdgePolicy(pubKey *btcec.PublicKey, + channelID uint64) *channeldb.ChannelEdgePolicy } // paymentSession is used during an HTLC routings session to prune the local