Merge pull request #3470 from cfromknecht/invalid-onion-payload

htlcswitch+lnwire: invalid onion payload
This commit is contained in:
Olaoluwa Osuntokun 2019-11-01 18:58:00 -07:00 committed by GitHub
commit acd8a6e302
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 457 additions and 66 deletions

@ -5,21 +5,56 @@ import (
"fmt" "fmt"
"io" "io"
"github.com/lightningnetwork/lightning-onion" sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/tlv"
) )
// PayloadViolation is an enum encapsulating the possible invalid payload
// violations that can occur when processing or validating a payload.
type PayloadViolation byte
const (
// OmittedViolation indicates that a type was expected to be found the
// payload but was absent.
OmittedViolation PayloadViolation = iota
// IncludedViolation indicates that a type was expected to be omitted
// from the payload but was present.
IncludedViolation
// RequiredViolation indicates that an unknown even type was found in
// the payload that we could not process.
RequiredViolation
)
// String returns a human-readable description of the violation as a verb.
func (v PayloadViolation) String() string {
switch v {
case OmittedViolation:
return "omitted"
case IncludedViolation:
return "included"
case RequiredViolation:
return "required"
default:
return "unknown violation"
}
}
// ErrInvalidPayload is an error returned when a parsed onion payload either // ErrInvalidPayload is an error returned when a parsed onion payload either
// included or omitted incorrect records for a particular hop type. // included or omitted incorrect records for a particular hop type.
type ErrInvalidPayload struct { type ErrInvalidPayload struct {
// Type the record's type that cause the violation. // Type the record's type that cause the violation.
Type tlv.Type Type tlv.Type
// Ommitted if true, signals that the sender did not include the record. // Violation is an enum indicating the type of violation detected in
// Otherwise, the sender included the record when it shouldn't have. // processing Type.
Omitted bool Violation PayloadViolation
// FinalHop if true, indicates that the violation is for the final hop // FinalHop if true, indicates that the violation is for the final hop
// in the route (identified by next hop id), otherwise the violation is // in the route (identified by next hop id), otherwise the violation is
@ -33,13 +68,9 @@ func (e ErrInvalidPayload) Error() string {
if e.FinalHop { if e.FinalHop {
hopType = "final" hopType = "final"
} }
violation := "included"
if e.Omitted {
violation = "omitted"
}
return fmt.Sprintf("onion payload for %s hop %s record with type %d", return fmt.Sprintf("onion payload for %s hop %v record with type %d",
hopType, violation, e.Type) hopType, e.Violation, e.Type)
} }
// Payload encapsulates all information delivered to a hop in an onion payload. // Payload encapsulates all information delivered to a hop in an onion payload.
@ -87,13 +118,34 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
parsedTypes, err := tlvStream.DecodeWithParsedTypes(r) parsedTypes, err := tlvStream.DecodeWithParsedTypes(r)
if err != nil { if err != nil {
// Promote any required type failures into ErrInvalidPayload.
if e, required := err.(tlv.ErrUnknownRequiredType); required {
// If the parser returned an unknown required type
// failure, we'll first check that the payload is
// properly formed according to our known set of
// constraints. If an error is discovered, this
// overrides the required type failure.
nextHop := lnwire.NewShortChanIDFromInt(cid)
err = ValidateParsedPayloadTypes(parsedTypes, nextHop)
if err != nil {
return nil, err
}
// Otherwise the known constraints were applied
// successfully, report the invalid type failure
// returned by the parser.
return nil, ErrInvalidPayload{
Type: tlv.Type(e),
Violation: RequiredViolation,
FinalHop: nextHop == Exit,
}
}
return nil, err return nil, err
} }
nextHop := lnwire.NewShortChanIDFromInt(cid)
// Validate whether the sender properly included or omitted tlv records // Validate whether the sender properly included or omitted tlv records
// in accordance with BOLT 04. // in accordance with BOLT 04.
nextHop := lnwire.NewShortChanIDFromInt(cid)
err = ValidateParsedPayloadTypes(parsedTypes, nextHop) err = ValidateParsedPayloadTypes(parsedTypes, nextHop)
if err != nil { if err != nil {
return nil, err return nil, err
@ -133,17 +185,17 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet,
// All hops must include an amount to forward. // All hops must include an amount to forward.
case !hasAmt: case !hasAmt:
return ErrInvalidPayload{ return ErrInvalidPayload{
Type: record.AmtOnionType, Type: record.AmtOnionType,
Omitted: true, Violation: OmittedViolation,
FinalHop: isFinalHop, FinalHop: isFinalHop,
} }
// All hops must include a cltv expiry. // All hops must include a cltv expiry.
case !hasLockTime: case !hasLockTime:
return ErrInvalidPayload{ return ErrInvalidPayload{
Type: record.LockTimeOnionType, Type: record.LockTimeOnionType,
Omitted: true, Violation: OmittedViolation,
FinalHop: isFinalHop, FinalHop: isFinalHop,
} }
// The exit hop should omit the next hop id. If nextHop != Exit, the // The exit hop should omit the next hop id. If nextHop != Exit, the
@ -151,9 +203,9 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet,
// inclusion at intermediate hops directly. // inclusion at intermediate hops directly.
case isFinalHop && hasNextHop: case isFinalHop && hasNextHop:
return ErrInvalidPayload{ return ErrInvalidPayload{
Type: record.NextHopOnionType, Type: record.NextHopOnionType,
Omitted: false, Violation: IncludedViolation,
FinalHop: true, FinalHop: true,
} }
} }

@ -16,13 +16,23 @@ type decodePayloadTest struct {
} }
var decodePayloadTests = []decodePayloadTest{ var decodePayloadTests = []decodePayloadTest{
{
name: "final hop valid",
payload: []byte{0x02, 0x00, 0x04, 0x00},
},
{
name: "intermediate hop valid",
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
},
},
{ {
name: "final hop no amount", name: "final hop no amount",
payload: []byte{0x04, 0x00}, payload: []byte{0x04, 0x00},
expErr: hop.ErrInvalidPayload{ expErr: hop.ErrInvalidPayload{
Type: record.AmtOnionType, Type: record.AmtOnionType,
Omitted: true, Violation: hop.OmittedViolation,
FinalHop: true, FinalHop: true,
}, },
}, },
{ {
@ -31,18 +41,18 @@ var decodePayloadTests = []decodePayloadTest{
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}, },
expErr: hop.ErrInvalidPayload{ expErr: hop.ErrInvalidPayload{
Type: record.AmtOnionType, Type: record.AmtOnionType,
Omitted: true, Violation: hop.OmittedViolation,
FinalHop: false, FinalHop: false,
}, },
}, },
{ {
name: "final hop no expiry", name: "final hop no expiry",
payload: []byte{0x02, 0x00}, payload: []byte{0x02, 0x00},
expErr: hop.ErrInvalidPayload{ expErr: hop.ErrInvalidPayload{
Type: record.LockTimeOnionType, Type: record.LockTimeOnionType,
Omitted: true, Violation: hop.OmittedViolation,
FinalHop: true, FinalHop: true,
}, },
}, },
{ {
@ -51,9 +61,9 @@ var decodePayloadTests = []decodePayloadTest{
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}, },
expErr: hop.ErrInvalidPayload{ expErr: hop.ErrInvalidPayload{
Type: record.LockTimeOnionType, Type: record.LockTimeOnionType,
Omitted: true, Violation: hop.OmittedViolation,
FinalHop: false, FinalHop: false,
}, },
}, },
{ {
@ -62,9 +72,60 @@ var decodePayloadTests = []decodePayloadTest{
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}, },
expErr: hop.ErrInvalidPayload{ expErr: hop.ErrInvalidPayload{
Type: record.NextHopOnionType, Type: record.NextHopOnionType,
Omitted: false, Violation: hop.IncludedViolation,
FinalHop: true, FinalHop: true,
},
},
{
name: "required type after omitted hop id",
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x08, 0x00},
expErr: hop.ErrInvalidPayload{
Type: 8,
Violation: hop.RequiredViolation,
FinalHop: true,
},
},
{
name: "required type after included hop id",
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00,
},
expErr: hop.ErrInvalidPayload{
Type: 8,
Violation: hop.RequiredViolation,
FinalHop: false,
},
},
{
name: "required type zero final hop",
payload: []byte{0x00, 0x00, 0x02, 0x00, 0x04, 0x00},
expErr: hop.ErrInvalidPayload{
Type: 0,
Violation: hop.RequiredViolation,
FinalHop: true,
},
},
{
name: "required type zero final hop zero sid",
payload: []byte{0x00, 0x00, 0x02, 0x00, 0x04, 0x00, 0x06, 0x08,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
},
expErr: hop.ErrInvalidPayload{
Type: 6,
Violation: hop.IncludedViolation,
FinalHop: true,
},
},
{
name: "required type zero intermediate hop",
payload: []byte{0x00, 0x00, 0x02, 0x00, 0x04, 0x00, 0x06, 0x08,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
},
expErr: hop.ErrInvalidPayload{
Type: 0,
Violation: hop.RequiredViolation,
FinalHop: false,
}, },
}, },
} }

