Merge pull request #3828 from joostjager/custom-records-sanity

channeldb: custom records sanity check
This commit is contained in:
Conner Fromknecht 2019-12-12 11:36:27 -08:00 committed by GitHub
commit 1901f59c07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 37 additions and 28 deletions

@ -607,6 +607,15 @@ func serializeHop(w io.Writer, h *route.Hop) error {
records = append(records, h.MPP.Record()) records = append(records, h.MPP.Record())
} }
// Final sanity check to absolutely rule out custom records that are not
// custom and write into the standard range.
if err := h.CustomRecords.Validate(); err != nil {
return err
}
// Convert custom records to tlv and add to the record list.
// MapToRecords sorts the list, so adding it here will keep the list
// canonical.
tlvRecords := tlv.MapToRecords(h.CustomRecords) tlvRecords := tlv.MapToRecords(h.CustomRecords)
records = append(records, tlvRecords...) records = append(records, tlvRecords...)

@ -28,8 +28,8 @@ var (
OutgoingTimeLock: 111, OutgoingTimeLock: 111,
AmtToForward: 555, AmtToForward: 555,
CustomRecords: record.CustomSet{ CustomRecords: record.CustomSet{
1: []byte{}, 65536: []byte{},
2: []byte{}, 80001: []byte{},
}, },
MPP: record.NewMPP(32, [32]byte{0x42}), MPP: record.NewMPP(32, [32]byte{0x42}),
} }

@ -231,8 +231,8 @@ func (r *RouterBackend) QueryRoutes(ctx context.Context,
// If we have any TLV records destined for the final hop, then we'll // If we have any TLV records destined for the final hop, then we'll
// attempt to decode them now into a form that the router can more // attempt to decode them now into a form that the router can more
// easily manipulate. // easily manipulate.
err = ValidateCustomRecords(in.DestCustomRecords) customRecords := record.CustomSet(in.DestCustomRecords)
if err != nil { if err := customRecords.Validate(); err != nil {
return nil, err return nil, err
} }
@ -241,7 +241,7 @@ func (r *RouterBackend) QueryRoutes(ctx context.Context,
// the route. // the route.
route, err := r.FindRoute( route, err := r.FindRoute(
sourcePubKey, targetPubKey, amt, restrictions, sourcePubKey, targetPubKey, amt, restrictions,
in.DestCustomRecords, finalCLTVDelta, customRecords, finalCLTVDelta,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -366,25 +366,13 @@ func (r *RouterBackend) MarshallRoute(route *route.Route) (*lnrpc.Route, error)
return resp, nil return resp, nil
} }
// ValidateCustomRecords checks that all custom records are in the custom type
// range.
func ValidateCustomRecords(rpcRecords map[uint64][]byte) error {
for key := range rpcRecords {
if key < record.CustomTypeStart {
return fmt.Errorf("no custom records with types "+
"below %v allowed", record.CustomTypeStart)
}
}
return nil
}
// UnmarshallHopWithPubkey unmarshalls an rpc hop for which the pubkey has // UnmarshallHopWithPubkey unmarshalls an rpc hop for which the pubkey has
// already been extracted. // already been extracted.
func UnmarshallHopWithPubkey(rpcHop *lnrpc.Hop, pubkey route.Vertex) (*route.Hop, func UnmarshallHopWithPubkey(rpcHop *lnrpc.Hop, pubkey route.Vertex) (*route.Hop,
error) { error) {
err := ValidateCustomRecords(rpcHop.CustomRecords) customRecords := record.CustomSet(rpcHop.CustomRecords)
if err != nil { if err := customRecords.Validate(); err != nil {
return nil, err return nil, err
} }
@ -398,7 +386,7 @@ func UnmarshallHopWithPubkey(rpcHop *lnrpc.Hop, pubkey route.Vertex) (*route.Hop
AmtToForward: lnwire.MilliSatoshi(rpcHop.AmtToForwardMsat), AmtToForward: lnwire.MilliSatoshi(rpcHop.AmtToForwardMsat),
PubKeyBytes: pubkey, PubKeyBytes: pubkey,
ChannelID: rpcHop.ChanId, ChannelID: rpcHop.ChanId,
CustomRecords: rpcHop.CustomRecords, CustomRecords: customRecords,
LegacyPayload: !rpcHop.TlvPayload, LegacyPayload: !rpcHop.TlvPayload,
MPP: mpp, MPP: mpp,
}, nil }, nil
@ -526,11 +514,11 @@ func (r *RouterBackend) extractIntentFromSendRequest(
return nil, errors.New("timeout_seconds must be specified") return nil, errors.New("timeout_seconds must be specified")
} }
err = ValidateCustomRecords(rpcPayReq.DestCustomRecords) customRecords := record.CustomSet(rpcPayReq.DestCustomRecords)
if err != nil { if err := customRecords.Validate(); err != nil {
return nil, err return nil, err
} }
payIntent.DestCustomRecords = rpcPayReq.DestCustomRecords payIntent.DestCustomRecords = customRecords
payIntent.PayAttemptTimeout = time.Second * payIntent.PayAttemptTimeout = time.Second *
time.Duration(rpcPayReq.TimeoutSeconds) time.Duration(rpcPayReq.TimeoutSeconds)

@ -1,5 +1,7 @@
package record package record
import "fmt"
const ( const (
// CustomTypeStart is the start of the custom tlv type range as defined // CustomTypeStart is the start of the custom tlv type range as defined
// in BOLT 01. // in BOLT 01.
@ -8,3 +10,15 @@ const (
// CustomSet stores a set of custom key/value pairs. // CustomSet stores a set of custom key/value pairs.
type CustomSet map[uint64][]byte type CustomSet map[uint64][]byte
// Validate checks that all custom records are in the custom type range.
func (c CustomSet) Validate() error {
for key := range c {
if key < CustomTypeStart {
return fmt.Errorf("no custom records with types "+
"below %v allowed", CustomTypeStart)
}
}
return nil
}

@ -3162,13 +3162,11 @@ func (r *rpcServer) extractPaymentIntent(rpcPayReq *rpcPaymentRequest) (rpcPayme
} }
payIntent.cltvLimit = cltvLimit payIntent.cltvLimit = cltvLimit
err = routerrpc.ValidateCustomRecords( customRecords := record.CustomSet(rpcPayReq.DestCustomRecords)
rpcPayReq.DestCustomRecords, if err := customRecords.Validate(); err != nil {
)
if err != nil {
return payIntent, err return payIntent, err
} }
payIntent.destCustomRecords = rpcPayReq.DestCustomRecords payIntent.destCustomRecords = customRecords
validateDest := func(dest route.Vertex) error { validateDest := func(dest route.Vertex) error {
if rpcPayReq.AllowSelfPayment { if rpcPayReq.AllowSelfPayment {