channeldb: validate MPP options when registering attempts

We add validation making sure we are not trying to register MPP shards
for non-MPP payments, and vice versa. We also add validtion of total
sent amount against payment value, and matching MPP options.

We also add methods for copying Route/Hop, since it is useful to use
for modifying the route amount in the test.
This commit is contained in:
Johan T. Halseth 2020-04-01 00:13:27 +02:00
parent 9a1ec950bd
commit 864e64e725
No known key found for this signature in database
GPG Key ID: 15BAADA29DA20D26
4 changed files with 231 additions and 5 deletions

@ -45,6 +45,32 @@ var (
// failed HTLC attempt.
ErrAttemptAlreadyFailed = errors.New("attempt already failed")
// ErrValueMismatch is returned if we try to register a non-MPP attempt
// with an amount that doesn't match the payment amount.
ErrValueMismatch = errors.New("attempted value doesn't match payment" +
"amount")
// ErrValueExceedsAmt is returned if we try to register an attempt that
// would take the total sent amount above the payment amount.
ErrValueExceedsAmt = errors.New("attempted value exceeds payment" +
"amount")
// ErrNonMPPayment is returned if we try to register an MPP attempt for
// a payment that already has a non-MPP attempt regitered.
ErrNonMPPayment = errors.New("payment has non-MPP attempts")
// ErrMPPayment is returned if we try to register a non-MPP attempt for
// a payment that already has an MPP attempt regitered.
ErrMPPayment = errors.New("payment has MPP attempts")
// ErrMPPPaymentAddrMismatch is returned if we try to register an MPP
// shard where the payment address doesn't match existing shards.
ErrMPPPaymentAddrMismatch = errors.New("payment address mismatch")
// ErrMPPTotalAmountMismatch is returned if we try to register an MPP
// shard where the total amount doesn't match existing shards.
ErrMPPTotalAmountMismatch = errors.New("mp payment total amount mismatch")
// errNoAttemptInfo is returned when no attempt info is stored yet.
errNoAttemptInfo = errors.New("unable to find attempt info for " +
"inflight payment")
@ -189,11 +215,59 @@ func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash,
return err
}
// We cannot register a new attempt if the payment already has
// reached a terminal condition:
settle, fail := payment.TerminalInfo()
if settle != nil || fail != nil {
return ErrPaymentTerminal
}
// Make sure any existing shards match the new one with regards
// to MPP options.
mpp := attempt.Route.FinalHop().MPP
for _, h := range payment.InFlightHTLCs() {
hMpp := h.Route.FinalHop().MPP
switch {
// We tried to register a non-MPP attempt for a MPP
// payment.
case mpp == nil && hMpp != nil:
return ErrMPPayment
// We tried to register a MPP shard for a non-MPP
// payment.
case mpp != nil && hMpp == nil:
return ErrNonMPPayment
// Non-MPP payment, nothing more to validate.
case mpp == nil:
continue
}
// Check that MPP options match.
if mpp.PaymentAddr() != hMpp.PaymentAddr() {
return ErrMPPPaymentAddrMismatch
}
if mpp.TotalMsat() != hMpp.TotalMsat() {
return ErrMPPTotalAmountMismatch
}
}
// If this is a non-MPP attempt, it must match the total amount
// exactly.
amt := attempt.Route.ReceiverAmt()
if mpp == nil && amt != payment.Info.Value {
return ErrValueMismatch
}
// Ensure we aren't sending more than the total payment amount.
sentAmt, _ := payment.SentAmt()
if sentAmt+amt > payment.Info.Value {
return ErrValueExceedsAmt
}
htlcsBucket, err := bucket.CreateBucketIfNotExists(
paymentHtlcsBucket,
)

