From d16476e47745497deaa8c85c38230a4be9403d6d Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Thu, 12 Dec 2019 14:43:36 +0100 Subject: [PATCH] routerrpc+record: move custom set validation --- lnrpc/routerrpc/router_backend.go | 30 +++++++++--------------------- record/custom_records.go | 14 ++++++++++++++ rpcserver.go | 8 +++----- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index 63bec96e..7722a3c8 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -231,8 +231,8 @@ func (r *RouterBackend) QueryRoutes(ctx context.Context, // If we have any TLV records destined for the final hop, then we'll // attempt to decode them now into a form that the router can more // easily manipulate. - err = ValidateCustomRecords(in.DestCustomRecords) - if err != nil { + customRecords := record.CustomSet(in.DestCustomRecords) + if err := customRecords.Validate(); err != nil { return nil, err } @@ -241,7 +241,7 @@ func (r *RouterBackend) QueryRoutes(ctx context.Context, // the route. route, err := r.FindRoute( sourcePubKey, targetPubKey, amt, restrictions, - in.DestCustomRecords, finalCLTVDelta, + customRecords, finalCLTVDelta, ) if err != nil { return nil, err @@ -366,25 +366,13 @@ func (r *RouterBackend) MarshallRoute(route *route.Route) (*lnrpc.Route, error) return resp, nil } -// ValidateCustomRecords checks that all custom records are in the custom type -// range. -func ValidateCustomRecords(rpcRecords map[uint64][]byte) error { - for key := range rpcRecords { - if key < record.CustomTypeStart { - return fmt.Errorf("no custom records with types "+ - "below %v allowed", record.CustomTypeStart) - } - } - return nil -} - // UnmarshallHopWithPubkey unmarshalls an rpc hop for which the pubkey has // already been extracted. func UnmarshallHopWithPubkey(rpcHop *lnrpc.Hop, pubkey route.Vertex) (*route.Hop, error) { - err := ValidateCustomRecords(rpcHop.CustomRecords) - if err != nil { + customRecords := record.CustomSet(rpcHop.CustomRecords) + if err := customRecords.Validate(); err != nil { return nil, err } @@ -398,7 +386,7 @@ func UnmarshallHopWithPubkey(rpcHop *lnrpc.Hop, pubkey route.Vertex) (*route.Hop AmtToForward: lnwire.MilliSatoshi(rpcHop.AmtToForwardMsat), PubKeyBytes: pubkey, ChannelID: rpcHop.ChanId, - CustomRecords: rpcHop.CustomRecords, + CustomRecords: customRecords, LegacyPayload: !rpcHop.TlvPayload, MPP: mpp, }, nil @@ -526,11 +514,11 @@ func (r *RouterBackend) extractIntentFromSendRequest( return nil, errors.New("timeout_seconds must be specified") } - err = ValidateCustomRecords(rpcPayReq.DestCustomRecords) - if err != nil { + customRecords := record.CustomSet(rpcPayReq.DestCustomRecords) + if err := customRecords.Validate(); err != nil { return nil, err } - payIntent.DestCustomRecords = rpcPayReq.DestCustomRecords + payIntent.DestCustomRecords = customRecords payIntent.PayAttemptTimeout = time.Second * time.Duration(rpcPayReq.TimeoutSeconds) diff --git a/record/custom_records.go b/record/custom_records.go index 36e9e5ac..f9e4f342 100644 --- a/record/custom_records.go +++ b/record/custom_records.go @@ -1,5 +1,7 @@ package record +import "fmt" + const ( // CustomTypeStart is the start of the custom tlv type range as defined // in BOLT 01. @@ -8,3 +10,15 @@ const ( // CustomSet stores a set of custom key/value pairs. type CustomSet map[uint64][]byte + +// Validate checks that all custom records are in the custom type range. +func (c CustomSet) Validate() error { + for key := range c { + if key < CustomTypeStart { + return fmt.Errorf("no custom records with types "+ + "below %v allowed", CustomTypeStart) + } + } + + return nil +} diff --git a/rpcserver.go b/rpcserver.go index bb4301cd..1339f4bf 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -3162,13 +3162,11 @@ func (r *rpcServer) extractPaymentIntent(rpcPayReq *rpcPaymentRequest) (rpcPayme } payIntent.cltvLimit = cltvLimit - err = routerrpc.ValidateCustomRecords( - rpcPayReq.DestCustomRecords, - ) - if err != nil { + customRecords := record.CustomSet(rpcPayReq.DestCustomRecords) + if err := customRecords.Validate(); err != nil { return payIntent, err } - payIntent.destCustomRecords = rpcPayReq.DestCustomRecords + payIntent.destCustomRecords = customRecords validateDest := func(dest route.Vertex) error { if rpcPayReq.AllowSelfPayment {