routing/route+tlv: add new TLV-EOB awareness to Hop+Route
In this commit, we extend the Hop struct to carry an arbitrary set of TLV values, and add a new field that allows us to distinguish between the modern and legacy TLV payload. We add a new `PackPayload` method that will be used to encode the combined required routing TLV fields along any set of TLV fields that were specified as part of path finding. Finally, the `ToSphinxPath` has been extended to be able to recognize if a hop needs the modern, or legacy payload.
This commit is contained in:
parent
e60b36751c
commit
5b4c8ac232
@ -142,6 +142,7 @@ func TestControlTowerSubscribeSuccess(t *testing.T) {
|
||||
if result.Preimage != preimg {
|
||||
t.Fatal("unexpected preimage")
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(result.Route, &attempt.Route) {
|
||||
t.Fatal("unexpected route")
|
||||
}
|
||||
|
@ -1,14 +1,17 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
sphinx "github.com/lightningnetwork/lightning-onion"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
// VertexSize is the size of the array to store a vertex.
|
||||
@ -72,6 +75,61 @@ type Hop struct {
|
||||
// hop. This value is less than the value that the incoming HTLC
|
||||
// carries as a fee will be subtracted by the hop.
|
||||
AmtToForward lnwire.MilliSatoshi
|
||||
|
||||
// TLVRecords if non-nil are a set of additional TLV records that
|
||||
// should be included in the forwarding instructions for this node.
|
||||
TLVRecords []tlv.Record
|
||||
|
||||
// LegacyPayload if true, then this signals that this node doesn't
|
||||
// understand the new TLV payload, so we must instead use the legacy
|
||||
// payload.
|
||||
LegacyPayload bool
|
||||
}
|
||||
|
||||
// PackHopPayload writes to the passed io.Writer, the series of byes that can
|
||||
// be placed directly into the per-hop payload (EOB) for this hop. This will
|
||||
// include the required routing fields, as well as serializing any of the
|
||||
// passed optional TLVRecords. nextChanID is the unique channel ID that
|
||||
// references the _outgoing_ channel ID that follows this hop. This field
|
||||
// follows the same semantics as the NextAddress field in the onion: it should
|
||||
// be set to zero to indicate the terminal hop.
|
||||
func (h *Hop) PackHopPayload(w io.Writer, nextChanID uint64) error {
|
||||
// If this is a legacy payload, then we'll exit here as this method
|
||||
// shouldn't be called.
|
||||
if h.LegacyPayload == true {
|
||||
return fmt.Errorf("cannot pack hop payloads for legacy " +
|
||||
"payloads")
|
||||
}
|
||||
|
||||
// Otherwise, we'll need to make a new stream that includes our
|
||||
// required routing fields, as well as these optional values.
|
||||
amt := uint64(h.AmtToForward)
|
||||
combinedRecords := append(h.TLVRecords,
|
||||
tlv.MakeDynamicRecord(
|
||||
tlv.AmtOnionType, &amt, func() uint64 {
|
||||
return tlv.SizeTUint64(amt)
|
||||
},
|
||||
tlv.ETUint64, tlv.DTUint64,
|
||||
),
|
||||
tlv.MakeDynamicRecord(
|
||||
tlv.LockTimeOnionType, &h.OutgoingTimeLock, func() uint64 {
|
||||
return tlv.SizeTUint32(h.OutgoingTimeLock)
|
||||
},
|
||||
tlv.ETUint32, tlv.DTUint32,
|
||||
),
|
||||
tlv.MakePrimitiveRecord(tlv.NextHopOnionType, &nextChanID),
|
||||
)
|
||||
|
||||
// To ensure we produce a canonical stream, we'll sort the records
|
||||
// before encoding them as a stream in the hop payload.
|
||||
tlv.SortRecords(combinedRecords)
|
||||
|
||||
tlvStream, err := tlv.NewStream(combinedRecords...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tlvStream.Encode(w)
|
||||
}
|
||||
|
||||
// Route represents a path through the channel graph which runs over one or
|
||||
@ -156,7 +214,8 @@ func NewRouteFromHops(amtToSend lnwire.MilliSatoshi, timeLock uint32,
|
||||
|
||||
// ToSphinxPath converts a complete route into a sphinx PaymentPath that
|
||||
// contains the per-hop paylods used to encoding the HTLC routing data for each
|
||||
// hop in the route.
|
||||
// hop in the route. This method also accepts an optional EOB payload for the
|
||||
// final hop.
|
||||
func (r *Route) ToSphinxPath() (*sphinx.PaymentPath, error) {
|
||||
var path sphinx.PaymentPath
|
||||
|
||||
@ -171,17 +230,6 @@ func (r *Route) ToSphinxPath() (*sphinx.PaymentPath, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
path[i] = sphinx.OnionHop{
|
||||
NodePub: *pub,
|
||||
HopData: sphinx.HopData{
|
||||
// TODO(roasbeef): properly set realm, make
|
||||
// sphinx type an enum actually?
|
||||
Realm: [1]byte{0},
|
||||
ForwardAmount: uint64(hop.AmtToForward),
|
||||
OutgoingCltv: hop.OutgoingTimeLock,
|
||||
},
|
||||
}
|
||||
|
||||
// As a base case, the next hop is set to all zeroes in order
|
||||
// to indicate that the "last hop" as no further hops after it.
|
||||
nextHop := uint64(0)
|
||||
@ -192,9 +240,50 @@ func (r *Route) ToSphinxPath() (*sphinx.PaymentPath, error) {
|
||||
nextHop = r.Hops[i+1].ChannelID
|
||||
}
|
||||
|
||||
var payload sphinx.HopPayload
|
||||
|
||||
// If this is the legacy payload, then we can just include the
|
||||
// hop data as normal.
|
||||
if hop.LegacyPayload {
|
||||
// Before we encode this value, we'll pack the next hop
|
||||
// into the NextAddress field of the hop info to ensure
|
||||
// we point to the right now.
|
||||
hopData := sphinx.HopData{
|
||||
ForwardAmount: uint64(hop.AmtToForward),
|
||||
OutgoingCltv: hop.OutgoingTimeLock,
|
||||
}
|
||||
binary.BigEndian.PutUint64(
|
||||
path[i].HopData.NextAddress[:], nextHop,
|
||||
hopData.NextAddress[:], nextHop,
|
||||
)
|
||||
|
||||
payload, err = sphinx.NewHopPayload(&hopData, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// For non-legacy payloads, we'll need to pack the
|
||||
// routing information, along with any extra TLV
|
||||
// information into the new per-hop payload format.
|
||||
// We'll also pass in the chan ID of the hop this
|
||||
// channel should be forwarded to so we can construct a
|
||||
// valid payload.
|
||||
var b bytes.Buffer
|
||||
err := hop.PackHopPayload(&b, nextHop)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO(roasbeef): make better API for NewHopPayload?
|
||||
payload, err = sphinx.NewHopPayload(nil, b.Bytes())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
path[i] = sphinx.OnionHop{
|
||||
NodePub: *pub,
|
||||
HopPayload: payload,
|
||||
}
|
||||
}
|
||||
|
||||
return &path, nil
|
||||
|
15
tlv/onion_types.go
Normal file
15
tlv/onion_types.go
Normal file
@ -0,0 +1,15 @@
|
||||
package tlv
|
||||
|
||||
const (
|
||||
// AmtOnionType is the type used in the onion to refrence the amount to
|
||||
// send to the next hop.
|
||||
AmtOnionType Type = 2
|
||||
|
||||
// LockTimeTLV is the type used in the onion to refenernce the CLTV
|
||||
// value that should be used for the next hop's HTLC.
|
||||
LockTimeOnionType Type = 4
|
||||
|
||||
// NextHopOnionType is the type used in the onion to reference the ID
|
||||
// of the next hop.
|
||||
NextHopOnionType Type = 6
|
||||
)
|
@ -1,8 +1,10 @@
|
||||
package tlv
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
)
|
||||
@ -166,3 +168,63 @@ func MakeDynamicRecord(typ Type, val interface{}, sizeFunc SizeFunc,
|
||||
decoder: decoder,
|
||||
}
|
||||
}
|
||||
|
||||
// RecordsToMap encodes a series of TLV records as raw key-value pairs in the
|
||||
// form of a map.
|
||||
func RecordsToMap(records []Record) (map[uint64][]byte, error) {
|
||||
tlvMap := make(map[uint64][]byte, len(records))
|
||||
|
||||
for _, record := range records {
|
||||
var b bytes.Buffer
|
||||
if err := record.Encode(&b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlvMap[uint64(record.Type())] = b.Bytes()
|
||||
}
|
||||
|
||||
return tlvMap, nil
|
||||
}
|
||||
|
||||
// StubEncoder is a factory function that makes a stub tlv.Encoder out of a raw
|
||||
// value. We can use this to make a record that can be encoded when we don't
|
||||
// actually know it's true underlying value, and only it serialization.
|
||||
func StubEncoder(v []byte) Encoder {
|
||||
return func(w io.Writer, val interface{}, buf *[8]byte) error {
|
||||
_, err := w.Write(v)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// MapToRecords encodes the passed TLV map as a series of regular tlv.Record
|
||||
// instances. The resulting set of records will be returned in sorted order by
|
||||
// their type.
|
||||
func MapToRecords(tlvMap map[uint64][]byte) ([]Record, error) {
|
||||
records := make([]Record, 0, len(tlvMap))
|
||||
for k, v := range tlvMap {
|
||||
// We don't pass in a decoder here since we don't actually know
|
||||
// the type, and only expect this Record to be used for display
|
||||
// and encoding purposes.
|
||||
record := MakeStaticRecord(
|
||||
Type(k), nil, uint64(len(v)), StubEncoder(v), nil,
|
||||
)
|
||||
|
||||
records = append(records, record)
|
||||
}
|
||||
|
||||
SortRecords(records)
|
||||
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// SortRecords is a helper function that will sort a slice of records in place
|
||||
// according to their type.
|
||||
func SortRecords(records []Record) {
|
||||
if len(records) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
sort.Slice(records, func(i, j int) bool {
|
||||
return records[i].Type() < records[j].Type()
|
||||
})
|
||||
}
|
||||
|
149
tlv/record_test.go
Normal file
149
tlv/record_test.go
Normal file
@ -0,0 +1,149 @@
|
||||
package tlv
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
)
|
||||
|
||||
// TestSortRecords tests that SortRecords is able to properly sort records in
|
||||
// place.
|
||||
func TestSortRecords(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
preSort []Record
|
||||
postSort []Record
|
||||
}{
|
||||
// An empty slice requires no sorting.
|
||||
{
|
||||
preSort: []Record{},
|
||||
postSort: []Record{},
|
||||
},
|
||||
|
||||
// An already sorted slice should be passed through.
|
||||
{
|
||||
preSort: []Record{
|
||||
MakeStaticRecord(1, nil, 0, nil, nil),
|
||||
MakeStaticRecord(2, nil, 0, nil, nil),
|
||||
MakeStaticRecord(3, nil, 0, nil, nil),
|
||||
},
|
||||
postSort: []Record{
|
||||
MakeStaticRecord(1, nil, 0, nil, nil),
|
||||
MakeStaticRecord(2, nil, 0, nil, nil),
|
||||
MakeStaticRecord(3, nil, 0, nil, nil),
|
||||
},
|
||||
},
|
||||
|
||||
// We should be able to sort a randomized set of records .
|
||||
{
|
||||
preSort: []Record{
|
||||
MakeStaticRecord(9, nil, 0, nil, nil),
|
||||
MakeStaticRecord(43, nil, 0, nil, nil),
|
||||
MakeStaticRecord(1, nil, 0, nil, nil),
|
||||
MakeStaticRecord(0, nil, 0, nil, nil),
|
||||
},
|
||||
postSort: []Record{
|
||||
MakeStaticRecord(0, nil, 0, nil, nil),
|
||||
MakeStaticRecord(1, nil, 0, nil, nil),
|
||||
MakeStaticRecord(9, nil, 0, nil, nil),
|
||||
MakeStaticRecord(43, nil, 0, nil, nil),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, testCase := range testCases {
|
||||
SortRecords(testCase.preSort)
|
||||
|
||||
if !reflect.DeepEqual(testCase.preSort, testCase.postSort) {
|
||||
t.Fatalf("#%v: wrong order: expected %v, got %v", i,
|
||||
spew.Sdump(testCase.preSort),
|
||||
spew.Sdump(testCase.postSort))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecordMapTransformation tests that we're able to properly morph a set of
|
||||
// records into a map using TlvRecordsToMap, then the other way around using
|
||||
// the MapToTlvRecords method.
|
||||
func TestRecordMapTransformation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tlvBytes := []byte{1, 2, 3, 4}
|
||||
encoder := StubEncoder(tlvBytes)
|
||||
|
||||
testCases := []struct {
|
||||
records []Record
|
||||
|
||||
tlvMap map[uint64][]byte
|
||||
}{
|
||||
// An empty set of records should yield an empty map, and the other
|
||||
// way around.
|
||||
{
|
||||
records: []Record{},
|
||||
tlvMap: map[uint64][]byte{},
|
||||
},
|
||||
|
||||
// We should be able to transform this set of records, then obtain
|
||||
// the records back in the same order.
|
||||
{
|
||||
records: []Record{
|
||||
MakeStaticRecord(1, nil, 4, encoder, nil),
|
||||
MakeStaticRecord(2, nil, 4, encoder, nil),
|
||||
MakeStaticRecord(3, nil, 4, encoder, nil),
|
||||
},
|
||||
tlvMap: map[uint64][]byte{
|
||||
1: tlvBytes,
|
||||
2: tlvBytes,
|
||||
3: tlvBytes,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, testCase := range testCases {
|
||||
mappedRecords, err := RecordsToMap(testCase.records)
|
||||
if err != nil {
|
||||
t.Fatalf("#%v: unable to map records: %v", i, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(mappedRecords, testCase.tlvMap) {
|
||||
t.Fatalf("#%v: incorrect record map: expected %v, got %v",
|
||||
i, spew.Sdump(testCase.tlvMap),
|
||||
spew.Sdump(mappedRecords))
|
||||
}
|
||||
|
||||
unmappedRecords, err := MapToRecords(mappedRecords)
|
||||
if err != nil {
|
||||
t.Fatalf("#%v: unable to unmap records: %v", i, err)
|
||||
}
|
||||
|
||||
for i := 0; i < len(testCase.records); i++ {
|
||||
if unmappedRecords[i].Type() != testCase.records[i].Type() {
|
||||
t.Fatalf("#%v: wrong type: expected %v, got %v",
|
||||
i, unmappedRecords[i].Type(),
|
||||
testCase.records[i].Type())
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := unmappedRecords[i].Encode(&b); err != nil {
|
||||
t.Fatalf("#%v: unable to encode record: %v",
|
||||
i, err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(b.Bytes(), tlvBytes) {
|
||||
t.Fatalf("#%v: wrong raw record: "+
|
||||
"expected %x, got %x",
|
||||
i, tlvBytes, b.Bytes())
|
||||
}
|
||||
|
||||
if unmappedRecords[i].Size() != testCase.records[0].Size() {
|
||||
t.Fatalf("#%v: wrong size: expected %v, "+
|
||||
"got %v", i,
|
||||
unmappedRecords[i].Size(),
|
||||
testCase.records[i].Size())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user