@ -12,6 +12,7 @@ import (
"github.com/btcsuite/fastsha256"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/record"
)
func initDB() (*DB, error) {
@ -48,14 +49,14 @@ func genInfo() (*PaymentCreationInfo, *HTLCAttemptInfo,
rhash := fastsha256.Sum256(preimage[:])
return &PaymentCreationInfo{
PaymentHash: rhash,
Value: 1,
Value: testRoute.ReceiverAmt(),
CreationTime: time.Unix(time.Now().Unix(), 0),
PaymentRequest: []byte("hola"),
},
&HTLCAttemptInfo{
AttemptID: 0,
SessionKey: priv,
Route: testRoute,
Route: *testRoute.Copy(),
}, preimage, nil
}
@ -504,7 +505,15 @@ func TestPaymentControlMultiShard(t *testing.T) {
)
// Create three unique attempts we'll use for the test, and
// register them with the payment control.
// register them with the payment control. We set each
// attempts's value to one third of the payment amount, and
// populate the MPP options.
shardAmt := info.Value / 3
attempt.Route.FinalHop().AmtToForward = shardAmt
attempt.Route.FinalHop().MPP = record.NewMPP(
info.Value, [32]byte{1},
)
var attempts []*HTLCAttemptInfo
for i := uint64(0); i < 3; i++ {
a := *attempt
@ -527,6 +536,17 @@ func TestPaymentControlMultiShard(t *testing.T) {
)
}
// For a fourth attempt, check that attempting to
// register it will fail since the total sent amount
// will be too large.
b := *attempt
b.AttemptID = 3
err = pControl.RegisterAttempt(info.PaymentHash, &b)
if err != ErrValueExceedsAmt {
t.Fatalf("expected ErrValueExceedsAmt, got: %v",
err)
}
// Fail the second attempt.
a := attempts[1]
htlcFail := HTLCFailUnreadable
@ -612,7 +632,7 @@ func TestPaymentControlMultiShard(t *testing.T) {
// Try to register yet another attempt. This should fail now
// that the payment has reached a terminal condition.
b := *attempt
b = *attempt
b.AttemptID = 3
err = pControl.RegisterAttempt(info.PaymentHash, &b)
if err != ErrPaymentTerminal {
@ -705,6 +725,100 @@ func TestPaymentControlMultiShard(t *testing.T) {
}
}
func TestPaymentControlMPPRecordValidation(t *testing.T) {
t.Parallel()
db, err := initDB()
if err != nil {
t.Fatalf("unable to init db: %v", err)
}
pControl := NewPaymentControl(db)
info, attempt, _, err := genInfo()
if err != nil {
t.Fatalf("unable to generate htlc message: %v", err)
}
// Init the payment.
err = pControl.InitPayment(info.PaymentHash, info)
if err != nil {
t.Fatalf("unable to send htlc message: %v", err)
}
// Create three unique attempts we'll use for the test, and
// register them with the payment control. We set each
// attempts's value to one third of the payment amount, and
// populate the MPP options.
shardAmt := info.Value / 3
attempt.Route.FinalHop().AmtToForward = shardAmt
attempt.Route.FinalHop().MPP = record.NewMPP(
info.Value, [32]byte{1},
)
err = pControl.RegisterAttempt(info.PaymentHash, attempt)
if err != nil {
t.Fatalf("unable to send htlc message: %v", err)
}
// Now try to register a non-MPP attempt, which should fail.
b := *attempt
b.AttemptID = 1
b.Route.FinalHop().MPP = nil
err = pControl.RegisterAttempt(info.PaymentHash, &b)
if err != ErrMPPayment {
t.Fatalf("expected ErrMPPayment, got: %v", err)
}
// Try to register attempt one with a different payment address.
b.Route.FinalHop().MPP = record.NewMPP(
info.Value, [32]byte{2},
)
err = pControl.RegisterAttempt(info.PaymentHash, &b)
if err != ErrMPPPaymentAddrMismatch {
t.Fatalf("expected ErrMPPPaymentAddrMismatch, got: %v", err)
}
// Try registering one with a different total amount.
b.Route.FinalHop().MPP = record.NewMPP(
info.Value/2, [32]byte{1},
)
err = pControl.RegisterAttempt(info.PaymentHash, &b)
if err != ErrMPPTotalAmountMismatch {
t.Fatalf("expected ErrMPPTotalAmountMismatch, got: %v", err)
}
// Create and init a new payment. This time we'll check that we cannot
// register an MPP attempt if we already registered a non-MPP one.
info, attempt, _, err = genInfo()
if err != nil {
t.Fatalf("unable to generate htlc message: %v", err)
}
err = pControl.InitPayment(info.PaymentHash, info)
if err != nil {
t.Fatalf("unable to send htlc message: %v", err)
}
attempt.Route.FinalHop().MPP = nil
err = pControl.RegisterAttempt(info.PaymentHash, attempt)
if err != nil {
t.Fatalf("unable to send htlc message: %v", err)
}
// Attempt to register an MPP attempt, which should fail.
b = *attempt
b.AttemptID = 1
b.Route.FinalHop().MPP = record.NewMPP(
info.Value, [32]byte{1},
)
err = pControl.RegisterAttempt(info.PaymentHash, &b)
if err != ErrNonMPPayment {
t.Fatalf("expected ErrNonMPPayment, got: %v", err)
}
}
// assertPaymentStatus retrieves the status of the payment referred to by hash
// and compares it with the expected state.
func assertPaymentStatus(t *testing.T, p *PaymentControl,

@ -324,7 +324,7 @@ func genInfo() (*channeldb.PaymentCreationInfo, *channeldb.HTLCAttemptInfo,
rhash := sha256.Sum256(preimage[:])
return &channeldb.PaymentCreationInfo{
PaymentHash: rhash,
Value: 1,
Value: testRoute.ReceiverAmt(),
CreationTime: time.Unix(time.Now().Unix(), 0),
PaymentRequest: []byte("hola"),
},

@ -129,6 +129,23 @@ type Hop struct {
LegacyPayload bool
}
// Copy returns a deep copy of the Hop.
func (h *Hop) Copy() *Hop {
c := *h
if h.MPP != nil {
m := *h.MPP
c.MPP = &m
}
if h.AMP != nil {
a := *h.AMP
c.AMP = &a
}
return &c
}
// PackHopPayload writes to the passed io.Writer, the series of byes that can
// be placed directly into the per-hop payload (EOB) for this hop. This will
// include the required routing fields, as well as serializing any of the
@ -287,6 +304,18 @@ type Route struct {
Hops []*Hop
}
// Copy returns a deep copy of the Route.
func (r *Route) Copy() *Route {
c := *r
c.Hops = make([]*Hop, len(r.Hops))
for i := range r.Hops {
c.Hops[i] = r.Hops[i].Copy()
}
return &c
}
// HopFee returns the fee charged by the route hop indicated by hopIndex.
func (r *Route) HopFee(hopIndex int) lnwire.MilliSatoshi {
var incomingAmt lnwire.MilliSatoshi
@ -320,6 +349,15 @@ func (r *Route) ReceiverAmt() lnwire.MilliSatoshi {
return r.Hops[len(r.Hops)-1].AmtToForward
}
// FinalHop returns the last hop of the route, or nil if the route is empty.
func (r *Route) FinalHop() *Hop {
if len(r.Hops) == 0 {
return nil
}
return r.Hops[len(r.Hops)-1]
}
// NewRouteFromHops creates a new Route structure from the minimally required
// information to perform the payment. It infers fee amounts and populates the
// node, chan and prev/next hop maps.