lnrpc/routerrpc/router_backend: populate and unmarshal mpp fields
This commit parses mpp_total_amt_msat and mpp_payment_addr from the SendToRoute rpc and populates an MPP record on the internal hop reprsentation. When the router goes to encode the onion packet, these fields will be serialized for the destination. We also populate the mpp fields when marshalling routes in rpc responses.
This commit is contained in:
parent
92aa78dd5f
commit
8cb17d5c1f
@ -13,7 +13,9 @@ import (
|
||||
"github.com/btcsuite/btcd/chaincfg"
|
||||
"github.com/btcsuite/btcutil"
|
||||
"github.com/lightningnetwork/lnd/lnrpc"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
"github.com/lightningnetwork/lnd/routing"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
@ -352,6 +354,17 @@ func (r *RouterBackend) MarshallRoute(route *route.Route) (*lnrpc.Route, error)
|
||||
chanCapacity = incomingAmt.ToSatoshis()
|
||||
}
|
||||
|
||||
// Extract the MPP fields if present on this hop.
|
||||
var mpp *lnrpc.MPPRecord
|
||||
if hop.MPP != nil {
|
||||
addr := hop.MPP.PaymentAddr()
|
||||
|
||||
mpp = &lnrpc.MPPRecord{
|
||||
PaymentAddr: addr[:],
|
||||
TotalAmtMsat: int64(hop.MPP.TotalMsat()),
|
||||
}
|
||||
}
|
||||
|
||||
resp.Hops[i] = &lnrpc.Hop{
|
||||
ChanId: hop.ChannelID,
|
||||
ChanCapacity: int64(chanCapacity),
|
||||
@ -364,6 +377,7 @@ func (r *RouterBackend) MarshallRoute(route *route.Route) (*lnrpc.Route, error)
|
||||
hop.PubKeyBytes[:],
|
||||
),
|
||||
TlvPayload: !hop.LegacyPayload,
|
||||
MppRecord: mpp,
|
||||
}
|
||||
incomingAmt = hop.AmtToForward
|
||||
}
|
||||
@ -396,6 +410,11 @@ func (r *RouterBackend) UnmarshallHopByChannelLookup(hop *lnrpc.Hop,
|
||||
|
||||
var tlvRecords []tlv.Record
|
||||
|
||||
mpp, err := UnmarshalMPP(hop.MppRecord)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &route.Hop{
|
||||
OutgoingTimeLock: hop.Expiry,
|
||||
AmtToForward: lnwire.MilliSatoshi(hop.AmtToForwardMsat),
|
||||
@ -403,6 +422,7 @@ func (r *RouterBackend) UnmarshallHopByChannelLookup(hop *lnrpc.Hop,
|
||||
ChannelID: hop.ChanId,
|
||||
TLVRecords: tlvRecords,
|
||||
LegacyPayload: !hop.TlvPayload,
|
||||
MPP: mpp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -420,6 +440,11 @@ func UnmarshallKnownPubkeyHop(hop *lnrpc.Hop) (*route.Hop, error) {
|
||||
|
||||
var tlvRecords []tlv.Record
|
||||
|
||||
mpp, err := UnmarshalMPP(hop.MppRecord)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &route.Hop{
|
||||
OutgoingTimeLock: hop.Expiry,
|
||||
AmtToForward: lnwire.MilliSatoshi(hop.AmtToForwardMsat),
|
||||
@ -427,6 +452,7 @@ func UnmarshallKnownPubkeyHop(hop *lnrpc.Hop) (*route.Hop, error) {
|
||||
ChannelID: hop.ChanId,
|
||||
TLVRecords: tlvRecords,
|
||||
LegacyPayload: !hop.TlvPayload,
|
||||
MPP: mpp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -712,3 +738,45 @@ func ValidateCLTVLimit(val, max uint32) (uint32, error) {
|
||||
return val, nil
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalMPP accepts the mpp_total_amt_msat and mpp_payment_addr fields from
|
||||
// an RPC request and converts into an record.MPP object. An error is returned
|
||||
// if the payment address is not 0 or 32 bytes. If the total amount and payment
|
||||
// address are zero-value, the return value will be nil signaling there is no
|
||||
// MPP record to attach to this hop. Otherwise, a non-nil reocrd will be
|
||||
// contained combining the provided values.
|
||||
func UnmarshalMPP(reqMPP *lnrpc.MPPRecord) (*record.MPP, error) {
|
||||
// If no MPP record was submitted, assume the user wants to send a
|
||||
// regular payment.
|
||||
if reqMPP == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
reqTotal := reqMPP.TotalAmtMsat
|
||||
reqAddr := reqMPP.PaymentAddr
|
||||
|
||||
switch {
|
||||
|
||||
// No MPP fields were provided.
|
||||
case reqTotal == 0 && len(reqAddr) == 0:
|
||||
return nil, fmt.Errorf("missing total_msat and payment_addr")
|
||||
|
||||
// Total is present, but payment address is missing.
|
||||
case reqTotal > 0 && len(reqAddr) == 0:
|
||||
return nil, fmt.Errorf("missing payment_addr")
|
||||
|
||||
// Payment address is present, but total is missing.
|
||||
case reqTotal == 0 && len(reqAddr) > 0:
|
||||
return nil, fmt.Errorf("missing total_msat")
|
||||
}
|
||||
|
||||
addr, err := lntypes.MakeHash(reqAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse "+
|
||||
"payment_addr: %v", err)
|
||||
}
|
||||
|
||||
total := lnwire.MilliSatoshi(reqTotal)
|
||||
|
||||
return record.NewMPP(total, addr), nil
|
||||
}
|
||||
|
@ -180,3 +180,123 @@ func (m *mockMissionControl) GetPairHistorySnapshot(fromNode,
|
||||
|
||||
return routing.TimedPairResult{}
|
||||
}
|
||||
|
||||
type mppOutcome byte
|
||||
|
||||
const (
|
||||
valid mppOutcome = iota
|
||||
invalid
|
||||
nompp
|
||||
)
|
||||
|
||||
type unmarshalMPPTest struct {
|
||||
name string
|
||||
mpp *lnrpc.MPPRecord
|
||||
outcome mppOutcome
|
||||
}
|
||||
|
||||
// TestUnmarshalMPP checks both positive and negative cases of UnmarshalMPP to
|
||||
// assert that an MPP record is only returned when both fields are properly
|
||||
// specified. It also asserts that zero-values for both inputs is also valid,
|
||||
// but returns a nil record.
|
||||
func TestUnmarshalMPP(t *testing.T) {
|
||||
tests := []unmarshalMPPTest{
|
||||
{
|
||||
name: "nil record",
|
||||
mpp: nil,
|
||||
outcome: nompp,
|
||||
},
|
||||
{
|
||||
name: "invalid total or addr",
|
||||
mpp: &lnrpc.MPPRecord{
|
||||
PaymentAddr: nil,
|
||||
TotalAmtMsat: 0,
|
||||
},
|
||||
outcome: invalid,
|
||||
},
|
||||
{
|
||||
name: "valid total only",
|
||||
mpp: &lnrpc.MPPRecord{
|
||||
PaymentAddr: nil,
|
||||
TotalAmtMsat: 8,
|
||||
},
|
||||
outcome: invalid,
|
||||
},
|
||||
{
|
||||
name: "valid addr only",
|
||||
mpp: &lnrpc.MPPRecord{
|
||||
PaymentAddr: bytes.Repeat([]byte{0x02}, 32),
|
||||
TotalAmtMsat: 0,
|
||||
},
|
||||
outcome: invalid,
|
||||
},
|
||||
{
|
||||
name: "valid total and invalid addr",
|
||||
mpp: &lnrpc.MPPRecord{
|
||||
PaymentAddr: []byte{0x02},
|
||||
TotalAmtMsat: 8,
|
||||
},
|
||||
outcome: invalid,
|
||||
},
|
||||
{
|
||||
name: "valid total and valid addr",
|
||||
mpp: &lnrpc.MPPRecord{
|
||||
PaymentAddr: bytes.Repeat([]byte{0x02}, 32),
|
||||
TotalAmtMsat: 8,
|
||||
},
|
||||
outcome: valid,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
testUnmarshalMPP(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testUnmarshalMPP(t *testing.T, test unmarshalMPPTest) {
|
||||
mpp, err := UnmarshalMPP(test.mpp)
|
||||
switch test.outcome {
|
||||
|
||||
// Valid arguments should result in no error, a non-nil MPP record, and
|
||||
// the fields should be set correctly.
|
||||
case valid:
|
||||
if err != nil {
|
||||
t.Fatalf("unable to parse mpp record: %v", err)
|
||||
}
|
||||
if mpp == nil {
|
||||
t.Fatalf("mpp payload should be non-nil")
|
||||
}
|
||||
if int64(mpp.TotalMsat()) != test.mpp.TotalAmtMsat {
|
||||
t.Fatalf("incorrect total msat")
|
||||
}
|
||||
addr := mpp.PaymentAddr()
|
||||
if !bytes.Equal(addr[:], test.mpp.PaymentAddr) {
|
||||
t.Fatalf("incorrect payment addr")
|
||||
}
|
||||
|
||||
// Invalid arguments should produce a failure and nil MPP record.
|
||||
case invalid:
|
||||
if err == nil {
|
||||
t.Fatalf("expected failure for invalid mpp")
|
||||
}
|
||||
if mpp != nil {
|
||||
t.Fatalf("mpp payload should be nil for failure")
|
||||
}
|
||||
|
||||
// Arguments that produce no MPP field should return no error and no MPP
|
||||
// record.
|
||||
case nompp:
|
||||
if err != nil {
|
||||
t.Fatalf("failure for args resulting for no-mpp")
|
||||
}
|
||||
if mpp != nil {
|
||||
t.Fatalf("mpp payload should be nil for no-mpp")
|
||||
}
|
||||
|
||||
default:
|
||||
t.Fatalf("test case has non-standard outcome")
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user