routing/route+tlv: add new TLV-EOB awareness to Hop+Route

In this commit, we extend the Hop struct to carry an arbitrary set of
TLV values, and add a new field that allows us to distinguish between
the modern and legacy TLV payload.

We add a new `PackPayload` method that will be used to encode the
combined required routing TLV fields along any set of TLV fields that
were specified as part of path finding.

Finally, the `ToSphinxPath` has been extended to be able to recognize if
a hop needs the modern, or legacy payload.
This commit is contained in:
Olaoluwa Osuntokun 2019-07-30 21:38:43 -07:00
parent e60b36751c
commit 5b4c8ac232
No known key found for this signature in database
GPG Key ID: CE58F7F8E20FD9A2
5 changed files with 331 additions and 15 deletions

@ -142,6 +142,7 @@ func TestControlTowerSubscribeSuccess(t *testing.T) {
if result.Preimage != preimg {
t.Fatal("unexpected preimage")
}
if !reflect.DeepEqual(result.Route, &attempt.Route) {
t.Fatal("unexpected route")
}

@ -1,14 +1,17 @@
package route
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"strconv"
"strings"
"github.com/btcsuite/btcd/btcec"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tlv"
)
// VertexSize is the size of the array to store a vertex.
@ -72,6 +75,61 @@ type Hop struct {
// hop. This value is less than the value that the incoming HTLC
// carries as a fee will be subtracted by the hop.
AmtToForward lnwire.MilliSatoshi
// TLVRecords if non-nil are a set of additional TLV records that
// should be included in the forwarding instructions for this node.
TLVRecords []tlv.Record
// LegacyPayload if true, then this signals that this node doesn't
// understand the new TLV payload, so we must instead use the legacy
// payload.
LegacyPayload bool
}
// PackHopPayload writes to the passed io.Writer, the series of byes that can
// be placed directly into the per-hop payload (EOB) for this hop. This will
// include the required routing fields, as well as serializing any of the
// passed optional TLVRecords. nextChanID is the unique channel ID that
// references the _outgoing_ channel ID that follows this hop. This field
// follows the same semantics as the NextAddress field in the onion: it should
// be set to zero to indicate the terminal hop.
func (h *Hop) PackHopPayload(w io.Writer, nextChanID uint64) error {
// If this is a legacy payload, then we'll exit here as this method
// shouldn't be called.
if h.LegacyPayload == true {
return fmt.Errorf("cannot pack hop payloads for legacy " +
"payloads")
}
// Otherwise, we'll need to make a new stream that includes our
// required routing fields, as well as these optional values.
amt := uint64(h.AmtToForward)
combinedRecords := append(h.TLVRecords,
tlv.MakeDynamicRecord(
tlv.AmtOnionType, &amt, func() uint64 {
return tlv.SizeTUint64(amt)
},
tlv.ETUint64, tlv.DTUint64,
),
tlv.MakeDynamicRecord(
tlv.LockTimeOnionType, &h.OutgoingTimeLock, func() uint64 {
return tlv.SizeTUint32(h.OutgoingTimeLock)
},
tlv.ETUint32, tlv.DTUint32,
),
tlv.MakePrimitiveRecord(tlv.NextHopOnionType, &nextChanID),
)
// To ensure we produce a canonical stream, we'll sort the records
// before encoding them as a stream in the hop payload.
tlv.SortRecords(combinedRecords)
tlvStream, err := tlv.NewStream(combinedRecords...)
if err != nil {
return err
}
return tlvStream.Encode(w)
}
// Route represents a path through the channel graph which runs over one or
@ -156,7 +214,8 @@ func NewRouteFromHops(amtToSend lnwire.MilliSatoshi, timeLock uint32,
// ToSphinxPath converts a complete route into a sphinx PaymentPath that
// contains the per-hop paylods used to encoding the HTLC routing data for each
// hop in the route.
// hop in the route. This method also accepts an optional EOB payload for the
// final hop.
func (r *Route) ToSphinxPath() (*sphinx.PaymentPath, error) {
var path sphinx.PaymentPath
@ -171,17 +230,6 @@ func (r *Route) ToSphinxPath() (*sphinx.PaymentPath, error) {
return nil, err
}
path[i] = sphinx.OnionHop{
NodePub: *pub,
HopData: sphinx.HopData{
// TODO(roasbeef): properly set realm, make
// sphinx type an enum actually?
Realm: [1]byte{0},
ForwardAmount: uint64(hop.AmtToForward),
OutgoingCltv: hop.OutgoingTimeLock,
},
}
// As a base case, the next hop is set to all zeroes in order
// to indicate that the "last hop" as no further hops after it.
nextHop := uint64(0)
@ -192,9 +240,50 @@ func (r *Route) ToSphinxPath() (*sphinx.PaymentPath, error) {
nextHop = r.Hops[i+1].ChannelID
}
var payload sphinx.HopPayload
// If this is the legacy payload, then we can just include the
// hop data as normal.
if hop.LegacyPayload {
// Before we encode this value, we'll pack the next hop
// into the NextAddress field of the hop info to ensure
// we point to the right now.
hopData := sphinx.HopData{
ForwardAmount: uint64(hop.AmtToForward),
OutgoingCltv: hop.OutgoingTimeLock,
}
binary.BigEndian.PutUint64(
path[i].HopData.NextAddress[:], nextHop,
hopData.NextAddress[:], nextHop,
)
payload, err = sphinx.NewHopPayload(&hopData, nil)
if err != nil {
return nil, err
}
} else {
// For non-legacy payloads, we'll need to pack the
// routing information, along with any extra TLV
// information into the new per-hop payload format.
// We'll also pass in the chan ID of the hop this
// channel should be forwarded to so we can construct a
// valid payload.
var b bytes.Buffer
err := hop.PackHopPayload(&b, nextHop)
if err != nil {
return nil, err
}
// TODO(roasbeef): make better API for NewHopPayload?
payload, err = sphinx.NewHopPayload(nil, b.Bytes())
if err != nil {
return nil, err
}
}
path[i] = sphinx.OnionHop{
NodePub: *pub,
HopPayload: payload,
}
}
return &path, nil

15
tlv/onion_types.go Normal file

@ -0,0 +1,15 @@
package tlv
const (
// AmtOnionType is the type used in the onion to refrence the amount to
// send to the next hop.
AmtOnionType Type = 2
// LockTimeTLV is the type used in the onion to refenernce the CLTV
// value that should be used for the next hop's HTLC.
LockTimeOnionType Type = 4
// NextHopOnionType is the type used in the onion to reference the ID
// of the next hop.
NextHopOnionType Type = 6
)

@ -1,8 +1,10 @@
package tlv
import (
"bytes"
"fmt"
"io"
"sort"
"github.com/btcsuite/btcd/btcec"
)
@ -166,3 +168,63 @@ func MakeDynamicRecord(typ Type, val interface{}, sizeFunc SizeFunc,
decoder: decoder,
}
}
// RecordsToMap encodes a series of TLV records as raw key-value pairs in the
// form of a map.
func RecordsToMap(records []Record) (map[uint64][]byte, error) {
tlvMap := make(map[uint64][]byte, len(records))
for _, record := range records {
var b bytes.Buffer
if err := record.Encode(&b); err != nil {
return nil, err
}
tlvMap[uint64(record.Type())] = b.Bytes()
}
return tlvMap, nil
}
// StubEncoder is a factory function that makes a stub tlv.Encoder out of a raw
// value. We can use this to make a record that can be encoded when we don't
// actually know it's true underlying value, and only it serialization.
func StubEncoder(v []byte) Encoder {
return func(w io.Writer, val interface{}, buf *[8]byte) error {
_, err := w.Write(v)
return err
}
}
// MapToRecords encodes the passed TLV map as a series of regular tlv.Record
// instances. The resulting set of records will be returned in sorted order by
// their type.
func MapToRecords(tlvMap map[uint64][]byte) ([]Record, error) {
records := make([]Record, 0, len(tlvMap))
for k, v := range tlvMap {
// We don't pass in a decoder here since we don't actually know
// the type, and only expect this Record to be used for display
// and encoding purposes.
record := MakeStaticRecord(
Type(k), nil, uint64(len(v)), StubEncoder(v), nil,
)
records = append(records, record)
}
SortRecords(records)
return records, nil
}
// SortRecords is a helper function that will sort a slice of records in place
// according to their type.
func SortRecords(records []Record) {
if len(records) == 0 {
return
}
sort.Slice(records, func(i, j int) bool {
return records[i].Type() < records[j].Type()
})
}

