routing/payment_lifecycle: use ShardTracker to track shards
We'll let the payment's lifecycle register each shard it's sending with the ShardTracker, canceling failed shards. This will be the foundation for correct AMP derivation for each shard we'll send.
This commit is contained in:
parent
6474b253d6
commit
41ae3530a3
@ -194,6 +194,8 @@ func newRoute(sourceVertex route.Vertex,
|
||||
}
|
||||
|
||||
// Otherwise attach the mpp record if it exists.
|
||||
// TODO(halseth): move this to payment life cycle,
|
||||
// where AMP options are set.
|
||||
if finalHop.paymentAddr != nil {
|
||||
mpp = record.NewMPP(
|
||||
finalHop.totalAmt,
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
"github.com/lightningnetwork/lnd/routing/shards"
|
||||
)
|
||||
|
||||
// errShardHandlerExiting is returned from the shardHandler when it exits.
|
||||
@ -25,6 +26,7 @@ type paymentLifecycle struct {
|
||||
feeLimit lnwire.MilliSatoshi
|
||||
paymentHash lntypes.Hash
|
||||
paySession PaymentSession
|
||||
shardTracker shards.ShardTracker
|
||||
timeoutChan <-chan time.Time
|
||||
currentHeight int32
|
||||
}
|
||||
@ -85,6 +87,7 @@ func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) {
|
||||
shardHandler := &shardHandler{
|
||||
router: p.router,
|
||||
paymentHash: p.paymentHash,
|
||||
shardTracker: p.shardTracker,
|
||||
shardErrors: make(chan error),
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
@ -246,8 +249,12 @@ lifecycle:
|
||||
continue lifecycle
|
||||
}
|
||||
|
||||
// If this route will consume the last remeining amount to send
|
||||
// to the receiver, this will be our last shard (for now).
|
||||
lastShard := rt.ReceiverAmt() == state.remainingAmt
|
||||
|
||||
// We found a route to try, launch a new shard.
|
||||
attempt, outcome, err := shardHandler.launchShard(rt)
|
||||
attempt, outcome, err := shardHandler.launchShard(rt, lastShard)
|
||||
switch {
|
||||
// We may get a terminal error if we've processed a shard with
|
||||
// a terminal state (settled or permanent failure), while we
|
||||
@ -296,6 +303,7 @@ lifecycle:
|
||||
type shardHandler struct {
|
||||
paymentHash lntypes.Hash
|
||||
router *ChannelRouter
|
||||
shardTracker shards.ShardTracker
|
||||
|
||||
// shardErrors is a channel where errors collected by calling
|
||||
// collectResultAsync will be delivered. These results are meant to be
|
||||
@ -366,19 +374,20 @@ type launchOutcome struct {
|
||||
}
|
||||
|
||||
// launchShard creates and sends an HTLC attempt along the given route,
|
||||
// registering it with the control tower before sending it. It returns the
|
||||
// HTLCAttemptInfo that was created for the shard, along with a launchOutcome.
|
||||
// The launchOutcome is used to indicate whether the attempt was successfully
|
||||
// sent. If the launchOutcome wraps a non-nil error, it means that the attempt
|
||||
// was not sent onto the network, so no result will be available in the future
|
||||
// for it.
|
||||
func (p *shardHandler) launchShard(rt *route.Route) (*channeldb.HTLCAttemptInfo,
|
||||
*launchOutcome, error) {
|
||||
// registering it with the control tower before sending it. The lastShard
|
||||
// argument should be true if this shard will consume the remainder of the
|
||||
// amount to send. It returns the HTLCAttemptInfo that was created for the
|
||||
// shard, along with a launchOutcome. The launchOutcome is used to indicate
|
||||
// whether the attempt was successfully sent. If the launchOutcome wraps a
|
||||
// non-nil error, it means that the attempt was not sent onto the network, so
|
||||
// no result will be available in the future for it.
|
||||
func (p *shardHandler) launchShard(rt *route.Route,
|
||||
lastShard bool) (*channeldb.HTLCAttemptInfo, *launchOutcome, error) {
|
||||
|
||||
// Using the route received from the payment session, create a new
|
||||
// shard to send.
|
||||
firstHop, htlcAdd, attempt, err := p.createNewPaymentAttempt(
|
||||
rt,
|
||||
rt, lastShard,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@ -480,10 +489,17 @@ func (p *shardHandler) collectResultAsync(attempt *channeldb.HTLCAttemptInfo) {
|
||||
func (p *shardHandler) collectResult(attempt *channeldb.HTLCAttemptInfo) (
|
||||
*shardResult, error) {
|
||||
|
||||
// We'll retrieve the hash specific to this shard from the
|
||||
// shardTracker, since it will be needed to regenerate the circuit
|
||||
// below.
|
||||
hash, err := p.shardTracker.GetHash(attempt.AttemptID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Regenerate the circuit for this attempt.
|
||||
_, circuit, err := generateSphinxPacket(
|
||||
&attempt.Route, p.paymentHash[:],
|
||||
attempt.SessionKey,
|
||||
&attempt.Route, hash[:], attempt.SessionKey,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -597,7 +613,7 @@ func (p *shardHandler) collectResult(attempt *channeldb.HTLCAttemptInfo) (
|
||||
}
|
||||
|
||||
// createNewPaymentAttempt creates a new payment attempt from the given route.
|
||||
func (p *shardHandler) createNewPaymentAttempt(rt *route.Route) (
|
||||
func (p *shardHandler) createNewPaymentAttempt(rt *route.Route, lastShard bool) (
|
||||
lnwire.ShortChannelID, *lnwire.UpdateAddHTLC,
|
||||
*channeldb.HTLCAttemptInfo, error) {
|
||||
|
||||
@ -607,12 +623,39 @@ func (p *shardHandler) createNewPaymentAttempt(rt *route.Route) (
|
||||
return lnwire.ShortChannelID{}, nil, nil, err
|
||||
}
|
||||
|
||||
// We generate a new, unique payment ID that we will use for
|
||||
// this HTLC.
|
||||
attemptID, err := p.router.cfg.NextPaymentID()
|
||||
if err != nil {
|
||||
return lnwire.ShortChannelID{}, nil, nil, err
|
||||
}
|
||||
|
||||
// Requesst a new shard from the ShardTracker. If this is an AMP
|
||||
// payment, and this is the last shard, the outstanding shards together
|
||||
// with ths one will be enough for the receiver to derive all HTLC
|
||||
// preimages. If this a non-AMP payment, the ShardTracker will return a
|
||||
// simple shard with the payment's static payment hash.
|
||||
shard, err := p.shardTracker.NewShard(attemptID, lastShard)
|
||||
if err != nil {
|
||||
return lnwire.ShortChannelID{}, nil, nil, err
|
||||
}
|
||||
|
||||
// It this shard carries MPP or AMP options, add them to the last hop
|
||||
// on the route.
|
||||
hop := rt.Hops[len(rt.Hops)-1]
|
||||
if shard.MPP() != nil {
|
||||
hop.MPP = shard.MPP()
|
||||
}
|
||||
|
||||
if shard.AMP() != nil {
|
||||
hop.AMP = shard.AMP()
|
||||
}
|
||||
|
||||
// Generate the raw encoded sphinx packet to be included along
|
||||
// with the htlcAdd message that we send directly to the
|
||||
// switch.
|
||||
onionBlob, _, err := generateSphinxPacket(
|
||||
rt, p.paymentHash[:], sessionKey,
|
||||
)
|
||||
hash := shard.Hash()
|
||||
onionBlob, _, err := generateSphinxPacket(rt, hash[:], sessionKey)
|
||||
if err != nil {
|
||||
return lnwire.ShortChannelID{}, nil, nil, err
|
||||
}
|
||||
@ -623,7 +666,7 @@ func (p *shardHandler) createNewPaymentAttempt(rt *route.Route) (
|
||||
htlcAdd := &lnwire.UpdateAddHTLC{
|
||||
Amount: rt.TotalAmount,
|
||||
Expiry: rt.TotalTimeLock,
|
||||
PaymentHash: p.paymentHash,
|
||||
PaymentHash: hash,
|
||||
}
|
||||
copy(htlcAdd.OnionBlob[:], onionBlob)
|
||||
|
||||
@ -634,13 +677,6 @@ func (p *shardHandler) createNewPaymentAttempt(rt *route.Route) (
|
||||
rt.Hops[0].ChannelID,
|
||||
)
|
||||
|
||||
// We generate a new, unique payment ID that we will use for
|
||||
// this HTLC.
|
||||
attemptID, err := p.router.cfg.NextPaymentID()
|
||||
if err != nil {
|
||||
return lnwire.ShortChannelID{}, nil, nil, err
|
||||
}
|
||||
|
||||
// We now have all the information needed to populate
|
||||
// the current attempt information.
|
||||
attempt := &channeldb.HTLCAttemptInfo{
|
||||
@ -722,6 +758,13 @@ func (p *shardHandler) failAttempt(attempt *channeldb.HTLCAttemptInfo,
|
||||
p.router.cfg.Clock.Now(),
|
||||
)
|
||||
|
||||
// Now that we are failing this payment attempt, cancel the shard with
|
||||
// the ShardTracker such that it can derive the correct hash for the
|
||||
// next attempt.
|
||||
if err := p.shardTracker.CancelShard(attempt.AttemptID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.router.cfg.Control.FailAttempt(
|
||||
p.paymentHash, attempt.AttemptID,
|
||||
failInfo,
|
||||
|
@ -29,6 +29,7 @@ import (
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
"github.com/lightningnetwork/lnd/routing/chainview"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
"github.com/lightningnetwork/lnd/routing/shards"
|
||||
"github.com/lightningnetwork/lnd/ticker"
|
||||
"github.com/lightningnetwork/lnd/zpay32"
|
||||
)
|
||||
@ -603,19 +604,40 @@ func (r *ChannelRouter) Start() error {
|
||||
go func(payment *channeldb.MPPayment) {
|
||||
defer r.wg.Done()
|
||||
|
||||
// Get the hashes used for the outstanding HTLCs.
|
||||
htlcs := make(map[uint64]lntypes.Hash)
|
||||
for _, a := range payment.HTLCs {
|
||||
a := a
|
||||
|
||||
hash := payment.Info.PaymentHash
|
||||
htlcs[a.AttemptID] = hash
|
||||
}
|
||||
|
||||
// Since we are not supporting creating more shards
|
||||
// after a restart (only receiving the result of the
|
||||
// shards already outstanding), we create a simple
|
||||
// shard tracker that will map the attempt IDs to
|
||||
// hashes used for the HTLCs. This will be enough also
|
||||
// for AMP payments, since we only need the hashes for
|
||||
// the individual HTLCs to regenerate the circuits, and
|
||||
// we don't currently persist the root share necessary
|
||||
// to re-derive them.
|
||||
shardTracker := shards.NewSimpleShardTracker(
|
||||
payment.Info.PaymentHash, htlcs,
|
||||
)
|
||||
|
||||
// We create a dummy, empty payment session such that
|
||||
// we won't make another payment attempt when the
|
||||
// result for the in-flight attempt is received.
|
||||
paySession := r.cfg.SessionSource.NewPaymentSessionEmpty()
|
||||
|
||||
// We pass in a zero timeout value, to indicate we
|
||||
// don't need it to timeout. It will stop immediately
|
||||
// after the existing attempt has finished anyway. We
|
||||
// also set a zero fee limit, as no more routes should
|
||||
// be tried.
|
||||
_, _, err := r.sendPayment(
|
||||
payment.Info.Value, 0,
|
||||
payment.Info.PaymentHash, 0, paySession,
|
||||
payment.Info.Value, 0, payment.Info.PaymentHash,
|
||||
0, paySession, shardTracker,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("Resuming payment with hash %v "+
|
||||
@ -1770,7 +1792,7 @@ type LightningPayment struct {
|
||||
func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte,
|
||||
*route.Route, error) {
|
||||
|
||||
paySession, err := r.preparePayment(payment)
|
||||
paySession, shardTracker, err := r.preparePayment(payment)
|
||||
if err != nil {
|
||||
return [32]byte{}, nil, err
|
||||
}
|
||||
@ -1782,14 +1804,14 @@ func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte,
|
||||
// for the existing attempt.
|
||||
return r.sendPayment(
|
||||
payment.Amount, payment.FeeLimit, payment.PaymentHash,
|
||||
payment.PayAttemptTimeout, paySession,
|
||||
payment.PayAttemptTimeout, paySession, shardTracker,
|
||||
)
|
||||
}
|
||||
|
||||
// SendPaymentAsync is the non-blocking version of SendPayment. The payment
|
||||
// result needs to be retrieved via the control tower.
|
||||
func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment) error {
|
||||
paySession, err := r.preparePayment(payment)
|
||||
paySession, shardTracker, err := r.preparePayment(payment)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -1805,7 +1827,7 @@ func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment) error {
|
||||
|
||||
_, _, err := r.sendPayment(
|
||||
payment.Amount, payment.FeeLimit, payment.PaymentHash,
|
||||
payment.PayAttemptTimeout, paySession,
|
||||
payment.PayAttemptTimeout, paySession, shardTracker,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("Payment with hash %x failed: %v",
|
||||
@ -1841,14 +1863,14 @@ func spewPayment(payment *LightningPayment) logClosure {
|
||||
// preparePayment creates the payment session and registers the payment with the
|
||||
// control tower.
|
||||
func (r *ChannelRouter) preparePayment(payment *LightningPayment) (
|
||||
PaymentSession, error) {
|
||||
PaymentSession, shards.ShardTracker, error) {
|
||||
|
||||
// Before starting the HTLC routing attempt, we'll create a fresh
|
||||
// payment session which will report our errors back to mission
|
||||
// control.
|
||||
paySession, err := r.cfg.SessionSource.NewPaymentSession(payment)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Record this payment hash with the ControlTower, ensuring it is not
|
||||
@ -1862,12 +1884,18 @@ func (r *ChannelRouter) preparePayment(payment *LightningPayment) (
|
||||
PaymentRequest: payment.PaymentRequest,
|
||||
}
|
||||
|
||||
// Create a new ShardTracker that we'll use during the life cycle of
|
||||
// this payment.
|
||||
shardTracker := shards.NewSimpleShardTracker(
|
||||
payment.PaymentHash, nil,
|
||||
)
|
||||
|
||||
err = r.cfg.Control.InitPayment(payment.PaymentHash, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return paySession, nil
|
||||
return paySession, shardTracker, nil
|
||||
}
|
||||
|
||||
// SendToRoute attempts to send a payment with the given hash through the
|
||||
@ -1915,14 +1943,22 @@ func (r *ChannelRouter) SendToRoute(hash lntypes.Hash, rt *route.Route) (
|
||||
}),
|
||||
)
|
||||
|
||||
// Since the HTLC hashes and preimages are specified manually over the
|
||||
// RPC for SendToRoute requests, we don't have to worry about creating
|
||||
// a ShardTracker that can generate hashes for AMP payments. Instead we
|
||||
// create a simple tracker that can just return the hash for the single
|
||||
// shard we'll now launch.
|
||||
shardTracker := shards.NewSimpleShardTracker(hash, nil)
|
||||
|
||||
// Launch a shard along the given route.
|
||||
sh := &shardHandler{
|
||||
router: r,
|
||||
paymentHash: hash,
|
||||
shardTracker: shardTracker,
|
||||
}
|
||||
|
||||
var shardError error
|
||||
attempt, outcome, err := sh.launchShard(rt)
|
||||
attempt, outcome, err := sh.launchShard(rt, false)
|
||||
|
||||
// With SendToRoute, it can happen that the route exceeds protocol
|
||||
// constraints. Mark the payment as failed with an internal error.
|
||||
@ -2007,8 +2043,8 @@ func (r *ChannelRouter) SendToRoute(hash lntypes.Hash, rt *route.Route) (
|
||||
// the ControlTower.
|
||||
func (r *ChannelRouter) sendPayment(
|
||||
totalAmt, feeLimit lnwire.MilliSatoshi, paymentHash lntypes.Hash,
|
||||
timeout time.Duration,
|
||||
paySession PaymentSession) ([32]byte, *route.Route, error) {
|
||||
timeout time.Duration, paySession PaymentSession,
|
||||
shardTracker shards.ShardTracker) ([32]byte, *route.Route, error) {
|
||||
|
||||
// We'll also fetch the current block height so we can properly
|
||||
// calculate the required HTLC time locks within the route.
|
||||
@ -2025,6 +2061,7 @@ func (r *ChannelRouter) sendPayment(
|
||||
feeLimit: feeLimit,
|
||||
paymentHash: paymentHash,
|
||||
paySession: paySession,
|
||||
shardTracker: shardTracker,
|
||||
currentHeight: currentHeight,
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,7 @@ package shards
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
@ -74,6 +75,7 @@ func (s *Shard) AMP() *record.AMP {
|
||||
type SimpleShardTracker struct {
|
||||
hash lntypes.Hash
|
||||
shards map[uint64]lntypes.Hash
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
// A compile time check to ensure SimpleShardTracker implements the
|
||||
@ -100,7 +102,9 @@ func NewSimpleShardTracker(paymentHash lntypes.Hash,
|
||||
// if it ends up not being used by the overall payment, i.e. if the attempt
|
||||
// fails.
|
||||
func (m *SimpleShardTracker) NewShard(id uint64, _ bool) (PaymentShard, error) {
|
||||
m.Lock()
|
||||
m.shards[id] = m.hash
|
||||
m.Unlock()
|
||||
|
||||
return &Shard{
|
||||
hash: m.hash,
|
||||
@ -109,14 +113,19 @@ func (m *SimpleShardTracker) NewShard(id uint64, _ bool) (PaymentShard, error) {
|
||||
|
||||
// CancelShard cancel's the shard corresponding to the given attempt ID.
|
||||
func (m *SimpleShardTracker) CancelShard(id uint64) error {
|
||||
m.Lock()
|
||||
delete(m.shards, id)
|
||||
m.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetHash retrieves the hash used by the shard of the given attempt ID. This
|
||||
// will return an error if the attempt ID is unknown.
|
||||
func (m *SimpleShardTracker) GetHash(id uint64) (lntypes.Hash, error) {
|
||||
m.Lock()
|
||||
hash, ok := m.shards[id]
|
||||
m.Unlock()
|
||||
if !ok {
|
||||
return lntypes.Hash{}, fmt.Errorf("hash for attempt id %v "+
|
||||
"not found", id)
|
||||
|
Loading…
Reference in New Issue
Block a user