diff --git a/htlcswitch/hop/payload.go b/htlcswitch/hop/payload.go index e05b4ee9..0304e2c7 100644 --- a/htlcswitch/hop/payload.go +++ b/htlcswitch/hop/payload.go @@ -124,28 +124,6 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { parsedTypes, err := tlvStream.DecodeWithParsedTypes(r) if err != nil { - // Promote any required type failures into ErrInvalidPayload. - if e, required := err.(tlv.ErrUnknownRequiredType); required { - // If the parser returned an unknown required type - // failure, we'll first check that the payload is - // properly formed according to our known set of - // constraints. If an error is discovered, this - // overrides the required type failure. - nextHop := lnwire.NewShortChanIDFromInt(cid) - err = ValidateParsedPayloadTypes(parsedTypes, nextHop) - if err != nil { - return nil, err - } - - // Otherwise the known constraints were applied - // successfully, report the invalid type failure - // returned by the parser. - return nil, ErrInvalidPayload{ - Type: tlv.Type(e), - Violation: RequiredViolation, - FinalHop: nextHop == Exit, - } - } return nil, err } @@ -157,6 +135,16 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { return nil, err } + // Check for violation of the rules for mandatory fields. + violatingType := getMinRequiredViolation(parsedTypes) + if violatingType != nil { + return nil, ErrInvalidPayload{ + Type: *violatingType, + Violation: RequiredViolation, + FinalHop: nextHop == Exit, + } + } + // If no MPP field was parsed, set the MPP field on the resulting // payload to nil. if _, ok := parsedTypes[record.MPPOnionType]; !ok { @@ -239,3 +227,32 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet, func (h *Payload) MultiPath() *record.MPP { return h.MPP } + +// getMinRequiredViolation checks for unrecognized required (even) fields in the +// standard range and returns the lowest required type. Always returning the +// lowest required type allows a failure message to be deterministic. +func getMinRequiredViolation(set tlv.TypeSet) *tlv.Type { + var ( + requiredViolation bool + minRequiredViolationType tlv.Type + ) + for t, known := range set { + // If a type is even but not known to us, we cannot process the + // payload. We are required to understand a field that we don't + // support. + if known || t%2 != 0 { + continue + } + + if !requiredViolation || t < minRequiredViolationType { + minRequiredViolationType = t + } + requiredViolation = true + } + + if requiredViolation { + return &minRequiredViolationType + } + + return nil +} diff --git a/tlv/record.go b/tlv/record.go index fe774263..b9e7980d 100644 --- a/tlv/record.go +++ b/tlv/record.go @@ -12,8 +12,9 @@ import ( // Type is an 64-bit identifier for a TLV Record. type Type uint64 -// TypeSet is an unordered set of Types. -type TypeSet map[Type]struct{} +// TypeSet is an unordered set of Types. The map item boolean values indicate +// whether the type that we parsed was known. +type TypeSet map[Type]bool // Encoder is a signature for methods that can encode TLV values. An error // should be returned if the Encoder cannot support the underlying type of val. diff --git a/tlv/stream.go b/tlv/stream.go index d104a206..ed2c0d00 100644 --- a/tlv/stream.go +++ b/tlv/stream.go @@ -2,7 +2,6 @@ package tlv import ( "errors" - "fmt" "io" "io/ioutil" "math" @@ -22,15 +21,6 @@ var ErrStreamNotCanonical = errors.New("tlv stream is not canonical") // long to parse. var ErrRecordTooLarge = errors.New("record is too large") -// ErrUnknownRequiredType is an error returned when decoding an unknown and even -// type from a Stream. -type ErrUnknownRequiredType Type - -// Error returns a human-readable description of unknown required type. -func (t ErrUnknownRequiredType) Error() string { - return fmt.Sprintf("unknown required type: %d", t) -} - // Stream defines a TLV stream that can be used for encoding or decoding a set // of TLV Records. type Stream struct { @@ -162,7 +152,6 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) { var ( typ Type min Type - firstFail *Type recordIdx int overflow bool ) @@ -177,10 +166,7 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) { // We'll silence an EOF when zero bytes remain, meaning the // stream was cleanly encoded. case err == io.EOF: - if firstFail == nil { - return parsedTypes, nil - } - return parsedTypes, ErrUnknownRequiredType(*firstFail) + return parsedTypes, nil // Other unexpected errors. case err != nil: @@ -244,31 +230,6 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) { return nil, err } - // This record type is unknown to the stream, fail if the type - // is even meaning that we are required to understand it. - case typ%2 == 0: - // We'll fail immediately in the case that we aren't - // tracking the set of parsed types. - if parsedTypes == nil { - return nil, ErrUnknownRequiredType(typ) - } - - // Otherwise, we'll track the first such failure and - // allow parsing to continue. If no other types of - // errors are encountered, the first failure will be - // returned as an ErrUnknownRequiredType so that the - // full set of included types can be returned. - if firstFail == nil { - failTyp := typ - firstFail = &failTyp - } - - // With the failure type recorded, we'll simply discard - // the remainder of the record as if it were optional. - // The first failure will be returned after reaching the - // stopping condition. - fallthrough - // Otherwise, the record type is unknown and is odd, discard the // number of bytes specified by length. default: @@ -289,7 +250,7 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) { // Record the successfully decoded or ignored type if the // caller provided an initialized TypeSet. if parsedTypes != nil { - parsedTypes[typ] = struct{}{} + parsedTypes[typ] = ok } // Update our record index so that we can begin our next search diff --git a/tlv/stream_test.go b/tlv/stream_test.go index f7d54f73..8e4a33b7 100644 --- a/tlv/stream_test.go +++ b/tlv/stream_test.go @@ -9,49 +9,38 @@ import ( ) type parsedTypeTest struct { - name string - encode []tlv.Type - decode []tlv.Type - expErr error + name string + encode []tlv.Type + decode []tlv.Type + expParsedTypes tlv.TypeSet } // TestParsedTypes asserts that a Stream will properly return the set of types // that it encounters when the type is known-and-decoded or unknown-and-ignored. func TestParsedTypes(t *testing.T) { const ( - firstReqType = 0 - knownType = 1 - unknownType = 3 - secondReqType = 4 + knownType = 1 + unknownType = 3 + secondKnownType = 4 ) tests := []parsedTypeTest{ { - name: "known optional and unknown optional", + name: "known and unknown", encode: []tlv.Type{knownType, unknownType}, decode: []tlv.Type{knownType}, + expParsedTypes: tlv.TypeSet{ + unknownType: false, + knownType: true, + }, }, { - name: "unknown required and known optional", - encode: []tlv.Type{firstReqType, knownType}, - decode: []tlv.Type{knownType}, - expErr: tlv.ErrUnknownRequiredType(firstReqType), - }, - { - name: "unknown required and unknown optional", - encode: []tlv.Type{unknownType, secondReqType}, - expErr: tlv.ErrUnknownRequiredType(secondReqType), - }, - { - name: "unknown required and known required", - encode: []tlv.Type{firstReqType, secondReqType}, - decode: []tlv.Type{secondReqType}, - expErr: tlv.ErrUnknownRequiredType(firstReqType), - }, - { - name: "two unknown required", - encode: []tlv.Type{firstReqType, secondReqType}, - expErr: tlv.ErrUnknownRequiredType(firstReqType), + name: "known and missing known", + encode: []tlv.Type{knownType}, + decode: []tlv.Type{knownType, secondKnownType}, + expParsedTypes: tlv.TypeSet{ + knownType: true, + }, }, } @@ -92,16 +81,10 @@ func testParsedTypes(t *testing.T, test parsedTypeTest) { parsedTypes, err := decStream.DecodeWithParsedTypes( bytes.NewReader(b.Bytes()), ) - if !reflect.DeepEqual(err, test.expErr) { - t.Fatalf("error mismatch, want: %v got: %v", err, test.expErr) + if err != nil { + t.Fatalf("error decoding: %v", err) } - - // Assert that all encoded types are included in the set of parsed - // types. - for _, typ := range test.encode { - if _, ok := parsedTypes[typ]; !ok { - t.Fatalf("encoded type %d should be in parsed types", - typ) - } + if !reflect.DeepEqual(parsedTypes, test.expParsedTypes) { + t.Fatalf("error mismatch on parsed types") } } diff --git a/tlv/tlv_test.go b/tlv/tlv_test.go index 13ef2468..e5d47341 100644 --- a/tlv/tlv_test.go +++ b/tlv/tlv_test.go @@ -203,26 +203,6 @@ var tlvDecodingFailureTests = []struct { }, expErr: io.ErrUnexpectedEOF, }, - { - name: "unknown even type", - bytes: []byte{0x12, 0x00}, - expErr: tlv.ErrUnknownRequiredType(0x12), - }, - { - name: "unknown even type", - bytes: []byte{0xfd, 0x01, 0x02, 0x00}, - expErr: tlv.ErrUnknownRequiredType(0x102), - }, - { - name: "unknown even type", - bytes: []byte{0xfe, 0x01, 0x00, 0x00, 0x02, 0x00}, - expErr: tlv.ErrUnknownRequiredType(0x01000002), - }, - { - name: "unknown even type", - bytes: []byte{0xff, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00}, - expErr: tlv.ErrUnknownRequiredType(0x0100000000000002), - }, { name: "greater than encoding length for n1's amt", bytes: []byte{0x01, 0x09, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, @@ -340,12 +320,6 @@ var tlvDecodingFailureTests = []struct { expErr: tlv.NewTypeForDecodingErr(new(nodeAmts), "nodeAmts", 50, 49), skipN2: true, }, - { - name: "unknown required type or n1", - bytes: []byte{0x00, 0x00}, - expErr: tlv.ErrUnknownRequiredType(0x00), - skipN2: true, - }, { name: "less than encoding length for n1's cltvDelta", bytes: []byte{0xfd, 0x00, 0x0fe, 0x00}, @@ -364,12 +338,6 @@ var tlvDecodingFailureTests = []struct { expErr: tlv.NewTypeForDecodingErr(new(uint16), "uint16", 3, 2), skipN2: true, }, - { - name: "unknown even field for n1's namespace", - bytes: []byte{0x0a, 0x00}, - expErr: tlv.ErrUnknownRequiredType(0x0a), - skipN2: true, - }, { name: "valid records but invalid ordering", bytes: []byte{0x02, 0x08,