From 864e64e725b520fa22c771b755904b82bd43060c Mon Sep 17 00:00:00 2001 From: "Johan T. Halseth" Date: Wed, 1 Apr 2020 00:13:27 +0200 Subject: [PATCH] channeldb: validate MPP options when registering attempts We add validation making sure we are not trying to register MPP shards for non-MPP payments, and vice versa. We also add validtion of total sent amount against payment value, and matching MPP options. We also add methods for copying Route/Hop, since it is useful to use for modifying the route amount in the test. --- channeldb/payment_control.go | 74 ++++++++++++++++++ channeldb/payment_control_test.go | 122 +++++++++++++++++++++++++++++- routing/control_tower_test.go | 2 +- routing/route/route.go | 38 ++++++++++ 4 files changed, 231 insertions(+), 5 deletions(-) diff --git a/channeldb/payment_control.go b/channeldb/payment_control.go index d0bbae75..ca5b6998 100644 --- a/channeldb/payment_control.go +++ b/channeldb/payment_control.go @@ -45,6 +45,32 @@ var ( // failed HTLC attempt. ErrAttemptAlreadyFailed = errors.New("attempt already failed") + // ErrValueMismatch is returned if we try to register a non-MPP attempt + // with an amount that doesn't match the payment amount. + ErrValueMismatch = errors.New("attempted value doesn't match payment" + + "amount") + + // ErrValueExceedsAmt is returned if we try to register an attempt that + // would take the total sent amount above the payment amount. + ErrValueExceedsAmt = errors.New("attempted value exceeds payment" + + "amount") + + // ErrNonMPPayment is returned if we try to register an MPP attempt for + // a payment that already has a non-MPP attempt regitered. + ErrNonMPPayment = errors.New("payment has non-MPP attempts") + + // ErrMPPayment is returned if we try to register a non-MPP attempt for + // a payment that already has an MPP attempt regitered. + ErrMPPayment = errors.New("payment has MPP attempts") + + // ErrMPPPaymentAddrMismatch is returned if we try to register an MPP + // shard where the payment address doesn't match existing shards. + ErrMPPPaymentAddrMismatch = errors.New("payment address mismatch") + + // ErrMPPTotalAmountMismatch is returned if we try to register an MPP + // shard where the total amount doesn't match existing shards. + ErrMPPTotalAmountMismatch = errors.New("mp payment total amount mismatch") + // errNoAttemptInfo is returned when no attempt info is stored yet. errNoAttemptInfo = errors.New("unable to find attempt info for " + "inflight payment") @@ -189,11 +215,59 @@ func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash, return err } + // We cannot register a new attempt if the payment already has + // reached a terminal condition: settle, fail := payment.TerminalInfo() if settle != nil || fail != nil { return ErrPaymentTerminal } + // Make sure any existing shards match the new one with regards + // to MPP options. + mpp := attempt.Route.FinalHop().MPP + for _, h := range payment.InFlightHTLCs() { + hMpp := h.Route.FinalHop().MPP + + switch { + + // We tried to register a non-MPP attempt for a MPP + // payment. + case mpp == nil && hMpp != nil: + return ErrMPPayment + + // We tried to register a MPP shard for a non-MPP + // payment. + case mpp != nil && hMpp == nil: + return ErrNonMPPayment + + // Non-MPP payment, nothing more to validate. + case mpp == nil: + continue + } + + // Check that MPP options match. + if mpp.PaymentAddr() != hMpp.PaymentAddr() { + return ErrMPPPaymentAddrMismatch + } + + if mpp.TotalMsat() != hMpp.TotalMsat() { + return ErrMPPTotalAmountMismatch + } + } + + // If this is a non-MPP attempt, it must match the total amount + // exactly. + amt := attempt.Route.ReceiverAmt() + if mpp == nil && amt != payment.Info.Value { + return ErrValueMismatch + } + + // Ensure we aren't sending more than the total payment amount. + sentAmt, _ := payment.SentAmt() + if sentAmt+amt > payment.Info.Value { + return ErrValueExceedsAmt + } + htlcsBucket, err := bucket.CreateBucketIfNotExists( paymentHtlcsBucket, ) diff --git a/channeldb/payment_control_test.go b/channeldb/payment_control_test.go index 015aa231..abd2722a 100644 --- a/channeldb/payment_control_test.go +++ b/channeldb/payment_control_test.go @@ -12,6 +12,7 @@ import ( "github.com/btcsuite/fastsha256" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/record" ) func initDB() (*DB, error) { @@ -48,14 +49,14 @@ func genInfo() (*PaymentCreationInfo, *HTLCAttemptInfo, rhash := fastsha256.Sum256(preimage[:]) return &PaymentCreationInfo{ PaymentHash: rhash, - Value: 1, + Value: testRoute.ReceiverAmt(), CreationTime: time.Unix(time.Now().Unix(), 0), PaymentRequest: []byte("hola"), }, &HTLCAttemptInfo{ AttemptID: 0, SessionKey: priv, - Route: testRoute, + Route: *testRoute.Copy(), }, preimage, nil } @@ -504,7 +505,15 @@ func TestPaymentControlMultiShard(t *testing.T) { ) // Create three unique attempts we'll use for the test, and - // register them with the payment control. + // register them with the payment control. We set each + // attempts's value to one third of the payment amount, and + // populate the MPP options. + shardAmt := info.Value / 3 + attempt.Route.FinalHop().AmtToForward = shardAmt + attempt.Route.FinalHop().MPP = record.NewMPP( + info.Value, [32]byte{1}, + ) + var attempts []*HTLCAttemptInfo for i := uint64(0); i < 3; i++ { a := *attempt @@ -527,6 +536,17 @@ func TestPaymentControlMultiShard(t *testing.T) { ) } + // For a fourth attempt, check that attempting to + // register it will fail since the total sent amount + // will be too large. + b := *attempt + b.AttemptID = 3 + err = pControl.RegisterAttempt(info.PaymentHash, &b) + if err != ErrValueExceedsAmt { + t.Fatalf("expected ErrValueExceedsAmt, got: %v", + err) + } + // Fail the second attempt. a := attempts[1] htlcFail := HTLCFailUnreadable @@ -612,7 +632,7 @@ func TestPaymentControlMultiShard(t *testing.T) { // Try to register yet another attempt. This should fail now // that the payment has reached a terminal condition. - b := *attempt + b = *attempt b.AttemptID = 3 err = pControl.RegisterAttempt(info.PaymentHash, &b) if err != ErrPaymentTerminal { @@ -705,6 +725,100 @@ func TestPaymentControlMultiShard(t *testing.T) { } } +func TestPaymentControlMPPRecordValidation(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(db) + + info, attempt, _, err := genInfo() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // Init the payment. + err = pControl.InitPayment(info.PaymentHash, info) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + // Create three unique attempts we'll use for the test, and + // register them with the payment control. We set each + // attempts's value to one third of the payment amount, and + // populate the MPP options. + shardAmt := info.Value / 3 + attempt.Route.FinalHop().AmtToForward = shardAmt + attempt.Route.FinalHop().MPP = record.NewMPP( + info.Value, [32]byte{1}, + ) + + err = pControl.RegisterAttempt(info.PaymentHash, attempt) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + // Now try to register a non-MPP attempt, which should fail. + b := *attempt + b.AttemptID = 1 + b.Route.FinalHop().MPP = nil + err = pControl.RegisterAttempt(info.PaymentHash, &b) + if err != ErrMPPayment { + t.Fatalf("expected ErrMPPayment, got: %v", err) + } + + // Try to register attempt one with a different payment address. + b.Route.FinalHop().MPP = record.NewMPP( + info.Value, [32]byte{2}, + ) + err = pControl.RegisterAttempt(info.PaymentHash, &b) + if err != ErrMPPPaymentAddrMismatch { + t.Fatalf("expected ErrMPPPaymentAddrMismatch, got: %v", err) + } + + // Try registering one with a different total amount. + b.Route.FinalHop().MPP = record.NewMPP( + info.Value/2, [32]byte{1}, + ) + err = pControl.RegisterAttempt(info.PaymentHash, &b) + if err != ErrMPPTotalAmountMismatch { + t.Fatalf("expected ErrMPPTotalAmountMismatch, got: %v", err) + } + + // Create and init a new payment. This time we'll check that we cannot + // register an MPP attempt if we already registered a non-MPP one. + info, attempt, _, err = genInfo() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + err = pControl.InitPayment(info.PaymentHash, info) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + attempt.Route.FinalHop().MPP = nil + err = pControl.RegisterAttempt(info.PaymentHash, attempt) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + // Attempt to register an MPP attempt, which should fail. + b = *attempt + b.AttemptID = 1 + b.Route.FinalHop().MPP = record.NewMPP( + info.Value, [32]byte{1}, + ) + + err = pControl.RegisterAttempt(info.PaymentHash, &b) + if err != ErrNonMPPayment { + t.Fatalf("expected ErrNonMPPayment, got: %v", err) + } +} + // assertPaymentStatus retrieves the status of the payment referred to by hash // and compares it with the expected state. func assertPaymentStatus(t *testing.T, p *PaymentControl, diff --git a/routing/control_tower_test.go b/routing/control_tower_test.go index 6bc8ffd7..82dc2706 100644 --- a/routing/control_tower_test.go +++ b/routing/control_tower_test.go @@ -324,7 +324,7 @@ func genInfo() (*channeldb.PaymentCreationInfo, *channeldb.HTLCAttemptInfo, rhash := sha256.Sum256(preimage[:]) return &channeldb.PaymentCreationInfo{ PaymentHash: rhash, - Value: 1, + Value: testRoute.ReceiverAmt(), CreationTime: time.Unix(time.Now().Unix(), 0), PaymentRequest: []byte("hola"), }, diff --git a/routing/route/route.go b/routing/route/route.go index 31fa3bf5..63944af1 100644 --- a/routing/route/route.go +++ b/routing/route/route.go @@ -129,6 +129,23 @@ type Hop struct { LegacyPayload bool } +// Copy returns a deep copy of the Hop. +func (h *Hop) Copy() *Hop { + c := *h + + if h.MPP != nil { + m := *h.MPP + c.MPP = &m + } + + if h.AMP != nil { + a := *h.AMP + c.AMP = &a + } + + return &c +} + // PackHopPayload writes to the passed io.Writer, the series of byes that can // be placed directly into the per-hop payload (EOB) for this hop. This will // include the required routing fields, as well as serializing any of the @@ -287,6 +304,18 @@ type Route struct { Hops []*Hop } +// Copy returns a deep copy of the Route. +func (r *Route) Copy() *Route { + c := *r + + c.Hops = make([]*Hop, len(r.Hops)) + for i := range r.Hops { + c.Hops[i] = r.Hops[i].Copy() + } + + return &c +} + // HopFee returns the fee charged by the route hop indicated by hopIndex. func (r *Route) HopFee(hopIndex int) lnwire.MilliSatoshi { var incomingAmt lnwire.MilliSatoshi @@ -320,6 +349,15 @@ func (r *Route) ReceiverAmt() lnwire.MilliSatoshi { return r.Hops[len(r.Hops)-1].AmtToForward } +// FinalHop returns the last hop of the route, or nil if the route is empty. +func (r *Route) FinalHop() *Hop { + if len(r.Hops) == 0 { + return nil + } + + return r.Hops[len(r.Hops)-1] +} + // NewRouteFromHops creates a new Route structure from the minimally required // information to perform the payment. It infers fee amounts and populates the // node, chan and prev/next hop maps.