multi: do not use tlv.Record outside wire format handling
This commit prepares for more manipulation of custom records. A list of tlv.Record types is more difficult to use than the more basic map[uint64][]byte. Furthermore fields and variables are renamed to make them more consistent.
This commit is contained in:
parent
8b5bb0ac63
commit
d02de70d20
@ -606,7 +606,9 @@ func serializeHop(w io.Writer, h *route.Hop) error {
|
||||
if h.MPP != nil {
|
||||
records = append(records, h.MPP.Record())
|
||||
}
|
||||
records = append(records, h.TLVRecords...)
|
||||
|
||||
tlvRecords := tlv.MapToRecords(h.CustomRecords)
|
||||
records = append(records, tlvRecords...)
|
||||
|
||||
// Otherwise, we'll transform our slice of records into a map of the
|
||||
// raw bytes, then serialize them in-line with a length (number of
|
||||
@ -710,7 +712,7 @@ func deserializeHop(r io.Reader) (*route.Hop, error) {
|
||||
h.MPP = mpp
|
||||
}
|
||||
|
||||
h.TLVRecords = tlv.MapToRecords(tlvMap)
|
||||
h.CustomRecords = tlvMap
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ package channeldb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
@ -28,9 +27,9 @@ var (
|
||||
ChannelID: 12345,
|
||||
OutgoingTimeLock: 111,
|
||||
AmtToForward: 555,
|
||||
TLVRecords: []tlv.Record{
|
||||
tlv.MakeStaticRecord(1, nil, 3, tlvEncoder, nil),
|
||||
tlv.MakeStaticRecord(2, nil, 3, tlvEncoder, nil),
|
||||
CustomRecords: record.CustomSet{
|
||||
1: []byte{},
|
||||
2: []byte{},
|
||||
},
|
||||
MPP: record.NewMPP(32, [32]byte{0x42}),
|
||||
}
|
||||
@ -144,25 +143,7 @@ func TestSentPaymentSerialization(t *testing.T) {
|
||||
// assertRouteEquals compares to routes for equality and returns an error if
|
||||
// they are not equal.
|
||||
func assertRouteEqual(a, b *route.Route) error {
|
||||
err := assertRouteHopRecordsEqual(a, b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TLV records have already been compared and need to be cleared to
|
||||
// properly compare the remaining fields using DeepEqual.
|
||||
copyRouteNoHops := func(r *route.Route) *route.Route {
|
||||
copy := *r
|
||||
copy.Hops = make([]*route.Hop, len(r.Hops))
|
||||
for i, hop := range r.Hops {
|
||||
hopCopy := *hop
|
||||
hopCopy.TLVRecords = nil
|
||||
copy.Hops[i] = &hopCopy
|
||||
}
|
||||
return ©
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(copyRouteNoHops(a), copyRouteNoHops(b)) {
|
||||
if !reflect.DeepEqual(a, b) {
|
||||
return fmt.Errorf("PaymentAttemptInfos don't match: %v vs %v",
|
||||
spew.Sdump(a), spew.Sdump(b))
|
||||
}
|
||||
@ -170,57 +151,6 @@ func assertRouteEqual(a, b *route.Route) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func assertRouteHopRecordsEqual(r1, r2 *route.Route) error {
|
||||
if len(r1.Hops) != len(r2.Hops) {
|
||||
return errors.New("route hop count mismatch")
|
||||
}
|
||||
|
||||
for i := 0; i < len(r1.Hops); i++ {
|
||||
records1 := r1.Hops[i].TLVRecords
|
||||
records2 := r2.Hops[i].TLVRecords
|
||||
if len(records1) != len(records2) {
|
||||
return fmt.Errorf("route record count for hop %v "+
|
||||
"mismatch", i)
|
||||
}
|
||||
|
||||
for j := 0; j < len(records1); j++ {
|
||||
expectedRecord := records1[j]
|
||||
newRecord := records2[j]
|
||||
|
||||
err := assertHopRecordsEqual(expectedRecord, newRecord)
|
||||
if err != nil {
|
||||
return fmt.Errorf("route record mismatch: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func assertHopRecordsEqual(h1, h2 tlv.Record) error {
|
||||
if h1.Type() != h2.Type() {
|
||||
return fmt.Errorf("wrong type: expected %v, got %v", h1.Type(),
|
||||
h2.Type())
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := h2.Encode(&b); err != nil {
|
||||
return fmt.Errorf("unable to encode record: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(b.Bytes(), tlvBytes) {
|
||||
return fmt.Errorf("wrong raw record: expected %x, got %x",
|
||||
tlvBytes, b.Bytes())
|
||||
}
|
||||
|
||||
if h1.Size() != h2.Size() {
|
||||
return fmt.Errorf("wrong size: expected %v, "+
|
||||
"got %v", h1.Size(), h2.Size())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRouteSerialization(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -19,7 +19,6 @@ import (
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
"github.com/lightningnetwork/lnd/routing"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
"github.com/lightningnetwork/lnd/zpay32"
|
||||
)
|
||||
|
||||
@ -45,7 +44,7 @@ type RouterBackend struct {
|
||||
// routes.
|
||||
FindRoute func(source, target route.Vertex,
|
||||
amt lnwire.MilliSatoshi, restrictions *routing.RestrictParams,
|
||||
destTlvRecords []tlv.Record,
|
||||
destCustomRecords record.CustomSet,
|
||||
finalExpiry ...uint16) (*route.Route, error)
|
||||
|
||||
MissionControl MissionControl
|
||||
@ -232,7 +231,7 @@ func (r *RouterBackend) QueryRoutes(ctx context.Context,
|
||||
// If we have any TLV records destined for the final hop, then we'll
|
||||
// attempt to decode them now into a form that the router can more
|
||||
// easily manipulate.
|
||||
destTlvRecords, err := UnmarshallCustomRecords(in.DestCustomRecords)
|
||||
err = ValidateCustomRecords(in.DestCustomRecords)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -242,7 +241,7 @@ func (r *RouterBackend) QueryRoutes(ctx context.Context,
|
||||
// the route.
|
||||
route, err := r.FindRoute(
|
||||
sourcePubKey, targetPubKey, amt, restrictions,
|
||||
destTlvRecords, finalCLTVDelta,
|
||||
in.DestCustomRecords, finalCLTVDelta,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -346,11 +345,6 @@ func (r *RouterBackend) MarshallRoute(route *route.Route) (*lnrpc.Route, error)
|
||||
}
|
||||
}
|
||||
|
||||
tlvMap, err := tlv.RecordsToMap(hop.TLVRecords)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp.Hops[i] = &lnrpc.Hop{
|
||||
ChanId: hop.ChannelID,
|
||||
ChanCapacity: int64(chanCapacity),
|
||||
@ -362,7 +356,7 @@ func (r *RouterBackend) MarshallRoute(route *route.Route) (*lnrpc.Route, error)
|
||||
PubKey: hex.EncodeToString(
|
||||
hop.PubKeyBytes[:],
|
||||
),
|
||||
CustomRecords: tlvMap,
|
||||
CustomRecords: hop.CustomRecords,
|
||||
TlvPayload: !hop.LegacyPayload,
|
||||
MppRecord: mpp,
|
||||
}
|
||||
@ -372,24 +366,16 @@ func (r *RouterBackend) MarshallRoute(route *route.Route) (*lnrpc.Route, error)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// UnmarshallCustomRecords unmarshall rpc custom records to tlv records.
|
||||
func UnmarshallCustomRecords(rpcRecords map[uint64][]byte) ([]tlv.Record,
|
||||
error) {
|
||||
|
||||
if len(rpcRecords) == 0 {
|
||||
return nil, nil
|
||||
// ValidateCustomRecords checks that all custom records are in the custom type
|
||||
// range.
|
||||
func ValidateCustomRecords(rpcRecords map[uint64][]byte) error {
|
||||
for key := range rpcRecords {
|
||||
if key < record.CustomTypeStart {
|
||||
return fmt.Errorf("no custom records with types "+
|
||||
"below %v allowed", record.CustomTypeStart)
|
||||
}
|
||||
}
|
||||
|
||||
tlvRecords := tlv.MapToRecords(rpcRecords)
|
||||
|
||||
// tlvRecords is sorted, so we only need to check that the first
|
||||
// element is within the custom range.
|
||||
if uint64(tlvRecords[0].Type()) < record.CustomTypeStart {
|
||||
return nil, fmt.Errorf("no custom records with types "+
|
||||
"below %v allowed", record.CustomTypeStart)
|
||||
}
|
||||
|
||||
return tlvRecords, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshallHopWithPubkey unmarshalls an rpc hop for which the pubkey has
|
||||
@ -397,7 +383,7 @@ func UnmarshallCustomRecords(rpcRecords map[uint64][]byte) ([]tlv.Record,
|
||||
func UnmarshallHopWithPubkey(rpcHop *lnrpc.Hop, pubkey route.Vertex) (*route.Hop,
|
||||
error) {
|
||||
|
||||
tlvRecords, err := UnmarshallCustomRecords(rpcHop.CustomRecords)
|
||||
err := ValidateCustomRecords(rpcHop.CustomRecords)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -412,7 +398,7 @@ func UnmarshallHopWithPubkey(rpcHop *lnrpc.Hop, pubkey route.Vertex) (*route.Hop
|
||||
AmtToForward: lnwire.MilliSatoshi(rpcHop.AmtToForwardMsat),
|
||||
PubKeyBytes: pubkey,
|
||||
ChannelID: rpcHop.ChanId,
|
||||
TLVRecords: tlvRecords,
|
||||
CustomRecords: rpcHop.CustomRecords,
|
||||
LegacyPayload: !rpcHop.TlvPayload,
|
||||
MPP: mpp,
|
||||
}, nil
|
||||
@ -540,12 +526,11 @@ func (r *RouterBackend) extractIntentFromSendRequest(
|
||||
return nil, errors.New("timeout_seconds must be specified")
|
||||
}
|
||||
|
||||
payIntent.FinalDestRecords, err = UnmarshallCustomRecords(
|
||||
rpcPayReq.DestCustomRecords,
|
||||
)
|
||||
err = ValidateCustomRecords(rpcPayReq.DestCustomRecords)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payIntent.DestCustomRecords = rpcPayReq.DestCustomRecords
|
||||
|
||||
payIntent.PayAttemptTimeout = time.Second *
|
||||
time.Duration(rpcPayReq.TimeoutSeconds)
|
||||
|
@ -8,9 +8,9 @@ import (
|
||||
|
||||
"github.com/btcsuite/btcutil"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
"github.com/lightningnetwork/lnd/routing"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
|
||||
"github.com/lightningnetwork/lnd/lnrpc"
|
||||
)
|
||||
@ -92,7 +92,7 @@ func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool) {
|
||||
|
||||
findRoute := func(source, target route.Vertex,
|
||||
amt lnwire.MilliSatoshi, restrictions *routing.RestrictParams,
|
||||
_ []tlv.Record,
|
||||
_ record.CustomSet,
|
||||
finalExpiry ...uint16) (*route.Route, error) {
|
||||
|
||||
if int64(amt) != amtSat*1000 {
|
||||
|
@ -11,8 +11,8 @@ import (
|
||||
"github.com/coreos/bbolt"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -100,7 +100,7 @@ type edgePolicyWithSource struct {
|
||||
func newRoute(amtToSend lnwire.MilliSatoshi, sourceVertex route.Vertex,
|
||||
pathEdges []*channeldb.ChannelEdgePolicy, currentHeight uint32,
|
||||
finalCLTVDelta uint16,
|
||||
finalDestRecords []tlv.Record) (*route.Route, error) {
|
||||
destCustomRecords record.CustomSet) (*route.Route, error) {
|
||||
|
||||
var (
|
||||
hops []*route.Hop
|
||||
@ -198,8 +198,8 @@ func newRoute(amtToSend lnwire.MilliSatoshi, sourceVertex route.Vertex,
|
||||
|
||||
// If this is the last hop, then we'll populate any TLV records
|
||||
// destined for it.
|
||||
if i == len(pathEdges)-1 && len(finalDestRecords) != 0 {
|
||||
currentHop.TLVRecords = finalDestRecords
|
||||
if i == len(pathEdges)-1 && len(destCustomRecords) != 0 {
|
||||
currentHop.CustomRecords = destCustomRecords
|
||||
}
|
||||
|
||||
hops = append([]*route.Hop{currentHop}, hops...)
|
||||
|
@ -129,7 +129,7 @@ func (p *paymentSession) RequestRoute(payment *LightningPayment,
|
||||
sourceVertex := route.Vertex(ss.SelfNode.PubKeyBytes)
|
||||
route, err := newRoute(
|
||||
payment.Amount, sourceVertex, path, height, finalCltvDelta,
|
||||
payment.FinalDestRecords,
|
||||
payment.DestCustomRecords,
|
||||
)
|
||||
if err != nil {
|
||||
// TODO(roasbeef): return which edge/vertex didn't work
|
||||
|
@ -107,9 +107,9 @@ type Hop struct {
|
||||
// only be set for the final hop.
|
||||
MPP *record.MPP
|
||||
|
||||
// TLVRecords if non-nil are a set of additional TLV records that
|
||||
// CustomRecords if non-nil are a set of additional TLV records that
|
||||
// should be included in the forwarding instructions for this node.
|
||||
TLVRecords []tlv.Record
|
||||
CustomRecords record.CustomSet
|
||||
|
||||
// LegacyPayload if true, then this signals that this node doesn't
|
||||
// understand the new TLV payload, so we must instead use the legacy
|
||||
@ -165,7 +165,8 @@ func (h *Hop) PackHopPayload(w io.Writer, nextChanID uint64) error {
|
||||
}
|
||||
|
||||
// Append any custom types destined for this hop.
|
||||
records = append(records, h.TLVRecords...)
|
||||
tlvRecords := tlv.MapToRecords(h.CustomRecords)
|
||||
records = append(records, tlvRecords...)
|
||||
|
||||
// To ensure we produce a canonical stream, we'll sort the records
|
||||
// before encoding them as a stream in the hop payload.
|
||||
|
@ -24,10 +24,10 @@ import (
|
||||
"github.com/lightningnetwork/lnd/lnwallet/chanvalidate"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/multimutex"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
"github.com/lightningnetwork/lnd/routing/chainview"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
"github.com/lightningnetwork/lnd/ticker"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
"github.com/lightningnetwork/lnd/zpay32"
|
||||
)
|
||||
|
||||
@ -1401,7 +1401,7 @@ type routingMsg struct {
|
||||
// factoring in channel capacities and cumulative fees along the route.
|
||||
func (r *ChannelRouter) FindRoute(source, target route.Vertex,
|
||||
amt lnwire.MilliSatoshi, restrictions *RestrictParams,
|
||||
destTlvRecords []tlv.Record,
|
||||
destCustomRecords record.CustomSet,
|
||||
finalExpiry ...uint16) (*route.Route, error) {
|
||||
|
||||
var finalCLTVDelta uint16
|
||||
@ -1455,7 +1455,7 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex,
|
||||
// Create the route with absolute time lock values.
|
||||
route, err := newRoute(
|
||||
amt, source, path, uint32(currentHeight), finalCLTVDelta,
|
||||
destTlvRecords,
|
||||
destCustomRecords,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -1608,11 +1608,11 @@ type LightningPayment struct {
|
||||
// attempting to complete.
|
||||
PaymentRequest []byte
|
||||
|
||||
// FinalDestRecords are TLV records that are to be sent to the final
|
||||
// DestCustomRecords are TLV records that are to be sent to the final
|
||||
// hop in the new onion payload format. If the destination does not
|
||||
// understand this new onion payload format, then the payment will
|
||||
// fail.
|
||||
FinalDestRecords []tlv.Record
|
||||
DestCustomRecords record.CustomSet
|
||||
}
|
||||
|
||||
// SendPayment attempts to send a payment as described within the passed
|
||||
|
@ -50,11 +50,11 @@ import (
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/macaroons"
|
||||
"github.com/lightningnetwork/lnd/monitoring"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
"github.com/lightningnetwork/lnd/routing"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
"github.com/lightningnetwork/lnd/signal"
|
||||
"github.com/lightningnetwork/lnd/sweep"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
"github.com/lightningnetwork/lnd/watchtower"
|
||||
"github.com/lightningnetwork/lnd/zpay32"
|
||||
"github.com/tv42/zbase32"
|
||||
@ -3099,7 +3099,7 @@ type rpcPaymentIntent struct {
|
||||
lastHop *route.Vertex
|
||||
payReq []byte
|
||||
|
||||
destTLV []tlv.Record
|
||||
destCustomRecords record.CustomSet
|
||||
|
||||
route *route.Route
|
||||
}
|
||||
@ -3158,12 +3158,13 @@ func (r *rpcServer) extractPaymentIntent(rpcPayReq *rpcPaymentRequest) (rpcPayme
|
||||
}
|
||||
payIntent.cltvLimit = cltvLimit
|
||||
|
||||
payIntent.destTLV, err = routerrpc.UnmarshallCustomRecords(
|
||||
err = routerrpc.ValidateCustomRecords(
|
||||
rpcPayReq.DestCustomRecords,
|
||||
)
|
||||
if err != nil {
|
||||
return payIntent, err
|
||||
}
|
||||
payIntent.destCustomRecords = rpcPayReq.DestCustomRecords
|
||||
|
||||
validateDest := func(dest route.Vertex) error {
|
||||
if rpcPayReq.AllowSelfPayment {
|
||||
@ -3348,7 +3349,7 @@ func (r *rpcServer) dispatchPaymentIntent(
|
||||
LastHop: payIntent.lastHop,
|
||||
PaymentRequest: payIntent.payReq,
|
||||
PayAttemptTimeout: routing.DefaultPayAttemptTimeout,
|
||||
FinalDestRecords: payIntent.destTLV,
|
||||
DestCustomRecords: payIntent.destCustomRecords,
|
||||
}
|
||||
|
||||
preImage, route, routerErr = r.server.chanRouter.SendPayment(
|
||||
|
Loading…
Reference in New Issue
Block a user