diff --git a/htlcswitch/hop/payload.go b/htlcswitch/hop/payload.go index 075f9c7d..2edd9aa8 100644 --- a/htlcswitch/hop/payload.go +++ b/htlcswitch/hop/payload.go @@ -5,7 +5,7 @@ import ( "fmt" "io" - "github.com/lightningnetwork/lightning-onion" + sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/tlv" @@ -120,34 +120,32 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { if err != nil { // Promote any required type failures into ErrInvalidPayload. if e, required := err.(tlv.ErrUnknownRequiredType); required { - // NOTE: Sigh. If the sender included a next hop whose - // value is zero, this would be considered invalid by - // our validation rules below. It's not totally clear - // whether this required failure should take precedence - // over the constraints applied by known types. - // Unfortunately this is an artifact of the layering - // violation in placing the even/odd rule in the parsing - // logic and not at a higher level of validation like - // the other presence/omission checks. - // - // As a result, this may need to be revisted if it is - // decided that the checks below overrule an unknown - // required type failure, in which case an - // IncludedViolation should be returned instead of the - // RequiredViolation. + // If the parser returned an unknown required type + // failure, we'll first check that the payload is + // properly formed according to our known set of + // constraints. If an error is discovered, this + // overrides the required type failure. + nextHop := lnwire.NewShortChanIDFromInt(cid) + err = ValidateParsedPayloadTypes(parsedTypes, nextHop) + if err != nil { + return nil, err + } + + // Otherwise the known constraints were applied + // successfully, report the invalid type failure + // returned by the parser. return nil, ErrInvalidPayload{ Type: tlv.Type(e), Violation: RequiredViolation, - FinalHop: cid == 0, + FinalHop: nextHop == Exit, } } return nil, err } - nextHop := lnwire.NewShortChanIDFromInt(cid) - // Validate whether the sender properly included or omitted tlv records // in accordance with BOLT 04. + nextHop := lnwire.NewShortChanIDFromInt(cid) err = ValidateParsedPayloadTypes(parsedTypes, nextHop) if err != nil { return nil, err diff --git a/htlcswitch/hop/payload_test.go b/htlcswitch/hop/payload_test.go index 0e442c7f..a49cd2b1 100644 --- a/htlcswitch/hop/payload_test.go +++ b/htlcswitch/hop/payload_test.go @@ -79,7 +79,7 @@ var decodePayloadTests = []decodePayloadTest{ }, { name: "required type after omitted hop id", - payload: []byte{0x08, 0x00}, + payload: []byte{0x02, 0x00, 0x04, 0x00, 0x08, 0x00}, expErr: hop.ErrInvalidPayload{ Type: 8, Violation: hop.RequiredViolation, @@ -88,8 +88,8 @@ var decodePayloadTests = []decodePayloadTest{ }, { name: "required type after included hop id", - payload: []byte{0x06, 0x08, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x08, 0x00, + payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, }, expErr: hop.ErrInvalidPayload{ Type: 8, @@ -99,7 +99,7 @@ var decodePayloadTests = []decodePayloadTest{ }, { name: "required type zero final hop", - payload: []byte{0x00, 0x00}, + payload: []byte{0x00, 0x00, 0x02, 0x00, 0x04, 0x00}, expErr: hop.ErrInvalidPayload{ Type: 0, Violation: hop.RequiredViolation, @@ -108,19 +108,19 @@ var decodePayloadTests = []decodePayloadTest{ }, { name: "required type zero final hop zero sid", - payload: []byte{0x00, 0x00, 0x06, 0x08, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, + payload: []byte{0x00, 0x00, 0x02, 0x00, 0x04, 0x00, 0x06, 0x08, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, expErr: hop.ErrInvalidPayload{ - Type: 0, - Violation: hop.RequiredViolation, + Type: 6, + Violation: hop.IncludedViolation, FinalHop: true, }, }, { name: "required type zero intermediate hop", - payload: []byte{0x00, 0x00, 0x06, 0x08, 0x01, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, + payload: []byte{0x00, 0x00, 0x02, 0x00, 0x04, 0x00, 0x06, 0x08, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, expErr: hop.ErrInvalidPayload{ Type: 0,