From 8cb17d5c1f88def8e63a2c43ee3d78c6b125d882 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Mon, 4 Nov 2019 15:11:23 -0800 Subject: [PATCH] lnrpc/routerrpc/router_backend: populate and unmarshal mpp fields This commit parses mpp_total_amt_msat and mpp_payment_addr from the SendToRoute rpc and populates an MPP record on the internal hop reprsentation. When the router goes to encode the onion packet, these fields will be serialized for the destination. We also populate the mpp fields when marshalling routes in rpc responses. --- lnrpc/routerrpc/router_backend.go | 68 ++++++++++++++ lnrpc/routerrpc/router_backend_test.go | 120 +++++++++++++++++++++++++ 2 files changed, 188 insertions(+) diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index 99942b0d..c90e5513 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -13,7 +13,9 @@ import ( "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcutil" "github.com/lightningnetwork/lnd/lnrpc" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/tlv" @@ -352,6 +354,17 @@ func (r *RouterBackend) MarshallRoute(route *route.Route) (*lnrpc.Route, error) chanCapacity = incomingAmt.ToSatoshis() } + // Extract the MPP fields if present on this hop. + var mpp *lnrpc.MPPRecord + if hop.MPP != nil { + addr := hop.MPP.PaymentAddr() + + mpp = &lnrpc.MPPRecord{ + PaymentAddr: addr[:], + TotalAmtMsat: int64(hop.MPP.TotalMsat()), + } + } + resp.Hops[i] = &lnrpc.Hop{ ChanId: hop.ChannelID, ChanCapacity: int64(chanCapacity), @@ -364,6 +377,7 @@ func (r *RouterBackend) MarshallRoute(route *route.Route) (*lnrpc.Route, error) hop.PubKeyBytes[:], ), TlvPayload: !hop.LegacyPayload, + MppRecord: mpp, } incomingAmt = hop.AmtToForward } @@ -396,6 +410,11 @@ func (r *RouterBackend) UnmarshallHopByChannelLookup(hop *lnrpc.Hop, var tlvRecords []tlv.Record + mpp, err := UnmarshalMPP(hop.MppRecord) + if err != nil { + return nil, err + } + return &route.Hop{ OutgoingTimeLock: hop.Expiry, AmtToForward: lnwire.MilliSatoshi(hop.AmtToForwardMsat), @@ -403,6 +422,7 @@ func (r *RouterBackend) UnmarshallHopByChannelLookup(hop *lnrpc.Hop, ChannelID: hop.ChanId, TLVRecords: tlvRecords, LegacyPayload: !hop.TlvPayload, + MPP: mpp, }, nil } @@ -420,6 +440,11 @@ func UnmarshallKnownPubkeyHop(hop *lnrpc.Hop) (*route.Hop, error) { var tlvRecords []tlv.Record + mpp, err := UnmarshalMPP(hop.MppRecord) + if err != nil { + return nil, err + } + return &route.Hop{ OutgoingTimeLock: hop.Expiry, AmtToForward: lnwire.MilliSatoshi(hop.AmtToForwardMsat), @@ -427,6 +452,7 @@ func UnmarshallKnownPubkeyHop(hop *lnrpc.Hop) (*route.Hop, error) { ChannelID: hop.ChanId, TLVRecords: tlvRecords, LegacyPayload: !hop.TlvPayload, + MPP: mpp, }, nil } @@ -712,3 +738,45 @@ func ValidateCLTVLimit(val, max uint32) (uint32, error) { return val, nil } } + +// UnmarshalMPP accepts the mpp_total_amt_msat and mpp_payment_addr fields from +// an RPC request and converts into an record.MPP object. An error is returned +// if the payment address is not 0 or 32 bytes. If the total amount and payment +// address are zero-value, the return value will be nil signaling there is no +// MPP record to attach to this hop. Otherwise, a non-nil reocrd will be +// contained combining the provided values. +func UnmarshalMPP(reqMPP *lnrpc.MPPRecord) (*record.MPP, error) { + // If no MPP record was submitted, assume the user wants to send a + // regular payment. + if reqMPP == nil { + return nil, nil + } + + reqTotal := reqMPP.TotalAmtMsat + reqAddr := reqMPP.PaymentAddr + + switch { + + // No MPP fields were provided. + case reqTotal == 0 && len(reqAddr) == 0: + return nil, fmt.Errorf("missing total_msat and payment_addr") + + // Total is present, but payment address is missing. + case reqTotal > 0 && len(reqAddr) == 0: + return nil, fmt.Errorf("missing payment_addr") + + // Payment address is present, but total is missing. + case reqTotal == 0 && len(reqAddr) > 0: + return nil, fmt.Errorf("missing total_msat") + } + + addr, err := lntypes.MakeHash(reqAddr) + if err != nil { + return nil, fmt.Errorf("unable to parse "+ + "payment_addr: %v", err) + } + + total := lnwire.MilliSatoshi(reqTotal) + + return record.NewMPP(total, addr), nil +} diff --git a/lnrpc/routerrpc/router_backend_test.go b/lnrpc/routerrpc/router_backend_test.go index 48e5b8fd..779231f8 100644 --- a/lnrpc/routerrpc/router_backend_test.go +++ b/lnrpc/routerrpc/router_backend_test.go @@ -180,3 +180,123 @@ func (m *mockMissionControl) GetPairHistorySnapshot(fromNode, return routing.TimedPairResult{} } + +type mppOutcome byte + +const ( + valid mppOutcome = iota + invalid + nompp +) + +type unmarshalMPPTest struct { + name string + mpp *lnrpc.MPPRecord + outcome mppOutcome +} + +// TestUnmarshalMPP checks both positive and negative cases of UnmarshalMPP to +// assert that an MPP record is only returned when both fields are properly +// specified. It also asserts that zero-values for both inputs is also valid, +// but returns a nil record. +func TestUnmarshalMPP(t *testing.T) { + tests := []unmarshalMPPTest{ + { + name: "nil record", + mpp: nil, + outcome: nompp, + }, + { + name: "invalid total or addr", + mpp: &lnrpc.MPPRecord{ + PaymentAddr: nil, + TotalAmtMsat: 0, + }, + outcome: invalid, + }, + { + name: "valid total only", + mpp: &lnrpc.MPPRecord{ + PaymentAddr: nil, + TotalAmtMsat: 8, + }, + outcome: invalid, + }, + { + name: "valid addr only", + mpp: &lnrpc.MPPRecord{ + PaymentAddr: bytes.Repeat([]byte{0x02}, 32), + TotalAmtMsat: 0, + }, + outcome: invalid, + }, + { + name: "valid total and invalid addr", + mpp: &lnrpc.MPPRecord{ + PaymentAddr: []byte{0x02}, + TotalAmtMsat: 8, + }, + outcome: invalid, + }, + { + name: "valid total and valid addr", + mpp: &lnrpc.MPPRecord{ + PaymentAddr: bytes.Repeat([]byte{0x02}, 32), + TotalAmtMsat: 8, + }, + outcome: valid, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + testUnmarshalMPP(t, test) + }) + } +} + +func testUnmarshalMPP(t *testing.T, test unmarshalMPPTest) { + mpp, err := UnmarshalMPP(test.mpp) + switch test.outcome { + + // Valid arguments should result in no error, a non-nil MPP record, and + // the fields should be set correctly. + case valid: + if err != nil { + t.Fatalf("unable to parse mpp record: %v", err) + } + if mpp == nil { + t.Fatalf("mpp payload should be non-nil") + } + if int64(mpp.TotalMsat()) != test.mpp.TotalAmtMsat { + t.Fatalf("incorrect total msat") + } + addr := mpp.PaymentAddr() + if !bytes.Equal(addr[:], test.mpp.PaymentAddr) { + t.Fatalf("incorrect payment addr") + } + + // Invalid arguments should produce a failure and nil MPP record. + case invalid: + if err == nil { + t.Fatalf("expected failure for invalid mpp") + } + if mpp != nil { + t.Fatalf("mpp payload should be nil for failure") + } + + // Arguments that produce no MPP field should return no error and no MPP + // record. + case nompp: + if err != nil { + t.Fatalf("failure for args resulting for no-mpp") + } + if mpp != nil { + t.Fatalf("mpp payload should be nil for no-mpp") + } + + default: + t.Fatalf("test case has non-standard outcome") + } +}