diff --git a/htlcswitch/hop/payload.go b/htlcswitch/hop/payload.go index 29784ab8..afc31308 100644 --- a/htlcswitch/hop/payload.go +++ b/htlcswitch/hop/payload.go @@ -178,7 +178,7 @@ func (h *Payload) ForwardingInfo() ForwardingInfo { // ensure that the proper fields are either included or omitted. The finalHop // boolean should be true if the payload was parsed for an exit hop. The // requirements for this method are described in BOLT 04. -func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet, +func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap, nextHop lnwire.ShortChannelID) error { isFinalHop := nextHop == Exit @@ -237,19 +237,19 @@ func (h *Payload) MultiPath() *record.MPP { // getMinRequiredViolation checks for unrecognized required (even) fields in the // standard range and returns the lowest required type. Always returning the // lowest required type allows a failure message to be deterministic. -func getMinRequiredViolation(set tlv.TypeSet) *tlv.Type { +func getMinRequiredViolation(set tlv.TypeMap) *tlv.Type { var ( requiredViolation bool minRequiredViolationType tlv.Type ) - for t, known := range set { + for t, parseResult := range set { // If a type is even but not known to us, we cannot process the // payload. We are required to understand a field that we don't // support. // // We always accept custom fields, because a higher level // application may understand them. - if known || t%2 != 0 || t >= CustomTypeStart { + if parseResult == nil || t%2 != 0 || t >= CustomTypeStart { continue } diff --git a/tlv/record.go b/tlv/record.go index 6159412c..38070956 100644 --- a/tlv/record.go +++ b/tlv/record.go @@ -12,9 +12,10 @@ import ( // Type is an 64-bit identifier for a TLV Record. type Type uint64 -// TypeSet is an unordered set of Types. The map item boolean values indicate -// whether the type that we parsed was known. -type TypeSet map[Type]bool +// TypeMap is a map of parsed Types. The map values are byte slices. If the byte +// slice is nil, the type was successfully parsed. Otherwise the value is byte +// slice containing the encoded data. +type TypeMap map[Type][]byte // Encoder is a signature for methods that can encode TLV values. An error // should be returned if the Encoder cannot support the underlying type of val. diff --git a/tlv/stream.go b/tlv/stream.go index ed2c0d00..4a8eb722 100644 --- a/tlv/stream.go +++ b/tlv/stream.go @@ -1,6 +1,7 @@ package tlv import ( + "bytes" "errors" "io" "io/ioutil" @@ -139,16 +140,16 @@ func (s *Stream) Decode(r io.Reader) error { } // DecodeWithParsedTypes is identical to Decode, but if successful, returns a -// TypeSet containing the types of all records that were decoded or ignored from +// TypeMap containing the types of all records that were decoded or ignored from // the stream. -func (s *Stream) DecodeWithParsedTypes(r io.Reader) (TypeSet, error) { - return s.decode(r, make(TypeSet)) +func (s *Stream) DecodeWithParsedTypes(r io.Reader) (TypeMap, error) { + return s.decode(r, make(TypeMap)) } // decode is a helper function that performs the basis of stream decoding. If // the caller needs the set of parsed types, it must provide an initialized -// parsedTypes, otherwise the returned TypeSet will be nil. -func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) { +// parsedTypes, otherwise the returned TypeMap will be nil. +func (s *Stream) decode(r io.Reader, parsedTypes TypeMap) (TypeMap, error) { var ( typ Type min Type @@ -230,10 +231,25 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) { return nil, err } + // Record the successfully decoded type if the caller + // provided an initialized TypeMap. + if parsedTypes != nil { + parsedTypes[typ] = nil + } + // Otherwise, the record type is unknown and is odd, discard the // number of bytes specified by length. default: - _, err := io.CopyN(ioutil.Discard, r, int64(length)) + // If the caller provided an initialized TypeMap, record + // the encoded bytes. + var b *bytes.Buffer + writer := ioutil.Discard + if parsedTypes != nil { + b = bytes.NewBuffer(make([]byte, 0, length)) + writer = b + } + + _, err := io.CopyN(writer, r, int64(length)) switch { // We'll convert any EOFs to ErrUnexpectedEOF, since this @@ -245,12 +261,10 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) { case err != nil: return nil, err } - } - // Record the successfully decoded or ignored type if the - // caller provided an initialized TypeSet. - if parsedTypes != nil { - parsedTypes[typ] = ok + if parsedTypes != nil { + parsedTypes[typ] = b.Bytes() + } } // Update our record index so that we can begin our next search diff --git a/tlv/stream_test.go b/tlv/stream_test.go index 8e4a33b7..8f67a316 100644 --- a/tlv/stream_test.go +++ b/tlv/stream_test.go @@ -12,7 +12,7 @@ type parsedTypeTest struct { name string encode []tlv.Type decode []tlv.Type - expParsedTypes tlv.TypeSet + expParsedTypes tlv.TypeMap } // TestParsedTypes asserts that a Stream will properly return the set of types @@ -29,17 +29,17 @@ func TestParsedTypes(t *testing.T) { name: "known and unknown", encode: []tlv.Type{knownType, unknownType}, decode: []tlv.Type{knownType}, - expParsedTypes: tlv.TypeSet{ - unknownType: false, - knownType: true, + expParsedTypes: tlv.TypeMap{ + unknownType: []byte{0, 0, 0, 0, 0, 0, 0, 0}, + knownType: nil, }, }, { name: "known and missing known", encode: []tlv.Type{knownType}, decode: []tlv.Type{knownType, secondKnownType}, - expParsedTypes: tlv.TypeSet{ - knownType: true, + expParsedTypes: tlv.TypeMap{ + knownType: nil, }, }, }