channeldb: validate MPP options when registering attempts
We add validation making sure we are not trying to register MPP shards for non-MPP payments, and vice versa. We also add validtion of total sent amount against payment value, and matching MPP options. We also add methods for copying Route/Hop, since it is useful to use for modifying the route amount in the test.
This commit is contained in:
parent
9a1ec950bd
commit
864e64e725
@ -45,6 +45,32 @@ var (
|
|||||||
// failed HTLC attempt.
|
// failed HTLC attempt.
|
||||||
ErrAttemptAlreadyFailed = errors.New("attempt already failed")
|
ErrAttemptAlreadyFailed = errors.New("attempt already failed")
|
||||||
|
|
||||||
|
// ErrValueMismatch is returned if we try to register a non-MPP attempt
|
||||||
|
// with an amount that doesn't match the payment amount.
|
||||||
|
ErrValueMismatch = errors.New("attempted value doesn't match payment" +
|
||||||
|
"amount")
|
||||||
|
|
||||||
|
// ErrValueExceedsAmt is returned if we try to register an attempt that
|
||||||
|
// would take the total sent amount above the payment amount.
|
||||||
|
ErrValueExceedsAmt = errors.New("attempted value exceeds payment" +
|
||||||
|
"amount")
|
||||||
|
|
||||||
|
// ErrNonMPPayment is returned if we try to register an MPP attempt for
|
||||||
|
// a payment that already has a non-MPP attempt regitered.
|
||||||
|
ErrNonMPPayment = errors.New("payment has non-MPP attempts")
|
||||||
|
|
||||||
|
// ErrMPPayment is returned if we try to register a non-MPP attempt for
|
||||||
|
// a payment that already has an MPP attempt regitered.
|
||||||
|
ErrMPPayment = errors.New("payment has MPP attempts")
|
||||||
|
|
||||||
|
// ErrMPPPaymentAddrMismatch is returned if we try to register an MPP
|
||||||
|
// shard where the payment address doesn't match existing shards.
|
||||||
|
ErrMPPPaymentAddrMismatch = errors.New("payment address mismatch")
|
||||||
|
|
||||||
|
// ErrMPPTotalAmountMismatch is returned if we try to register an MPP
|
||||||
|
// shard where the total amount doesn't match existing shards.
|
||||||
|
ErrMPPTotalAmountMismatch = errors.New("mp payment total amount mismatch")
|
||||||
|
|
||||||
// errNoAttemptInfo is returned when no attempt info is stored yet.
|
// errNoAttemptInfo is returned when no attempt info is stored yet.
|
||||||
errNoAttemptInfo = errors.New("unable to find attempt info for " +
|
errNoAttemptInfo = errors.New("unable to find attempt info for " +
|
||||||
"inflight payment")
|
"inflight payment")
|
||||||
@ -189,11 +215,59 @@ func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We cannot register a new attempt if the payment already has
|
||||||
|
// reached a terminal condition:
|
||||||
settle, fail := payment.TerminalInfo()
|
settle, fail := payment.TerminalInfo()
|
||||||
if settle != nil || fail != nil {
|
if settle != nil || fail != nil {
|
||||||
return ErrPaymentTerminal
|
return ErrPaymentTerminal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Make sure any existing shards match the new one with regards
|
||||||
|
// to MPP options.
|
||||||
|
mpp := attempt.Route.FinalHop().MPP
|
||||||
|
for _, h := range payment.InFlightHTLCs() {
|
||||||
|
hMpp := h.Route.FinalHop().MPP
|
||||||
|
|
||||||
|
switch {
|
||||||
|
|
||||||
|
// We tried to register a non-MPP attempt for a MPP
|
||||||
|
// payment.
|
||||||
|
case mpp == nil && hMpp != nil:
|
||||||
|
return ErrMPPayment
|
||||||
|
|
||||||
|
// We tried to register a MPP shard for a non-MPP
|
||||||
|
// payment.
|
||||||
|
case mpp != nil && hMpp == nil:
|
||||||
|
return ErrNonMPPayment
|
||||||
|
|
||||||
|
// Non-MPP payment, nothing more to validate.
|
||||||
|
case mpp == nil:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that MPP options match.
|
||||||
|
if mpp.PaymentAddr() != hMpp.PaymentAddr() {
|
||||||
|
return ErrMPPPaymentAddrMismatch
|
||||||
|
}
|
||||||
|
|
||||||
|
if mpp.TotalMsat() != hMpp.TotalMsat() {
|
||||||
|
return ErrMPPTotalAmountMismatch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this is a non-MPP attempt, it must match the total amount
|
||||||
|
// exactly.
|
||||||
|
amt := attempt.Route.ReceiverAmt()
|
||||||
|
if mpp == nil && amt != payment.Info.Value {
|
||||||
|
return ErrValueMismatch
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure we aren't sending more than the total payment amount.
|
||||||
|
sentAmt, _ := payment.SentAmt()
|
||||||
|
if sentAmt+amt > payment.Info.Value {
|
||||||
|
return ErrValueExceedsAmt
|
||||||
|
}
|
||||||
|
|
||||||
htlcsBucket, err := bucket.CreateBucketIfNotExists(
|
htlcsBucket, err := bucket.CreateBucketIfNotExists(
|
||||||
paymentHtlcsBucket,
|
paymentHtlcsBucket,
|
||||||
)
|
)
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/btcsuite/fastsha256"
|
"github.com/btcsuite/fastsha256"
|
||||||
"github.com/davecgh/go-spew/spew"
|
"github.com/davecgh/go-spew/spew"
|
||||||
"github.com/lightningnetwork/lnd/lntypes"
|
"github.com/lightningnetwork/lnd/lntypes"
|
||||||
|
"github.com/lightningnetwork/lnd/record"
|
||||||
)
|
)
|
||||||
|
|
||||||
func initDB() (*DB, error) {
|
func initDB() (*DB, error) {
|
||||||
@ -48,14 +49,14 @@ func genInfo() (*PaymentCreationInfo, *HTLCAttemptInfo,
|
|||||||
rhash := fastsha256.Sum256(preimage[:])
|
rhash := fastsha256.Sum256(preimage[:])
|
||||||
return &PaymentCreationInfo{
|
return &PaymentCreationInfo{
|
||||||
PaymentHash: rhash,
|
PaymentHash: rhash,
|
||||||
Value: 1,
|
Value: testRoute.ReceiverAmt(),
|
||||||
CreationTime: time.Unix(time.Now().Unix(), 0),
|
CreationTime: time.Unix(time.Now().Unix(), 0),
|
||||||
PaymentRequest: []byte("hola"),
|
PaymentRequest: []byte("hola"),
|
||||||
},
|
},
|
||||||
&HTLCAttemptInfo{
|
&HTLCAttemptInfo{
|
||||||
AttemptID: 0,
|
AttemptID: 0,
|
||||||
SessionKey: priv,
|
SessionKey: priv,
|
||||||
Route: testRoute,
|
Route: *testRoute.Copy(),
|
||||||
}, preimage, nil
|
}, preimage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -504,7 +505,15 @@ func TestPaymentControlMultiShard(t *testing.T) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Create three unique attempts we'll use for the test, and
|
// Create three unique attempts we'll use for the test, and
|
||||||
// register them with the payment control.
|
// register them with the payment control. We set each
|
||||||
|
// attempts's value to one third of the payment amount, and
|
||||||
|
// populate the MPP options.
|
||||||
|
shardAmt := info.Value / 3
|
||||||
|
attempt.Route.FinalHop().AmtToForward = shardAmt
|
||||||
|
attempt.Route.FinalHop().MPP = record.NewMPP(
|
||||||
|
info.Value, [32]byte{1},
|
||||||
|
)
|
||||||
|
|
||||||
var attempts []*HTLCAttemptInfo
|
var attempts []*HTLCAttemptInfo
|
||||||
for i := uint64(0); i < 3; i++ {
|
for i := uint64(0); i < 3; i++ {
|
||||||
a := *attempt
|
a := *attempt
|
||||||
@ -527,6 +536,17 @@ func TestPaymentControlMultiShard(t *testing.T) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For a fourth attempt, check that attempting to
|
||||||
|
// register it will fail since the total sent amount
|
||||||
|
// will be too large.
|
||||||
|
b := *attempt
|
||||||
|
b.AttemptID = 3
|
||||||
|
err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
||||||
|
if err != ErrValueExceedsAmt {
|
||||||
|
t.Fatalf("expected ErrValueExceedsAmt, got: %v",
|
||||||
|
err)
|
||||||
|
}
|
||||||
|
|
||||||
// Fail the second attempt.
|
// Fail the second attempt.
|
||||||
a := attempts[1]
|
a := attempts[1]
|
||||||
htlcFail := HTLCFailUnreadable
|
htlcFail := HTLCFailUnreadable
|
||||||
@ -612,7 +632,7 @@ func TestPaymentControlMultiShard(t *testing.T) {
|
|||||||
|
|
||||||
// Try to register yet another attempt. This should fail now
|
// Try to register yet another attempt. This should fail now
|
||||||
// that the payment has reached a terminal condition.
|
// that the payment has reached a terminal condition.
|
||||||
b := *attempt
|
b = *attempt
|
||||||
b.AttemptID = 3
|
b.AttemptID = 3
|
||||||
err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
||||||
if err != ErrPaymentTerminal {
|
if err != ErrPaymentTerminal {
|
||||||
@ -705,6 +725,100 @@ func TestPaymentControlMultiShard(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPaymentControlMPPRecordValidation(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
db, err := initDB()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to init db: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pControl := NewPaymentControl(db)
|
||||||
|
|
||||||
|
info, attempt, _, err := genInfo()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to generate htlc message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init the payment.
|
||||||
|
err = pControl.InitPayment(info.PaymentHash, info)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to send htlc message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create three unique attempts we'll use for the test, and
|
||||||
|
// register them with the payment control. We set each
|
||||||
|
// attempts's value to one third of the payment amount, and
|
||||||
|
// populate the MPP options.
|
||||||
|
shardAmt := info.Value / 3
|
||||||
|
attempt.Route.FinalHop().AmtToForward = shardAmt
|
||||||
|
attempt.Route.FinalHop().MPP = record.NewMPP(
|
||||||
|
info.Value, [32]byte{1},
|
||||||
|
)
|
||||||
|
|
||||||
|
err = pControl.RegisterAttempt(info.PaymentHash, attempt)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to send htlc message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now try to register a non-MPP attempt, which should fail.
|
||||||
|
b := *attempt
|
||||||
|
b.AttemptID = 1
|
||||||
|
b.Route.FinalHop().MPP = nil
|
||||||
|
err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
||||||
|
if err != ErrMPPayment {
|
||||||
|
t.Fatalf("expected ErrMPPayment, got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to register attempt one with a different payment address.
|
||||||
|
b.Route.FinalHop().MPP = record.NewMPP(
|
||||||
|
info.Value, [32]byte{2},
|
||||||
|
)
|
||||||
|
err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
||||||
|
if err != ErrMPPPaymentAddrMismatch {
|
||||||
|
t.Fatalf("expected ErrMPPPaymentAddrMismatch, got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try registering one with a different total amount.
|
||||||
|
b.Route.FinalHop().MPP = record.NewMPP(
|
||||||
|
info.Value/2, [32]byte{1},
|
||||||
|
)
|
||||||
|
err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
||||||
|
if err != ErrMPPTotalAmountMismatch {
|
||||||
|
t.Fatalf("expected ErrMPPTotalAmountMismatch, got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and init a new payment. This time we'll check that we cannot
|
||||||
|
// register an MPP attempt if we already registered a non-MPP one.
|
||||||
|
info, attempt, _, err = genInfo()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to generate htlc message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pControl.InitPayment(info.PaymentHash, info)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to send htlc message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
attempt.Route.FinalHop().MPP = nil
|
||||||
|
err = pControl.RegisterAttempt(info.PaymentHash, attempt)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to send htlc message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt to register an MPP attempt, which should fail.
|
||||||
|
b = *attempt
|
||||||
|
b.AttemptID = 1
|
||||||
|
b.Route.FinalHop().MPP = record.NewMPP(
|
||||||
|
info.Value, [32]byte{1},
|
||||||
|
)
|
||||||
|
|
||||||
|
err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
||||||
|
if err != ErrNonMPPayment {
|
||||||
|
t.Fatalf("expected ErrNonMPPayment, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// assertPaymentStatus retrieves the status of the payment referred to by hash
|
// assertPaymentStatus retrieves the status of the payment referred to by hash
|
||||||
// and compares it with the expected state.
|
// and compares it with the expected state.
|
||||||
func assertPaymentStatus(t *testing.T, p *PaymentControl,
|
func assertPaymentStatus(t *testing.T, p *PaymentControl,
|
||||||
|
@ -324,7 +324,7 @@ func genInfo() (*channeldb.PaymentCreationInfo, *channeldb.HTLCAttemptInfo,
|
|||||||
rhash := sha256.Sum256(preimage[:])
|
rhash := sha256.Sum256(preimage[:])
|
||||||
return &channeldb.PaymentCreationInfo{
|
return &channeldb.PaymentCreationInfo{
|
||||||
PaymentHash: rhash,
|
PaymentHash: rhash,
|
||||||
Value: 1,
|
Value: testRoute.ReceiverAmt(),
|
||||||
CreationTime: time.Unix(time.Now().Unix(), 0),
|
CreationTime: time.Unix(time.Now().Unix(), 0),
|
||||||
PaymentRequest: []byte("hola"),
|
PaymentRequest: []byte("hola"),
|
||||||
},
|
},
|
||||||
|
@ -129,6 +129,23 @@ type Hop struct {
|
|||||||
LegacyPayload bool
|
LegacyPayload bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Copy returns a deep copy of the Hop.
|
||||||
|
func (h *Hop) Copy() *Hop {
|
||||||
|
c := *h
|
||||||
|
|
||||||
|
if h.MPP != nil {
|
||||||
|
m := *h.MPP
|
||||||
|
c.MPP = &m
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.AMP != nil {
|
||||||
|
a := *h.AMP
|
||||||
|
c.AMP = &a
|
||||||
|
}
|
||||||
|
|
||||||
|
return &c
|
||||||
|
}
|
||||||
|
|
||||||
// PackHopPayload writes to the passed io.Writer, the series of byes that can
|
// PackHopPayload writes to the passed io.Writer, the series of byes that can
|
||||||
// be placed directly into the per-hop payload (EOB) for this hop. This will
|
// be placed directly into the per-hop payload (EOB) for this hop. This will
|
||||||
// include the required routing fields, as well as serializing any of the
|
// include the required routing fields, as well as serializing any of the
|
||||||
@ -287,6 +304,18 @@ type Route struct {
|
|||||||
Hops []*Hop
|
Hops []*Hop
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Copy returns a deep copy of the Route.
|
||||||
|
func (r *Route) Copy() *Route {
|
||||||
|
c := *r
|
||||||
|
|
||||||
|
c.Hops = make([]*Hop, len(r.Hops))
|
||||||
|
for i := range r.Hops {
|
||||||
|
c.Hops[i] = r.Hops[i].Copy()
|
||||||
|
}
|
||||||
|
|
||||||
|
return &c
|
||||||
|
}
|
||||||
|
|
||||||
// HopFee returns the fee charged by the route hop indicated by hopIndex.
|
// HopFee returns the fee charged by the route hop indicated by hopIndex.
|
||||||
func (r *Route) HopFee(hopIndex int) lnwire.MilliSatoshi {
|
func (r *Route) HopFee(hopIndex int) lnwire.MilliSatoshi {
|
||||||
var incomingAmt lnwire.MilliSatoshi
|
var incomingAmt lnwire.MilliSatoshi
|
||||||
@ -320,6 +349,15 @@ func (r *Route) ReceiverAmt() lnwire.MilliSatoshi {
|
|||||||
return r.Hops[len(r.Hops)-1].AmtToForward
|
return r.Hops[len(r.Hops)-1].AmtToForward
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FinalHop returns the last hop of the route, or nil if the route is empty.
|
||||||
|
func (r *Route) FinalHop() *Hop {
|
||||||
|
if len(r.Hops) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.Hops[len(r.Hops)-1]
|
||||||
|
}
|
||||||
|
|
||||||
// NewRouteFromHops creates a new Route structure from the minimally required
|
// NewRouteFromHops creates a new Route structure from the minimally required
|
||||||
// information to perform the payment. It infers fee amounts and populates the
|
// information to perform the payment. It infers fee amounts and populates the
|
||||||
// node, chan and prev/next hop maps.
|
// node, chan and prev/next hop maps.
|
||||||
|
Loading…
Reference in New Issue
Block a user