diff --git a/routing/control_tower_test.go b/routing/control_tower_test.go index 0a765fdb..95a7c36b 100644 --- a/routing/control_tower_test.go +++ b/routing/control_tower_test.go @@ -142,6 +142,7 @@ func TestControlTowerSubscribeSuccess(t *testing.T) { if result.Preimage != preimg { t.Fatal("unexpected preimage") } + if !reflect.DeepEqual(result.Route, &attempt.Route) { t.Fatal("unexpected route") } diff --git a/routing/route/route.go b/routing/route/route.go index 50688c57..70712d04 100644 --- a/routing/route/route.go +++ b/routing/route/route.go @@ -1,14 +1,17 @@ package route import ( + "bytes" "encoding/binary" "fmt" + "io" "strconv" "strings" "github.com/btcsuite/btcd/btcec" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" ) // VertexSize is the size of the array to store a vertex. @@ -72,6 +75,61 @@ type Hop struct { // hop. This value is less than the value that the incoming HTLC // carries as a fee will be subtracted by the hop. AmtToForward lnwire.MilliSatoshi + + // TLVRecords if non-nil are a set of additional TLV records that + // should be included in the forwarding instructions for this node. + TLVRecords []tlv.Record + + // LegacyPayload if true, then this signals that this node doesn't + // understand the new TLV payload, so we must instead use the legacy + // payload. + LegacyPayload bool +} + +// PackHopPayload writes to the passed io.Writer, the series of byes that can +// be placed directly into the per-hop payload (EOB) for this hop. This will +// include the required routing fields, as well as serializing any of the +// passed optional TLVRecords. nextChanID is the unique channel ID that +// references the _outgoing_ channel ID that follows this hop. This field +// follows the same semantics as the NextAddress field in the onion: it should +// be set to zero to indicate the terminal hop. +func (h *Hop) PackHopPayload(w io.Writer, nextChanID uint64) error { + // If this is a legacy payload, then we'll exit here as this method + // shouldn't be called. + if h.LegacyPayload == true { + return fmt.Errorf("cannot pack hop payloads for legacy " + + "payloads") + } + + // Otherwise, we'll need to make a new stream that includes our + // required routing fields, as well as these optional values. + amt := uint64(h.AmtToForward) + combinedRecords := append(h.TLVRecords, + tlv.MakeDynamicRecord( + tlv.AmtOnionType, &amt, func() uint64 { + return tlv.SizeTUint64(amt) + }, + tlv.ETUint64, tlv.DTUint64, + ), + tlv.MakeDynamicRecord( + tlv.LockTimeOnionType, &h.OutgoingTimeLock, func() uint64 { + return tlv.SizeTUint32(h.OutgoingTimeLock) + }, + tlv.ETUint32, tlv.DTUint32, + ), + tlv.MakePrimitiveRecord(tlv.NextHopOnionType, &nextChanID), + ) + + // To ensure we produce a canonical stream, we'll sort the records + // before encoding them as a stream in the hop payload. + tlv.SortRecords(combinedRecords) + + tlvStream, err := tlv.NewStream(combinedRecords...) + if err != nil { + return err + } + + return tlvStream.Encode(w) } // Route represents a path through the channel graph which runs over one or @@ -156,7 +214,8 @@ func NewRouteFromHops(amtToSend lnwire.MilliSatoshi, timeLock uint32, // ToSphinxPath converts a complete route into a sphinx PaymentPath that // contains the per-hop paylods used to encoding the HTLC routing data for each -// hop in the route. +// hop in the route. This method also accepts an optional EOB payload for the +// final hop. func (r *Route) ToSphinxPath() (*sphinx.PaymentPath, error) { var path sphinx.PaymentPath @@ -171,17 +230,6 @@ func (r *Route) ToSphinxPath() (*sphinx.PaymentPath, error) { return nil, err } - path[i] = sphinx.OnionHop{ - NodePub: *pub, - HopData: sphinx.HopData{ - // TODO(roasbeef): properly set realm, make - // sphinx type an enum actually? - Realm: [1]byte{0}, - ForwardAmount: uint64(hop.AmtToForward), - OutgoingCltv: hop.OutgoingTimeLock, - }, - } - // As a base case, the next hop is set to all zeroes in order // to indicate that the "last hop" as no further hops after it. nextHop := uint64(0) @@ -192,9 +240,50 @@ func (r *Route) ToSphinxPath() (*sphinx.PaymentPath, error) { nextHop = r.Hops[i+1].ChannelID } - binary.BigEndian.PutUint64( - path[i].HopData.NextAddress[:], nextHop, - ) + var payload sphinx.HopPayload + + // If this is the legacy payload, then we can just include the + // hop data as normal. + if hop.LegacyPayload { + // Before we encode this value, we'll pack the next hop + // into the NextAddress field of the hop info to ensure + // we point to the right now. + hopData := sphinx.HopData{ + ForwardAmount: uint64(hop.AmtToForward), + OutgoingCltv: hop.OutgoingTimeLock, + } + binary.BigEndian.PutUint64( + hopData.NextAddress[:], nextHop, + ) + + payload, err = sphinx.NewHopPayload(&hopData, nil) + if err != nil { + return nil, err + } + } else { + // For non-legacy payloads, we'll need to pack the + // routing information, along with any extra TLV + // information into the new per-hop payload format. + // We'll also pass in the chan ID of the hop this + // channel should be forwarded to so we can construct a + // valid payload. + var b bytes.Buffer + err := hop.PackHopPayload(&b, nextHop) + if err != nil { + return nil, err + } + + // TODO(roasbeef): make better API for NewHopPayload? + payload, err = sphinx.NewHopPayload(nil, b.Bytes()) + if err != nil { + return nil, err + } + } + + path[i] = sphinx.OnionHop{ + NodePub: *pub, + HopPayload: payload, + } } return &path, nil diff --git a/tlv/onion_types.go b/tlv/onion_types.go new file mode 100644 index 00000000..65d5b42c --- /dev/null +++ b/tlv/onion_types.go @@ -0,0 +1,15 @@ +package tlv + +const ( + // AmtOnionType is the type used in the onion to refrence the amount to + // send to the next hop. + AmtOnionType Type = 2 + + // LockTimeTLV is the type used in the onion to refenernce the CLTV + // value that should be used for the next hop's HTLC. + LockTimeOnionType Type = 4 + + // NextHopOnionType is the type used in the onion to reference the ID + // of the next hop. + NextHopOnionType Type = 6 +) diff --git a/tlv/record.go b/tlv/record.go index cad51379..610ab6c1 100644 --- a/tlv/record.go +++ b/tlv/record.go @@ -1,8 +1,10 @@ package tlv import ( + "bytes" "fmt" "io" + "sort" "github.com/btcsuite/btcd/btcec" ) @@ -166,3 +168,63 @@ func MakeDynamicRecord(typ Type, val interface{}, sizeFunc SizeFunc, decoder: decoder, } } + +// RecordsToMap encodes a series of TLV records as raw key-value pairs in the +// form of a map. +func RecordsToMap(records []Record) (map[uint64][]byte, error) { + tlvMap := make(map[uint64][]byte, len(records)) + + for _, record := range records { + var b bytes.Buffer + if err := record.Encode(&b); err != nil { + return nil, err + } + + tlvMap[uint64(record.Type())] = b.Bytes() + } + + return tlvMap, nil +} + +// StubEncoder is a factory function that makes a stub tlv.Encoder out of a raw +// value. We can use this to make a record that can be encoded when we don't +// actually know it's true underlying value, and only it serialization. +func StubEncoder(v []byte) Encoder { + return func(w io.Writer, val interface{}, buf *[8]byte) error { + _, err := w.Write(v) + return err + } +} + +// MapToRecords encodes the passed TLV map as a series of regular tlv.Record +// instances. The resulting set of records will be returned in sorted order by +// their type. +func MapToRecords(tlvMap map[uint64][]byte) ([]Record, error) { + records := make([]Record, 0, len(tlvMap)) + for k, v := range tlvMap { + // We don't pass in a decoder here since we don't actually know + // the type, and only expect this Record to be used for display + // and encoding purposes. + record := MakeStaticRecord( + Type(k), nil, uint64(len(v)), StubEncoder(v), nil, + ) + + records = append(records, record) + } + + SortRecords(records) + + return records, nil +} + +// SortRecords is a helper function that will sort a slice of records in place +// according to their type. +func SortRecords(records []Record) { + if len(records) == 0 { + return + } + + sort.Slice(records, func(i, j int) bool { + return records[i].Type() < records[j].Type() + }) +} diff --git a/tlv/record_test.go b/tlv/record_test.go new file mode 100644 index 00000000..02d2e893 --- /dev/null +++ b/tlv/record_test.go @@ -0,0 +1,149 @@ +package tlv + +import ( + "bytes" + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" +) + +// TestSortRecords tests that SortRecords is able to properly sort records in +// place. +func TestSortRecords(t *testing.T) { + t.Parallel() + + testCases := []struct { + preSort []Record + postSort []Record + }{ + // An empty slice requires no sorting. + { + preSort: []Record{}, + postSort: []Record{}, + }, + + // An already sorted slice should be passed through. + { + preSort: []Record{ + MakeStaticRecord(1, nil, 0, nil, nil), + MakeStaticRecord(2, nil, 0, nil, nil), + MakeStaticRecord(3, nil, 0, nil, nil), + }, + postSort: []Record{ + MakeStaticRecord(1, nil, 0, nil, nil), + MakeStaticRecord(2, nil, 0, nil, nil), + MakeStaticRecord(3, nil, 0, nil, nil), + }, + }, + + // We should be able to sort a randomized set of records . + { + preSort: []Record{ + MakeStaticRecord(9, nil, 0, nil, nil), + MakeStaticRecord(43, nil, 0, nil, nil), + MakeStaticRecord(1, nil, 0, nil, nil), + MakeStaticRecord(0, nil, 0, nil, nil), + }, + postSort: []Record{ + MakeStaticRecord(0, nil, 0, nil, nil), + MakeStaticRecord(1, nil, 0, nil, nil), + MakeStaticRecord(9, nil, 0, nil, nil), + MakeStaticRecord(43, nil, 0, nil, nil), + }, + }, + } + + for i, testCase := range testCases { + SortRecords(testCase.preSort) + + if !reflect.DeepEqual(testCase.preSort, testCase.postSort) { + t.Fatalf("#%v: wrong order: expected %v, got %v", i, + spew.Sdump(testCase.preSort), + spew.Sdump(testCase.postSort)) + } + } +} + +// TestRecordMapTransformation tests that we're able to properly morph a set of +// records into a map using TlvRecordsToMap, then the other way around using +// the MapToTlvRecords method. +func TestRecordMapTransformation(t *testing.T) { + t.Parallel() + + tlvBytes := []byte{1, 2, 3, 4} + encoder := StubEncoder(tlvBytes) + + testCases := []struct { + records []Record + + tlvMap map[uint64][]byte + }{ + // An empty set of records should yield an empty map, and the other + // way around. + { + records: []Record{}, + tlvMap: map[uint64][]byte{}, + }, + + // We should be able to transform this set of records, then obtain + // the records back in the same order. + { + records: []Record{ + MakeStaticRecord(1, nil, 4, encoder, nil), + MakeStaticRecord(2, nil, 4, encoder, nil), + MakeStaticRecord(3, nil, 4, encoder, nil), + }, + tlvMap: map[uint64][]byte{ + 1: tlvBytes, + 2: tlvBytes, + 3: tlvBytes, + }, + }, + } + + for i, testCase := range testCases { + mappedRecords, err := RecordsToMap(testCase.records) + if err != nil { + t.Fatalf("#%v: unable to map records: %v", i, err) + } + + if !reflect.DeepEqual(mappedRecords, testCase.tlvMap) { + t.Fatalf("#%v: incorrect record map: expected %v, got %v", + i, spew.Sdump(testCase.tlvMap), + spew.Sdump(mappedRecords)) + } + + unmappedRecords, err := MapToRecords(mappedRecords) + if err != nil { + t.Fatalf("#%v: unable to unmap records: %v", i, err) + } + + for i := 0; i < len(testCase.records); i++ { + if unmappedRecords[i].Type() != testCase.records[i].Type() { + t.Fatalf("#%v: wrong type: expected %v, got %v", + i, unmappedRecords[i].Type(), + testCase.records[i].Type()) + } + + var b bytes.Buffer + if err := unmappedRecords[i].Encode(&b); err != nil { + t.Fatalf("#%v: unable to encode record: %v", + i, err) + } + + if !bytes.Equal(b.Bytes(), tlvBytes) { + t.Fatalf("#%v: wrong raw record: "+ + "expected %x, got %x", + i, tlvBytes, b.Bytes()) + } + + if unmappedRecords[i].Size() != testCase.records[0].Size() { + t.Fatalf("#%v: wrong size: expected %v, "+ + "got %v", i, + unmappedRecords[i].Size(), + testCase.records[i].Size()) + } + } + } +}