diff --git a/routing/pathfind.go b/routing/pathfind.go index 387e7150..3d722c82 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -194,6 +194,8 @@ func newRoute(sourceVertex route.Vertex, } // Otherwise attach the mpp record if it exists. + // TODO(halseth): move this to payment life cycle, + // where AMP options are set. if finalHop.paymentAddr != nil { mpp = record.NewMPP( finalHop.totalAmt, diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index be9ade0d..1760488a 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -12,6 +12,7 @@ import ( "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/routing/shards" ) // errShardHandlerExiting is returned from the shardHandler when it exits. @@ -25,6 +26,7 @@ type paymentLifecycle struct { feeLimit lnwire.MilliSatoshi paymentHash lntypes.Hash paySession PaymentSession + shardTracker shards.ShardTracker timeoutChan <-chan time.Time currentHeight int32 } @@ -83,10 +85,11 @@ func (p *paymentLifecycle) paymentState(payment *channeldb.MPPayment) ( // resumePayment resumes the paymentLifecycle from the current state. func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) { shardHandler := &shardHandler{ - router: p.router, - paymentHash: p.paymentHash, - shardErrors: make(chan error), - quit: make(chan struct{}), + router: p.router, + paymentHash: p.paymentHash, + shardTracker: p.shardTracker, + shardErrors: make(chan error), + quit: make(chan struct{}), } // When the payment lifecycle loop exits, we make sure to signal any @@ -246,8 +249,12 @@ lifecycle: continue lifecycle } + // If this route will consume the last remeining amount to send + // to the receiver, this will be our last shard (for now). + lastShard := rt.ReceiverAmt() == state.remainingAmt + // We found a route to try, launch a new shard. - attempt, outcome, err := shardHandler.launchShard(rt) + attempt, outcome, err := shardHandler.launchShard(rt, lastShard) switch { // We may get a terminal error if we've processed a shard with // a terminal state (settled or permanent failure), while we @@ -294,8 +301,9 @@ lifecycle: // shardHandler holds what is necessary to send and collect the result of // shards. type shardHandler struct { - paymentHash lntypes.Hash - router *ChannelRouter + paymentHash lntypes.Hash + router *ChannelRouter + shardTracker shards.ShardTracker // shardErrors is a channel where errors collected by calling // collectResultAsync will be delivered. These results are meant to be @@ -366,19 +374,20 @@ type launchOutcome struct { } // launchShard creates and sends an HTLC attempt along the given route, -// registering it with the control tower before sending it. It returns the -// HTLCAttemptInfo that was created for the shard, along with a launchOutcome. -// The launchOutcome is used to indicate whether the attempt was successfully -// sent. If the launchOutcome wraps a non-nil error, it means that the attempt -// was not sent onto the network, so no result will be available in the future -// for it. -func (p *shardHandler) launchShard(rt *route.Route) (*channeldb.HTLCAttemptInfo, - *launchOutcome, error) { +// registering it with the control tower before sending it. The lastShard +// argument should be true if this shard will consume the remainder of the +// amount to send. It returns the HTLCAttemptInfo that was created for the +// shard, along with a launchOutcome. The launchOutcome is used to indicate +// whether the attempt was successfully sent. If the launchOutcome wraps a +// non-nil error, it means that the attempt was not sent onto the network, so +// no result will be available in the future for it. +func (p *shardHandler) launchShard(rt *route.Route, + lastShard bool) (*channeldb.HTLCAttemptInfo, *launchOutcome, error) { // Using the route received from the payment session, create a new // shard to send. firstHop, htlcAdd, attempt, err := p.createNewPaymentAttempt( - rt, + rt, lastShard, ) if err != nil { return nil, nil, err @@ -480,10 +489,17 @@ func (p *shardHandler) collectResultAsync(attempt *channeldb.HTLCAttemptInfo) { func (p *shardHandler) collectResult(attempt *channeldb.HTLCAttemptInfo) ( *shardResult, error) { + // We'll retrieve the hash specific to this shard from the + // shardTracker, since it will be needed to regenerate the circuit + // below. + hash, err := p.shardTracker.GetHash(attempt.AttemptID) + if err != nil { + return nil, err + } + // Regenerate the circuit for this attempt. _, circuit, err := generateSphinxPacket( - &attempt.Route, p.paymentHash[:], - attempt.SessionKey, + &attempt.Route, hash[:], attempt.SessionKey, ) if err != nil { return nil, err @@ -597,7 +613,7 @@ func (p *shardHandler) collectResult(attempt *channeldb.HTLCAttemptInfo) ( } // createNewPaymentAttempt creates a new payment attempt from the given route. -func (p *shardHandler) createNewPaymentAttempt(rt *route.Route) ( +func (p *shardHandler) createNewPaymentAttempt(rt *route.Route, lastShard bool) ( lnwire.ShortChannelID, *lnwire.UpdateAddHTLC, *channeldb.HTLCAttemptInfo, error) { @@ -607,12 +623,39 @@ func (p *shardHandler) createNewPaymentAttempt(rt *route.Route) ( return lnwire.ShortChannelID{}, nil, nil, err } + // We generate a new, unique payment ID that we will use for + // this HTLC. + attemptID, err := p.router.cfg.NextPaymentID() + if err != nil { + return lnwire.ShortChannelID{}, nil, nil, err + } + + // Requesst a new shard from the ShardTracker. If this is an AMP + // payment, and this is the last shard, the outstanding shards together + // with ths one will be enough for the receiver to derive all HTLC + // preimages. If this a non-AMP payment, the ShardTracker will return a + // simple shard with the payment's static payment hash. + shard, err := p.shardTracker.NewShard(attemptID, lastShard) + if err != nil { + return lnwire.ShortChannelID{}, nil, nil, err + } + + // It this shard carries MPP or AMP options, add them to the last hop + // on the route. + hop := rt.Hops[len(rt.Hops)-1] + if shard.MPP() != nil { + hop.MPP = shard.MPP() + } + + if shard.AMP() != nil { + hop.AMP = shard.AMP() + } + // Generate the raw encoded sphinx packet to be included along // with the htlcAdd message that we send directly to the // switch. - onionBlob, _, err := generateSphinxPacket( - rt, p.paymentHash[:], sessionKey, - ) + hash := shard.Hash() + onionBlob, _, err := generateSphinxPacket(rt, hash[:], sessionKey) if err != nil { return lnwire.ShortChannelID{}, nil, nil, err } @@ -623,7 +666,7 @@ func (p *shardHandler) createNewPaymentAttempt(rt *route.Route) ( htlcAdd := &lnwire.UpdateAddHTLC{ Amount: rt.TotalAmount, Expiry: rt.TotalTimeLock, - PaymentHash: p.paymentHash, + PaymentHash: hash, } copy(htlcAdd.OnionBlob[:], onionBlob) @@ -634,13 +677,6 @@ func (p *shardHandler) createNewPaymentAttempt(rt *route.Route) ( rt.Hops[0].ChannelID, ) - // We generate a new, unique payment ID that we will use for - // this HTLC. - attemptID, err := p.router.cfg.NextPaymentID() - if err != nil { - return lnwire.ShortChannelID{}, nil, nil, err - } - // We now have all the information needed to populate // the current attempt information. attempt := &channeldb.HTLCAttemptInfo{ @@ -722,6 +758,13 @@ func (p *shardHandler) failAttempt(attempt *channeldb.HTLCAttemptInfo, p.router.cfg.Clock.Now(), ) + // Now that we are failing this payment attempt, cancel the shard with + // the ShardTracker such that it can derive the correct hash for the + // next attempt. + if err := p.shardTracker.CancelShard(attempt.AttemptID); err != nil { + return nil, err + } + return p.router.cfg.Control.FailAttempt( p.paymentHash, attempt.AttemptID, failInfo, diff --git a/routing/router.go b/routing/router.go index 72b17dcd..7f54f915 100644 --- a/routing/router.go +++ b/routing/router.go @@ -29,6 +29,7 @@ import ( "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/chainview" "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/routing/shards" "github.com/lightningnetwork/lnd/ticker" "github.com/lightningnetwork/lnd/zpay32" ) @@ -603,19 +604,40 @@ func (r *ChannelRouter) Start() error { go func(payment *channeldb.MPPayment) { defer r.wg.Done() + // Get the hashes used for the outstanding HTLCs. + htlcs := make(map[uint64]lntypes.Hash) + for _, a := range payment.HTLCs { + a := a + + hash := payment.Info.PaymentHash + htlcs[a.AttemptID] = hash + } + + // Since we are not supporting creating more shards + // after a restart (only receiving the result of the + // shards already outstanding), we create a simple + // shard tracker that will map the attempt IDs to + // hashes used for the HTLCs. This will be enough also + // for AMP payments, since we only need the hashes for + // the individual HTLCs to regenerate the circuits, and + // we don't currently persist the root share necessary + // to re-derive them. + shardTracker := shards.NewSimpleShardTracker( + payment.Info.PaymentHash, htlcs, + ) + // We create a dummy, empty payment session such that // we won't make another payment attempt when the // result for the in-flight attempt is received. paySession := r.cfg.SessionSource.NewPaymentSessionEmpty() - // We pass in a zero timeout value, to indicate we // don't need it to timeout. It will stop immediately // after the existing attempt has finished anyway. We // also set a zero fee limit, as no more routes should // be tried. _, _, err := r.sendPayment( - payment.Info.Value, 0, - payment.Info.PaymentHash, 0, paySession, + payment.Info.Value, 0, payment.Info.PaymentHash, + 0, paySession, shardTracker, ) if err != nil { log.Errorf("Resuming payment with hash %v "+ @@ -1770,7 +1792,7 @@ type LightningPayment struct { func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte, *route.Route, error) { - paySession, err := r.preparePayment(payment) + paySession, shardTracker, err := r.preparePayment(payment) if err != nil { return [32]byte{}, nil, err } @@ -1782,14 +1804,14 @@ func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte, // for the existing attempt. return r.sendPayment( payment.Amount, payment.FeeLimit, payment.PaymentHash, - payment.PayAttemptTimeout, paySession, + payment.PayAttemptTimeout, paySession, shardTracker, ) } // SendPaymentAsync is the non-blocking version of SendPayment. The payment // result needs to be retrieved via the control tower. func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment) error { - paySession, err := r.preparePayment(payment) + paySession, shardTracker, err := r.preparePayment(payment) if err != nil { return err } @@ -1805,7 +1827,7 @@ func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment) error { _, _, err := r.sendPayment( payment.Amount, payment.FeeLimit, payment.PaymentHash, - payment.PayAttemptTimeout, paySession, + payment.PayAttemptTimeout, paySession, shardTracker, ) if err != nil { log.Errorf("Payment with hash %x failed: %v", @@ -1841,14 +1863,14 @@ func spewPayment(payment *LightningPayment) logClosure { // preparePayment creates the payment session and registers the payment with the // control tower. func (r *ChannelRouter) preparePayment(payment *LightningPayment) ( - PaymentSession, error) { + PaymentSession, shards.ShardTracker, error) { // Before starting the HTLC routing attempt, we'll create a fresh // payment session which will report our errors back to mission // control. paySession, err := r.cfg.SessionSource.NewPaymentSession(payment) if err != nil { - return nil, err + return nil, nil, err } // Record this payment hash with the ControlTower, ensuring it is not @@ -1862,12 +1884,18 @@ func (r *ChannelRouter) preparePayment(payment *LightningPayment) ( PaymentRequest: payment.PaymentRequest, } + // Create a new ShardTracker that we'll use during the life cycle of + // this payment. + shardTracker := shards.NewSimpleShardTracker( + payment.PaymentHash, nil, + ) + err = r.cfg.Control.InitPayment(payment.PaymentHash, info) if err != nil { - return nil, err + return nil, nil, err } - return paySession, nil + return paySession, shardTracker, nil } // SendToRoute attempts to send a payment with the given hash through the @@ -1915,14 +1943,22 @@ func (r *ChannelRouter) SendToRoute(hash lntypes.Hash, rt *route.Route) ( }), ) + // Since the HTLC hashes and preimages are specified manually over the + // RPC for SendToRoute requests, we don't have to worry about creating + // a ShardTracker that can generate hashes for AMP payments. Instead we + // create a simple tracker that can just return the hash for the single + // shard we'll now launch. + shardTracker := shards.NewSimpleShardTracker(hash, nil) + // Launch a shard along the given route. sh := &shardHandler{ - router: r, - paymentHash: hash, + router: r, + paymentHash: hash, + shardTracker: shardTracker, } var shardError error - attempt, outcome, err := sh.launchShard(rt) + attempt, outcome, err := sh.launchShard(rt, false) // With SendToRoute, it can happen that the route exceeds protocol // constraints. Mark the payment as failed with an internal error. @@ -2007,8 +2043,8 @@ func (r *ChannelRouter) SendToRoute(hash lntypes.Hash, rt *route.Route) ( // the ControlTower. func (r *ChannelRouter) sendPayment( totalAmt, feeLimit lnwire.MilliSatoshi, paymentHash lntypes.Hash, - timeout time.Duration, - paySession PaymentSession) ([32]byte, *route.Route, error) { + timeout time.Duration, paySession PaymentSession, + shardTracker shards.ShardTracker) ([32]byte, *route.Route, error) { // We'll also fetch the current block height so we can properly // calculate the required HTLC time locks within the route. @@ -2025,6 +2061,7 @@ func (r *ChannelRouter) sendPayment( feeLimit: feeLimit, paymentHash: paymentHash, paySession: paySession, + shardTracker: shardTracker, currentHeight: currentHeight, } diff --git a/routing/shards/shard_tracker.go b/routing/shards/shard_tracker.go index 474c85bc..a41461eb 100644 --- a/routing/shards/shard_tracker.go +++ b/routing/shards/shard_tracker.go @@ -2,6 +2,7 @@ package shards import ( "fmt" + "sync" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/record" @@ -74,6 +75,7 @@ func (s *Shard) AMP() *record.AMP { type SimpleShardTracker struct { hash lntypes.Hash shards map[uint64]lntypes.Hash + sync.Mutex } // A compile time check to ensure SimpleShardTracker implements the @@ -100,7 +102,9 @@ func NewSimpleShardTracker(paymentHash lntypes.Hash, // if it ends up not being used by the overall payment, i.e. if the attempt // fails. func (m *SimpleShardTracker) NewShard(id uint64, _ bool) (PaymentShard, error) { + m.Lock() m.shards[id] = m.hash + m.Unlock() return &Shard{ hash: m.hash, @@ -109,14 +113,19 @@ func (m *SimpleShardTracker) NewShard(id uint64, _ bool) (PaymentShard, error) { // CancelShard cancel's the shard corresponding to the given attempt ID. func (m *SimpleShardTracker) CancelShard(id uint64) error { + m.Lock() delete(m.shards, id) + m.Unlock() + return nil } // GetHash retrieves the hash used by the shard of the given attempt ID. This // will return an error if the attempt ID is unknown. func (m *SimpleShardTracker) GetHash(id uint64) (lntypes.Hash, error) { + m.Lock() hash, ok := m.shards[id] + m.Unlock() if !ok { return lntypes.Hash{}, fmt.Errorf("hash for attempt id %v "+ "not found", id)