Merge pull request #3465 from cfromknecht/tlv-parsed-types
tlv+htlcswitch: validate presence/omission of parsed onion payload types
This commit is contained in:
commit
2cf10ada0f
@ -2,6 +2,7 @@ package hop
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/lightningnetwork/lightning-onion"
|
"github.com/lightningnetwork/lightning-onion"
|
||||||
@ -10,6 +11,37 @@ import (
|
|||||||
"github.com/lightningnetwork/lnd/tlv"
|
"github.com/lightningnetwork/lnd/tlv"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrInvalidPayload is an error returned when a parsed onion payload either
|
||||||
|
// included or omitted incorrect records for a particular hop type.
|
||||||
|
type ErrInvalidPayload struct {
|
||||||
|
// Type the record's type that cause the violation.
|
||||||
|
Type tlv.Type
|
||||||
|
|
||||||
|
// Ommitted if true, signals that the sender did not include the record.
|
||||||
|
// Otherwise, the sender included the record when it shouldn't have.
|
||||||
|
Omitted bool
|
||||||
|
|
||||||
|
// FinalHop if true, indicates that the violation is for the final hop
|
||||||
|
// in the route (identified by next hop id), otherwise the violation is
|
||||||
|
// for an intermediate hop.
|
||||||
|
FinalHop bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error returns a human-readable description of the invalid payload error.
|
||||||
|
func (e ErrInvalidPayload) Error() string {
|
||||||
|
hopType := "intermediate"
|
||||||
|
if e.FinalHop {
|
||||||
|
hopType = "final"
|
||||||
|
}
|
||||||
|
violation := "included"
|
||||||
|
if e.Omitted {
|
||||||
|
violation = "omitted"
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("onion payload for %s hop %s record with type %d",
|
||||||
|
hopType, violation, e.Type)
|
||||||
|
}
|
||||||
|
|
||||||
// Payload encapsulates all information delivered to a hop in an onion payload.
|
// Payload encapsulates all information delivered to a hop in an onion payload.
|
||||||
// A Hop can represent either a TLV or legacy payload. The primary forwarding
|
// A Hop can represent either a TLV or legacy payload. The primary forwarding
|
||||||
// instruction can be accessed via ForwardingInfo, and additional records can be
|
// instruction can be accessed via ForwardingInfo, and additional records can be
|
||||||
@ -53,7 +85,16 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tlvStream.Decode(r)
|
parsedTypes, err := tlvStream.DecodeWithParsedTypes(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
nextHop := lnwire.NewShortChanIDFromInt(cid)
|
||||||
|
|
||||||
|
// Validate whether the sender properly included or omitted tlv records
|
||||||
|
// in accordance with BOLT 04.
|
||||||
|
err = ValidateParsedPayloadTypes(parsedTypes, nextHop)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -61,7 +102,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
|||||||
return &Payload{
|
return &Payload{
|
||||||
FwdInfo: ForwardingInfo{
|
FwdInfo: ForwardingInfo{
|
||||||
Network: BitcoinNetwork,
|
Network: BitcoinNetwork,
|
||||||
NextHop: lnwire.NewShortChanIDFromInt(cid),
|
NextHop: nextHop,
|
||||||
AmountToForward: lnwire.MilliSatoshi(amt),
|
AmountToForward: lnwire.MilliSatoshi(amt),
|
||||||
OutgoingCTLV: cltv,
|
OutgoingCTLV: cltv,
|
||||||
},
|
},
|
||||||
@ -73,3 +114,48 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
|||||||
func (h *Payload) ForwardingInfo() ForwardingInfo {
|
func (h *Payload) ForwardingInfo() ForwardingInfo {
|
||||||
return h.FwdInfo
|
return h.FwdInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidateParsedPayloadTypes checks the types parsed from a hop payload to
|
||||||
|
// ensure that the proper fields are either included or omitted. The finalHop
|
||||||
|
// boolean should be true if the payload was parsed for an exit hop. The
|
||||||
|
// requirements for this method are described in BOLT 04.
|
||||||
|
func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet,
|
||||||
|
nextHop lnwire.ShortChannelID) error {
|
||||||
|
|
||||||
|
isFinalHop := nextHop == Exit
|
||||||
|
|
||||||
|
_, hasAmt := parsedTypes[record.AmtOnionType]
|
||||||
|
_, hasLockTime := parsedTypes[record.LockTimeOnionType]
|
||||||
|
_, hasNextHop := parsedTypes[record.NextHopOnionType]
|
||||||
|
|
||||||
|
switch {
|
||||||
|
|
||||||
|
// All hops must include an amount to forward.
|
||||||
|
case !hasAmt:
|
||||||
|
return ErrInvalidPayload{
|
||||||
|
Type: record.AmtOnionType,
|
||||||
|
Omitted: true,
|
||||||
|
FinalHop: isFinalHop,
|
||||||
|
}
|
||||||
|
|
||||||
|
// All hops must include a cltv expiry.
|
||||||
|
case !hasLockTime:
|
||||||
|
return ErrInvalidPayload{
|
||||||
|
Type: record.LockTimeOnionType,
|
||||||
|
Omitted: true,
|
||||||
|
FinalHop: isFinalHop,
|
||||||
|
}
|
||||||
|
|
||||||
|
// The exit hop should omit the next hop id. If nextHop != Exit, the
|
||||||
|
// sender must have included a record, so we don't need to test for its
|
||||||
|
// inclusion at intermediate hops directly.
|
||||||
|
case isFinalHop && hasNextHop:
|
||||||
|
return ErrInvalidPayload{
|
||||||
|
Type: record.NextHopOnionType,
|
||||||
|
Omitted: false,
|
||||||
|
FinalHop: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
89
htlcswitch/hop/payload_test.go
Normal file
89
htlcswitch/hop/payload_test.go
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
package hop_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/htlcswitch/hop"
|
||||||
|
"github.com/lightningnetwork/lnd/record"
|
||||||
|
)
|
||||||
|
|
||||||
|
type decodePayloadTest struct {
|
||||||
|
name string
|
||||||
|
payload []byte
|
||||||
|
expErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
var decodePayloadTests = []decodePayloadTest{
|
||||||
|
{
|
||||||
|
name: "final hop no amount",
|
||||||
|
payload: []byte{0x04, 0x00},
|
||||||
|
expErr: hop.ErrInvalidPayload{
|
||||||
|
Type: record.AmtOnionType,
|
||||||
|
Omitted: true,
|
||||||
|
FinalHop: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "intermediate hop no amount",
|
||||||
|
payload: []byte{0x04, 0x00, 0x06, 0x08, 0x01, 0x00, 0x00, 0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00,
|
||||||
|
},
|
||||||
|
expErr: hop.ErrInvalidPayload{
|
||||||
|
Type: record.AmtOnionType,
|
||||||
|
Omitted: true,
|
||||||
|
FinalHop: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "final hop no expiry",
|
||||||
|
payload: []byte{0x02, 0x00},
|
||||||
|
expErr: hop.ErrInvalidPayload{
|
||||||
|
Type: record.LockTimeOnionType,
|
||||||
|
Omitted: true,
|
||||||
|
FinalHop: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "intermediate hop no expiry",
|
||||||
|
payload: []byte{0x02, 0x00, 0x06, 0x08, 0x01, 0x00, 0x00, 0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00,
|
||||||
|
},
|
||||||
|
expErr: hop.ErrInvalidPayload{
|
||||||
|
Type: record.LockTimeOnionType,
|
||||||
|
Omitted: true,
|
||||||
|
FinalHop: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "final hop next sid present",
|
||||||
|
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x00, 0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
},
|
||||||
|
expErr: hop.ErrInvalidPayload{
|
||||||
|
Type: record.NextHopOnionType,
|
||||||
|
Omitted: false,
|
||||||
|
FinalHop: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDecodeHopPayloadRecordValidation asserts that parsing the payloads in the
|
||||||
|
// tests yields the expected errors depending on whether the proper fields were
|
||||||
|
// included or omitted.
|
||||||
|
func TestDecodeHopPayloadRecordValidation(t *testing.T) {
|
||||||
|
for _, test := range decodePayloadTests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
testDecodeHopPayloadValidation(t, test)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) {
|
||||||
|
_, err := hop.NewPayloadFromReader(bytes.NewReader(test.payload))
|
||||||
|
if !reflect.DeepEqual(test.expErr, err) {
|
||||||
|
t.Fatalf("expected error mismatch, want: %v, got: %v",
|
||||||
|
test.expErr, err)
|
||||||
|
}
|
||||||
|
}
|
@ -12,6 +12,9 @@ import (
|
|||||||
// Type is an 64-bit identifier for a TLV Record.
|
// Type is an 64-bit identifier for a TLV Record.
|
||||||
type Type uint64
|
type Type uint64
|
||||||
|
|
||||||
|
// TypeSet is an unordered set of Types.
|
||||||
|
type TypeSet map[Type]struct{}
|
||||||
|
|
||||||
// Encoder is a signature for methods that can encode TLV values. An error
|
// Encoder is a signature for methods that can encode TLV values. An error
|
||||||
// should be returned if the Encoder cannot support the underlying type of val.
|
// should be returned if the Encoder cannot support the underlying type of val.
|
||||||
// The provided scratch buffer must be non-nil.
|
// The provided scratch buffer must be non-nil.
|
||||||
|
@ -144,6 +144,21 @@ func (s *Stream) Encode(w io.Writer) error {
|
|||||||
// the last record was read cleanly and we should stop parsing. All other io.EOF
|
// the last record was read cleanly and we should stop parsing. All other io.EOF
|
||||||
// or io.ErrUnexpectedEOF errors are returned.
|
// or io.ErrUnexpectedEOF errors are returned.
|
||||||
func (s *Stream) Decode(r io.Reader) error {
|
func (s *Stream) Decode(r io.Reader) error {
|
||||||
|
_, err := s.decode(r, nil)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeWithParsedTypes is identical to Decode, but if successful, returns a
|
||||||
|
// TypeSet containing the types of all records that were decoded or ignored from
|
||||||
|
// the stream.
|
||||||
|
func (s *Stream) DecodeWithParsedTypes(r io.Reader) (TypeSet, error) {
|
||||||
|
return s.decode(r, make(TypeSet))
|
||||||
|
}
|
||||||
|
|
||||||
|
// decode is a helper function that performs the basis of stream decoding. If
|
||||||
|
// the caller needs the set of parsed types, it must provide an initialized
|
||||||
|
// parsedTypes, otherwise the returned TypeSet will be nil.
|
||||||
|
func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
|
||||||
var (
|
var (
|
||||||
typ Type
|
typ Type
|
||||||
min Type
|
min Type
|
||||||
@ -161,11 +176,11 @@ func (s *Stream) Decode(r io.Reader) error {
|
|||||||
// We'll silence an EOF when zero bytes remain, meaning the
|
// We'll silence an EOF when zero bytes remain, meaning the
|
||||||
// stream was cleanly encoded.
|
// stream was cleanly encoded.
|
||||||
case err == io.EOF:
|
case err == io.EOF:
|
||||||
return nil
|
return parsedTypes, nil
|
||||||
|
|
||||||
// Other unexpected errors.
|
// Other unexpected errors.
|
||||||
case err != nil:
|
case err != nil:
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
typ = Type(t)
|
typ = Type(t)
|
||||||
@ -176,7 +191,7 @@ func (s *Stream) Decode(r io.Reader) error {
|
|||||||
// encodings that have duplicate records or from accepting an
|
// encodings that have duplicate records or from accepting an
|
||||||
// unsorted series.
|
// unsorted series.
|
||||||
if overflow || typ < min {
|
if overflow || typ < min {
|
||||||
return ErrStreamNotCanonical
|
return nil, ErrStreamNotCanonical
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read the varint length.
|
// Read the varint length.
|
||||||
@ -186,11 +201,11 @@ func (s *Stream) Decode(r io.Reader) error {
|
|||||||
// We'll convert any EOFs to ErrUnexpectedEOF, since this
|
// We'll convert any EOFs to ErrUnexpectedEOF, since this
|
||||||
// results in an invalid record.
|
// results in an invalid record.
|
||||||
case err == io.EOF:
|
case err == io.EOF:
|
||||||
return io.ErrUnexpectedEOF
|
return nil, io.ErrUnexpectedEOF
|
||||||
|
|
||||||
// Other unexpected errors.
|
// Other unexpected errors.
|
||||||
case err != nil:
|
case err != nil:
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Place a soft limit on the size of a sane record, which
|
// Place a soft limit on the size of a sane record, which
|
||||||
@ -198,7 +213,7 @@ func (s *Stream) Decode(r io.Reader) error {
|
|||||||
// unbounded amount of memory when decoding variable-sized
|
// unbounded amount of memory when decoding variable-sized
|
||||||
// fields.
|
// fields.
|
||||||
if length > MaxRecordSize {
|
if length > MaxRecordSize {
|
||||||
return ErrRecordTooLarge
|
return nil, ErrRecordTooLarge
|
||||||
}
|
}
|
||||||
|
|
||||||
// Search the records known to the stream for this type. We'll
|
// Search the records known to the stream for this type. We'll
|
||||||
@ -218,17 +233,17 @@ func (s *Stream) Decode(r io.Reader) error {
|
|||||||
// We'll convert any EOFs to ErrUnexpectedEOF, since this
|
// We'll convert any EOFs to ErrUnexpectedEOF, since this
|
||||||
// results in an invalid record.
|
// results in an invalid record.
|
||||||
case err == io.EOF:
|
case err == io.EOF:
|
||||||
return io.ErrUnexpectedEOF
|
return nil, io.ErrUnexpectedEOF
|
||||||
|
|
||||||
// Other unexpected errors.
|
// Other unexpected errors.
|
||||||
case err != nil:
|
case err != nil:
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// This record type is unknown to the stream, fail if the type
|
// This record type is unknown to the stream, fail if the type
|
||||||
// is even meaning that we are required to understand it.
|
// is even meaning that we are required to understand it.
|
||||||
case typ%2 == 0:
|
case typ%2 == 0:
|
||||||
return ErrUnknownRequiredType(typ)
|
return nil, ErrUnknownRequiredType(typ)
|
||||||
|
|
||||||
// Otherwise, the record type is unknown and is odd, discard the
|
// Otherwise, the record type is unknown and is odd, discard the
|
||||||
// number of bytes specified by length.
|
// number of bytes specified by length.
|
||||||
@ -239,14 +254,20 @@ func (s *Stream) Decode(r io.Reader) error {
|
|||||||
// We'll convert any EOFs to ErrUnexpectedEOF, since this
|
// We'll convert any EOFs to ErrUnexpectedEOF, since this
|
||||||
// results in an invalid record.
|
// results in an invalid record.
|
||||||
case err == io.EOF:
|
case err == io.EOF:
|
||||||
return io.ErrUnexpectedEOF
|
return nil, io.ErrUnexpectedEOF
|
||||||
|
|
||||||
// Other unexpected errors.
|
// Other unexpected errors.
|
||||||
case err != nil:
|
case err != nil:
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Record the successfully decoded or ignored type if the
|
||||||
|
// caller provided an initialized TypeSet.
|
||||||
|
if parsedTypes != nil {
|
||||||
|
parsedTypes[typ] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
// Update our record index so that we can begin our next search
|
// Update our record index so that we can begin our next search
|
||||||
// from where we left off.
|
// from where we left off.
|
||||||
recordIdx = newIdx
|
recordIdx = newIdx
|
||||||
|
51
tlv/stream_test.go
Normal file
51
tlv/stream_test.go
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
package tlv_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/tlv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestParsedTypes asserts that a Stream will properly return the set of types
|
||||||
|
// that it encounters when the type is known-and-decoded or unknown-and-ignored.
|
||||||
|
func TestParsedTypes(t *testing.T) {
|
||||||
|
const (
|
||||||
|
knownType = 1
|
||||||
|
unknownType = 3
|
||||||
|
)
|
||||||
|
|
||||||
|
// Construct a stream that will encode two types, one that will be known
|
||||||
|
// to the decoder and another that will be unknown.
|
||||||
|
encStream := tlv.MustNewStream(
|
||||||
|
tlv.MakePrimitiveRecord(knownType, new(uint64)),
|
||||||
|
tlv.MakePrimitiveRecord(unknownType, new(uint64)),
|
||||||
|
)
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := encStream.Encode(&b); err != nil {
|
||||||
|
t.Fatalf("unable to encode stream: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a stream that will parse only the known type.
|
||||||
|
decStream := tlv.MustNewStream(
|
||||||
|
tlv.MakePrimitiveRecord(knownType, new(uint64)),
|
||||||
|
)
|
||||||
|
|
||||||
|
parsedTypes, err := decStream.DecodeWithParsedTypes(
|
||||||
|
bytes.NewReader(b.Bytes()),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to decode stream: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert that both the known and unknown types are included in the set
|
||||||
|
// of parsed types.
|
||||||
|
if _, ok := parsedTypes[knownType]; !ok {
|
||||||
|
t.Fatalf("known type %d should be in parsed types", knownType)
|
||||||
|
}
|
||||||
|
if _, ok := parsedTypes[unknownType]; !ok {
|
||||||
|
t.Fatalf("unknown type %d should be in parsed types",
|
||||||
|
unknownType)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user