routing/route+tlv: add new TLV-EOB awareness to Hop+Route
In this commit, we extend the Hop struct to carry an arbitrary set of TLV values, and add a new field that allows us to distinguish between the modern and legacy TLV payload. We add a new `PackPayload` method that will be used to encode the combined required routing TLV fields along any set of TLV fields that were specified as part of path finding. Finally, the `ToSphinxPath` has been extended to be able to recognize if a hop needs the modern, or legacy payload.
This commit is contained in:
parent
e60b36751c
commit
5b4c8ac232
@ -142,6 +142,7 @@ func TestControlTowerSubscribeSuccess(t *testing.T) {
|
|||||||
if result.Preimage != preimg {
|
if result.Preimage != preimg {
|
||||||
t.Fatal("unexpected preimage")
|
t.Fatal("unexpected preimage")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(result.Route, &attempt.Route) {
|
if !reflect.DeepEqual(result.Route, &attempt.Route) {
|
||||||
t.Fatal("unexpected route")
|
t.Fatal("unexpected route")
|
||||||
}
|
}
|
||||||
|
@ -1,14 +1,17 @@
|
|||||||
package route
|
package route
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec"
|
"github.com/btcsuite/btcd/btcec"
|
||||||
sphinx "github.com/lightningnetwork/lightning-onion"
|
sphinx "github.com/lightningnetwork/lightning-onion"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
|
"github.com/lightningnetwork/lnd/tlv"
|
||||||
)
|
)
|
||||||
|
|
||||||
// VertexSize is the size of the array to store a vertex.
|
// VertexSize is the size of the array to store a vertex.
|
||||||
@ -72,6 +75,61 @@ type Hop struct {
|
|||||||
// hop. This value is less than the value that the incoming HTLC
|
// hop. This value is less than the value that the incoming HTLC
|
||||||
// carries as a fee will be subtracted by the hop.
|
// carries as a fee will be subtracted by the hop.
|
||||||
AmtToForward lnwire.MilliSatoshi
|
AmtToForward lnwire.MilliSatoshi
|
||||||
|
|
||||||
|
// TLVRecords if non-nil are a set of additional TLV records that
|
||||||
|
// should be included in the forwarding instructions for this node.
|
||||||
|
TLVRecords []tlv.Record
|
||||||
|
|
||||||
|
// LegacyPayload if true, then this signals that this node doesn't
|
||||||
|
// understand the new TLV payload, so we must instead use the legacy
|
||||||
|
// payload.
|
||||||
|
LegacyPayload bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// PackHopPayload writes to the passed io.Writer, the series of byes that can
|
||||||
|
// be placed directly into the per-hop payload (EOB) for this hop. This will
|
||||||
|
// include the required routing fields, as well as serializing any of the
|
||||||
|
// passed optional TLVRecords. nextChanID is the unique channel ID that
|
||||||
|
// references the _outgoing_ channel ID that follows this hop. This field
|
||||||
|
// follows the same semantics as the NextAddress field in the onion: it should
|
||||||
|
// be set to zero to indicate the terminal hop.
|
||||||
|
func (h *Hop) PackHopPayload(w io.Writer, nextChanID uint64) error {
|
||||||
|
// If this is a legacy payload, then we'll exit here as this method
|
||||||
|
// shouldn't be called.
|
||||||
|
if h.LegacyPayload == true {
|
||||||
|
return fmt.Errorf("cannot pack hop payloads for legacy " +
|
||||||
|
"payloads")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, we'll need to make a new stream that includes our
|
||||||
|
// required routing fields, as well as these optional values.
|
||||||
|
amt := uint64(h.AmtToForward)
|
||||||
|
combinedRecords := append(h.TLVRecords,
|
||||||
|
tlv.MakeDynamicRecord(
|
||||||
|
tlv.AmtOnionType, &amt, func() uint64 {
|
||||||
|
return tlv.SizeTUint64(amt)
|
||||||
|
},
|
||||||
|
tlv.ETUint64, tlv.DTUint64,
|
||||||
|
),
|
||||||
|
tlv.MakeDynamicRecord(
|
||||||
|
tlv.LockTimeOnionType, &h.OutgoingTimeLock, func() uint64 {
|
||||||
|
return tlv.SizeTUint32(h.OutgoingTimeLock)
|
||||||
|
},
|
||||||
|
tlv.ETUint32, tlv.DTUint32,
|
||||||
|
),
|
||||||
|
tlv.MakePrimitiveRecord(tlv.NextHopOnionType, &nextChanID),
|
||||||
|
)
|
||||||
|
|
||||||
|
// To ensure we produce a canonical stream, we'll sort the records
|
||||||
|
// before encoding them as a stream in the hop payload.
|
||||||
|
tlv.SortRecords(combinedRecords)
|
||||||
|
|
||||||
|
tlvStream, err := tlv.NewStream(combinedRecords...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlvStream.Encode(w)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Route represents a path through the channel graph which runs over one or
|
// Route represents a path through the channel graph which runs over one or
|
||||||
@ -156,7 +214,8 @@ func NewRouteFromHops(amtToSend lnwire.MilliSatoshi, timeLock uint32,
|
|||||||
|
|
||||||
// ToSphinxPath converts a complete route into a sphinx PaymentPath that
|
// ToSphinxPath converts a complete route into a sphinx PaymentPath that
|
||||||
// contains the per-hop paylods used to encoding the HTLC routing data for each
|
// contains the per-hop paylods used to encoding the HTLC routing data for each
|
||||||
// hop in the route.
|
// hop in the route. This method also accepts an optional EOB payload for the
|
||||||
|
// final hop.
|
||||||
func (r *Route) ToSphinxPath() (*sphinx.PaymentPath, error) {
|
func (r *Route) ToSphinxPath() (*sphinx.PaymentPath, error) {
|
||||||
var path sphinx.PaymentPath
|
var path sphinx.PaymentPath
|
||||||
|
|
||||||
@ -171,17 +230,6 @@ func (r *Route) ToSphinxPath() (*sphinx.PaymentPath, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
path[i] = sphinx.OnionHop{
|
|
||||||
NodePub: *pub,
|
|
||||||
HopData: sphinx.HopData{
|
|
||||||
// TODO(roasbeef): properly set realm, make
|
|
||||||
// sphinx type an enum actually?
|
|
||||||
Realm: [1]byte{0},
|
|
||||||
ForwardAmount: uint64(hop.AmtToForward),
|
|
||||||
OutgoingCltv: hop.OutgoingTimeLock,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// As a base case, the next hop is set to all zeroes in order
|
// As a base case, the next hop is set to all zeroes in order
|
||||||
// to indicate that the "last hop" as no further hops after it.
|
// to indicate that the "last hop" as no further hops after it.
|
||||||
nextHop := uint64(0)
|
nextHop := uint64(0)
|
||||||
@ -192,9 +240,50 @@ func (r *Route) ToSphinxPath() (*sphinx.PaymentPath, error) {
|
|||||||
nextHop = r.Hops[i+1].ChannelID
|
nextHop = r.Hops[i+1].ChannelID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var payload sphinx.HopPayload
|
||||||
|
|
||||||
|
// If this is the legacy payload, then we can just include the
|
||||||
|
// hop data as normal.
|
||||||
|
if hop.LegacyPayload {
|
||||||
|
// Before we encode this value, we'll pack the next hop
|
||||||
|
// into the NextAddress field of the hop info to ensure
|
||||||
|
// we point to the right now.
|
||||||
|
hopData := sphinx.HopData{
|
||||||
|
ForwardAmount: uint64(hop.AmtToForward),
|
||||||
|
OutgoingCltv: hop.OutgoingTimeLock,
|
||||||
|
}
|
||||||
binary.BigEndian.PutUint64(
|
binary.BigEndian.PutUint64(
|
||||||
path[i].HopData.NextAddress[:], nextHop,
|
hopData.NextAddress[:], nextHop,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
payload, err = sphinx.NewHopPayload(&hopData, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// For non-legacy payloads, we'll need to pack the
|
||||||
|
// routing information, along with any extra TLV
|
||||||
|
// information into the new per-hop payload format.
|
||||||
|
// We'll also pass in the chan ID of the hop this
|
||||||
|
// channel should be forwarded to so we can construct a
|
||||||
|
// valid payload.
|
||||||
|
var b bytes.Buffer
|
||||||
|
err := hop.PackHopPayload(&b, nextHop)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(roasbeef): make better API for NewHopPayload?
|
||||||
|
payload, err = sphinx.NewHopPayload(nil, b.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
path[i] = sphinx.OnionHop{
|
||||||
|
NodePub: *pub,
|
||||||
|
HopPayload: payload,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &path, nil
|
return &path, nil
|
||||||
|
15
tlv/onion_types.go
Normal file
15
tlv/onion_types.go
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
package tlv
|
||||||
|
|
||||||
|
const (
|
||||||
|
// AmtOnionType is the type used in the onion to refrence the amount to
|
||||||
|
// send to the next hop.
|
||||||
|
AmtOnionType Type = 2
|
||||||
|
|
||||||
|
// LockTimeTLV is the type used in the onion to refenernce the CLTV
|
||||||
|
// value that should be used for the next hop's HTLC.
|
||||||
|
LockTimeOnionType Type = 4
|
||||||
|
|
||||||
|
// NextHopOnionType is the type used in the onion to reference the ID
|
||||||
|
// of the next hop.
|
||||||
|
NextHopOnionType Type = 6
|
||||||
|
)
|
@ -1,8 +1,10 @@
|
|||||||
package tlv
|
package tlv
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"sort"
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec"
|
"github.com/btcsuite/btcd/btcec"
|
||||||
)
|
)
|
||||||
@ -166,3 +168,63 @@ func MakeDynamicRecord(typ Type, val interface{}, sizeFunc SizeFunc,
|
|||||||
decoder: decoder,
|
decoder: decoder,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RecordsToMap encodes a series of TLV records as raw key-value pairs in the
|
||||||
|
// form of a map.
|
||||||
|
func RecordsToMap(records []Record) (map[uint64][]byte, error) {
|
||||||
|
tlvMap := make(map[uint64][]byte, len(records))
|
||||||
|
|
||||||
|
for _, record := range records {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := record.Encode(&b); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tlvMap[uint64(record.Type())] = b.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlvMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StubEncoder is a factory function that makes a stub tlv.Encoder out of a raw
|
||||||
|
// value. We can use this to make a record that can be encoded when we don't
|
||||||
|
// actually know it's true underlying value, and only it serialization.
|
||||||
|
func StubEncoder(v []byte) Encoder {
|
||||||
|
return func(w io.Writer, val interface{}, buf *[8]byte) error {
|
||||||
|
_, err := w.Write(v)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MapToRecords encodes the passed TLV map as a series of regular tlv.Record
|
||||||
|
// instances. The resulting set of records will be returned in sorted order by
|
||||||
|
// their type.
|
||||||
|
func MapToRecords(tlvMap map[uint64][]byte) ([]Record, error) {
|
||||||
|
records := make([]Record, 0, len(tlvMap))
|
||||||
|
for k, v := range tlvMap {
|
||||||
|
// We don't pass in a decoder here since we don't actually know
|
||||||
|
// the type, and only expect this Record to be used for display
|
||||||
|
// and encoding purposes.
|
||||||
|
record := MakeStaticRecord(
|
||||||
|
Type(k), nil, uint64(len(v)), StubEncoder(v), nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
records = append(records, record)
|
||||||
|
}
|
||||||
|
|
||||||
|
SortRecords(records)
|
||||||
|
|
||||||
|
return records, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SortRecords is a helper function that will sort a slice of records in place
|
||||||
|
// according to their type.
|
||||||
|
func SortRecords(records []Record) {
|
||||||
|
if len(records) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(records, func(i, j int) bool {
|
||||||
|
return records[i].Type() < records[j].Type()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
149
tlv/record_test.go
Normal file
149
tlv/record_test.go
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
package tlv
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/davecgh/go-spew/spew"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestSortRecords tests that SortRecords is able to properly sort records in
|
||||||
|
// place.
|
||||||
|
func TestSortRecords(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
preSort []Record
|
||||||
|
postSort []Record
|
||||||
|
}{
|
||||||
|
// An empty slice requires no sorting.
|
||||||
|
{
|
||||||
|
preSort: []Record{},
|
||||||
|
postSort: []Record{},
|
||||||
|
},
|
||||||
|
|
||||||
|
// An already sorted slice should be passed through.
|
||||||
|
{
|
||||||
|
preSort: []Record{
|
||||||
|
MakeStaticRecord(1, nil, 0, nil, nil),
|
||||||
|
MakeStaticRecord(2, nil, 0, nil, nil),
|
||||||
|
MakeStaticRecord(3, nil, 0, nil, nil),
|
||||||
|
},
|
||||||
|
postSort: []Record{
|
||||||
|
MakeStaticRecord(1, nil, 0, nil, nil),
|
||||||
|
MakeStaticRecord(2, nil, 0, nil, nil),
|
||||||
|
MakeStaticRecord(3, nil, 0, nil, nil),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
// We should be able to sort a randomized set of records .
|
||||||
|
{
|
||||||
|
preSort: []Record{
|
||||||
|
MakeStaticRecord(9, nil, 0, nil, nil),
|
||||||
|
MakeStaticRecord(43, nil, 0, nil, nil),
|
||||||
|
MakeStaticRecord(1, nil, 0, nil, nil),
|
||||||
|
MakeStaticRecord(0, nil, 0, nil, nil),
|
||||||
|
},
|
||||||
|
postSort: []Record{
|
||||||
|
MakeStaticRecord(0, nil, 0, nil, nil),
|
||||||
|
MakeStaticRecord(1, nil, 0, nil, nil),
|
||||||
|
MakeStaticRecord(9, nil, 0, nil, nil),
|
||||||
|
MakeStaticRecord(43, nil, 0, nil, nil),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, testCase := range testCases {
|
||||||
|
SortRecords(testCase.preSort)
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(testCase.preSort, testCase.postSort) {
|
||||||
|
t.Fatalf("#%v: wrong order: expected %v, got %v", i,
|
||||||
|
spew.Sdump(testCase.preSort),
|
||||||
|
spew.Sdump(testCase.postSort))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRecordMapTransformation tests that we're able to properly morph a set of
|
||||||
|
// records into a map using TlvRecordsToMap, then the other way around using
|
||||||
|
// the MapToTlvRecords method.
|
||||||
|
func TestRecordMapTransformation(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tlvBytes := []byte{1, 2, 3, 4}
|
||||||
|
encoder := StubEncoder(tlvBytes)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
records []Record
|
||||||
|
|
||||||
|
tlvMap map[uint64][]byte
|
||||||
|
}{
|
||||||
|
// An empty set of records should yield an empty map, and the other
|
||||||
|
// way around.
|
||||||
|
{
|
||||||
|
records: []Record{},
|
||||||
|
tlvMap: map[uint64][]byte{},
|
||||||
|
},
|
||||||
|
|
||||||
|
// We should be able to transform this set of records, then obtain
|
||||||
|
// the records back in the same order.
|
||||||
|
{
|
||||||
|
records: []Record{
|
||||||
|
MakeStaticRecord(1, nil, 4, encoder, nil),
|
||||||
|
MakeStaticRecord(2, nil, 4, encoder, nil),
|
||||||
|
MakeStaticRecord(3, nil, 4, encoder, nil),
|
||||||
|
},
|
||||||
|
tlvMap: map[uint64][]byte{
|
||||||
|
1: tlvBytes,
|
||||||
|
2: tlvBytes,
|
||||||
|
3: tlvBytes,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, testCase := range testCases {
|
||||||
|
mappedRecords, err := RecordsToMap(testCase.records)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("#%v: unable to map records: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(mappedRecords, testCase.tlvMap) {
|
||||||
|
t.Fatalf("#%v: incorrect record map: expected %v, got %v",
|
||||||
|
i, spew.Sdump(testCase.tlvMap),
|
||||||
|
spew.Sdump(mappedRecords))
|
||||||
|
}
|
||||||
|
|
||||||
|
unmappedRecords, err := MapToRecords(mappedRecords)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("#%v: unable to unmap records: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(testCase.records); i++ {
|
||||||
|
if unmappedRecords[i].Type() != testCase.records[i].Type() {
|
||||||
|
t.Fatalf("#%v: wrong type: expected %v, got %v",
|
||||||
|
i, unmappedRecords[i].Type(),
|
||||||
|
testCase.records[i].Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := unmappedRecords[i].Encode(&b); err != nil {
|
||||||
|
t.Fatalf("#%v: unable to encode record: %v",
|
||||||
|
i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(b.Bytes(), tlvBytes) {
|
||||||
|
t.Fatalf("#%v: wrong raw record: "+
|
||||||
|
"expected %x, got %x",
|
||||||
|
i, tlvBytes, b.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
if unmappedRecords[i].Size() != testCase.records[0].Size() {
|
||||||
|
t.Fatalf("#%v: wrong size: expected %v, "+
|
||||||
|
"got %v", i,
|
||||||
|
unmappedRecords[i].Size(),
|
||||||
|
testCase.records[i].Size())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user