Merge pull request #3470 from cfromknecht/invalid-onion-payload
htlcswitch+lnwire: invalid onion payload
This commit is contained in:
commit
acd8a6e302
@ -5,21 +5,56 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/lightningnetwork/lightning-onion"
|
||||
sphinx "github.com/lightningnetwork/lightning-onion"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
// PayloadViolation is an enum encapsulating the possible invalid payload
|
||||
// violations that can occur when processing or validating a payload.
|
||||
type PayloadViolation byte
|
||||
|
||||
const (
|
||||
// OmittedViolation indicates that a type was expected to be found the
|
||||
// payload but was absent.
|
||||
OmittedViolation PayloadViolation = iota
|
||||
|
||||
// IncludedViolation indicates that a type was expected to be omitted
|
||||
// from the payload but was present.
|
||||
IncludedViolation
|
||||
|
||||
// RequiredViolation indicates that an unknown even type was found in
|
||||
// the payload that we could not process.
|
||||
RequiredViolation
|
||||
)
|
||||
|
||||
// String returns a human-readable description of the violation as a verb.
|
||||
func (v PayloadViolation) String() string {
|
||||
switch v {
|
||||
case OmittedViolation:
|
||||
return "omitted"
|
||||
|
||||
case IncludedViolation:
|
||||
return "included"
|
||||
|
||||
case RequiredViolation:
|
||||
return "required"
|
||||
|
||||
default:
|
||||
return "unknown violation"
|
||||
}
|
||||
}
|
||||
|
||||
// ErrInvalidPayload is an error returned when a parsed onion payload either
|
||||
// included or omitted incorrect records for a particular hop type.
|
||||
type ErrInvalidPayload struct {
|
||||
// Type the record's type that cause the violation.
|
||||
Type tlv.Type
|
||||
|
||||
// Ommitted if true, signals that the sender did not include the record.
|
||||
// Otherwise, the sender included the record when it shouldn't have.
|
||||
Omitted bool
|
||||
// Violation is an enum indicating the type of violation detected in
|
||||
// processing Type.
|
||||
Violation PayloadViolation
|
||||
|
||||
// FinalHop if true, indicates that the violation is for the final hop
|
||||
// in the route (identified by next hop id), otherwise the violation is
|
||||
@ -33,13 +68,9 @@ func (e ErrInvalidPayload) Error() string {
|
||||
if e.FinalHop {
|
||||
hopType = "final"
|
||||
}
|
||||
violation := "included"
|
||||
if e.Omitted {
|
||||
violation = "omitted"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("onion payload for %s hop %s record with type %d",
|
||||
hopType, violation, e.Type)
|
||||
return fmt.Sprintf("onion payload for %s hop %v record with type %d",
|
||||
hopType, e.Violation, e.Type)
|
||||
}
|
||||
|
||||
// Payload encapsulates all information delivered to a hop in an onion payload.
|
||||
@ -86,14 +117,35 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
||||
}
|
||||
|
||||
parsedTypes, err := tlvStream.DecodeWithParsedTypes(r)
|
||||
if err != nil {
|
||||
// Promote any required type failures into ErrInvalidPayload.
|
||||
if e, required := err.(tlv.ErrUnknownRequiredType); required {
|
||||
// If the parser returned an unknown required type
|
||||
// failure, we'll first check that the payload is
|
||||
// properly formed according to our known set of
|
||||
// constraints. If an error is discovered, this
|
||||
// overrides the required type failure.
|
||||
nextHop := lnwire.NewShortChanIDFromInt(cid)
|
||||
err = ValidateParsedPayloadTypes(parsedTypes, nextHop)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nextHop := lnwire.NewShortChanIDFromInt(cid)
|
||||
// Otherwise the known constraints were applied
|
||||
// successfully, report the invalid type failure
|
||||
// returned by the parser.
|
||||
return nil, ErrInvalidPayload{
|
||||
Type: tlv.Type(e),
|
||||
Violation: RequiredViolation,
|
||||
FinalHop: nextHop == Exit,
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate whether the sender properly included or omitted tlv records
|
||||
// in accordance with BOLT 04.
|
||||
nextHop := lnwire.NewShortChanIDFromInt(cid)
|
||||
err = ValidateParsedPayloadTypes(parsedTypes, nextHop)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -134,7 +186,7 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet,
|
||||
case !hasAmt:
|
||||
return ErrInvalidPayload{
|
||||
Type: record.AmtOnionType,
|
||||
Omitted: true,
|
||||
Violation: OmittedViolation,
|
||||
FinalHop: isFinalHop,
|
||||
}
|
||||
|
||||
@ -142,7 +194,7 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet,
|
||||
case !hasLockTime:
|
||||
return ErrInvalidPayload{
|
||||
Type: record.LockTimeOnionType,
|
||||
Omitted: true,
|
||||
Violation: OmittedViolation,
|
||||
FinalHop: isFinalHop,
|
||||
}
|
||||
|
||||
@ -152,7 +204,7 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet,
|
||||
case isFinalHop && hasNextHop:
|
||||
return ErrInvalidPayload{
|
||||
Type: record.NextHopOnionType,
|
||||
Omitted: false,
|
||||
Violation: IncludedViolation,
|
||||
FinalHop: true,
|
||||
}
|
||||
}
|
||||
|
@ -16,12 +16,22 @@ type decodePayloadTest struct {
|
||||
}
|
||||
|
||||
var decodePayloadTests = []decodePayloadTest{
|
||||
{
|
||||
name: "final hop valid",
|
||||
payload: []byte{0x02, 0x00, 0x04, 0x00},
|
||||
},
|
||||
{
|
||||
name: "intermediate hop valid",
|
||||
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "final hop no amount",
|
||||
payload: []byte{0x04, 0x00},
|
||||
expErr: hop.ErrInvalidPayload{
|
||||
Type: record.AmtOnionType,
|
||||
Omitted: true,
|
||||
Violation: hop.OmittedViolation,
|
||||
FinalHop: true,
|
||||
},
|
||||
},
|
||||
@ -32,7 +42,7 @@ var decodePayloadTests = []decodePayloadTest{
|
||||
},
|
||||
expErr: hop.ErrInvalidPayload{
|
||||
Type: record.AmtOnionType,
|
||||
Omitted: true,
|
||||
Violation: hop.OmittedViolation,
|
||||
FinalHop: false,
|
||||
},
|
||||
},
|
||||
@ -41,7 +51,7 @@ var decodePayloadTests = []decodePayloadTest{
|
||||
payload: []byte{0x02, 0x00},
|
||||
expErr: hop.ErrInvalidPayload{
|
||||
Type: record.LockTimeOnionType,
|
||||
Omitted: true,
|
||||
Violation: hop.OmittedViolation,
|
||||
FinalHop: true,
|
||||
},
|
||||
},
|
||||
@ -52,7 +62,7 @@ var decodePayloadTests = []decodePayloadTest{
|
||||
},
|
||||
expErr: hop.ErrInvalidPayload{
|
||||
Type: record.LockTimeOnionType,
|
||||
Omitted: true,
|
||||
Violation: hop.OmittedViolation,
|
||||
FinalHop: false,
|
||||
},
|
||||
},
|
||||
@ -63,10 +73,61 @@ var decodePayloadTests = []decodePayloadTest{
|
||||
},
|
||||
expErr: hop.ErrInvalidPayload{
|
||||
Type: record.NextHopOnionType,
|
||||
Omitted: false,
|
||||
Violation: hop.IncludedViolation,
|
||||
FinalHop: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "required type after omitted hop id",
|
||||
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x08, 0x00},
|
||||
expErr: hop.ErrInvalidPayload{
|
||||
Type: 8,
|
||||
Violation: hop.RequiredViolation,
|
||||
FinalHop: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "required type after included hop id",
|
||||
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00,
|
||||
},
|
||||
expErr: hop.ErrInvalidPayload{
|
||||
Type: 8,
|
||||
Violation: hop.RequiredViolation,
|
||||
FinalHop: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "required type zero final hop",
|
||||
payload: []byte{0x00, 0x00, 0x02, 0x00, 0x04, 0x00},
|
||||
expErr: hop.ErrInvalidPayload{
|
||||
Type: 0,
|
||||
Violation: hop.RequiredViolation,
|
||||
FinalHop: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "required type zero final hop zero sid",
|
||||
payload: []byte{0x00, 0x00, 0x02, 0x00, 0x04, 0x00, 0x06, 0x08,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
},
|
||||
expErr: hop.ErrInvalidPayload{
|
||||
Type: 6,
|
||||
Violation: hop.IncludedViolation,
|
||||
FinalHop: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "required type zero intermediate hop",
|
||||
payload: []byte{0x00, 0x00, 0x02, 0x00, 0x04, 0x00, 0x06, 0x08,
|
||||
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
},
|
||||
expErr: hop.ErrInvalidPayload{
|
||||
Type: 0,
|
||||
Violation: hop.RequiredViolation,
|
||||
FinalHop: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// TestDecodeHopPayloadRecordValidation asserts that parsing the payloads in the
|
||||
|
@ -2645,12 +2645,23 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
|
||||
fwdInfo, err := chanIterator.ForwardingInstructions()
|
||||
if err != nil {
|
||||
// If we're unable to process the onion payload, or we
|
||||
// we received malformed TLV stream, then we should
|
||||
// send an error back to the caller so the HTLC can be
|
||||
// canceled.
|
||||
// received invalid onion payload failure, then we
|
||||
// should send an error back to the caller so the HTLC
|
||||
// can be canceled.
|
||||
var failedType uint64
|
||||
if e, ok := err.(hop.ErrInvalidPayload); ok {
|
||||
failedType = uint64(e.Type)
|
||||
}
|
||||
|
||||
// TODO: currently none of the test unit infrastructure
|
||||
// is setup to handle TLV payloads, so testing this
|
||||
// would require implementing a separate mock iterator
|
||||
// for TLV payloads that also supports injecting invalid
|
||||
// payloads. Deferring this non-trival effort till a
|
||||
// later date
|
||||
l.sendHTLCError(
|
||||
pd.HtlcIndex,
|
||||
lnwire.NewInvalidOnionVersion(onionBlob[:]),
|
||||
lnwire.NewInvalidOnionPayload(failedType, 0),
|
||||
obfuscator, pd.SourceRef,
|
||||
)
|
||||
needUpdate = true
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/go-errors/errors"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
// FailureMessage represents the onion failure object identified by its unique
|
||||
@ -78,6 +79,7 @@ const (
|
||||
CodeFinalIncorrectCltvExpiry FailCode = 18
|
||||
CodeFinalIncorrectHtlcAmount FailCode = 19
|
||||
CodeExpiryTooFar FailCode = 21
|
||||
CodeInvalidOnionPayload = FlagPerm | 22
|
||||
)
|
||||
|
||||
// String returns the string representation of the failure code.
|
||||
@ -149,6 +151,9 @@ func (c FailCode) String() string {
|
||||
case CodeExpiryTooFar:
|
||||
return "ExpiryTooFar"
|
||||
|
||||
case CodeInvalidOnionPayload:
|
||||
return "InvalidOnionPayload"
|
||||
|
||||
default:
|
||||
return "<unknown>"
|
||||
}
|
||||
@ -1117,6 +1122,66 @@ func (f *FailExpiryTooFar) Error() string {
|
||||
return f.Code().String()
|
||||
}
|
||||
|
||||
// InvalidOnionPayload is returned if the hop could not process the TLV payload
|
||||
// enclosed in the onion.
|
||||
type InvalidOnionPayload struct {
|
||||
// Type is the TLV type that caused the specific failure.
|
||||
Type uint64
|
||||
|
||||
// Offset is the byte offset within the payload where the failure
|
||||
// occurred.
|
||||
Offset uint16
|
||||
}
|
||||
|
||||
// NewInvalidOnionPayload initializes a new InvalidOnionPayload failure.
|
||||
func NewInvalidOnionPayload(typ uint64, offset uint16) *InvalidOnionPayload {
|
||||
return &InvalidOnionPayload{
|
||||
Type: typ,
|
||||
Offset: offset,
|
||||
}
|
||||
}
|
||||
|
||||
// Code returns the failure unique code.
|
||||
//
|
||||
// NOTE: Part of the FailureMessage interface.
|
||||
func (f *InvalidOnionPayload) Code() FailCode {
|
||||
return CodeInvalidOnionPayload
|
||||
}
|
||||
|
||||
// Returns a human readable string describing the target FailureMessage.
|
||||
//
|
||||
// NOTE: Implements the error interface.
|
||||
func (f *InvalidOnionPayload) Error() string {
|
||||
return fmt.Sprintf("%v(type=%v, offset=%d)",
|
||||
f.Code(), f.Type, f.Offset)
|
||||
}
|
||||
|
||||
// Decode decodes the failure from bytes stream.
|
||||
//
|
||||
// NOTE: Part of the Serializable interface.
|
||||
func (f *InvalidOnionPayload) Decode(r io.Reader, pver uint32) error {
|
||||
var buf [8]byte
|
||||
typ, err := tlv.ReadVarInt(r, &buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.Type = typ
|
||||
|
||||
return ReadElements(r, &f.Offset)
|
||||
}
|
||||
|
||||
// Encode writes the failure in bytes stream.
|
||||
//
|
||||
// NOTE: Part of the Serializable interface.
|
||||
func (f *InvalidOnionPayload) Encode(w io.Writer, pver uint32) error {
|
||||
var buf [8]byte
|
||||
if err := tlv.WriteVarInt(w, f.Type, &buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return WriteElements(w, f.Offset)
|
||||
}
|
||||
|
||||
// DecodeFailure decodes, validates, and parses the lnwire onion failure, for
|
||||
// the provided protocol version.
|
||||
func DecodeFailure(r io.Reader, pver uint32) (FailureMessage, error) {
|
||||
@ -1298,6 +1363,9 @@ func makeEmptyOnionError(code FailCode) (FailureMessage, error) {
|
||||
case CodeExpiryTooFar:
|
||||
return &FailExpiryTooFar{}, nil
|
||||
|
||||
case CodeInvalidOnionPayload:
|
||||
return &InvalidOnionPayload{}, nil
|
||||
|
||||
default:
|
||||
return nil, errors.Errorf("unknown error code: %v", code)
|
||||
}
|
||||
|
@ -16,6 +16,8 @@ var (
|
||||
testAmount = MilliSatoshi(1)
|
||||
testCtlvExpiry = uint32(2)
|
||||
testFlags = uint16(2)
|
||||
testType = uint64(3)
|
||||
testOffset = uint16(24)
|
||||
sig, _ = NewSigFromSignature(testSig)
|
||||
testChannelUpdate = ChannelUpdate{
|
||||
Signature: sig,
|
||||
@ -50,6 +52,7 @@ var onionFailures = []FailureMessage{
|
||||
NewChannelDisabled(testFlags, testChannelUpdate),
|
||||
NewFinalIncorrectCltvExpiry(testCtlvExpiry),
|
||||
NewFinalIncorrectHtlcAmount(testAmount),
|
||||
NewInvalidOnionPayload(testType, testOffset),
|
||||
}
|
||||
|
||||
// TestEncodeDecodeCode tests the ability of onion errors to be properly encoded
|
||||
|
@ -265,7 +265,17 @@ func (i *interpretedResult) processPaymentOutcomeIntermediate(
|
||||
|
||||
// All nodes up to the failing pair must have forwarded
|
||||
// successfully.
|
||||
if errorSourceIdx > 2 {
|
||||
if errorSourceIdx > 1 {
|
||||
i.successPairRange(route, 0, errorSourceIdx-2)
|
||||
}
|
||||
}
|
||||
|
||||
reportNode := func() {
|
||||
// Fail only the node that reported the failure.
|
||||
i.failNode(route, errorSourceIdx)
|
||||
|
||||
// Other preceding channels in the route forwarded correctly.
|
||||
if errorSourceIdx > 1 {
|
||||
i.successPairRange(route, 0, errorSourceIdx-2)
|
||||
}
|
||||
}
|
||||
@ -302,6 +312,14 @@ func (i *interpretedResult) processPaymentOutcomeIntermediate(
|
||||
|
||||
reportOutgoing()
|
||||
|
||||
// If InvalidOnionPayload is received, we penalize only the reporting
|
||||
// node. We know the preceding hop didn't corrupt the onion, since the
|
||||
// reporting node is able to send the failure. We assume that we
|
||||
// constructed a valid onion payload and that the failure is most likely
|
||||
// an unknown required type or a bug in their implementation.
|
||||
case *lnwire.InvalidOnionPayload:
|
||||
reportNode()
|
||||
|
||||
// If the next hop in the route wasn't known or offline, we'll only
|
||||
// penalize the channel set which we attempted to route over. This is
|
||||
// conservative, and it can handle faulty channels between nodes
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
@ -47,6 +48,11 @@ func getTestPair(from, to int) DirectedNodePair {
|
||||
return NewDirectedNodePair(hops[from], hops[to])
|
||||
}
|
||||
|
||||
func getPolicyFailure(from, to int) *DirectedNodePair {
|
||||
pair := getTestPair(from, to)
|
||||
return &pair
|
||||
}
|
||||
|
||||
type resultTestCase struct {
|
||||
name string
|
||||
route *route.Route
|
||||
@ -169,6 +175,97 @@ var resultTestCases = []resultTestCase{
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
// Tests that a fee insufficient failure to an intermediate hop with
|
||||
// index 2 results in the first hop marked as success, and then a
|
||||
// bidirectional failure for the incoming channel. It should also result
|
||||
// in a policy failure for the outgoing hop.
|
||||
{
|
||||
name: "fail fee insufficient intermediate",
|
||||
route: &routeFourHop,
|
||||
failureSrcIdx: 2,
|
||||
failure: lnwire.NewFeeInsufficient(0, lnwire.ChannelUpdate{}),
|
||||
|
||||
expectedResult: &interpretedResult{
|
||||
pairResults: map[DirectedNodePair]pairResult{
|
||||
getTestPair(0, 1): {
|
||||
success: true,
|
||||
},
|
||||
getTestPair(1, 2): {},
|
||||
getTestPair(2, 1): {},
|
||||
},
|
||||
policyFailure: getPolicyFailure(2, 3),
|
||||
},
|
||||
},
|
||||
|
||||
// Tests an invalid onion payload from a final hop. The final hop should
|
||||
// be failed while the proceeding hops are reproed as successes. The
|
||||
// failure is terminal since the receiver can't process our onion.
|
||||
{
|
||||
name: "fail invalid onion payload final hop",
|
||||
route: &routeFourHop,
|
||||
failureSrcIdx: 4,
|
||||
failure: lnwire.NewInvalidOnionPayload(0, 0),
|
||||
|
||||
expectedResult: &interpretedResult{
|
||||
pairResults: map[DirectedNodePair]pairResult{
|
||||
getTestPair(0, 1): {
|
||||
success: true,
|
||||
},
|
||||
getTestPair(1, 2): {
|
||||
success: true,
|
||||
},
|
||||
getTestPair(2, 3): {
|
||||
success: true,
|
||||
},
|
||||
getTestPair(4, 3): {},
|
||||
},
|
||||
finalFailureReason: &reasonError,
|
||||
nodeFailure: &hops[4],
|
||||
},
|
||||
},
|
||||
|
||||
// Tests an invalid onion payload from an intermediate hop. Only the
|
||||
// reporting node should be failed. The failure is non-terminal since we
|
||||
// can still try other paths.
|
||||
{
|
||||
name: "fail invalid onion payload intermediate",
|
||||
route: &routeFourHop,
|
||||
failureSrcIdx: 3,
|
||||
failure: lnwire.NewInvalidOnionPayload(0, 0),
|
||||
|
||||
expectedResult: &interpretedResult{
|
||||
pairResults: map[DirectedNodePair]pairResult{
|
||||
getTestPair(0, 1): {
|
||||
success: true,
|
||||
},
|
||||
getTestPair(1, 2): {
|
||||
success: true,
|
||||
},
|
||||
getTestPair(3, 2): {},
|
||||
getTestPair(3, 4): {},
|
||||
},
|
||||
nodeFailure: &hops[3],
|
||||
},
|
||||
},
|
||||
|
||||
// Tests an invalid onion payload in a direct peer that is also the
|
||||
// final hop. The final node should be failed and the error is terminal
|
||||
// since the remote node can't process our onion.
|
||||
{
|
||||
name: "fail invalid onion payload direct",
|
||||
route: &routeOneHop,
|
||||
failureSrcIdx: 1,
|
||||
failure: lnwire.NewInvalidOnionPayload(0, 0),
|
||||
|
||||
expectedResult: &interpretedResult{
|
||||
pairResults: map[DirectedNodePair]pairResult{
|
||||
getTestPair(1, 0): {},
|
||||
},
|
||||
finalFailureReason: &reasonError,
|
||||
nodeFailure: &hops[1],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// TestResultInterpretation executes a list of test cases that test the result
|
||||
@ -192,7 +289,8 @@ func TestResultInterpretation(t *testing.T) {
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(i, expected) {
|
||||
t.Fatal("unexpected result")
|
||||
t.Fatalf("unexpected result\nwant: %v\ngot: %v",
|
||||
spew.Sdump(expected), spew.Sdump(i))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -162,6 +162,7 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
|
||||
var (
|
||||
typ Type
|
||||
min Type
|
||||
firstFail *Type
|
||||
recordIdx int
|
||||
overflow bool
|
||||
)
|
||||
@ -176,7 +177,10 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
|
||||
// We'll silence an EOF when zero bytes remain, meaning the
|
||||
// stream was cleanly encoded.
|
||||
case err == io.EOF:
|
||||
if firstFail == nil {
|
||||
return parsedTypes, nil
|
||||
}
|
||||
return parsedTypes, ErrUnknownRequiredType(*firstFail)
|
||||
|
||||
// Other unexpected errors.
|
||||
case err != nil:
|
||||
@ -243,7 +247,27 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
|
||||
// This record type is unknown to the stream, fail if the type
|
||||
// is even meaning that we are required to understand it.
|
||||
case typ%2 == 0:
|
||||
// We'll fail immediately in the case that we aren't
|
||||
// tracking the set of parsed types.
|
||||
if parsedTypes == nil {
|
||||
return nil, ErrUnknownRequiredType(typ)
|
||||
}
|
||||
|
||||
// Otherwise, we'll track the first such failure and
|
||||
// allow parsing to continue. If no other types of
|
||||
// errors are encountered, the first failure will be
|
||||
// returned as an ErrUnknownRequiredType so that the
|
||||
// full set of included types can be returned.
|
||||
if firstFail == nil {
|
||||
failTyp := typ
|
||||
firstFail = &failTyp
|
||||
}
|
||||
|
||||
// With the failure type recorded, we'll simply discard
|
||||
// the remainder of the record as if it were optional.
|
||||
// The first failure will be returned after reaching the
|
||||
// stopping condition.
|
||||
fallthrough
|
||||
|
||||
// Otherwise, the record type is unknown and is odd, discard the
|
||||
// number of bytes specified by length.
|
||||
|
@ -2,50 +2,106 @@ package tlv_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
type parsedTypeTest struct {
|
||||
name string
|
||||
encode []tlv.Type
|
||||
decode []tlv.Type
|
||||
expErr error
|
||||
}
|
||||
|
||||
// TestParsedTypes asserts that a Stream will properly return the set of types
|
||||
// that it encounters when the type is known-and-decoded or unknown-and-ignored.
|
||||
func TestParsedTypes(t *testing.T) {
|
||||
const (
|
||||
firstReqType = 0
|
||||
knownType = 1
|
||||
unknownType = 3
|
||||
secondReqType = 4
|
||||
)
|
||||
|
||||
// Construct a stream that will encode two types, one that will be known
|
||||
// to the decoder and another that will be unknown.
|
||||
encStream := tlv.MustNewStream(
|
||||
tlv.MakePrimitiveRecord(knownType, new(uint64)),
|
||||
tlv.MakePrimitiveRecord(unknownType, new(uint64)),
|
||||
tests := []parsedTypeTest{
|
||||
{
|
||||
name: "known optional and unknown optional",
|
||||
encode: []tlv.Type{knownType, unknownType},
|
||||
decode: []tlv.Type{knownType},
|
||||
},
|
||||
{
|
||||
name: "unknown required and known optional",
|
||||
encode: []tlv.Type{firstReqType, knownType},
|
||||
decode: []tlv.Type{knownType},
|
||||
expErr: tlv.ErrUnknownRequiredType(firstReqType),
|
||||
},
|
||||
{
|
||||
name: "unknown required and unknown optional",
|
||||
encode: []tlv.Type{unknownType, secondReqType},
|
||||
expErr: tlv.ErrUnknownRequiredType(secondReqType),
|
||||
},
|
||||
{
|
||||
name: "unknown required and known required",
|
||||
encode: []tlv.Type{firstReqType, secondReqType},
|
||||
decode: []tlv.Type{secondReqType},
|
||||
expErr: tlv.ErrUnknownRequiredType(firstReqType),
|
||||
},
|
||||
{
|
||||
name: "two unknown required",
|
||||
encode: []tlv.Type{firstReqType, secondReqType},
|
||||
expErr: tlv.ErrUnknownRequiredType(firstReqType),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
testParsedTypes(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testParsedTypes(t *testing.T, test parsedTypeTest) {
|
||||
encRecords := make([]tlv.Record, 0, len(test.encode))
|
||||
for _, typ := range test.encode {
|
||||
encRecords = append(
|
||||
encRecords, tlv.MakePrimitiveRecord(typ, new(uint64)),
|
||||
)
|
||||
}
|
||||
|
||||
decRecords := make([]tlv.Record, 0, len(test.decode))
|
||||
for _, typ := range test.decode {
|
||||
decRecords = append(
|
||||
decRecords, tlv.MakePrimitiveRecord(typ, new(uint64)),
|
||||
)
|
||||
}
|
||||
|
||||
// Construct a stream that will encode the test's set of types.
|
||||
encStream := tlv.MustNewStream(encRecords...)
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := encStream.Encode(&b); err != nil {
|
||||
t.Fatalf("unable to encode stream: %v", err)
|
||||
}
|
||||
|
||||
// Create a stream that will parse only the known type.
|
||||
decStream := tlv.MustNewStream(
|
||||
tlv.MakePrimitiveRecord(knownType, new(uint64)),
|
||||
)
|
||||
// Create a stream that will parse a subset of the test's types.
|
||||
decStream := tlv.MustNewStream(decRecords...)
|
||||
|
||||
parsedTypes, err := decStream.DecodeWithParsedTypes(
|
||||
bytes.NewReader(b.Bytes()),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to decode stream: %v", err)
|
||||
if !reflect.DeepEqual(err, test.expErr) {
|
||||
t.Fatalf("error mismatch, want: %v got: %v", err, test.expErr)
|
||||
}
|
||||
|
||||
// Assert that both the known and unknown types are included in the set
|
||||
// of parsed types.
|
||||
if _, ok := parsedTypes[knownType]; !ok {
|
||||
t.Fatalf("known type %d should be in parsed types", knownType)
|
||||
// Assert that all encoded types are included in the set of parsed
|
||||
// types.
|
||||
for _, typ := range test.encode {
|
||||
if _, ok := parsedTypes[typ]; !ok {
|
||||
t.Fatalf("encoded type %d should be in parsed types",
|
||||
typ)
|
||||
}
|
||||
if _, ok := parsedTypes[unknownType]; !ok {
|
||||
t.Fatalf("unknown type %d should be in parsed types",
|
||||
unknownType)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user