routing/payment_lifecycle: use ShardTracker to track shards

We'll let the payment's lifecycle register each shard it's sending with
the ShardTracker, canceling failed shards. This will be the foundation
for correct AMP derivation for each shard we'll send.
This commit is contained in:
Johan T. Halseth 2021-04-12 15:21:59 +02:00
parent 6474b253d6
commit 41ae3530a3
No known key found for this signature in database
GPG Key ID: 15BAADA29DA20D26
4 changed files with 137 additions and 46 deletions

View File

@ -194,6 +194,8 @@ func newRoute(sourceVertex route.Vertex,
} }
// Otherwise attach the mpp record if it exists. // Otherwise attach the mpp record if it exists.
// TODO(halseth): move this to payment life cycle,
// where AMP options are set.
if finalHop.paymentAddr != nil { if finalHop.paymentAddr != nil {
mpp = record.NewMPP( mpp = record.NewMPP(
finalHop.totalAmt, finalHop.totalAmt,

View File

@ -12,6 +12,7 @@ import (
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/routing/shards"
) )
// errShardHandlerExiting is returned from the shardHandler when it exits. // errShardHandlerExiting is returned from the shardHandler when it exits.
@ -25,6 +26,7 @@ type paymentLifecycle struct {
feeLimit lnwire.MilliSatoshi feeLimit lnwire.MilliSatoshi
paymentHash lntypes.Hash paymentHash lntypes.Hash
paySession PaymentSession paySession PaymentSession
shardTracker shards.ShardTracker
timeoutChan <-chan time.Time timeoutChan <-chan time.Time
currentHeight int32 currentHeight int32
} }
@ -83,10 +85,11 @@ func (p *paymentLifecycle) paymentState(payment *channeldb.MPPayment) (
// resumePayment resumes the paymentLifecycle from the current state. // resumePayment resumes the paymentLifecycle from the current state.
func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) { func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) {
shardHandler := &shardHandler{ shardHandler := &shardHandler{
router: p.router, router: p.router,
paymentHash: p.paymentHash, paymentHash: p.paymentHash,
shardErrors: make(chan error), shardTracker: p.shardTracker,
quit: make(chan struct{}), shardErrors: make(chan error),
quit: make(chan struct{}),
} }
// When the payment lifecycle loop exits, we make sure to signal any // When the payment lifecycle loop exits, we make sure to signal any
@ -246,8 +249,12 @@ lifecycle:
continue lifecycle continue lifecycle
} }
// If this route will consume the last remeining amount to send
// to the receiver, this will be our last shard (for now).
lastShard := rt.ReceiverAmt() == state.remainingAmt
// We found a route to try, launch a new shard. // We found a route to try, launch a new shard.
attempt, outcome, err := shardHandler.launchShard(rt) attempt, outcome, err := shardHandler.launchShard(rt, lastShard)
switch { switch {
// We may get a terminal error if we've processed a shard with // We may get a terminal error if we've processed a shard with
// a terminal state (settled or permanent failure), while we // a terminal state (settled or permanent failure), while we
@ -294,8 +301,9 @@ lifecycle:
// shardHandler holds what is necessary to send and collect the result of // shardHandler holds what is necessary to send and collect the result of
// shards. // shards.
type shardHandler struct { type shardHandler struct {
paymentHash lntypes.Hash paymentHash lntypes.Hash
router *ChannelRouter router *ChannelRouter
shardTracker shards.ShardTracker
// shardErrors is a channel where errors collected by calling // shardErrors is a channel where errors collected by calling
// collectResultAsync will be delivered. These results are meant to be // collectResultAsync will be delivered. These results are meant to be
@ -366,19 +374,20 @@ type launchOutcome struct {
} }
// launchShard creates and sends an HTLC attempt along the given route, // launchShard creates and sends an HTLC attempt along the given route,
// registering it with the control tower before sending it. It returns the // registering it with the control tower before sending it. The lastShard
// HTLCAttemptInfo that was created for the shard, along with a launchOutcome. // argument should be true if this shard will consume the remainder of the
// The launchOutcome is used to indicate whether the attempt was successfully // amount to send. It returns the HTLCAttemptInfo that was created for the
// sent. If the launchOutcome wraps a non-nil error, it means that the attempt // shard, along with a launchOutcome. The launchOutcome is used to indicate
// was not sent onto the network, so no result will be available in the future // whether the attempt was successfully sent. If the launchOutcome wraps a
// for it. // non-nil error, it means that the attempt was not sent onto the network, so
func (p *shardHandler) launchShard(rt *route.Route) (*channeldb.HTLCAttemptInfo, // no result will be available in the future for it.
*launchOutcome, error) { func (p *shardHandler) launchShard(rt *route.Route,
lastShard bool) (*channeldb.HTLCAttemptInfo, *launchOutcome, error) {
// Using the route received from the payment session, create a new // Using the route received from the payment session, create a new
// shard to send. // shard to send.
firstHop, htlcAdd, attempt, err := p.createNewPaymentAttempt( firstHop, htlcAdd, attempt, err := p.createNewPaymentAttempt(
rt, rt, lastShard,
) )
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -480,10 +489,17 @@ func (p *shardHandler) collectResultAsync(attempt *channeldb.HTLCAttemptInfo) {
func (p *shardHandler) collectResult(attempt *channeldb.HTLCAttemptInfo) ( func (p *shardHandler) collectResult(attempt *channeldb.HTLCAttemptInfo) (
*shardResult, error) { *shardResult, error) {
// We'll retrieve the hash specific to this shard from the
// shardTracker, since it will be needed to regenerate the circuit
// below.
hash, err := p.shardTracker.GetHash(attempt.AttemptID)
if err != nil {
return nil, err
}
// Regenerate the circuit for this attempt. // Regenerate the circuit for this attempt.
_, circuit, err := generateSphinxPacket( _, circuit, err := generateSphinxPacket(
&attempt.Route, p.paymentHash[:], &attempt.Route, hash[:], attempt.SessionKey,
attempt.SessionKey,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -597,7 +613,7 @@ func (p *shardHandler) collectResult(attempt *channeldb.HTLCAttemptInfo) (
} }
// createNewPaymentAttempt creates a new payment attempt from the given route. // createNewPaymentAttempt creates a new payment attempt from the given route.
func (p *shardHandler) createNewPaymentAttempt(rt *route.Route) ( func (p *shardHandler) createNewPaymentAttempt(rt *route.Route, lastShard bool) (
lnwire.ShortChannelID, *lnwire.UpdateAddHTLC, lnwire.ShortChannelID, *lnwire.UpdateAddHTLC,
*channeldb.HTLCAttemptInfo, error) { *channeldb.HTLCAttemptInfo, error) {
@ -607,12 +623,39 @@ func (p *shardHandler) createNewPaymentAttempt(rt *route.Route) (
return lnwire.ShortChannelID{}, nil, nil, err return lnwire.ShortChannelID{}, nil, nil, err
} }
// We generate a new, unique payment ID that we will use for
// this HTLC.
attemptID, err := p.router.cfg.NextPaymentID()
if err != nil {
return lnwire.ShortChannelID{}, nil, nil, err
}
// Requesst a new shard from the ShardTracker. If this is an AMP
// payment, and this is the last shard, the outstanding shards together
// with ths one will be enough for the receiver to derive all HTLC
// preimages. If this a non-AMP payment, the ShardTracker will return a
// simple shard with the payment's static payment hash.
shard, err := p.shardTracker.NewShard(attemptID, lastShard)
if err != nil {
return lnwire.ShortChannelID{}, nil, nil, err
}
// It this shard carries MPP or AMP options, add them to the last hop
// on the route.
hop := rt.Hops[len(rt.Hops)-1]
if shard.MPP() != nil {
hop.MPP = shard.MPP()
}
if shard.AMP() != nil {
hop.AMP = shard.AMP()
}
// Generate the raw encoded sphinx packet to be included along // Generate the raw encoded sphinx packet to be included along
// with the htlcAdd message that we send directly to the // with the htlcAdd message that we send directly to the
// switch. // switch.
onionBlob, _, err := generateSphinxPacket( hash := shard.Hash()
rt, p.paymentHash[:], sessionKey, onionBlob, _, err := generateSphinxPacket(rt, hash[:], sessionKey)
)
if err != nil { if err != nil {
return lnwire.ShortChannelID{}, nil, nil, err return lnwire.ShortChannelID{}, nil, nil, err
} }
@ -623,7 +666,7 @@ func (p *shardHandler) createNewPaymentAttempt(rt *route.Route) (
htlcAdd := &lnwire.UpdateAddHTLC{ htlcAdd := &lnwire.UpdateAddHTLC{
Amount: rt.TotalAmount, Amount: rt.TotalAmount,
Expiry: rt.TotalTimeLock, Expiry: rt.TotalTimeLock,
PaymentHash: p.paymentHash, PaymentHash: hash,
} }
copy(htlcAdd.OnionBlob[:], onionBlob) copy(htlcAdd.OnionBlob[:], onionBlob)
@ -634,13 +677,6 @@ func (p *shardHandler) createNewPaymentAttempt(rt *route.Route) (
rt.Hops[0].ChannelID, rt.Hops[0].ChannelID,
) )
// We generate a new, unique payment ID that we will use for
// this HTLC.
attemptID, err := p.router.cfg.NextPaymentID()
if err != nil {
return lnwire.ShortChannelID{}, nil, nil, err
}
// We now have all the information needed to populate // We now have all the information needed to populate
// the current attempt information. // the current attempt information.
attempt := &channeldb.HTLCAttemptInfo{ attempt := &channeldb.HTLCAttemptInfo{
@ -722,6 +758,13 @@ func (p *shardHandler) failAttempt(attempt *channeldb.HTLCAttemptInfo,
p.router.cfg.Clock.Now(), p.router.cfg.Clock.Now(),
) )
// Now that we are failing this payment attempt, cancel the shard with
// the ShardTracker such that it can derive the correct hash for the
// next attempt.
if err := p.shardTracker.CancelShard(attempt.AttemptID); err != nil {
return nil, err
}
return p.router.cfg.Control.FailAttempt( return p.router.cfg.Control.FailAttempt(
p.paymentHash, attempt.AttemptID, p.paymentHash, attempt.AttemptID,
failInfo, failInfo,

View File

@ -29,6 +29,7 @@ import (
"github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing/chainview" "github.com/lightningnetwork/lnd/routing/chainview"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/routing/shards"
"github.com/lightningnetwork/lnd/ticker" "github.com/lightningnetwork/lnd/ticker"
"github.com/lightningnetwork/lnd/zpay32" "github.com/lightningnetwork/lnd/zpay32"
) )
@ -603,19 +604,40 @@ func (r *ChannelRouter) Start() error {
go func(payment *channeldb.MPPayment) { go func(payment *channeldb.MPPayment) {
defer r.wg.Done() defer r.wg.Done()
// Get the hashes used for the outstanding HTLCs.
htlcs := make(map[uint64]lntypes.Hash)
for _, a := range payment.HTLCs {
a := a
hash := payment.Info.PaymentHash
htlcs[a.AttemptID] = hash
}
// Since we are not supporting creating more shards
// after a restart (only receiving the result of the
// shards already outstanding), we create a simple
// shard tracker that will map the attempt IDs to
// hashes used for the HTLCs. This will be enough also
// for AMP payments, since we only need the hashes for
// the individual HTLCs to regenerate the circuits, and
// we don't currently persist the root share necessary
// to re-derive them.
shardTracker := shards.NewSimpleShardTracker(
payment.Info.PaymentHash, htlcs,
)
// We create a dummy, empty payment session such that // We create a dummy, empty payment session such that
// we won't make another payment attempt when the // we won't make another payment attempt when the
// result for the in-flight attempt is received. // result for the in-flight attempt is received.
paySession := r.cfg.SessionSource.NewPaymentSessionEmpty() paySession := r.cfg.SessionSource.NewPaymentSessionEmpty()
// We pass in a zero timeout value, to indicate we // We pass in a zero timeout value, to indicate we
// don't need it to timeout. It will stop immediately // don't need it to timeout. It will stop immediately
// after the existing attempt has finished anyway. We // after the existing attempt has finished anyway. We
// also set a zero fee limit, as no more routes should // also set a zero fee limit, as no more routes should
// be tried. // be tried.
_, _, err := r.sendPayment( _, _, err := r.sendPayment(
payment.Info.Value, 0, payment.Info.Value, 0, payment.Info.PaymentHash,
payment.Info.PaymentHash, 0, paySession, 0, paySession, shardTracker,
) )
if err != nil { if err != nil {
log.Errorf("Resuming payment with hash %v "+ log.Errorf("Resuming payment with hash %v "+
@ -1770,7 +1792,7 @@ type LightningPayment struct {
func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte, func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte,
*route.Route, error) { *route.Route, error) {
paySession, err := r.preparePayment(payment) paySession, shardTracker, err := r.preparePayment(payment)
if err != nil { if err != nil {
return [32]byte{}, nil, err return [32]byte{}, nil, err
} }
@ -1782,14 +1804,14 @@ func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte,
// for the existing attempt. // for the existing attempt.
return r.sendPayment( return r.sendPayment(
payment.Amount, payment.FeeLimit, payment.PaymentHash, payment.Amount, payment.FeeLimit, payment.PaymentHash,
payment.PayAttemptTimeout, paySession, payment.PayAttemptTimeout, paySession, shardTracker,
) )
} }
// SendPaymentAsync is the non-blocking version of SendPayment. The payment // SendPaymentAsync is the non-blocking version of SendPayment. The payment
// result needs to be retrieved via the control tower. // result needs to be retrieved via the control tower.
func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment) error { func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment) error {
paySession, err := r.preparePayment(payment) paySession, shardTracker, err := r.preparePayment(payment)
if err != nil { if err != nil {
return err return err
} }
@ -1805,7 +1827,7 @@ func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment) error {
_, _, err := r.sendPayment( _, _, err := r.sendPayment(
payment.Amount, payment.FeeLimit, payment.PaymentHash, payment.Amount, payment.FeeLimit, payment.PaymentHash,
payment.PayAttemptTimeout, paySession, payment.PayAttemptTimeout, paySession, shardTracker,
) )
if err != nil { if err != nil {
log.Errorf("Payment with hash %x failed: %v", log.Errorf("Payment with hash %x failed: %v",
@ -1841,14 +1863,14 @@ func spewPayment(payment *LightningPayment) logClosure {
// preparePayment creates the payment session and registers the payment with the // preparePayment creates the payment session and registers the payment with the
// control tower. // control tower.
func (r *ChannelRouter) preparePayment(payment *LightningPayment) ( func (r *ChannelRouter) preparePayment(payment *LightningPayment) (
PaymentSession, error) { PaymentSession, shards.ShardTracker, error) {
// Before starting the HTLC routing attempt, we'll create a fresh // Before starting the HTLC routing attempt, we'll create a fresh
// payment session which will report our errors back to mission // payment session which will report our errors back to mission
// control. // control.
paySession, err := r.cfg.SessionSource.NewPaymentSession(payment) paySession, err := r.cfg.SessionSource.NewPaymentSession(payment)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
// Record this payment hash with the ControlTower, ensuring it is not // Record this payment hash with the ControlTower, ensuring it is not
@ -1862,12 +1884,18 @@ func (r *ChannelRouter) preparePayment(payment *LightningPayment) (
PaymentRequest: payment.PaymentRequest, PaymentRequest: payment.PaymentRequest,
} }
// Create a new ShardTracker that we'll use during the life cycle of
// this payment.
shardTracker := shards.NewSimpleShardTracker(
payment.PaymentHash, nil,
)
err = r.cfg.Control.InitPayment(payment.PaymentHash, info) err = r.cfg.Control.InitPayment(payment.PaymentHash, info)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
return paySession, nil return paySession, shardTracker, nil
} }
// SendToRoute attempts to send a payment with the given hash through the // SendToRoute attempts to send a payment with the given hash through the
@ -1915,14 +1943,22 @@ func (r *ChannelRouter) SendToRoute(hash lntypes.Hash, rt *route.Route) (
}), }),
) )
// Since the HTLC hashes and preimages are specified manually over the
// RPC for SendToRoute requests, we don't have to worry about creating
// a ShardTracker that can generate hashes for AMP payments. Instead we
// create a simple tracker that can just return the hash for the single
// shard we'll now launch.
shardTracker := shards.NewSimpleShardTracker(hash, nil)
// Launch a shard along the given route. // Launch a shard along the given route.
sh := &shardHandler{ sh := &shardHandler{
router: r, router: r,
paymentHash: hash, paymentHash: hash,
shardTracker: shardTracker,
} }
var shardError error var shardError error
attempt, outcome, err := sh.launchShard(rt) attempt, outcome, err := sh.launchShard(rt, false)
// With SendToRoute, it can happen that the route exceeds protocol // With SendToRoute, it can happen that the route exceeds protocol
// constraints. Mark the payment as failed with an internal error. // constraints. Mark the payment as failed with an internal error.
@ -2007,8 +2043,8 @@ func (r *ChannelRouter) SendToRoute(hash lntypes.Hash, rt *route.Route) (
// the ControlTower. // the ControlTower.
func (r *ChannelRouter) sendPayment( func (r *ChannelRouter) sendPayment(
totalAmt, feeLimit lnwire.MilliSatoshi, paymentHash lntypes.Hash, totalAmt, feeLimit lnwire.MilliSatoshi, paymentHash lntypes.Hash,
timeout time.Duration, timeout time.Duration, paySession PaymentSession,
paySession PaymentSession) ([32]byte, *route.Route, error) { shardTracker shards.ShardTracker) ([32]byte, *route.Route, error) {
// We'll also fetch the current block height so we can properly // We'll also fetch the current block height so we can properly
// calculate the required HTLC time locks within the route. // calculate the required HTLC time locks within the route.
@ -2025,6 +2061,7 @@ func (r *ChannelRouter) sendPayment(
feeLimit: feeLimit, feeLimit: feeLimit,
paymentHash: paymentHash, paymentHash: paymentHash,
paySession: paySession, paySession: paySession,
shardTracker: shardTracker,
currentHeight: currentHeight, currentHeight: currentHeight,
} }

View File

@ -2,6 +2,7 @@ package shards
import ( import (
"fmt" "fmt"
"sync"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/record"
@ -74,6 +75,7 @@ func (s *Shard) AMP() *record.AMP {
type SimpleShardTracker struct { type SimpleShardTracker struct {
hash lntypes.Hash hash lntypes.Hash
shards map[uint64]lntypes.Hash shards map[uint64]lntypes.Hash
sync.Mutex
} }
// A compile time check to ensure SimpleShardTracker implements the // A compile time check to ensure SimpleShardTracker implements the
@ -100,7 +102,9 @@ func NewSimpleShardTracker(paymentHash lntypes.Hash,
// if it ends up not being used by the overall payment, i.e. if the attempt // if it ends up not being used by the overall payment, i.e. if the attempt
// fails. // fails.
func (m *SimpleShardTracker) NewShard(id uint64, _ bool) (PaymentShard, error) { func (m *SimpleShardTracker) NewShard(id uint64, _ bool) (PaymentShard, error) {
m.Lock()
m.shards[id] = m.hash m.shards[id] = m.hash
m.Unlock()
return &Shard{ return &Shard{
hash: m.hash, hash: m.hash,
@ -109,14 +113,19 @@ func (m *SimpleShardTracker) NewShard(id uint64, _ bool) (PaymentShard, error) {
// CancelShard cancel's the shard corresponding to the given attempt ID. // CancelShard cancel's the shard corresponding to the given attempt ID.
func (m *SimpleShardTracker) CancelShard(id uint64) error { func (m *SimpleShardTracker) CancelShard(id uint64) error {
m.Lock()
delete(m.shards, id) delete(m.shards, id)
m.Unlock()
return nil return nil
} }
// GetHash retrieves the hash used by the shard of the given attempt ID. This // GetHash retrieves the hash used by the shard of the given attempt ID. This
// will return an error if the attempt ID is unknown. // will return an error if the attempt ID is unknown.
func (m *SimpleShardTracker) GetHash(id uint64) (lntypes.Hash, error) { func (m *SimpleShardTracker) GetHash(id uint64) (lntypes.Hash, error) {
m.Lock()
hash, ok := m.shards[id] hash, ok := m.shards[id]
m.Unlock()
if !ok { if !ok {
return lntypes.Hash{}, fmt.Errorf("hash for attempt id %v "+ return lntypes.Hash{}, fmt.Errorf("hash for attempt id %v "+
"not found", id) "not found", id)