From aefec9b10faf8ba9ee858ab568d7a17339ae4026 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Wed, 4 Sep 2019 11:46:28 -0700 Subject: [PATCH 1/2] tlv: return parsed types from DecodeWithParsedTypes This commit adds an additional return value to Stream.DecodeWithParsedTypes, which returns the set of types that were encountered during decoding. The set will contain all known types that were decoded, as well as unknown odd types that were ignored. The rationale for the return value (rather than an internal member) is so that the stream remains stateless. This return value can be used by callers during decoding to make assertions as to whether specific types were included in the stream. This is need, for example, when parsing onion payloads where certain fields must be included/omitted depending on the hop type. The original Decode method would incur the additional performance hit of needing to track the parsed types, so we can selectively enable this functionality when a decoder requires it by using a helper which conditionally tracks the parsed types. --- htlcswitch/hop/payload.go | 2 +- tlv/record.go | 3 +++ tlv/stream.go | 43 ++++++++++++++++++++++++--------- tlv/stream_test.go | 51 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 87 insertions(+), 12 deletions(-) create mode 100644 tlv/stream_test.go diff --git a/htlcswitch/hop/payload.go b/htlcswitch/hop/payload.go index 1c7e563b..5577d76d 100644 --- a/htlcswitch/hop/payload.go +++ b/htlcswitch/hop/payload.go @@ -53,7 +53,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { return nil, err } - err = tlvStream.Decode(r) + _, err = tlvStream.DecodeWithParsedTypes(r) if err != nil { return nil, err } diff --git a/tlv/record.go b/tlv/record.go index 610ab6c1..75647e03 100644 --- a/tlv/record.go +++ b/tlv/record.go @@ -12,6 +12,9 @@ import ( // Type is an 64-bit identifier for a TLV Record. type Type uint64 +// TypeSet is an unordered set of Types. +type TypeSet map[Type]struct{} + // Encoder is a signature for methods that can encode TLV values. An error // should be returned if the Encoder cannot support the underlying type of val. // The provided scratch buffer must be non-nil. diff --git a/tlv/stream.go b/tlv/stream.go index 159301fd..60b40919 100644 --- a/tlv/stream.go +++ b/tlv/stream.go @@ -144,6 +144,21 @@ func (s *Stream) Encode(w io.Writer) error { // the last record was read cleanly and we should stop parsing. All other io.EOF // or io.ErrUnexpectedEOF errors are returned. func (s *Stream) Decode(r io.Reader) error { + _, err := s.decode(r, nil) + return err +} + +// DecodeWithParsedTypes is identical to Decode, but if successful, returns a +// TypeSet containing the types of all records that were decoded or ignored from +// the stream. +func (s *Stream) DecodeWithParsedTypes(r io.Reader) (TypeSet, error) { + return s.decode(r, make(TypeSet)) +} + +// decode is a helper function that performs the basis of stream decoding. If +// the caller needs the set of parsed types, it must provide an initialized +// parsedTypes, otherwise the returned TypeSet will be nil. +func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) { var ( typ Type min Type @@ -161,11 +176,11 @@ func (s *Stream) Decode(r io.Reader) error { // We'll silence an EOF when zero bytes remain, meaning the // stream was cleanly encoded. case err == io.EOF: - return nil + return parsedTypes, nil // Other unexpected errors. case err != nil: - return err + return nil, err } typ = Type(t) @@ -176,7 +191,7 @@ func (s *Stream) Decode(r io.Reader) error { // encodings that have duplicate records or from accepting an // unsorted series. if overflow || typ < min { - return ErrStreamNotCanonical + return nil, ErrStreamNotCanonical } // Read the varint length. @@ -186,11 +201,11 @@ func (s *Stream) Decode(r io.Reader) error { // We'll convert any EOFs to ErrUnexpectedEOF, since this // results in an invalid record. case err == io.EOF: - return io.ErrUnexpectedEOF + return nil, io.ErrUnexpectedEOF // Other unexpected errors. case err != nil: - return err + return nil, err } // Place a soft limit on the size of a sane record, which @@ -198,7 +213,7 @@ func (s *Stream) Decode(r io.Reader) error { // unbounded amount of memory when decoding variable-sized // fields. if length > MaxRecordSize { - return ErrRecordTooLarge + return nil, ErrRecordTooLarge } // Search the records known to the stream for this type. We'll @@ -218,17 +233,17 @@ func (s *Stream) Decode(r io.Reader) error { // We'll convert any EOFs to ErrUnexpectedEOF, since this // results in an invalid record. case err == io.EOF: - return io.ErrUnexpectedEOF + return nil, io.ErrUnexpectedEOF // Other unexpected errors. case err != nil: - return err + return nil, err } // This record type is unknown to the stream, fail if the type // is even meaning that we are required to understand it. case typ%2 == 0: - return ErrUnknownRequiredType(typ) + return nil, ErrUnknownRequiredType(typ) // Otherwise, the record type is unknown and is odd, discard the // number of bytes specified by length. @@ -239,14 +254,20 @@ func (s *Stream) Decode(r io.Reader) error { // We'll convert any EOFs to ErrUnexpectedEOF, since this // results in an invalid record. case err == io.EOF: - return io.ErrUnexpectedEOF + return nil, io.ErrUnexpectedEOF // Other unexpected errors. case err != nil: - return err + return nil, err } } + // Record the successfully decoded or ignored type if the + // caller provided an initialized TypeSet. + if parsedTypes != nil { + parsedTypes[typ] = struct{}{} + } + // Update our record index so that we can begin our next search // from where we left off. recordIdx = newIdx diff --git a/tlv/stream_test.go b/tlv/stream_test.go new file mode 100644 index 00000000..e9970e74 --- /dev/null +++ b/tlv/stream_test.go @@ -0,0 +1,51 @@ +package tlv_test + +import ( + "bytes" + "testing" + + "github.com/lightningnetwork/lnd/tlv" +) + +// TestParsedTypes asserts that a Stream will properly return the set of types +// that it encounters when the type is known-and-decoded or unknown-and-ignored. +func TestParsedTypes(t *testing.T) { + const ( + knownType = 1 + unknownType = 3 + ) + + // Construct a stream that will encode two types, one that will be known + // to the decoder and another that will be unknown. + encStream := tlv.MustNewStream( + tlv.MakePrimitiveRecord(knownType, new(uint64)), + tlv.MakePrimitiveRecord(unknownType, new(uint64)), + ) + + var b bytes.Buffer + if err := encStream.Encode(&b); err != nil { + t.Fatalf("unable to encode stream: %v", err) + } + + // Create a stream that will parse only the known type. + decStream := tlv.MustNewStream( + tlv.MakePrimitiveRecord(knownType, new(uint64)), + ) + + parsedTypes, err := decStream.DecodeWithParsedTypes( + bytes.NewReader(b.Bytes()), + ) + if err != nil { + t.Fatalf("unable to decode stream: %v", err) + } + + // Assert that both the known and unknown types are included in the set + // of parsed types. + if _, ok := parsedTypes[knownType]; !ok { + t.Fatalf("known type %d should be in parsed types", knownType) + } + if _, ok := parsedTypes[unknownType]; !ok { + t.Fatalf("unknown type %d should be in parsed types", + unknownType) + } +} From 60155679274091aacb2a4c78b10d4ea8e015d531 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 5 Sep 2019 06:05:38 -0700 Subject: [PATCH 2/2] htlcswitch/iterator: validate presence/omission of payload types From BOLT 04: The writer: - MUST include amt_to_forward and outgoing_cltv_value for every node. - MUST include short_channel_id for every non-final node. - MUST NOT include short_channel_id for the final node. --- htlcswitch/hop/payload.go | 90 +++++++++++++++++++++++++++++++++- htlcswitch/hop/payload_test.go | 89 +++++++++++++++++++++++++++++++++ 2 files changed, 177 insertions(+), 2 deletions(-) create mode 100644 htlcswitch/hop/payload_test.go diff --git a/htlcswitch/hop/payload.go b/htlcswitch/hop/payload.go index 5577d76d..2fd1cd86 100644 --- a/htlcswitch/hop/payload.go +++ b/htlcswitch/hop/payload.go @@ -2,6 +2,7 @@ package hop import ( "encoding/binary" + "fmt" "io" "github.com/lightningnetwork/lightning-onion" @@ -10,6 +11,37 @@ import ( "github.com/lightningnetwork/lnd/tlv" ) +// ErrInvalidPayload is an error returned when a parsed onion payload either +// included or omitted incorrect records for a particular hop type. +type ErrInvalidPayload struct { + // Type the record's type that cause the violation. + Type tlv.Type + + // Ommitted if true, signals that the sender did not include the record. + // Otherwise, the sender included the record when it shouldn't have. + Omitted bool + + // FinalHop if true, indicates that the violation is for the final hop + // in the route (identified by next hop id), otherwise the violation is + // for an intermediate hop. + FinalHop bool +} + +// Error returns a human-readable description of the invalid payload error. +func (e ErrInvalidPayload) Error() string { + hopType := "intermediate" + if e.FinalHop { + hopType = "final" + } + violation := "included" + if e.Omitted { + violation = "omitted" + } + + return fmt.Sprintf("onion payload for %s hop %s record with type %d", + hopType, violation, e.Type) +} + // Payload encapsulates all information delivered to a hop in an onion payload. // A Hop can represent either a TLV or legacy payload. The primary forwarding // instruction can be accessed via ForwardingInfo, and additional records can be @@ -53,7 +85,16 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { return nil, err } - _, err = tlvStream.DecodeWithParsedTypes(r) + parsedTypes, err := tlvStream.DecodeWithParsedTypes(r) + if err != nil { + return nil, err + } + + nextHop := lnwire.NewShortChanIDFromInt(cid) + + // Validate whether the sender properly included or omitted tlv records + // in accordance with BOLT 04. + err = ValidateParsedPayloadTypes(parsedTypes, nextHop) if err != nil { return nil, err } @@ -61,7 +102,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { return &Payload{ FwdInfo: ForwardingInfo{ Network: BitcoinNetwork, - NextHop: lnwire.NewShortChanIDFromInt(cid), + NextHop: nextHop, AmountToForward: lnwire.MilliSatoshi(amt), OutgoingCTLV: cltv, }, @@ -73,3 +114,48 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { func (h *Payload) ForwardingInfo() ForwardingInfo { return h.FwdInfo } + +// ValidateParsedPayloadTypes checks the types parsed from a hop payload to +// ensure that the proper fields are either included or omitted. The finalHop +// boolean should be true if the payload was parsed for an exit hop. The +// requirements for this method are described in BOLT 04. +func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet, + nextHop lnwire.ShortChannelID) error { + + isFinalHop := nextHop == Exit + + _, hasAmt := parsedTypes[record.AmtOnionType] + _, hasLockTime := parsedTypes[record.LockTimeOnionType] + _, hasNextHop := parsedTypes[record.NextHopOnionType] + + switch { + + // All hops must include an amount to forward. + case !hasAmt: + return ErrInvalidPayload{ + Type: record.AmtOnionType, + Omitted: true, + FinalHop: isFinalHop, + } + + // All hops must include a cltv expiry. + case !hasLockTime: + return ErrInvalidPayload{ + Type: record.LockTimeOnionType, + Omitted: true, + FinalHop: isFinalHop, + } + + // The exit hop should omit the next hop id. If nextHop != Exit, the + // sender must have included a record, so we don't need to test for its + // inclusion at intermediate hops directly. + case isFinalHop && hasNextHop: + return ErrInvalidPayload{ + Type: record.NextHopOnionType, + Omitted: false, + FinalHop: true, + } + } + + return nil +} diff --git a/htlcswitch/hop/payload_test.go b/htlcswitch/hop/payload_test.go new file mode 100644 index 00000000..0c9342d5 --- /dev/null +++ b/htlcswitch/hop/payload_test.go @@ -0,0 +1,89 @@ +package hop_test + +import ( + "bytes" + "reflect" + "testing" + + "github.com/lightningnetwork/lnd/htlcswitch/hop" + "github.com/lightningnetwork/lnd/record" +) + +type decodePayloadTest struct { + name string + payload []byte + expErr error +} + +var decodePayloadTests = []decodePayloadTest{ + { + name: "final hop no amount", + payload: []byte{0x04, 0x00}, + expErr: hop.ErrInvalidPayload{ + Type: record.AmtOnionType, + Omitted: true, + FinalHop: true, + }, + }, + { + name: "intermediate hop no amount", + payload: []byte{0x04, 0x00, 0x06, 0x08, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + }, + expErr: hop.ErrInvalidPayload{ + Type: record.AmtOnionType, + Omitted: true, + FinalHop: false, + }, + }, + { + name: "final hop no expiry", + payload: []byte{0x02, 0x00}, + expErr: hop.ErrInvalidPayload{ + Type: record.LockTimeOnionType, + Omitted: true, + FinalHop: true, + }, + }, + { + name: "intermediate hop no expiry", + payload: []byte{0x02, 0x00, 0x06, 0x08, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + }, + expErr: hop.ErrInvalidPayload{ + Type: record.LockTimeOnionType, + Omitted: true, + FinalHop: false, + }, + }, + { + name: "final hop next sid present", + payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + expErr: hop.ErrInvalidPayload{ + Type: record.NextHopOnionType, + Omitted: false, + FinalHop: true, + }, + }, +} + +// TestDecodeHopPayloadRecordValidation asserts that parsing the payloads in the +// tests yields the expected errors depending on whether the proper fields were +// included or omitted. +func TestDecodeHopPayloadRecordValidation(t *testing.T) { + for _, test := range decodePayloadTests { + t.Run(test.name, func(t *testing.T) { + testDecodeHopPayloadValidation(t, test) + }) + } +} + +func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) { + _, err := hop.NewPayloadFromReader(bytes.NewReader(test.payload)) + if !reflect.DeepEqual(test.expErr, err) { + t.Fatalf("expected error mismatch, want: %v, got: %v", + test.expErr, err) + } +}