routing: use AMP shard tracker

We'll use the AMP-specific ShardTracker for payments having non-nil
AMPOptions.
This commit is contained in:
Johan T. Halseth 2021-04-12 15:05:48 +02:00
parent 2d397b12b1
commit 5531b812e3
No known key found for this signature in database
GPG Key ID: 15BAADA29DA20D26

@ -15,6 +15,7 @@ import (
"github.com/go-errors/errors"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/amp"
"github.com/lightningnetwork/lnd/batch"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/kvdb"
@ -1722,6 +1723,10 @@ type LightningPayment struct {
// the first hop.
PaymentHash [32]byte
// amp is an optional field that is set if and only if this is am AMP
// payment.
amp *AMPOptions
// FinalCLTVDelta is the CTLV expiry delta to use for the _final_ hop
// in the route. This means that the final hop will have a CLTV delta
// of at least: currentHeight + FinalCLTVDelta.
@ -1789,6 +1794,13 @@ type LightningPayment struct {
MaxShardAmt *lnwire.MilliSatoshi
}
// AMPOptions houses information that must be known in order to send an AMP
// payment.
type AMPOptions struct {
SetID [32]byte
RootShare [32]byte
}
// SendPayment attempts to send a payment as described within the passed
// LightningPayment. This function is blocking and will return either: when the
// payment is successful, or all candidates routes have been attempted and
@ -1893,9 +1905,23 @@ func (r *ChannelRouter) preparePayment(payment *LightningPayment) (
// Create a new ShardTracker that we'll use during the life cycle of
// this payment.
shardTracker := shards.NewSimpleShardTracker(
var shardTracker shards.ShardTracker
switch {
// If this is an AMP payment, we'll use the AMP shard tracker.
case payment.amp != nil:
shardTracker = amp.NewShardTracker(
payment.amp.RootShare, payment.amp.SetID,
*payment.PaymentAddr, payment.Amount,
)
// Otherwise we'll use the simple tracker that will map each attempt to
// the same payment hash.
default:
shardTracker = shards.NewSimpleShardTracker(
payment.PaymentHash, nil,
)
}
err = r.cfg.Control.InitPayment(payment.PaymentHash, info)
if err != nil {