Browse Source

tlv+hop: contain odd/even logic in payload parsing

Tlv is used more widely in lnd than just for the onion payload. This
commit isolated the protocol-specific odd/even logic, so that tlv can be
used freely elsewhere. An example of this use is db serialization.
master
Joost Jager 5 years ago
parent
commit
048971b40b
No known key found for this signature in database
GPG Key ID: A61B9D4C393C59C7
  1. 61
      htlcswitch/hop/payload.go
  2. 5
      tlv/record.go
  3. 43
      tlv/stream.go
  4. 61
      tlv/stream_test.go
  5. 32
      tlv/tlv_test.go

61
htlcswitch/hop/payload.go

@ -124,28 +124,6 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
parsedTypes, err := tlvStream.DecodeWithParsedTypes(r)
if err != nil {
// Promote any required type failures into ErrInvalidPayload.
if e, required := err.(tlv.ErrUnknownRequiredType); required {
// If the parser returned an unknown required type
// failure, we'll first check that the payload is
// properly formed according to our known set of
// constraints. If an error is discovered, this
// overrides the required type failure.
nextHop := lnwire.NewShortChanIDFromInt(cid)
err = ValidateParsedPayloadTypes(parsedTypes, nextHop)
if err != nil {
return nil, err
}
// Otherwise the known constraints were applied
// successfully, report the invalid type failure
// returned by the parser.
return nil, ErrInvalidPayload{
Type: tlv.Type(e),
Violation: RequiredViolation,
FinalHop: nextHop == Exit,
}
}
return nil, err
}
@ -157,6 +135,16 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
return nil, err
}
// Check for violation of the rules for mandatory fields.
violatingType := getMinRequiredViolation(parsedTypes)
if violatingType != nil {
return nil, ErrInvalidPayload{
Type: *violatingType,
Violation: RequiredViolation,
FinalHop: nextHop == Exit,
}
}
// If no MPP field was parsed, set the MPP field on the resulting
// payload to nil.
if _, ok := parsedTypes[record.MPPOnionType]; !ok {
@ -239,3 +227,32 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet,
func (h *Payload) MultiPath() *record.MPP {
return h.MPP
}
// getMinRequiredViolation checks for unrecognized required (even) fields in the
// standard range and returns the lowest required type. Always returning the
// lowest required type allows a failure message to be deterministic.
func getMinRequiredViolation(set tlv.TypeSet) *tlv.Type {
var (
requiredViolation bool
minRequiredViolationType tlv.Type
)
for t, known := range set {
// If a type is even but not known to us, we cannot process the
// payload. We are required to understand a field that we don't
// support.
if known || t%2 != 0 {
continue
}
if !requiredViolation || t < minRequiredViolationType {
minRequiredViolationType = t
}
requiredViolation = true
}
if requiredViolation {
return &minRequiredViolationType
}
return nil
}

5
tlv/record.go

@ -12,8 +12,9 @@ import (
// Type is an 64-bit identifier for a TLV Record.
type Type uint64
// TypeSet is an unordered set of Types.
type TypeSet map[Type]struct{}
// TypeSet is an unordered set of Types. The map item boolean values indicate
// whether the type that we parsed was known.
type TypeSet map[Type]bool
// Encoder is a signature for methods that can encode TLV values. An error
// should be returned if the Encoder cannot support the underlying type of val.

43
tlv/stream.go

@ -2,7 +2,6 @@ package tlv
import (
"errors"
"fmt"
"io"
"io/ioutil"
"math"
@ -22,15 +21,6 @@ var ErrStreamNotCanonical = errors.New("tlv stream is not canonical")
// long to parse.
var ErrRecordTooLarge = errors.New("record is too large")
// ErrUnknownRequiredType is an error returned when decoding an unknown and even
// type from a Stream.
type ErrUnknownRequiredType Type
// Error returns a human-readable description of unknown required type.
func (t ErrUnknownRequiredType) Error() string {
return fmt.Sprintf("unknown required type: %d", t)
}
// Stream defines a TLV stream that can be used for encoding or decoding a set
// of TLV Records.
type Stream struct {
@ -162,7 +152,6 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
var (
typ Type
min Type
firstFail *Type
recordIdx int
overflow bool
)
@ -177,10 +166,7 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
// We'll silence an EOF when zero bytes remain, meaning the
// stream was cleanly encoded.
case err == io.EOF:
if firstFail == nil {
return parsedTypes, nil
}
return parsedTypes, ErrUnknownRequiredType(*firstFail)
return parsedTypes, nil
// Other unexpected errors.
case err != nil:
@ -244,31 +230,6 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
return nil, err
}
// This record type is unknown to the stream, fail if the type
// is even meaning that we are required to understand it.
case typ%2 == 0:
// We'll fail immediately in the case that we aren't
// tracking the set of parsed types.
if parsedTypes == nil {
return nil, ErrUnknownRequiredType(typ)
}
// Otherwise, we'll track the first such failure and
// allow parsing to continue. If no other types of
// errors are encountered, the first failure will be
// returned as an ErrUnknownRequiredType so that the
// full set of included types can be returned.
if firstFail == nil {
failTyp := typ
firstFail = &failTyp
}
// With the failure type recorded, we'll simply discard
// the remainder of the record as if it were optional.
// The first failure will be returned after reaching the
// stopping condition.
fallthrough
// Otherwise, the record type is unknown and is odd, discard the
// number of bytes specified by length.
default:
@ -289,7 +250,7 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
// Record the successfully decoded or ignored type if the
// caller provided an initialized TypeSet.
if parsedTypes != nil {
parsedTypes[typ] = struct{}{}
parsedTypes[typ] = ok
}
// Update our record index so that we can begin our next search

61
tlv/stream_test.go

@ -9,49 +9,38 @@ import (
)
type parsedTypeTest struct {
name string
encode []tlv.Type
decode []tlv.Type
expErr error
name string
encode []tlv.Type
decode []tlv.Type
expParsedTypes tlv.TypeSet
}
// TestParsedTypes asserts that a Stream will properly return the set of types
// that it encounters when the type is known-and-decoded or unknown-and-ignored.
func TestParsedTypes(t *testing.T) {
const (
firstReqType = 0
knownType = 1
unknownType = 3
secondReqType = 4
knownType = 1
unknownType = 3
secondKnownType = 4
)
tests := []parsedTypeTest{
{
name: "known optional and unknown optional",
name: "known and unknown",
encode: []tlv.Type{knownType, unknownType},
decode: []tlv.Type{knownType},
expParsedTypes: tlv.TypeSet{
unknownType: false,
knownType: true,
},
},
{
name: "unknown required and known optional",
encode: []tlv.Type{firstReqType, knownType},
decode: []tlv.Type{knownType},
expErr: tlv.ErrUnknownRequiredType(firstReqType),
},
{
name: "unknown required and unknown optional",
encode: []tlv.Type{unknownType, secondReqType},
expErr: tlv.ErrUnknownRequiredType(secondReqType),
},
{
name: "unknown required and known required",
encode: []tlv.Type{firstReqType, secondReqType},
decode: []tlv.Type{secondReqType},
expErr: tlv.ErrUnknownRequiredType(firstReqType),
},
{
name: "two unknown required",
encode: []tlv.Type{firstReqType, secondReqType},
expErr: tlv.ErrUnknownRequiredType(firstReqType),
name: "known and missing known",
encode: []tlv.Type{knownType},
decode: []tlv.Type{knownType, secondKnownType},
expParsedTypes: tlv.TypeSet{
knownType: true,
},
},
}
@ -92,16 +81,10 @@ func testParsedTypes(t *testing.T, test parsedTypeTest) {
parsedTypes, err := decStream.DecodeWithParsedTypes(
bytes.NewReader(b.Bytes()),
)
if !reflect.DeepEqual(err, test.expErr) {
t.Fatalf("error mismatch, want: %v got: %v", err, test.expErr)
if err != nil {
t.Fatalf("error decoding: %v", err)
}
// Assert that all encoded types are included in the set of parsed
// types.
for _, typ := range test.encode {
if _, ok := parsedTypes[typ]; !ok {
t.Fatalf("encoded type %d should be in parsed types",
typ)
}
if !reflect.DeepEqual(parsedTypes, test.expParsedTypes) {
t.Fatalf("error mismatch on parsed types")
}
}

32
tlv/tlv_test.go

@ -203,26 +203,6 @@ var tlvDecodingFailureTests = []struct {
},
expErr: io.ErrUnexpectedEOF,
},
{
name: "unknown even type",
bytes: []byte{0x12, 0x00},
expErr: tlv.ErrUnknownRequiredType(0x12),
},
{
name: "unknown even type",
bytes: []byte{0xfd, 0x01, 0x02, 0x00},
expErr: tlv.ErrUnknownRequiredType(0x102),
},
{
name: "unknown even type",
bytes: []byte{0xfe, 0x01, 0x00, 0x00, 0x02, 0x00},
expErr: tlv.ErrUnknownRequiredType(0x01000002),
},
{
name: "unknown even type",
bytes: []byte{0xff, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00},
expErr: tlv.ErrUnknownRequiredType(0x0100000000000002),
},
{
name: "greater than encoding length for n1's amt",
bytes: []byte{0x01, 0x09, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
@ -340,12 +320,6 @@ var tlvDecodingFailureTests = []struct {
expErr: tlv.NewTypeForDecodingErr(new(nodeAmts), "nodeAmts", 50, 49),
skipN2: true,
},
{
name: "unknown required type or n1",
bytes: []byte{0x00, 0x00},
expErr: tlv.ErrUnknownRequiredType(0x00),
skipN2: true,
},
{
name: "less than encoding length for n1's cltvDelta",
bytes: []byte{0xfd, 0x00, 0x0fe, 0x00},
@ -364,12 +338,6 @@ var tlvDecodingFailureTests = []struct {
expErr: tlv.NewTypeForDecodingErr(new(uint16), "uint16", 3, 2),
skipN2: true,
},
{
name: "unknown even field for n1's namespace",
bytes: []byte{0x0a, 0x00},
expErr: tlv.ErrUnknownRequiredType(0x0a),
skipN2: true,
},
{
name: "valid records but invalid ordering",
bytes: []byte{0x02, 0x08,

Loading…
Cancel
Save