Merge pull request #3465 from cfromknecht/tlv-parsed-types

tlv+htlcswitch: validate presence/omission of parsed onion payload types
This commit is contained in:
Olaoluwa Osuntokun 2019-09-09 05:42:03 -07:00 committed by GitHub
commit 2cf10ada0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 263 additions and 13 deletions

@ -2,6 +2,7 @@ package hop
import (
"encoding/binary"
"fmt"
"io"
"github.com/lightningnetwork/lightning-onion"
@ -10,6 +11,37 @@ import (
"github.com/lightningnetwork/lnd/tlv"
)
// ErrInvalidPayload is an error returned when a parsed onion payload either
// included or omitted incorrect records for a particular hop type.
type ErrInvalidPayload struct {
// Type the record's type that cause the violation.
Type tlv.Type
// Ommitted if true, signals that the sender did not include the record.
// Otherwise, the sender included the record when it shouldn't have.
Omitted bool
// FinalHop if true, indicates that the violation is for the final hop
// in the route (identified by next hop id), otherwise the violation is
// for an intermediate hop.
FinalHop bool
}
// Error returns a human-readable description of the invalid payload error.
func (e ErrInvalidPayload) Error() string {
hopType := "intermediate"
if e.FinalHop {
hopType = "final"
}
violation := "included"
if e.Omitted {
violation = "omitted"
}
return fmt.Sprintf("onion payload for %s hop %s record with type %d",
hopType, violation, e.Type)
}
// Payload encapsulates all information delivered to a hop in an onion payload.
// A Hop can represent either a TLV or legacy payload. The primary forwarding
// instruction can be accessed via ForwardingInfo, and additional records can be
@ -53,7 +85,16 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
return nil, err
}
err = tlvStream.Decode(r)
parsedTypes, err := tlvStream.DecodeWithParsedTypes(r)
if err != nil {
return nil, err
}
nextHop := lnwire.NewShortChanIDFromInt(cid)
// Validate whether the sender properly included or omitted tlv records
// in accordance with BOLT 04.
err = ValidateParsedPayloadTypes(parsedTypes, nextHop)
if err != nil {
return nil, err
}
@ -61,7 +102,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
return &Payload{
FwdInfo: ForwardingInfo{
Network: BitcoinNetwork,
NextHop: lnwire.NewShortChanIDFromInt(cid),
NextHop: nextHop,
AmountToForward: lnwire.MilliSatoshi(amt),
OutgoingCTLV: cltv,
},
@ -73,3 +114,48 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
func (h *Payload) ForwardingInfo() ForwardingInfo {
return h.FwdInfo
}
// ValidateParsedPayloadTypes checks the types parsed from a hop payload to
// ensure that the proper fields are either included or omitted. The finalHop
// boolean should be true if the payload was parsed for an exit hop. The
// requirements for this method are described in BOLT 04.
func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet,
nextHop lnwire.ShortChannelID) error {
isFinalHop := nextHop == Exit
_, hasAmt := parsedTypes[record.AmtOnionType]
_, hasLockTime := parsedTypes[record.LockTimeOnionType]
_, hasNextHop := parsedTypes[record.NextHopOnionType]
switch {
// All hops must include an amount to forward.
case !hasAmt:
return ErrInvalidPayload{
Type: record.AmtOnionType,
Omitted: true,
FinalHop: isFinalHop,
}
// All hops must include a cltv expiry.
case !hasLockTime:
return ErrInvalidPayload{
Type: record.LockTimeOnionType,
Omitted: true,
FinalHop: isFinalHop,
}
// The exit hop should omit the next hop id. If nextHop != Exit, the
// sender must have included a record, so we don't need to test for its
// inclusion at intermediate hops directly.
case isFinalHop && hasNextHop:
return ErrInvalidPayload{
Type: record.NextHopOnionType,
Omitted: false,
FinalHop: true,
}
}
return nil
}

@ -0,0 +1,89 @@
package hop_test
import (
"bytes"
"reflect"
"testing"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/record"
)
type decodePayloadTest struct {
name string
payload []byte
expErr error
}
var decodePayloadTests = []decodePayloadTest{
{
name: "final hop no amount",
payload: []byte{0x04, 0x00},
expErr: hop.ErrInvalidPayload{
Type: record.AmtOnionType,
Omitted: true,
FinalHop: true,
},
},
{
name: "intermediate hop no amount",
payload: []byte{0x04, 0x00, 0x06, 0x08, 0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
},
expErr: hop.ErrInvalidPayload{
Type: record.AmtOnionType,
Omitted: true,
FinalHop: false,
},
},
{
name: "final hop no expiry",
payload: []byte{0x02, 0x00},
expErr: hop.ErrInvalidPayload{
Type: record.LockTimeOnionType,
Omitted: true,
FinalHop: true,
},
},
{
name: "intermediate hop no expiry",
payload: []byte{0x02, 0x00, 0x06, 0x08, 0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
},
expErr: hop.ErrInvalidPayload{
Type: record.LockTimeOnionType,
Omitted: true,
FinalHop: false,
},
},
{
name: "final hop next sid present",
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
},
expErr: hop.ErrInvalidPayload{
Type: record.NextHopOnionType,
Omitted: false,
FinalHop: true,
},
},
}
// TestDecodeHopPayloadRecordValidation asserts that parsing the payloads in the
// tests yields the expected errors depending on whether the proper fields were
// included or omitted.
func TestDecodeHopPayloadRecordValidation(t *testing.T) {
for _, test := range decodePayloadTests {
t.Run(test.name, func(t *testing.T) {
testDecodeHopPayloadValidation(t, test)
})
}
}
func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) {
_, err := hop.NewPayloadFromReader(bytes.NewReader(test.payload))
if !reflect.DeepEqual(test.expErr, err) {
t.Fatalf("expected error mismatch, want: %v, got: %v",
test.expErr, err)
}
}

