diff --git a/channeldb/payments.go b/channeldb/payments.go index 00d29366..d576c65c 100644 --- a/channeldb/payments.go +++ b/channeldb/payments.go @@ -14,6 +14,7 @@ import ( "github.com/coreos/bbolt" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/tlv" ) @@ -508,7 +509,9 @@ func deserializePaymentAttemptInfo(r io.Reader) (*PaymentAttemptInfo, error) { func serializeHop(w io.Writer, h *route.Hop) error { if err := WriteElements(w, - h.PubKeyBytes[:], h.ChannelID, h.OutgoingTimeLock, + h.PubKeyBytes[:], + h.ChannelID, + h.OutgoingTimeLock, h.AmtToForward, ); err != nil { return err @@ -525,10 +528,23 @@ func serializeHop(w io.Writer, h *route.Hop) error { return WriteElements(w, uint32(0)) } + // Gather all non-primitive TLV records so that they can be serialized + // as a single blob. + // + // TODO(conner): add migration to unify all fields in a single TLV + // blobs. The split approach will cause headaches down the road as more + // fields are added, which we can avoid by having a single TLV stream + // for all payload fields. + var records []tlv.Record + if h.MPP != nil { + records = append(records, h.MPP.Record()) + } + records = append(records, h.TLVRecords...) + // Otherwise, we'll transform our slice of records into a map of the // raw bytes, then serialize them in-line with a length (number of // elements) prefix. - mapRecords, err := tlv.RecordsToMap(h.TLVRecords) + mapRecords, err := tlv.RecordsToMap(records) if err != nil { return err } @@ -604,6 +620,29 @@ func deserializeHop(r io.Reader) (*route.Hop, error) { tlvMap[tlvType] = rawRecordBytes } + // If the MPP type is present, remove it from the generic TLV map and + // parse it back into a proper MPP struct. + // + // TODO(conner): add migration to unify all fields in a single TLV + // blobs. The split approach will cause headaches down the road as more + // fields are added, which we can avoid by having a single TLV stream + // for all payload fields. + mppType := uint64(record.MPPOnionType) + if mppBytes, ok := tlvMap[mppType]; ok { + delete(tlvMap, mppType) + + var ( + mpp = &record.MPP{} + mppRec = mpp.Record() + r = bytes.NewReader(mppBytes) + ) + err := mppRec.Decode(r, uint64(len(mppBytes))) + if err != nil { + return nil, err + } + h.MPP = mpp + } + tlvRecords, err := tlv.MapToRecords(tlvMap) if err != nil { return nil, err diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index a792f965..bec8e528 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -12,6 +12,7 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/tlv" ) @@ -31,6 +32,7 @@ var ( tlv.MakeStaticRecord(1, nil, 3, tlvEncoder, nil), tlv.MakeStaticRecord(2, nil, 3, tlvEncoder, nil), }, + MPP: record.NewMPP(32, [32]byte{0x42}), } testHop2 = &route.Hop{ @@ -46,8 +48,8 @@ var ( TotalAmount: 1234567, SourcePubKey: route.NewVertex(pub), Hops: []*route.Hop{ - testHop1, testHop2, + testHop1, }, } ) diff --git a/routing/route/route.go b/routing/route/route.go index ed51d7d0..62511c3f 100644 --- a/routing/route/route.go +++ b/routing/route/route.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "encoding/hex" + "errors" "fmt" "io" "strconv" @@ -19,9 +20,17 @@ import ( // VertexSize is the size of the array to store a vertex. const VertexSize = 33 -// ErrNoRouteHopsProvided is returned when a caller attempts to construct a new -// sphinx packet, but provides an empty set of hops for each route. -var ErrNoRouteHopsProvided = fmt.Errorf("empty route hops provided") +var ( + // ErrNoRouteHopsProvided is returned when a caller attempts to + // construct a new sphinx packet, but provides an empty set of hops for + // each route. + ErrNoRouteHopsProvided = fmt.Errorf("empty route hops provided") + + // ErrIntermediateMPPHop is returned when a hop tries to deliver an MPP + // record to an intermediate hop, only final hops can receive MPP + // records. + ErrIntermediateMPPHop = errors.New("cannot send MPP to intermediate") +) // Vertex is a simple alias for the serialization of a compressed Bitcoin // public key. @@ -94,6 +103,10 @@ type Hop struct { // carries as a fee will be subtracted by the hop. AmtToForward lnwire.MilliSatoshi + // MPP encapsulates the data required for option_mpp. This field should + // only be set for the final hop. + MPP *record.MPP + // TLVRecords if non-nil are a set of additional TLV records that // should be included in the forwarding instructions for this node. TLVRecords []tlv.Record @@ -140,6 +153,17 @@ func (h *Hop) PackHopPayload(w io.Writer, nextChanID uint64) error { ) } + // If an MPP record is destined for this hop, ensure that we only ever + // attach it to the final hop. Otherwise the route was constructed + // incorrectly. + if h.MPP != nil { + if nextChanID == 0 { + records = append(records, h.MPP.Record()) + } else { + return ErrIntermediateMPPHop + } + } + // Append any custom types destined for this hop. records = append(records, h.TLVRecords...) diff --git a/routing/route/route_test.go b/routing/route/route_test.go index 92b0ee0d..2894cbdc 100644 --- a/routing/route/route_test.go +++ b/routing/route/route_test.go @@ -1,9 +1,11 @@ package route import ( + "bytes" "testing" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" ) // TestRouteTotalFees checks that a route reports the expected total fee. @@ -56,3 +58,38 @@ func TestRouteTotalFees(t *testing.T) { } } + +var ( + testAmt = lnwire.MilliSatoshi(1000) + testAddr = [32]byte{0x01, 0x02} +) + +// TestMPPHop asserts that a Hop will encode a non-nil to final nodes, and fail +// when trying to send to intermediaries. +func TestMPPHop(t *testing.T) { + t.Parallel() + + hop := Hop{ + ChannelID: 1, + OutgoingTimeLock: 44, + AmtToForward: testAmt, + LegacyPayload: false, + MPP: record.NewMPP(testAmt, testAddr), + } + + // Encoding an MPP record to an intermediate hop should result in a + // failure. + var b bytes.Buffer + err := hop.PackHopPayload(&b, 2) + if err != ErrIntermediateMPPHop { + t.Fatalf("expected err: %v, got: %v", + ErrIntermediateMPPHop, err) + } + + // Encoding an MPP record to a final hop should be successful. + b.Reset() + err = hop.PackHopPayload(&b, 0) + if err != nil { + t.Fatalf("expected err: %v, got: %v", nil, err) + } +} diff --git a/tlv/record.go b/tlv/record.go index 66ae8f1c..fe774263 100644 --- a/tlv/record.go +++ b/tlv/record.go @@ -85,6 +85,14 @@ func (f *Record) Encode(w io.Writer) error { return f.encoder(w, f.value, &b) } +// Decode read in the TLV record from the passed reader. This is useful when a +// caller wants decode a *single* TLV record, outside the context of the Stream +// struct. +func (f *Record) Decode(r io.Reader, l uint64) error { + var b [8]byte + return f.decoder(r, f.value, &b, l) +} + // MakePrimitiveRecord creates a record for common types. func MakePrimitiveRecord(typ Type, val interface{}) Record { var (