diff --git a/htlcswitch/hop/payload.go b/htlcswitch/hop/payload.go index 1c7e563b..2fd1cd86 100644 --- a/htlcswitch/hop/payload.go +++ b/htlcswitch/hop/payload.go @@ -2,6 +2,7 @@ package hop import ( "encoding/binary" + "fmt" "io" "github.com/lightningnetwork/lightning-onion" @@ -10,6 +11,37 @@ import ( "github.com/lightningnetwork/lnd/tlv" ) +// ErrInvalidPayload is an error returned when a parsed onion payload either +// included or omitted incorrect records for a particular hop type. +type ErrInvalidPayload struct { + // Type the record's type that cause the violation. + Type tlv.Type + + // Ommitted if true, signals that the sender did not include the record. + // Otherwise, the sender included the record when it shouldn't have. + Omitted bool + + // FinalHop if true, indicates that the violation is for the final hop + // in the route (identified by next hop id), otherwise the violation is + // for an intermediate hop. + FinalHop bool +} + +// Error returns a human-readable description of the invalid payload error. +func (e ErrInvalidPayload) Error() string { + hopType := "intermediate" + if e.FinalHop { + hopType = "final" + } + violation := "included" + if e.Omitted { + violation = "omitted" + } + + return fmt.Sprintf("onion payload for %s hop %s record with type %d", + hopType, violation, e.Type) +} + // Payload encapsulates all information delivered to a hop in an onion payload. // A Hop can represent either a TLV or legacy payload. The primary forwarding // instruction can be accessed via ForwardingInfo, and additional records can be @@ -53,7 +85,16 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { return nil, err } - err = tlvStream.Decode(r) + parsedTypes, err := tlvStream.DecodeWithParsedTypes(r) + if err != nil { + return nil, err + } + + nextHop := lnwire.NewShortChanIDFromInt(cid) + + // Validate whether the sender properly included or omitted tlv records + // in accordance with BOLT 04. + err = ValidateParsedPayloadTypes(parsedTypes, nextHop) if err != nil { return nil, err } @@ -61,7 +102,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { return &Payload{ FwdInfo: ForwardingInfo{ Network: BitcoinNetwork, - NextHop: lnwire.NewShortChanIDFromInt(cid), + NextHop: nextHop, AmountToForward: lnwire.MilliSatoshi(amt), OutgoingCTLV: cltv, }, @@ -73,3 +114,48 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { func (h *Payload) ForwardingInfo() ForwardingInfo { return h.FwdInfo } + +// ValidateParsedPayloadTypes checks the types parsed from a hop payload to +// ensure that the proper fields are either included or omitted. The finalHop +// boolean should be true if the payload was parsed for an exit hop. The +// requirements for this method are described in BOLT 04. +func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet, + nextHop lnwire.ShortChannelID) error { + + isFinalHop := nextHop == Exit + + _, hasAmt := parsedTypes[record.AmtOnionType] + _, hasLockTime := parsedTypes[record.LockTimeOnionType] + _, hasNextHop := parsedTypes[record.NextHopOnionType] + + switch { + + // All hops must include an amount to forward. + case !hasAmt: + return ErrInvalidPayload{ + Type: record.AmtOnionType, + Omitted: true, + FinalHop: isFinalHop, + } + + // All hops must include a cltv expiry. + case !hasLockTime: + return ErrInvalidPayload{ + Type: record.LockTimeOnionType, + Omitted: true, + FinalHop: isFinalHop, + } + + // The exit hop should omit the next hop id. If nextHop != Exit, the + // sender must have included a record, so we don't need to test for its + // inclusion at intermediate hops directly. + case isFinalHop && hasNextHop: + return ErrInvalidPayload{ + Type: record.NextHopOnionType, + Omitted: false, + FinalHop: true, + } + } + + return nil +} diff --git a/htlcswitch/hop/payload_test.go b/htlcswitch/hop/payload_test.go new file mode 100644 index 00000000..0c9342d5 --- /dev/null +++ b/htlcswitch/hop/payload_test.go @@ -0,0 +1,89 @@ +package hop_test + +import ( + "bytes" + "reflect" + "testing" + + "github.com/lightningnetwork/lnd/htlcswitch/hop" + "github.com/lightningnetwork/lnd/record" +) + +type decodePayloadTest struct { + name string + payload []byte + expErr error +} + +var decodePayloadTests = []decodePayloadTest{ + { + name: "final hop no amount", + payload: []byte{0x04, 0x00}, + expErr: hop.ErrInvalidPayload{ + Type: record.AmtOnionType, + Omitted: true, + FinalHop: true, + }, + }, + { + name: "intermediate hop no amount", + payload: []byte{0x04, 0x00, 0x06, 0x08, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + }, + expErr: hop.ErrInvalidPayload{ + Type: record.AmtOnionType, + Omitted: true, + FinalHop: false, + }, + }, + { + name: "final hop no expiry", + payload: []byte{0x02, 0x00}, + expErr: hop.ErrInvalidPayload{ + Type: record.LockTimeOnionType, + Omitted: true, + FinalHop: true, + }, + }, + { + name: "intermediate hop no expiry", + payload: []byte{0x02, 0x00, 0x06, 0x08, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + }, + expErr: hop.ErrInvalidPayload{ + Type: record.LockTimeOnionType, + Omitted: true, + FinalHop: false, + }, + }, + { + name: "final hop next sid present", + payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + expErr: hop.ErrInvalidPayload{ + Type: record.NextHopOnionType, + Omitted: false, + FinalHop: true, + }, + }, +} + +// TestDecodeHopPayloadRecordValidation asserts that parsing the payloads in the +// tests yields the expected errors depending on whether the proper fields were +// included or omitted. +func TestDecodeHopPayloadRecordValidation(t *testing.T) { + for _, test := range decodePayloadTests { + t.Run(test.name, func(t *testing.T) { + testDecodeHopPayloadValidation(t, test) + }) + } +} + +func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) { + _, err := hop.NewPayloadFromReader(bytes.NewReader(test.payload)) + if !reflect.DeepEqual(test.expErr, err) { + t.Fatalf("expected error mismatch, want: %v, got: %v", + test.expErr, err) + } +} diff --git a/tlv/record.go b/tlv/record.go index 610ab6c1..75647e03 100644 --- a/tlv/record.go +++ b/tlv/record.go @@ -12,6 +12,9 @@ import ( // Type is an 64-bit identifier for a TLV Record. type Type uint64 +// TypeSet is an unordered set of Types. +type TypeSet map[Type]struct{} + // Encoder is a signature for methods that can encode TLV values. An error // should be returned if the Encoder cannot support the underlying type of val. // The provided scratch buffer must be non-nil. diff --git a/tlv/stream.go b/tlv/stream.go index 159301fd..60b40919 100644 --- a/tlv/stream.go +++ b/tlv/stream.go @@ -144,6 +144,21 @@ func (s *Stream) Encode(w io.Writer) error { // the last record was read cleanly and we should stop parsing. All other io.EOF // or io.ErrUnexpectedEOF errors are returned. func (s *Stream) Decode(r io.Reader) error { + _, err := s.decode(r, nil) + return err +} + +// DecodeWithParsedTypes is identical to Decode, but if successful, returns a +// TypeSet containing the types of all records that were decoded or ignored from +// the stream. +func (s *Stream) DecodeWithParsedTypes(r io.Reader) (TypeSet, error) { + return s.decode(r, make(TypeSet)) +} + +// decode is a helper function that performs the basis of stream decoding. If +// the caller needs the set of parsed types, it must provide an initialized +// parsedTypes, otherwise the returned TypeSet will be nil. +func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) { var ( typ Type min Type @@ -161,11 +176,11 @@ func (s *Stream) Decode(r io.Reader) error { // We'll silence an EOF when zero bytes remain, meaning the // stream was cleanly encoded. case err == io.EOF: - return nil + return parsedTypes, nil // Other unexpected errors. case err != nil: - return err + return nil, err } typ = Type(t) @@ -176,7 +191,7 @@ func (s *Stream) Decode(r io.Reader) error { // encodings that have duplicate records or from accepting an // unsorted series. if overflow || typ < min { - return ErrStreamNotCanonical + return nil, ErrStreamNotCanonical } // Read the varint length. @@ -186,11 +201,11 @@ func (s *Stream) Decode(r io.Reader) error { // We'll convert any EOFs to ErrUnexpectedEOF, since this // results in an invalid record. case err == io.EOF: - return io.ErrUnexpectedEOF + return nil, io.ErrUnexpectedEOF // Other unexpected errors. case err != nil: - return err + return nil, err } // Place a soft limit on the size of a sane record, which @@ -198,7 +213,7 @@ func (s *Stream) Decode(r io.Reader) error { // unbounded amount of memory when decoding variable-sized // fields. if length > MaxRecordSize { - return ErrRecordTooLarge + return nil, ErrRecordTooLarge } // Search the records known to the stream for this type. We'll @@ -218,17 +233,17 @@ func (s *Stream) Decode(r io.Reader) error { // We'll convert any EOFs to ErrUnexpectedEOF, since this // results in an invalid record. case err == io.EOF: - return io.ErrUnexpectedEOF + return nil, io.ErrUnexpectedEOF // Other unexpected errors. case err != nil: - return err + return nil, err } // This record type is unknown to the stream, fail if the type // is even meaning that we are required to understand it. case typ%2 == 0: - return ErrUnknownRequiredType(typ) + return nil, ErrUnknownRequiredType(typ) // Otherwise, the record type is unknown and is odd, discard the // number of bytes specified by length. @@ -239,14 +254,20 @@ func (s *Stream) Decode(r io.Reader) error { // We'll convert any EOFs to ErrUnexpectedEOF, since this // results in an invalid record. case err == io.EOF: - return io.ErrUnexpectedEOF + return nil, io.ErrUnexpectedEOF // Other unexpected errors. case err != nil: - return err + return nil, err } } + // Record the successfully decoded or ignored type if the + // caller provided an initialized TypeSet. + if parsedTypes != nil { + parsedTypes[typ] = struct{}{} + } + // Update our record index so that we can begin our next search // from where we left off. recordIdx = newIdx diff --git a/tlv/stream_test.go b/tlv/stream_test.go new file mode 100644 index 00000000..e9970e74 --- /dev/null +++ b/tlv/stream_test.go @@ -0,0 +1,51 @@ +package tlv_test + +import ( + "bytes" + "testing" + + "github.com/lightningnetwork/lnd/tlv" +) + +// TestParsedTypes asserts that a Stream will properly return the set of types +// that it encounters when the type is known-and-decoded or unknown-and-ignored. +func TestParsedTypes(t *testing.T) { + const ( + knownType = 1 + unknownType = 3 + ) + + // Construct a stream that will encode two types, one that will be known + // to the decoder and another that will be unknown. + encStream := tlv.MustNewStream( + tlv.MakePrimitiveRecord(knownType, new(uint64)), + tlv.MakePrimitiveRecord(unknownType, new(uint64)), + ) + + var b bytes.Buffer + if err := encStream.Encode(&b); err != nil { + t.Fatalf("unable to encode stream: %v", err) + } + + // Create a stream that will parse only the known type. + decStream := tlv.MustNewStream( + tlv.MakePrimitiveRecord(knownType, new(uint64)), + ) + + parsedTypes, err := decStream.DecodeWithParsedTypes( + bytes.NewReader(b.Bytes()), + ) + if err != nil { + t.Fatalf("unable to decode stream: %v", err) + } + + // Assert that both the known and unknown types are included in the set + // of parsed types. + if _, ok := parsedTypes[knownType]; !ok { + t.Fatalf("known type %d should be in parsed types", knownType) + } + if _, ok := parsedTypes[unknownType]; !ok { + t.Fatalf("unknown type %d should be in parsed types", + unknownType) + } +}