@ -12,6 +12,9 @@ import (
// Type is an 64-bit identifier for a TLV Record.
type Type uint64
// TypeSet is an unordered set of Types.
type TypeSet map[Type]struct{}
// Encoder is a signature for methods that can encode TLV values. An error
// should be returned if the Encoder cannot support the underlying type of val.
// The provided scratch buffer must be non-nil.

@ -144,6 +144,21 @@ func (s *Stream) Encode(w io.Writer) error {
// the last record was read cleanly and we should stop parsing. All other io.EOF
// or io.ErrUnexpectedEOF errors are returned.
func (s *Stream) Decode(r io.Reader) error {
_, err := s.decode(r, nil)
return err
}
// DecodeWithParsedTypes is identical to Decode, but if successful, returns a
// TypeSet containing the types of all records that were decoded or ignored from
// the stream.
func (s *Stream) DecodeWithParsedTypes(r io.Reader) (TypeSet, error) {
return s.decode(r, make(TypeSet))
}
// decode is a helper function that performs the basis of stream decoding. If
// the caller needs the set of parsed types, it must provide an initialized
// parsedTypes, otherwise the returned TypeSet will be nil.
func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
var (
typ Type
min Type
@ -161,11 +176,11 @@ func (s *Stream) Decode(r io.Reader) error {
// We'll silence an EOF when zero bytes remain, meaning the
// stream was cleanly encoded.
case err == io.EOF:
return nil
return parsedTypes, nil
// Other unexpected errors.
case err != nil:
return err
return nil, err
}
typ = Type(t)
@ -176,7 +191,7 @@ func (s *Stream) Decode(r io.Reader) error {
// encodings that have duplicate records or from accepting an
// unsorted series.
if overflow || typ < min {
return ErrStreamNotCanonical
return nil, ErrStreamNotCanonical
}
// Read the varint length.
@ -186,11 +201,11 @@ func (s *Stream) Decode(r io.Reader) error {
// We'll convert any EOFs to ErrUnexpectedEOF, since this
// results in an invalid record.
case err == io.EOF:
return io.ErrUnexpectedEOF
return nil, io.ErrUnexpectedEOF
// Other unexpected errors.
case err != nil:
return err
return nil, err
}
// Place a soft limit on the size of a sane record, which
@ -198,7 +213,7 @@ func (s *Stream) Decode(r io.Reader) error {
// unbounded amount of memory when decoding variable-sized
// fields.
if length > MaxRecordSize {
return ErrRecordTooLarge
return nil, ErrRecordTooLarge
}
// Search the records known to the stream for this type. We'll
@ -218,17 +233,17 @@ func (s *Stream) Decode(r io.Reader) error {
// We'll convert any EOFs to ErrUnexpectedEOF, since this
// results in an invalid record.
case err == io.EOF:
return io.ErrUnexpectedEOF
return nil, io.ErrUnexpectedEOF
// Other unexpected errors.
case err != nil:
return err
return nil, err
}
// This record type is unknown to the stream, fail if the type
// is even meaning that we are required to understand it.
case typ%2 == 0:
return ErrUnknownRequiredType(typ)
return nil, ErrUnknownRequiredType(typ)
// Otherwise, the record type is unknown and is odd, discard the
// number of bytes specified by length.
@ -239,14 +254,20 @@ func (s *Stream) Decode(r io.Reader) error {
// We'll convert any EOFs to ErrUnexpectedEOF, since this
// results in an invalid record.
case err == io.EOF:
return io.ErrUnexpectedEOF
return nil, io.ErrUnexpectedEOF
// Other unexpected errors.
case err != nil:
return err
return nil, err
}
}
// Record the successfully decoded or ignored type if the
// caller provided an initialized TypeSet.
if parsedTypes != nil {
parsedTypes[typ] = struct{}{}
}
// Update our record index so that we can begin our next search
// from where we left off.
recordIdx = newIdx

51
tlv/stream_test.go Normal file

@ -0,0 +1,51 @@
package tlv_test
import (
"bytes"
"testing"
"github.com/lightningnetwork/lnd/tlv"
)
// TestParsedTypes asserts that a Stream will properly return the set of types
// that it encounters when the type is known-and-decoded or unknown-and-ignored.
func TestParsedTypes(t *testing.T) {
const (
knownType = 1
unknownType = 3
)
// Construct a stream that will encode two types, one that will be known
// to the decoder and another that will be unknown.
encStream := tlv.MustNewStream(
tlv.MakePrimitiveRecord(knownType, new(uint64)),
tlv.MakePrimitiveRecord(unknownType, new(uint64)),
)
var b bytes.Buffer
if err := encStream.Encode(&b); err != nil {
t.Fatalf("unable to encode stream: %v", err)
}
// Create a stream that will parse only the known type.
decStream := tlv.MustNewStream(
tlv.MakePrimitiveRecord(knownType, new(uint64)),
)
parsedTypes, err := decStream.DecodeWithParsedTypes(
bytes.NewReader(b.Bytes()),
)
if err != nil {
t.Fatalf("unable to decode stream: %v", err)
}
// Assert that both the known and unknown types are included in the set
// of parsed types.
if _, ok := parsedTypes[knownType]; !ok {
t.Fatalf("known type %d should be in parsed types", knownType)
}
if _, ok := parsedTypes[unknownType]; !ok {
t.Fatalf("unknown type %d should be in parsed types",
unknownType)
}
}