diff --git a/htlcswitch/hop/payload.go b/htlcswitch/hop/payload.go index 2fd1cd86..dd595e11 100644 --- a/htlcswitch/hop/payload.go +++ b/htlcswitch/hop/payload.go @@ -11,15 +11,50 @@ import ( "github.com/lightningnetwork/lnd/tlv" ) +// PayloadViolation is an enum encapsulating the possible invalid payload +// violations that can occur when processing or validating a payload. +type PayloadViolation byte + +const ( + // OmittedViolation indicates that a type was expected to be found the + // payload but was absent. + OmittedViolation PayloadViolation = iota + + // IncludedViolation indicates that a type was expected to be omitted + // from the payload but was present. + IncludedViolation + + // RequiredViolation indicates that an unknown even type was found in + // the payload that we could not process. + RequiredViolation +) + +// String returns a human-readable description of the violation as a verb. +func (v PayloadViolation) String() string { + switch v { + case OmittedViolation: + return "omitted" + + case IncludedViolation: + return "included" + + case RequiredViolation: + return "required" + + default: + return "unknown violation" + } +} + // ErrInvalidPayload is an error returned when a parsed onion payload either // included or omitted incorrect records for a particular hop type. type ErrInvalidPayload struct { // Type the record's type that cause the violation. Type tlv.Type - // Ommitted if true, signals that the sender did not include the record. - // Otherwise, the sender included the record when it shouldn't have. - Omitted bool + // Violation is an enum indicating the type of violation detected in + // processing Type. + Violation PayloadViolation // FinalHop if true, indicates that the violation is for the final hop // in the route (identified by next hop id), otherwise the violation is @@ -33,13 +68,9 @@ func (e ErrInvalidPayload) Error() string { if e.FinalHop { hopType = "final" } - violation := "included" - if e.Omitted { - violation = "omitted" - } - return fmt.Sprintf("onion payload for %s hop %s record with type %d", - hopType, violation, e.Type) + return fmt.Sprintf("onion payload for %s hop %v record with type %d", + hopType, e.Violation, e.Type) } // Payload encapsulates all information delivered to a hop in an onion payload. @@ -87,6 +118,18 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { parsedTypes, err := tlvStream.DecodeWithParsedTypes(r) if err != nil { + // Promote any required type failures into ErrInvalidPayload. + if e, required := err.(tlv.ErrUnknownRequiredType); required { + // NOTE: FinalHop will be incorrect if the unknown + // required was type 0. Otherwise, the failure must have + // occurred after type 6 and cid should contain an + // accurate value. + return nil, ErrInvalidPayload{ + Type: tlv.Type(e), + Violation: RequiredViolation, + FinalHop: cid == 0, + } + } return nil, err } @@ -133,17 +176,17 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet, // All hops must include an amount to forward. case !hasAmt: return ErrInvalidPayload{ - Type: record.AmtOnionType, - Omitted: true, - FinalHop: isFinalHop, + Type: record.AmtOnionType, + Violation: OmittedViolation, + FinalHop: isFinalHop, } // All hops must include a cltv expiry. case !hasLockTime: return ErrInvalidPayload{ - Type: record.LockTimeOnionType, - Omitted: true, - FinalHop: isFinalHop, + Type: record.LockTimeOnionType, + Violation: OmittedViolation, + FinalHop: isFinalHop, } // The exit hop should omit the next hop id. If nextHop != Exit, the @@ -151,9 +194,9 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet, // inclusion at intermediate hops directly. case isFinalHop && hasNextHop: return ErrInvalidPayload{ - Type: record.NextHopOnionType, - Omitted: false, - FinalHop: true, + Type: record.NextHopOnionType, + Violation: IncludedViolation, + FinalHop: true, } } diff --git a/htlcswitch/hop/payload_test.go b/htlcswitch/hop/payload_test.go index 0c9342d5..f11a6087 100644 --- a/htlcswitch/hop/payload_test.go +++ b/htlcswitch/hop/payload_test.go @@ -16,13 +16,23 @@ type decodePayloadTest struct { } var decodePayloadTests = []decodePayloadTest{ + { + name: "final hop valid", + payload: []byte{0x02, 0x00, 0x04, 0x00}, + }, + { + name: "intermediate hop valid", + payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + }, { name: "final hop no amount", payload: []byte{0x04, 0x00}, expErr: hop.ErrInvalidPayload{ - Type: record.AmtOnionType, - Omitted: true, - FinalHop: true, + Type: record.AmtOnionType, + Violation: hop.OmittedViolation, + FinalHop: true, }, }, { @@ -31,18 +41,18 @@ var decodePayloadTests = []decodePayloadTest{ 0x00, 0x00, 0x00, 0x00, }, expErr: hop.ErrInvalidPayload{ - Type: record.AmtOnionType, - Omitted: true, - FinalHop: false, + Type: record.AmtOnionType, + Violation: hop.OmittedViolation, + FinalHop: false, }, }, { name: "final hop no expiry", payload: []byte{0x02, 0x00}, expErr: hop.ErrInvalidPayload{ - Type: record.LockTimeOnionType, - Omitted: true, - FinalHop: true, + Type: record.LockTimeOnionType, + Violation: hop.OmittedViolation, + FinalHop: true, }, }, { @@ -51,9 +61,9 @@ var decodePayloadTests = []decodePayloadTest{ 0x00, 0x00, 0x00, 0x00, }, expErr: hop.ErrInvalidPayload{ - Type: record.LockTimeOnionType, - Omitted: true, - FinalHop: false, + Type: record.LockTimeOnionType, + Violation: hop.OmittedViolation, + FinalHop: false, }, }, { @@ -62,9 +72,38 @@ var decodePayloadTests = []decodePayloadTest{ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, expErr: hop.ErrInvalidPayload{ - Type: record.NextHopOnionType, - Omitted: false, - FinalHop: true, + Type: record.NextHopOnionType, + Violation: hop.IncludedViolation, + FinalHop: true, + }, + }, + { + name: "required type after omitted hop id", + payload: []byte{0x08, 0x00}, + expErr: hop.ErrInvalidPayload{ + Type: 8, + Violation: hop.RequiredViolation, + FinalHop: true, + }, + }, + { + name: "required type after included hop id", + payload: []byte{0x06, 0x08, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x08, 0x00, + }, + expErr: hop.ErrInvalidPayload{ + Type: 8, + Violation: hop.RequiredViolation, + FinalHop: false, + }, + }, + { + name: "required type zero", + payload: []byte{0x00, 0x00}, + expErr: hop.ErrInvalidPayload{ + Type: 0, + Violation: hop.RequiredViolation, + FinalHop: true, }, }, }