149
tlv/record_test.go Normal file

@ -0,0 +1,149 @@
package tlv
import (
"bytes"
"reflect"
"testing"
"github.com/davecgh/go-spew/spew"
)
// TestSortRecords tests that SortRecords is able to properly sort records in
// place.
func TestSortRecords(t *testing.T) {
t.Parallel()
testCases := []struct {
preSort []Record
postSort []Record
}{
// An empty slice requires no sorting.
{
preSort: []Record{},
postSort: []Record{},
},
// An already sorted slice should be passed through.
{
preSort: []Record{
MakeStaticRecord(1, nil, 0, nil, nil),
MakeStaticRecord(2, nil, 0, nil, nil),
MakeStaticRecord(3, nil, 0, nil, nil),
},
postSort: []Record{
MakeStaticRecord(1, nil, 0, nil, nil),
MakeStaticRecord(2, nil, 0, nil, nil),
MakeStaticRecord(3, nil, 0, nil, nil),
},
},
// We should be able to sort a randomized set of records .
{
preSort: []Record{
MakeStaticRecord(9, nil, 0, nil, nil),
MakeStaticRecord(43, nil, 0, nil, nil),
MakeStaticRecord(1, nil, 0, nil, nil),
MakeStaticRecord(0, nil, 0, nil, nil),
},
postSort: []Record{
MakeStaticRecord(0, nil, 0, nil, nil),
MakeStaticRecord(1, nil, 0, nil, nil),
MakeStaticRecord(9, nil, 0, nil, nil),
MakeStaticRecord(43, nil, 0, nil, nil),
},
},
}
for i, testCase := range testCases {
SortRecords(testCase.preSort)
if !reflect.DeepEqual(testCase.preSort, testCase.postSort) {
t.Fatalf("#%v: wrong order: expected %v, got %v", i,
spew.Sdump(testCase.preSort),
spew.Sdump(testCase.postSort))
}
}
}
// TestRecordMapTransformation tests that we're able to properly morph a set of
// records into a map using TlvRecordsToMap, then the other way around using
// the MapToTlvRecords method.
func TestRecordMapTransformation(t *testing.T) {
t.Parallel()
tlvBytes := []byte{1, 2, 3, 4}
encoder := StubEncoder(tlvBytes)
testCases := []struct {
records []Record
tlvMap map[uint64][]byte
}{
// An empty set of records should yield an empty map, and the other
// way around.
{
records: []Record{},
tlvMap: map[uint64][]byte{},
},
// We should be able to transform this set of records, then obtain
// the records back in the same order.
{
records: []Record{
MakeStaticRecord(1, nil, 4, encoder, nil),
MakeStaticRecord(2, nil, 4, encoder, nil),
MakeStaticRecord(3, nil, 4, encoder, nil),
},
tlvMap: map[uint64][]byte{
1: tlvBytes,
2: tlvBytes,
3: tlvBytes,
},
},
}
for i, testCase := range testCases {
mappedRecords, err := RecordsToMap(testCase.records)
if err != nil {
t.Fatalf("#%v: unable to map records: %v", i, err)
}
if !reflect.DeepEqual(mappedRecords, testCase.tlvMap) {
t.Fatalf("#%v: incorrect record map: expected %v, got %v",
i, spew.Sdump(testCase.tlvMap),
spew.Sdump(mappedRecords))
}
unmappedRecords, err := MapToRecords(mappedRecords)
if err != nil {
t.Fatalf("#%v: unable to unmap records: %v", i, err)
}
for i := 0; i < len(testCase.records); i++ {
if unmappedRecords[i].Type() != testCase.records[i].Type() {
t.Fatalf("#%v: wrong type: expected %v, got %v",
i, unmappedRecords[i].Type(),
testCase.records[i].Type())
}
var b bytes.Buffer
if err := unmappedRecords[i].Encode(&b); err != nil {
t.Fatalf("#%v: unable to encode record: %v",
i, err)
}
if !bytes.Equal(b.Bytes(), tlvBytes) {
t.Fatalf("#%v: wrong raw record: "+
"expected %x, got %x",
i, tlvBytes, b.Bytes())
}
if unmappedRecords[i].Size() != testCase.records[0].Size() {
t.Fatalf("#%v: wrong size: expected %v, "+
"got %v", i,
unmappedRecords[i].Size(),
testCase.records[i].Size())
}
}
}
}