From c4fc72d573ec834b4651725eca59e9a7258e7883 Mon Sep 17 00:00:00 2001 From: "Johan T. Halseth" Date: Wed, 31 Mar 2021 12:44:59 +0200 Subject: [PATCH] routerrpc+routing: set AMP options for payments specified as AMP in SendPayment --- lnrpc/routerrpc/router_backend.go | 91 +++++++++++++++++++++++++------ routing/router.go | 19 +++++++ 2 files changed, 94 insertions(+), 16 deletions(-) diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index 66b7244c..45ec0dfe 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -2,6 +2,7 @@ package routerrpc import ( "context" + "crypto/rand" "encoding/hex" "errors" "fmt" @@ -741,17 +742,6 @@ func (r *RouterBackend) extractIntentFromSendRequest( payIntent.Amount = reqAmt - // Payment hash. - paymentHash, err := lntypes.MakeHash(rpcPayReq.PaymentHash) - if err != nil { - return nil, err - } - - err = payIntent.SetPaymentHash(paymentHash) - if err != nil { - return nil, err - } - // Parse destination feature bits. features, err := UnmarshalFeatures(rpcPayReq.DestFeatures) if err != nil { @@ -766,13 +756,82 @@ func (r *RouterBackend) extractIntentFromSendRequest( } } - // If the payment addresses is specified, then we'll also - // populate that now as well. - if len(rpcPayReq.PaymentAddr) != 0 { - var payAddr [32]byte - copy(payAddr[:], rpcPayReq.PaymentAddr) + // If this is an AMP payment, we must generate the initial + // randomness. + if rpcPayReq.Amp { + // If no destination features were specified, we set + // those necessary for AMP payments. + if features == nil { + ampFeatures := []lnrpc.FeatureBit{ + lnrpc.FeatureBit_TLV_ONION_OPT, + lnrpc.FeatureBit_PAYMENT_ADDR_OPT, + lnrpc.FeatureBit_MPP_OPT, + lnrpc.FeatureBit_AMP_OPT, + } + features, err = UnmarshalFeatures(ampFeatures) + if err != nil { + return nil, err + } + } + + // First make sure the destination supports AMP. + if !features.HasFeature(lnwire.AMPOptional) { + return nil, fmt.Errorf("destination doesn't " + + "support AMP payments") + } + + // If no payment address is set, generate a random one. + var payAddr [32]byte + if len(rpcPayReq.PaymentAddr) == 0 { + _, err = rand.Read(payAddr[:]) + if err != nil { + return nil, err + } + } else { + copy(payAddr[:], rpcPayReq.PaymentAddr) + } payIntent.PaymentAddr = &payAddr + + // Generate random SetID and root share. + var setID [32]byte + _, err = rand.Read(setID[:]) + if err != nil { + return nil, err + } + + var rootShare [32]byte + _, err = rand.Read(rootShare[:]) + if err != nil { + return nil, err + } + err := payIntent.SetAMP(&routing.AMPOptions{ + SetID: setID, + RootShare: rootShare, + }) + if err != nil { + return nil, err + } + } else { + // Payment hash. + paymentHash, err := lntypes.MakeHash(rpcPayReq.PaymentHash) + if err != nil { + return nil, err + } + + err = payIntent.SetPaymentHash(paymentHash) + if err != nil { + return nil, err + } + + // If the payment addresses is specified, then we'll + // also populate that now as well. + if len(rpcPayReq.PaymentAddr) != 0 { + var payAddr [32]byte + copy(payAddr[:], rpcPayReq.PaymentAddr) + + payIntent.PaymentAddr = &payAddr + } } payIntent.DestFeatures = features diff --git a/routing/router.go b/routing/router.go index 205bd214..20f7ee0c 100644 --- a/routing/router.go +++ b/routing/router.go @@ -1806,14 +1806,33 @@ type AMPOptions struct { // SetPaymentHash sets the given hash as the payment's overall hash. This // should only be used for non-AMP payments. func (l *LightningPayment) SetPaymentHash(hash lntypes.Hash) error { + if l.amp != nil { + return fmt.Errorf("cannot set payment hash for AMP payment") + } + l.paymentHash = &hash return nil } +// SetAMP sets the given AMP options for the payment. +func (l *LightningPayment) SetAMP(amp *AMPOptions) error { + if l.paymentHash != nil { + return fmt.Errorf("cannot set amp options for payment " + + "with payment hash") + } + + l.amp = amp + return nil +} + // Identifier returns a 32-byte slice that uniquely identifies this single // payment. For non-AMP payments this will be the payment hash, for AMP // payments this will be the used SetID. func (l *LightningPayment) Identifier() [32]byte { + if l.amp != nil { + return l.amp.SetID + } + return *l.paymentHash }