@ -2645,12 +2645,23 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
fwdInfo, err := chanIterator.ForwardingInstructions() fwdInfo, err := chanIterator.ForwardingInstructions()
if err != nil { if err != nil {
// If we're unable to process the onion payload, or we // If we're unable to process the onion payload, or we
// we received malformed TLV stream, then we should // received invalid onion payload failure, then we
// send an error back to the caller so the HTLC can be // should send an error back to the caller so the HTLC
// canceled. // can be canceled.
var failedType uint64
if e, ok := err.(hop.ErrInvalidPayload); ok {
failedType = uint64(e.Type)
}
// TODO: currently none of the test unit infrastructure
// is setup to handle TLV payloads, so testing this
// would require implementing a separate mock iterator
// for TLV payloads that also supports injecting invalid
// payloads. Deferring this non-trival effort till a
// later date
l.sendHTLCError( l.sendHTLCError(
pd.HtlcIndex, pd.HtlcIndex,
lnwire.NewInvalidOnionVersion(onionBlob[:]), lnwire.NewInvalidOnionPayload(failedType, 0),
obfuscator, pd.SourceRef, obfuscator, pd.SourceRef,
) )
needUpdate = true needUpdate = true

@ -11,6 +11,7 @@ import (
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors" "github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/tlv"
) )
// FailureMessage represents the onion failure object identified by its unique // FailureMessage represents the onion failure object identified by its unique
@ -78,6 +79,7 @@ const (
CodeFinalIncorrectCltvExpiry FailCode = 18 CodeFinalIncorrectCltvExpiry FailCode = 18
CodeFinalIncorrectHtlcAmount FailCode = 19 CodeFinalIncorrectHtlcAmount FailCode = 19
CodeExpiryTooFar FailCode = 21 CodeExpiryTooFar FailCode = 21
CodeInvalidOnionPayload = FlagPerm | 22
) )
// String returns the string representation of the failure code. // String returns the string representation of the failure code.
@ -149,6 +151,9 @@ func (c FailCode) String() string {
case CodeExpiryTooFar: case CodeExpiryTooFar:
return "ExpiryTooFar" return "ExpiryTooFar"
case CodeInvalidOnionPayload:
return "InvalidOnionPayload"
default: default:
return "<unknown>" return "<unknown>"
} }
@ -1117,6 +1122,66 @@ func (f *FailExpiryTooFar) Error() string {
return f.Code().String() return f.Code().String()
} }
// InvalidOnionPayload is returned if the hop could not process the TLV payload
// enclosed in the onion.
type InvalidOnionPayload struct {
// Type is the TLV type that caused the specific failure.
Type uint64
// Offset is the byte offset within the payload where the failure
// occurred.
Offset uint16
}
// NewInvalidOnionPayload initializes a new InvalidOnionPayload failure.
func NewInvalidOnionPayload(typ uint64, offset uint16) *InvalidOnionPayload {
return &InvalidOnionPayload{
Type: typ,
Offset: offset,
}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *InvalidOnionPayload) Code() FailCode {
return CodeInvalidOnionPayload
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *InvalidOnionPayload) Error() string {
return fmt.Sprintf("%v(type=%v, offset=%d)",
f.Code(), f.Type, f.Offset)
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *InvalidOnionPayload) Decode(r io.Reader, pver uint32) error {
var buf [8]byte
typ, err := tlv.ReadVarInt(r, &buf)
if err != nil {
return err
}
f.Type = typ
return ReadElements(r, &f.Offset)
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *InvalidOnionPayload) Encode(w io.Writer, pver uint32) error {
var buf [8]byte
if err := tlv.WriteVarInt(w, f.Type, &buf); err != nil {
return err
}
return WriteElements(w, f.Offset)
}
// DecodeFailure decodes, validates, and parses the lnwire onion failure, for // DecodeFailure decodes, validates, and parses the lnwire onion failure, for
// the provided protocol version. // the provided protocol version.
func DecodeFailure(r io.Reader, pver uint32) (FailureMessage, error) { func DecodeFailure(r io.Reader, pver uint32) (FailureMessage, error) {
@ -1298,6 +1363,9 @@ func makeEmptyOnionError(code FailCode) (FailureMessage, error) {
case CodeExpiryTooFar: case CodeExpiryTooFar:
return &FailExpiryTooFar{}, nil return &FailExpiryTooFar{}, nil
case CodeInvalidOnionPayload:
return &InvalidOnionPayload{}, nil
default: default:
return nil, errors.Errorf("unknown error code: %v", code) return nil, errors.Errorf("unknown error code: %v", code)
} }

@ -16,6 +16,8 @@ var (
testAmount = MilliSatoshi(1) testAmount = MilliSatoshi(1)
testCtlvExpiry = uint32(2) testCtlvExpiry = uint32(2)
testFlags = uint16(2) testFlags = uint16(2)
testType = uint64(3)
testOffset = uint16(24)
sig, _ = NewSigFromSignature(testSig) sig, _ = NewSigFromSignature(testSig)
testChannelUpdate = ChannelUpdate{ testChannelUpdate = ChannelUpdate{
Signature: sig, Signature: sig,
@ -50,6 +52,7 @@ var onionFailures = []FailureMessage{
NewChannelDisabled(testFlags, testChannelUpdate), NewChannelDisabled(testFlags, testChannelUpdate),
NewFinalIncorrectCltvExpiry(testCtlvExpiry), NewFinalIncorrectCltvExpiry(testCtlvExpiry),
NewFinalIncorrectHtlcAmount(testAmount), NewFinalIncorrectHtlcAmount(testAmount),
NewInvalidOnionPayload(testType, testOffset),
} }
// TestEncodeDecodeCode tests the ability of onion errors to be properly encoded // TestEncodeDecodeCode tests the ability of onion errors to be properly encoded

@ -265,7 +265,17 @@ func (i *interpretedResult) processPaymentOutcomeIntermediate(
// All nodes up to the failing pair must have forwarded // All nodes up to the failing pair must have forwarded
// successfully. // successfully.
if errorSourceIdx > 2 { if errorSourceIdx > 1 {
i.successPairRange(route, 0, errorSourceIdx-2)
}
}
reportNode := func() {
// Fail only the node that reported the failure.
i.failNode(route, errorSourceIdx)
// Other preceding channels in the route forwarded correctly.
if errorSourceIdx > 1 {
i.successPairRange(route, 0, errorSourceIdx-2) i.successPairRange(route, 0, errorSourceIdx-2)
} }
} }
@ -302,6 +312,14 @@ func (i *interpretedResult) processPaymentOutcomeIntermediate(
reportOutgoing() reportOutgoing()
// If InvalidOnionPayload is received, we penalize only the reporting
// node. We know the preceding hop didn't corrupt the onion, since the
// reporting node is able to send the failure. We assume that we
// constructed a valid onion payload and that the failure is most likely
// an unknown required type or a bug in their implementation.
case *lnwire.InvalidOnionPayload:
reportNode()
// If the next hop in the route wasn't known or offline, we'll only // If the next hop in the route wasn't known or offline, we'll only
// penalize the channel set which we attempted to route over. This is // penalize the channel set which we attempted to route over. This is
// conservative, and it can handle faulty channels between nodes // conservative, and it can handle faulty channels between nodes

@ -4,6 +4,7 @@ import (
"reflect" "reflect"
"testing" "testing"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
@ -47,6 +48,11 @@ func getTestPair(from, to int) DirectedNodePair {
return NewDirectedNodePair(hops[from], hops[to]) return NewDirectedNodePair(hops[from], hops[to])
} }
func getPolicyFailure(from, to int) *DirectedNodePair {
pair := getTestPair(from, to)
return &pair
}
type resultTestCase struct { type resultTestCase struct {
name string name string
route *route.Route route *route.Route
@ -169,6 +175,97 @@ var resultTestCases = []resultTestCase{
}, },
}, },
}, },
// Tests that a fee insufficient failure to an intermediate hop with
// index 2 results in the first hop marked as success, and then a
// bidirectional failure for the incoming channel. It should also result
// in a policy failure for the outgoing hop.
{
name: "fail fee insufficient intermediate",
route: &routeFourHop,
failureSrcIdx: 2,
failure: lnwire.NewFeeInsufficient(0, lnwire.ChannelUpdate{}),
expectedResult: &interpretedResult{
pairResults: map[DirectedNodePair]pairResult{
getTestPair(0, 1): {
success: true,
},
getTestPair(1, 2): {},
getTestPair(2, 1): {},
},
policyFailure: getPolicyFailure(2, 3),
},
},
// Tests an invalid onion payload from a final hop. The final hop should
// be failed while the proceeding hops are reproed as successes. The
// failure is terminal since the receiver can't process our onion.
{
name: "fail invalid onion payload final hop",
route: &routeFourHop,
failureSrcIdx: 4,
failure: lnwire.NewInvalidOnionPayload(0, 0),
expectedResult: &interpretedResult{
pairResults: map[DirectedNodePair]pairResult{
getTestPair(0, 1): {
success: true,
},
getTestPair(1, 2): {
success: true,
},
getTestPair(2, 3): {
success: true,
},
getTestPair(4, 3): {},
},
finalFailureReason: &reasonError,
nodeFailure: &hops[4],
},
},
// Tests an invalid onion payload from an intermediate hop. Only the
// reporting node should be failed. The failure is non-terminal since we
// can still try other paths.
{
name: "fail invalid onion payload intermediate",
route: &routeFourHop,
failureSrcIdx: 3,
failure: lnwire.NewInvalidOnionPayload(0, 0),
expectedResult: &interpretedResult{
pairResults: map[DirectedNodePair]pairResult{
getTestPair(0, 1): {
success: true,
},
getTestPair(1, 2): {
success: true,
},
getTestPair(3, 2): {},
getTestPair(3, 4): {},
},
nodeFailure: &hops[3],
},
},
// Tests an invalid onion payload in a direct peer that is also the
// final hop. The final node should be failed and the error is terminal
// since the remote node can't process our onion.
{
name: "fail invalid onion payload direct",
route: &routeOneHop,
failureSrcIdx: 1,
failure: lnwire.NewInvalidOnionPayload(0, 0),
expectedResult: &interpretedResult{
pairResults: map[DirectedNodePair]pairResult{
getTestPair(1, 0): {},
},
finalFailureReason: &reasonError,
nodeFailure: &hops[1],
},
},
} }
// TestResultInterpretation executes a list of test cases that test the result // TestResultInterpretation executes a list of test cases that test the result
@ -192,7 +289,8 @@ func TestResultInterpretation(t *testing.T) {
} }
if !reflect.DeepEqual(i, expected) { if !reflect.DeepEqual(i, expected) {
t.Fatal("unexpected result") t.Fatalf("unexpected result\nwant: %v\ngot: %v",
spew.Sdump(expected), spew.Sdump(i))
} }
}) })
} }

