channeldb: validate MPP options when registering attempts
We add validation making sure we are not trying to register MPP shards for non-MPP payments, and vice versa. We also add validtion of total sent amount against payment value, and matching MPP options. We also add methods for copying Route/Hop, since it is useful to use for modifying the route amount in the test.
This commit is contained in:
parent
9a1ec950bd
commit
864e64e725
@ -45,6 +45,32 @@ var (
|
||||
// failed HTLC attempt.
|
||||
ErrAttemptAlreadyFailed = errors.New("attempt already failed")
|
||||
|
||||
// ErrValueMismatch is returned if we try to register a non-MPP attempt
|
||||
// with an amount that doesn't match the payment amount.
|
||||
ErrValueMismatch = errors.New("attempted value doesn't match payment" +
|
||||
"amount")
|
||||
|
||||
// ErrValueExceedsAmt is returned if we try to register an attempt that
|
||||
// would take the total sent amount above the payment amount.
|
||||
ErrValueExceedsAmt = errors.New("attempted value exceeds payment" +
|
||||
"amount")
|
||||
|
||||
// ErrNonMPPayment is returned if we try to register an MPP attempt for
|
||||
// a payment that already has a non-MPP attempt regitered.
|
||||
ErrNonMPPayment = errors.New("payment has non-MPP attempts")
|
||||
|
||||
// ErrMPPayment is returned if we try to register a non-MPP attempt for
|
||||
// a payment that already has an MPP attempt regitered.
|
||||
ErrMPPayment = errors.New("payment has MPP attempts")
|
||||
|
||||
// ErrMPPPaymentAddrMismatch is returned if we try to register an MPP
|
||||
// shard where the payment address doesn't match existing shards.
|
||||
ErrMPPPaymentAddrMismatch = errors.New("payment address mismatch")
|
||||
|
||||
// ErrMPPTotalAmountMismatch is returned if we try to register an MPP
|
||||
// shard where the total amount doesn't match existing shards.
|
||||
ErrMPPTotalAmountMismatch = errors.New("mp payment total amount mismatch")
|
||||
|
||||
// errNoAttemptInfo is returned when no attempt info is stored yet.
|
||||
errNoAttemptInfo = errors.New("unable to find attempt info for " +
|
||||
"inflight payment")
|
||||
@ -189,11 +215,59 @@ func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash,
|
||||
return err
|
||||
}
|
||||
|
||||
// We cannot register a new attempt if the payment already has
|
||||
// reached a terminal condition:
|
||||
settle, fail := payment.TerminalInfo()
|
||||
if settle != nil || fail != nil {
|
||||
return ErrPaymentTerminal
|
||||
}
|
||||
|
||||
// Make sure any existing shards match the new one with regards
|
||||
// to MPP options.
|
||||
mpp := attempt.Route.FinalHop().MPP
|
||||
for _, h := range payment.InFlightHTLCs() {
|
||||
hMpp := h.Route.FinalHop().MPP
|
||||
|
||||
switch {
|
||||
|
||||
// We tried to register a non-MPP attempt for a MPP
|
||||
// payment.
|
||||
case mpp == nil && hMpp != nil:
|
||||
return ErrMPPayment
|
||||
|
||||
// We tried to register a MPP shard for a non-MPP
|
||||
// payment.
|
||||
case mpp != nil && hMpp == nil:
|
||||
return ErrNonMPPayment
|
||||
|
||||
// Non-MPP payment, nothing more to validate.
|
||||
case mpp == nil:
|
||||
continue
|
||||
}
|
||||
|
||||
// Check that MPP options match.
|
||||
if mpp.PaymentAddr() != hMpp.PaymentAddr() {
|
||||
return ErrMPPPaymentAddrMismatch
|
||||
}
|
||||
|
||||
if mpp.TotalMsat() != hMpp.TotalMsat() {
|
||||
return ErrMPPTotalAmountMismatch
|
||||
}
|
||||
}
|
||||
|
||||
// If this is a non-MPP attempt, it must match the total amount
|
||||
// exactly.
|
||||
amt := attempt.Route.ReceiverAmt()
|
||||
if mpp == nil && amt != payment.Info.Value {
|
||||
return ErrValueMismatch
|
||||
}
|
||||
|
||||
// Ensure we aren't sending more than the total payment amount.
|
||||
sentAmt, _ := payment.SentAmt()
|
||||
if sentAmt+amt > payment.Info.Value {
|
||||
return ErrValueExceedsAmt
|
||||
}
|
||||
|
||||
htlcsBucket, err := bucket.CreateBucketIfNotExists(
|
||||
paymentHtlcsBucket,
|
||||
)
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"github.com/btcsuite/fastsha256"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
)
|
||||
|
||||
func initDB() (*DB, error) {
|
||||
@ -48,14 +49,14 @@ func genInfo() (*PaymentCreationInfo, *HTLCAttemptInfo,
|
||||
rhash := fastsha256.Sum256(preimage[:])
|
||||
return &PaymentCreationInfo{
|
||||
PaymentHash: rhash,
|
||||
Value: 1,
|
||||
Value: testRoute.ReceiverAmt(),
|
||||
CreationTime: time.Unix(time.Now().Unix(), 0),
|
||||
PaymentRequest: []byte("hola"),
|
||||
},
|
||||
&HTLCAttemptInfo{
|
||||
AttemptID: 0,
|
||||
SessionKey: priv,
|
||||
Route: testRoute,
|
||||
Route: *testRoute.Copy(),
|
||||
}, preimage, nil
|
||||
}
|
||||
|
||||
@ -504,7 +505,15 @@ func TestPaymentControlMultiShard(t *testing.T) {
|
||||
)
|
||||
|
||||
// Create three unique attempts we'll use for the test, and
|
||||
// register them with the payment control.
|
||||
// register them with the payment control. We set each
|
||||
// attempts's value to one third of the payment amount, and
|
||||
// populate the MPP options.
|
||||
shardAmt := info.Value / 3
|
||||
attempt.Route.FinalHop().AmtToForward = shardAmt
|
||||
attempt.Route.FinalHop().MPP = record.NewMPP(
|
||||
info.Value, [32]byte{1},
|
||||
)
|
||||
|
||||
var attempts []*HTLCAttemptInfo
|
||||
for i := uint64(0); i < 3; i++ {
|
||||
a := *attempt
|
||||
@ -527,6 +536,17 @@ func TestPaymentControlMultiShard(t *testing.T) {
|
||||
)
|
||||
}
|
||||
|
||||
// For a fourth attempt, check that attempting to
|
||||
// register it will fail since the total sent amount
|
||||
// will be too large.
|
||||
b := *attempt
|
||||
b.AttemptID = 3
|
||||
err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
||||
if err != ErrValueExceedsAmt {
|
||||
t.Fatalf("expected ErrValueExceedsAmt, got: %v",
|
||||
err)
|
||||
}
|
||||
|
||||
// Fail the second attempt.
|
||||
a := attempts[1]
|
||||
htlcFail := HTLCFailUnreadable
|
||||
@ -612,7 +632,7 @@ func TestPaymentControlMultiShard(t *testing.T) {
|
||||
|
||||
// Try to register yet another attempt. This should fail now
|
||||
// that the payment has reached a terminal condition.
|
||||
b := *attempt
|
||||
b = *attempt
|
||||
b.AttemptID = 3
|
||||
err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
||||
if err != ErrPaymentTerminal {
|
||||
@ -705,6 +725,100 @@ func TestPaymentControlMultiShard(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPaymentControlMPPRecordValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := initDB()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to init db: %v", err)
|
||||
}
|
||||
|
||||
pControl := NewPaymentControl(db)
|
||||
|
||||
info, attempt, _, err := genInfo()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to generate htlc message: %v", err)
|
||||
}
|
||||
|
||||
// Init the payment.
|
||||
err = pControl.InitPayment(info.PaymentHash, info)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send htlc message: %v", err)
|
||||
}
|
||||
|
||||
// Create three unique attempts we'll use for the test, and
|
||||
// register them with the payment control. We set each
|
||||
// attempts's value to one third of the payment amount, and
|
||||
// populate the MPP options.
|
||||
shardAmt := info.Value / 3
|
||||
attempt.Route.FinalHop().AmtToForward = shardAmt
|
||||
attempt.Route.FinalHop().MPP = record.NewMPP(
|
||||
info.Value, [32]byte{1},
|
||||
)
|
||||
|
||||
err = pControl.RegisterAttempt(info.PaymentHash, attempt)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send htlc message: %v", err)
|
||||
}
|
||||
|
||||
// Now try to register a non-MPP attempt, which should fail.
|
||||
b := *attempt
|
||||
b.AttemptID = 1
|
||||
b.Route.FinalHop().MPP = nil
|
||||
err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
||||
if err != ErrMPPayment {
|
||||
t.Fatalf("expected ErrMPPayment, got: %v", err)
|
||||
}
|
||||
|
||||
// Try to register attempt one with a different payment address.
|
||||
b.Route.FinalHop().MPP = record.NewMPP(
|
||||
info.Value, [32]byte{2},
|
||||
)
|
||||
err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
||||
if err != ErrMPPPaymentAddrMismatch {
|
||||
t.Fatalf("expected ErrMPPPaymentAddrMismatch, got: %v", err)
|
||||
}
|
||||
|
||||
// Try registering one with a different total amount.
|
||||
b.Route.FinalHop().MPP = record.NewMPP(
|
||||
info.Value/2, [32]byte{1},
|
||||
)
|
||||
err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
||||
if err != ErrMPPTotalAmountMismatch {
|
||||
t.Fatalf("expected ErrMPPTotalAmountMismatch, got: %v", err)
|
||||
}
|
||||
|
||||
// Create and init a new payment. This time we'll check that we cannot
|
||||
// register an MPP attempt if we already registered a non-MPP one.
|
||||
info, attempt, _, err = genInfo()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to generate htlc message: %v", err)
|
||||
}
|
||||
|
||||
err = pControl.InitPayment(info.PaymentHash, info)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send htlc message: %v", err)
|
||||
}
|
||||
|
||||
attempt.Route.FinalHop().MPP = nil
|
||||
err = pControl.RegisterAttempt(info.PaymentHash, attempt)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send htlc message: %v", err)
|
||||
}
|
||||
|
||||
// Attempt to register an MPP attempt, which should fail.
|
||||
b = *attempt
|
||||
b.AttemptID = 1
|
||||
b.Route.FinalHop().MPP = record.NewMPP(
|
||||
info.Value, [32]byte{1},
|
||||
)
|
||||
|
||||
err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
||||
if err != ErrNonMPPayment {
|
||||
t.Fatalf("expected ErrNonMPPayment, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// assertPaymentStatus retrieves the status of the payment referred to by hash
|
||||
// and compares it with the expected state.
|
||||
func assertPaymentStatus(t *testing.T, p *PaymentControl,
|
||||
|
@ -324,7 +324,7 @@ func genInfo() (*channeldb.PaymentCreationInfo, *channeldb.HTLCAttemptInfo,
|
||||
rhash := sha256.Sum256(preimage[:])
|
||||
return &channeldb.PaymentCreationInfo{
|
||||
PaymentHash: rhash,
|
||||
Value: 1,
|
||||
Value: testRoute.ReceiverAmt(),
|
||||
CreationTime: time.Unix(time.Now().Unix(), 0),
|
||||
PaymentRequest: []byte("hola"),
|
||||
},
|
||||
|
@ -129,6 +129,23 @@ type Hop struct {
|
||||
LegacyPayload bool
|
||||
}
|
||||
|
||||
// Copy returns a deep copy of the Hop.
|
||||
func (h *Hop) Copy() *Hop {
|
||||
c := *h
|
||||
|
||||
if h.MPP != nil {
|
||||
m := *h.MPP
|
||||
c.MPP = &m
|
||||
}
|
||||
|
||||
if h.AMP != nil {
|
||||
a := *h.AMP
|
||||
c.AMP = &a
|
||||
}
|
||||
|
||||
return &c
|
||||
}
|
||||
|
||||
// PackHopPayload writes to the passed io.Writer, the series of byes that can
|
||||
// be placed directly into the per-hop payload (EOB) for this hop. This will
|
||||
// include the required routing fields, as well as serializing any of the
|
||||
@ -287,6 +304,18 @@ type Route struct {
|
||||
Hops []*Hop
|
||||
}
|
||||
|
||||
// Copy returns a deep copy of the Route.
|
||||
func (r *Route) Copy() *Route {
|
||||
c := *r
|
||||
|
||||
c.Hops = make([]*Hop, len(r.Hops))
|
||||
for i := range r.Hops {
|
||||
c.Hops[i] = r.Hops[i].Copy()
|
||||
}
|
||||
|
||||
return &c
|
||||
}
|
||||
|
||||
// HopFee returns the fee charged by the route hop indicated by hopIndex.
|
||||
func (r *Route) HopFee(hopIndex int) lnwire.MilliSatoshi {
|
||||
var incomingAmt lnwire.MilliSatoshi
|
||||
@ -320,6 +349,15 @@ func (r *Route) ReceiverAmt() lnwire.MilliSatoshi {
|
||||
return r.Hops[len(r.Hops)-1].AmtToForward
|
||||
}
|
||||
|
||||
// FinalHop returns the last hop of the route, or nil if the route is empty.
|
||||
func (r *Route) FinalHop() *Hop {
|
||||
if len(r.Hops) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.Hops[len(r.Hops)-1]
|
||||
}
|
||||
|
||||
// NewRouteFromHops creates a new Route structure from the minimally required
|
||||
// information to perform the payment. It infers fee amounts and populates the
|
||||
// node, chan and prev/next hop maps.
|
||||
|
Loading…
Reference in New Issue
Block a user