Merge pull request #3465 from cfromknecht/tlv-parsed-types
tlv+htlcswitch: validate presence/omission of parsed onion payload types
This commit is contained in:
commit
2cf10ada0f
@ -2,6 +2,7 @@ package hop
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/lightningnetwork/lightning-onion"
|
||||
@ -10,6 +11,37 @@ import (
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
// ErrInvalidPayload is an error returned when a parsed onion payload either
|
||||
// included or omitted incorrect records for a particular hop type.
|
||||
type ErrInvalidPayload struct {
|
||||
// Type the record's type that cause the violation.
|
||||
Type tlv.Type
|
||||
|
||||
// Ommitted if true, signals that the sender did not include the record.
|
||||
// Otherwise, the sender included the record when it shouldn't have.
|
||||
Omitted bool
|
||||
|
||||
// FinalHop if true, indicates that the violation is for the final hop
|
||||
// in the route (identified by next hop id), otherwise the violation is
|
||||
// for an intermediate hop.
|
||||
FinalHop bool
|
||||
}
|
||||
|
||||
// Error returns a human-readable description of the invalid payload error.
|
||||
func (e ErrInvalidPayload) Error() string {
|
||||
hopType := "intermediate"
|
||||
if e.FinalHop {
|
||||
hopType = "final"
|
||||
}
|
||||
violation := "included"
|
||||
if e.Omitted {
|
||||
violation = "omitted"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("onion payload for %s hop %s record with type %d",
|
||||
hopType, violation, e.Type)
|
||||
}
|
||||
|
||||
// Payload encapsulates all information delivered to a hop in an onion payload.
|
||||
// A Hop can represent either a TLV or legacy payload. The primary forwarding
|
||||
// instruction can be accessed via ForwardingInfo, and additional records can be
|
||||
@ -53,7 +85,16 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = tlvStream.Decode(r)
|
||||
parsedTypes, err := tlvStream.DecodeWithParsedTypes(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nextHop := lnwire.NewShortChanIDFromInt(cid)
|
||||
|
||||
// Validate whether the sender properly included or omitted tlv records
|
||||
// in accordance with BOLT 04.
|
||||
err = ValidateParsedPayloadTypes(parsedTypes, nextHop)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -61,7 +102,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
||||
return &Payload{
|
||||
FwdInfo: ForwardingInfo{
|
||||
Network: BitcoinNetwork,
|
||||
NextHop: lnwire.NewShortChanIDFromInt(cid),
|
||||
NextHop: nextHop,
|
||||
AmountToForward: lnwire.MilliSatoshi(amt),
|
||||
OutgoingCTLV: cltv,
|
||||
},
|
||||
@ -73,3 +114,48 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
||||
func (h *Payload) ForwardingInfo() ForwardingInfo {
|
||||
return h.FwdInfo
|
||||
}
|
||||
|
||||
// ValidateParsedPayloadTypes checks the types parsed from a hop payload to
|
||||
// ensure that the proper fields are either included or omitted. The finalHop
|
||||
// boolean should be true if the payload was parsed for an exit hop. The
|
||||
// requirements for this method are described in BOLT 04.
|
||||
func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet,
|
||||
nextHop lnwire.ShortChannelID) error {
|
||||
|
||||
isFinalHop := nextHop == Exit
|
||||
|
||||
_, hasAmt := parsedTypes[record.AmtOnionType]
|
||||
_, hasLockTime := parsedTypes[record.LockTimeOnionType]
|
||||
_, hasNextHop := parsedTypes[record.NextHopOnionType]
|
||||
|
||||
switch {
|
||||
|
||||
// All hops must include an amount to forward.
|
||||
case !hasAmt:
|
||||
return ErrInvalidPayload{
|
||||
Type: record.AmtOnionType,
|
||||
Omitted: true,
|
||||
FinalHop: isFinalHop,
|
||||
}
|
||||
|
||||
// All hops must include a cltv expiry.
|
||||
case !hasLockTime:
|
||||
return ErrInvalidPayload{
|
||||
Type: record.LockTimeOnionType,
|
||||
Omitted: true,
|
||||
FinalHop: isFinalHop,
|
||||
}
|
||||
|
||||
// The exit hop should omit the next hop id. If nextHop != Exit, the
|
||||
// sender must have included a record, so we don't need to test for its
|
||||
// inclusion at intermediate hops directly.
|
||||
case isFinalHop && hasNextHop:
|
||||
return ErrInvalidPayload{
|
||||
Type: record.NextHopOnionType,
|
||||
Omitted: false,
|
||||
FinalHop: true,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
89
htlcswitch/hop/payload_test.go
Normal file
89
htlcswitch/hop/payload_test.go
Normal file
@ -0,0 +1,89 @@
|
||||
package hop_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/htlcswitch/hop"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
)
|
||||
|
||||
type decodePayloadTest struct {
|
||||
name string
|
||||
payload []byte
|
||||
expErr error
|
||||
}
|
||||
|
||||
var decodePayloadTests = []decodePayloadTest{
|
||||
{
|
||||
name: "final hop no amount",
|
||||
payload: []byte{0x04, 0x00},
|
||||
expErr: hop.ErrInvalidPayload{
|
||||
Type: record.AmtOnionType,
|
||||
Omitted: true,
|
||||
FinalHop: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "intermediate hop no amount",
|
||||
payload: []byte{0x04, 0x00, 0x06, 0x08, 0x01, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
},
|
||||
expErr: hop.ErrInvalidPayload{
|
||||
Type: record.AmtOnionType,
|
||||
Omitted: true,
|
||||
FinalHop: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "final hop no expiry",
|
||||
payload: []byte{0x02, 0x00},
|
||||
expErr: hop.ErrInvalidPayload{
|
||||
Type: record.LockTimeOnionType,
|
||||
Omitted: true,
|
||||
FinalHop: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "intermediate hop no expiry",
|
||||
payload: []byte{0x02, 0x00, 0x06, 0x08, 0x01, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
},
|
||||
expErr: hop.ErrInvalidPayload{
|
||||
Type: record.LockTimeOnionType,
|
||||
Omitted: true,
|
||||
FinalHop: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "final hop next sid present",
|
||||
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
},
|
||||
expErr: hop.ErrInvalidPayload{
|
||||
Type: record.NextHopOnionType,
|
||||
Omitted: false,
|
||||
FinalHop: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// TestDecodeHopPayloadRecordValidation asserts that parsing the payloads in the
|
||||
// tests yields the expected errors depending on whether the proper fields were
|
||||
// included or omitted.
|
||||
func TestDecodeHopPayloadRecordValidation(t *testing.T) {
|
||||
for _, test := range decodePayloadTests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
testDecodeHopPayloadValidation(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) {
|
||||
_, err := hop.NewPayloadFromReader(bytes.NewReader(test.payload))
|
||||
if !reflect.DeepEqual(test.expErr, err) {
|
||||
t.Fatalf("expected error mismatch, want: %v, got: %v",
|
||||
test.expErr, err)
|
||||
}
|
||||
}
|
@ -12,6 +12,9 @@ import (
|
||||
// Type is an 64-bit identifier for a TLV Record.
|
||||
type Type uint64
|
||||
|
||||
// TypeSet is an unordered set of Types.
|
||||
type TypeSet map[Type]struct{}
|
||||
|
||||
// Encoder is a signature for methods that can encode TLV values. An error
|
||||
// should be returned if the Encoder cannot support the underlying type of val.
|
||||
// The provided scratch buffer must be non-nil.
|
||||
|
@ -144,6 +144,21 @@ func (s *Stream) Encode(w io.Writer) error {
|
||||
// the last record was read cleanly and we should stop parsing. All other io.EOF
|
||||
// or io.ErrUnexpectedEOF errors are returned.
|
||||
func (s *Stream) Decode(r io.Reader) error {
|
||||
_, err := s.decode(r, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
// DecodeWithParsedTypes is identical to Decode, but if successful, returns a
|
||||
// TypeSet containing the types of all records that were decoded or ignored from
|
||||
// the stream.
|
||||
func (s *Stream) DecodeWithParsedTypes(r io.Reader) (TypeSet, error) {
|
||||
return s.decode(r, make(TypeSet))
|
||||
}
|
||||
|
||||
// decode is a helper function that performs the basis of stream decoding. If
|
||||
// the caller needs the set of parsed types, it must provide an initialized
|
||||
// parsedTypes, otherwise the returned TypeSet will be nil.
|
||||
func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
|
||||
var (
|
||||
typ Type
|
||||
min Type
|
||||
@ -161,11 +176,11 @@ func (s *Stream) Decode(r io.Reader) error {
|
||||
// We'll silence an EOF when zero bytes remain, meaning the
|
||||
// stream was cleanly encoded.
|
||||
case err == io.EOF:
|
||||
return nil
|
||||
return parsedTypes, nil
|
||||
|
||||
// Other unexpected errors.
|
||||
case err != nil:
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
typ = Type(t)
|
||||
@ -176,7 +191,7 @@ func (s *Stream) Decode(r io.Reader) error {
|
||||
// encodings that have duplicate records or from accepting an
|
||||
// unsorted series.
|
||||
if overflow || typ < min {
|
||||
return ErrStreamNotCanonical
|
||||
return nil, ErrStreamNotCanonical
|
||||
}
|
||||
|
||||
// Read the varint length.
|
||||
@ -186,11 +201,11 @@ func (s *Stream) Decode(r io.Reader) error {
|
||||
// We'll convert any EOFs to ErrUnexpectedEOF, since this
|
||||
// results in an invalid record.
|
||||
case err == io.EOF:
|
||||
return io.ErrUnexpectedEOF
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
|
||||
// Other unexpected errors.
|
||||
case err != nil:
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Place a soft limit on the size of a sane record, which
|
||||
@ -198,7 +213,7 @@ func (s *Stream) Decode(r io.Reader) error {
|
||||
// unbounded amount of memory when decoding variable-sized
|
||||
// fields.
|
||||
if length > MaxRecordSize {
|
||||
return ErrRecordTooLarge
|
||||
return nil, ErrRecordTooLarge
|
||||
}
|
||||
|
||||
// Search the records known to the stream for this type. We'll
|
||||
@ -218,17 +233,17 @@ func (s *Stream) Decode(r io.Reader) error {
|
||||
// We'll convert any EOFs to ErrUnexpectedEOF, since this
|
||||
// results in an invalid record.
|
||||
case err == io.EOF:
|
||||
return io.ErrUnexpectedEOF
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
|
||||
// Other unexpected errors.
|
||||
case err != nil:
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// This record type is unknown to the stream, fail if the type
|
||||
// is even meaning that we are required to understand it.
|
||||
case typ%2 == 0:
|
||||
return ErrUnknownRequiredType(typ)
|
||||
return nil, ErrUnknownRequiredType(typ)
|
||||
|
||||
// Otherwise, the record type is unknown and is odd, discard the
|
||||
// number of bytes specified by length.
|
||||
@ -239,14 +254,20 @@ func (s *Stream) Decode(r io.Reader) error {
|
||||
// We'll convert any EOFs to ErrUnexpectedEOF, since this
|
||||
// results in an invalid record.
|
||||
case err == io.EOF:
|
||||
return io.ErrUnexpectedEOF
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
|
||||
// Other unexpected errors.
|
||||
case err != nil:
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Record the successfully decoded or ignored type if the
|
||||
// caller provided an initialized TypeSet.
|
||||
if parsedTypes != nil {
|
||||
parsedTypes[typ] = struct{}{}
|
||||
}
|
||||
|
||||
// Update our record index so that we can begin our next search
|
||||
// from where we left off.
|
||||
recordIdx = newIdx
|
||||
|
51
tlv/stream_test.go
Normal file
51
tlv/stream_test.go
Normal file
@ -0,0 +1,51 @@
|
||||
package tlv_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
// TestParsedTypes asserts that a Stream will properly return the set of types
|
||||
// that it encounters when the type is known-and-decoded or unknown-and-ignored.
|
||||
func TestParsedTypes(t *testing.T) {
|
||||
const (
|
||||
knownType = 1
|
||||
unknownType = 3
|
||||
)
|
||||
|
||||
// Construct a stream that will encode two types, one that will be known
|
||||
// to the decoder and another that will be unknown.
|
||||
encStream := tlv.MustNewStream(
|
||||
tlv.MakePrimitiveRecord(knownType, new(uint64)),
|
||||
tlv.MakePrimitiveRecord(unknownType, new(uint64)),
|
||||
)
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := encStream.Encode(&b); err != nil {
|
||||
t.Fatalf("unable to encode stream: %v", err)
|
||||
}
|
||||
|
||||
// Create a stream that will parse only the known type.
|
||||
decStream := tlv.MustNewStream(
|
||||
tlv.MakePrimitiveRecord(knownType, new(uint64)),
|
||||
)
|
||||
|
||||
parsedTypes, err := decStream.DecodeWithParsedTypes(
|
||||
bytes.NewReader(b.Bytes()),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to decode stream: %v", err)
|
||||
}
|
||||
|
||||
// Assert that both the known and unknown types are included in the set
|
||||
// of parsed types.
|
||||
if _, ok := parsedTypes[knownType]; !ok {
|
||||
t.Fatalf("known type %d should be in parsed types", knownType)
|
||||
}
|
||||
if _, ok := parsedTypes[unknownType]; !ok {
|
||||
t.Fatalf("unknown type %d should be in parsed types",
|
||||
unknownType)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user