routerrpc+record: move custom set validation
This commit is contained in:
parent
75b94dec2b
commit
d16476e477
@ -231,8 +231,8 @@ func (r *RouterBackend) QueryRoutes(ctx context.Context,
|
|||||||
// If we have any TLV records destined for the final hop, then we'll
|
// If we have any TLV records destined for the final hop, then we'll
|
||||||
// attempt to decode them now into a form that the router can more
|
// attempt to decode them now into a form that the router can more
|
||||||
// easily manipulate.
|
// easily manipulate.
|
||||||
err = ValidateCustomRecords(in.DestCustomRecords)
|
customRecords := record.CustomSet(in.DestCustomRecords)
|
||||||
if err != nil {
|
if err := customRecords.Validate(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -241,7 +241,7 @@ func (r *RouterBackend) QueryRoutes(ctx context.Context,
|
|||||||
// the route.
|
// the route.
|
||||||
route, err := r.FindRoute(
|
route, err := r.FindRoute(
|
||||||
sourcePubKey, targetPubKey, amt, restrictions,
|
sourcePubKey, targetPubKey, amt, restrictions,
|
||||||
in.DestCustomRecords, finalCLTVDelta,
|
customRecords, finalCLTVDelta,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -366,25 +366,13 @@ func (r *RouterBackend) MarshallRoute(route *route.Route) (*lnrpc.Route, error)
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateCustomRecords checks that all custom records are in the custom type
|
|
||||||
// range.
|
|
||||||
func ValidateCustomRecords(rpcRecords map[uint64][]byte) error {
|
|
||||||
for key := range rpcRecords {
|
|
||||||
if key < record.CustomTypeStart {
|
|
||||||
return fmt.Errorf("no custom records with types "+
|
|
||||||
"below %v allowed", record.CustomTypeStart)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnmarshallHopWithPubkey unmarshalls an rpc hop for which the pubkey has
|
// UnmarshallHopWithPubkey unmarshalls an rpc hop for which the pubkey has
|
||||||
// already been extracted.
|
// already been extracted.
|
||||||
func UnmarshallHopWithPubkey(rpcHop *lnrpc.Hop, pubkey route.Vertex) (*route.Hop,
|
func UnmarshallHopWithPubkey(rpcHop *lnrpc.Hop, pubkey route.Vertex) (*route.Hop,
|
||||||
error) {
|
error) {
|
||||||
|
|
||||||
err := ValidateCustomRecords(rpcHop.CustomRecords)
|
customRecords := record.CustomSet(rpcHop.CustomRecords)
|
||||||
if err != nil {
|
if err := customRecords.Validate(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -398,7 +386,7 @@ func UnmarshallHopWithPubkey(rpcHop *lnrpc.Hop, pubkey route.Vertex) (*route.Hop
|
|||||||
AmtToForward: lnwire.MilliSatoshi(rpcHop.AmtToForwardMsat),
|
AmtToForward: lnwire.MilliSatoshi(rpcHop.AmtToForwardMsat),
|
||||||
PubKeyBytes: pubkey,
|
PubKeyBytes: pubkey,
|
||||||
ChannelID: rpcHop.ChanId,
|
ChannelID: rpcHop.ChanId,
|
||||||
CustomRecords: rpcHop.CustomRecords,
|
CustomRecords: customRecords,
|
||||||
LegacyPayload: !rpcHop.TlvPayload,
|
LegacyPayload: !rpcHop.TlvPayload,
|
||||||
MPP: mpp,
|
MPP: mpp,
|
||||||
}, nil
|
}, nil
|
||||||
@ -526,11 +514,11 @@ func (r *RouterBackend) extractIntentFromSendRequest(
|
|||||||
return nil, errors.New("timeout_seconds must be specified")
|
return nil, errors.New("timeout_seconds must be specified")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = ValidateCustomRecords(rpcPayReq.DestCustomRecords)
|
customRecords := record.CustomSet(rpcPayReq.DestCustomRecords)
|
||||||
if err != nil {
|
if err := customRecords.Validate(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
payIntent.DestCustomRecords = rpcPayReq.DestCustomRecords
|
payIntent.DestCustomRecords = customRecords
|
||||||
|
|
||||||
payIntent.PayAttemptTimeout = time.Second *
|
payIntent.PayAttemptTimeout = time.Second *
|
||||||
time.Duration(rpcPayReq.TimeoutSeconds)
|
time.Duration(rpcPayReq.TimeoutSeconds)
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
package record
|
package record
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// CustomTypeStart is the start of the custom tlv type range as defined
|
// CustomTypeStart is the start of the custom tlv type range as defined
|
||||||
// in BOLT 01.
|
// in BOLT 01.
|
||||||
@ -8,3 +10,15 @@ const (
|
|||||||
|
|
||||||
// CustomSet stores a set of custom key/value pairs.
|
// CustomSet stores a set of custom key/value pairs.
|
||||||
type CustomSet map[uint64][]byte
|
type CustomSet map[uint64][]byte
|
||||||
|
|
||||||
|
// Validate checks that all custom records are in the custom type range.
|
||||||
|
func (c CustomSet) Validate() error {
|
||||||
|
for key := range c {
|
||||||
|
if key < CustomTypeStart {
|
||||||
|
return fmt.Errorf("no custom records with types "+
|
||||||
|
"below %v allowed", CustomTypeStart)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -3162,13 +3162,11 @@ func (r *rpcServer) extractPaymentIntent(rpcPayReq *rpcPaymentRequest) (rpcPayme
|
|||||||
}
|
}
|
||||||
payIntent.cltvLimit = cltvLimit
|
payIntent.cltvLimit = cltvLimit
|
||||||
|
|
||||||
err = routerrpc.ValidateCustomRecords(
|
customRecords := record.CustomSet(rpcPayReq.DestCustomRecords)
|
||||||
rpcPayReq.DestCustomRecords,
|
if err := customRecords.Validate(); err != nil {
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return payIntent, err
|
return payIntent, err
|
||||||
}
|
}
|
||||||
payIntent.destCustomRecords = rpcPayReq.DestCustomRecords
|
payIntent.destCustomRecords = customRecords
|
||||||
|
|
||||||
validateDest := func(dest route.Vertex) error {
|
validateDest := func(dest route.Vertex) error {
|
||||||
if rpcPayReq.AllowSelfPayment {
|
if rpcPayReq.AllowSelfPayment {
|
||||||
|
Loading…
Reference in New Issue
Block a user