@ -162,6 +162,7 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
var ( var (
typ Type typ Type
min Type min Type
firstFail *Type
recordIdx int recordIdx int
overflow bool overflow bool
) )
@ -176,7 +177,10 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
// We'll silence an EOF when zero bytes remain, meaning the // We'll silence an EOF when zero bytes remain, meaning the
// stream was cleanly encoded. // stream was cleanly encoded.
case err == io.EOF: case err == io.EOF:
return parsedTypes, nil if firstFail == nil {
return parsedTypes, nil
}
return parsedTypes, ErrUnknownRequiredType(*firstFail)
// Other unexpected errors. // Other unexpected errors.
case err != nil: case err != nil:
@ -243,7 +247,27 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
// This record type is unknown to the stream, fail if the type // This record type is unknown to the stream, fail if the type
// is even meaning that we are required to understand it. // is even meaning that we are required to understand it.
case typ%2 == 0: case typ%2 == 0:
return nil, ErrUnknownRequiredType(typ) // We'll fail immediately in the case that we aren't
// tracking the set of parsed types.
if parsedTypes == nil {
return nil, ErrUnknownRequiredType(typ)
}
// Otherwise, we'll track the first such failure and
// allow parsing to continue. If no other types of
// errors are encountered, the first failure will be
// returned as an ErrUnknownRequiredType so that the
// full set of included types can be returned.
if firstFail == nil {
failTyp := typ
firstFail = &failTyp
}
// With the failure type recorded, we'll simply discard
// the remainder of the record as if it were optional.
// The first failure will be returned after reaching the
// stopping condition.
fallthrough
// Otherwise, the record type is unknown and is odd, discard the // Otherwise, the record type is unknown and is odd, discard the
// number of bytes specified by length. // number of bytes specified by length.

@ -2,50 +2,106 @@ package tlv_test
import ( import (
"bytes" "bytes"
"reflect"
"testing" "testing"
"github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/tlv"
) )
type parsedTypeTest struct {
name string
encode []tlv.Type
decode []tlv.Type
expErr error
}
// TestParsedTypes asserts that a Stream will properly return the set of types // TestParsedTypes asserts that a Stream will properly return the set of types
// that it encounters when the type is known-and-decoded or unknown-and-ignored. // that it encounters when the type is known-and-decoded or unknown-and-ignored.
func TestParsedTypes(t *testing.T) { func TestParsedTypes(t *testing.T) {
const ( const (
knownType = 1 firstReqType = 0
unknownType = 3 knownType = 1
unknownType = 3
secondReqType = 4
) )
// Construct a stream that will encode two types, one that will be known tests := []parsedTypeTest{
// to the decoder and another that will be unknown. {
encStream := tlv.MustNewStream( name: "known optional and unknown optional",
tlv.MakePrimitiveRecord(knownType, new(uint64)), encode: []tlv.Type{knownType, unknownType},
tlv.MakePrimitiveRecord(unknownType, new(uint64)), decode: []tlv.Type{knownType},
) },
{
name: "unknown required and known optional",
encode: []tlv.Type{firstReqType, knownType},
decode: []tlv.Type{knownType},
expErr: tlv.ErrUnknownRequiredType(firstReqType),
},
{
name: "unknown required and unknown optional",
encode: []tlv.Type{unknownType, secondReqType},
expErr: tlv.ErrUnknownRequiredType(secondReqType),
},
{
name: "unknown required and known required",
encode: []tlv.Type{firstReqType, secondReqType},
decode: []tlv.Type{secondReqType},
expErr: tlv.ErrUnknownRequiredType(firstReqType),
},
{
name: "two unknown required",
encode: []tlv.Type{firstReqType, secondReqType},
expErr: tlv.ErrUnknownRequiredType(firstReqType),
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
testParsedTypes(t, test)
})
}
}
func testParsedTypes(t *testing.T, test parsedTypeTest) {
encRecords := make([]tlv.Record, 0, len(test.encode))
for _, typ := range test.encode {
encRecords = append(
encRecords, tlv.MakePrimitiveRecord(typ, new(uint64)),
)
}
decRecords := make([]tlv.Record, 0, len(test.decode))
for _, typ := range test.decode {
decRecords = append(
decRecords, tlv.MakePrimitiveRecord(typ, new(uint64)),
)
}
// Construct a stream that will encode the test's set of types.
encStream := tlv.MustNewStream(encRecords...)
var b bytes.Buffer var b bytes.Buffer
if err := encStream.Encode(&b); err != nil { if err := encStream.Encode(&b); err != nil {
t.Fatalf("unable to encode stream: %v", err) t.Fatalf("unable to encode stream: %v", err)
} }
// Create a stream that will parse only the known type. // Create a stream that will parse a subset of the test's types.
decStream := tlv.MustNewStream( decStream := tlv.MustNewStream(decRecords...)
tlv.MakePrimitiveRecord(knownType, new(uint64)),
)
parsedTypes, err := decStream.DecodeWithParsedTypes( parsedTypes, err := decStream.DecodeWithParsedTypes(
bytes.NewReader(b.Bytes()), bytes.NewReader(b.Bytes()),
) )
if err != nil { if !reflect.DeepEqual(err, test.expErr) {
t.Fatalf("unable to decode stream: %v", err) t.Fatalf("error mismatch, want: %v got: %v", err, test.expErr)
} }
// Assert that both the known and unknown types are included in the set // Assert that all encoded types are included in the set of parsed
// of parsed types. // types.
if _, ok := parsedTypes[knownType]; !ok { for _, typ := range test.encode {
t.Fatalf("known type %d should be in parsed types", knownType) if _, ok := parsedTypes[typ]; !ok {
} t.Fatalf("encoded type %d should be in parsed types",
if _, ok := parsedTypes[unknownType]; !ok { typ)
t.Fatalf("unknown type %d should be in parsed types", }
unknownType)
} }
} }