Merge pull request #3470 from cfromknecht/invalid-onion-payload
htlcswitch+lnwire: invalid onion payload
This commit is contained in:
commit
acd8a6e302
@ -5,21 +5,56 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/lightningnetwork/lightning-onion"
|
sphinx "github.com/lightningnetwork/lightning-onion"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
"github.com/lightningnetwork/lnd/record"
|
"github.com/lightningnetwork/lnd/record"
|
||||||
"github.com/lightningnetwork/lnd/tlv"
|
"github.com/lightningnetwork/lnd/tlv"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// PayloadViolation is an enum encapsulating the possible invalid payload
|
||||||
|
// violations that can occur when processing or validating a payload.
|
||||||
|
type PayloadViolation byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
// OmittedViolation indicates that a type was expected to be found the
|
||||||
|
// payload but was absent.
|
||||||
|
OmittedViolation PayloadViolation = iota
|
||||||
|
|
||||||
|
// IncludedViolation indicates that a type was expected to be omitted
|
||||||
|
// from the payload but was present.
|
||||||
|
IncludedViolation
|
||||||
|
|
||||||
|
// RequiredViolation indicates that an unknown even type was found in
|
||||||
|
// the payload that we could not process.
|
||||||
|
RequiredViolation
|
||||||
|
)
|
||||||
|
|
||||||
|
// String returns a human-readable description of the violation as a verb.
|
||||||
|
func (v PayloadViolation) String() string {
|
||||||
|
switch v {
|
||||||
|
case OmittedViolation:
|
||||||
|
return "omitted"
|
||||||
|
|
||||||
|
case IncludedViolation:
|
||||||
|
return "included"
|
||||||
|
|
||||||
|
case RequiredViolation:
|
||||||
|
return "required"
|
||||||
|
|
||||||
|
default:
|
||||||
|
return "unknown violation"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ErrInvalidPayload is an error returned when a parsed onion payload either
|
// ErrInvalidPayload is an error returned when a parsed onion payload either
|
||||||
// included or omitted incorrect records for a particular hop type.
|
// included or omitted incorrect records for a particular hop type.
|
||||||
type ErrInvalidPayload struct {
|
type ErrInvalidPayload struct {
|
||||||
// Type the record's type that cause the violation.
|
// Type the record's type that cause the violation.
|
||||||
Type tlv.Type
|
Type tlv.Type
|
||||||
|
|
||||||
// Ommitted if true, signals that the sender did not include the record.
|
// Violation is an enum indicating the type of violation detected in
|
||||||
// Otherwise, the sender included the record when it shouldn't have.
|
// processing Type.
|
||||||
Omitted bool
|
Violation PayloadViolation
|
||||||
|
|
||||||
// FinalHop if true, indicates that the violation is for the final hop
|
// FinalHop if true, indicates that the violation is for the final hop
|
||||||
// in the route (identified by next hop id), otherwise the violation is
|
// in the route (identified by next hop id), otherwise the violation is
|
||||||
@ -33,13 +68,9 @@ func (e ErrInvalidPayload) Error() string {
|
|||||||
if e.FinalHop {
|
if e.FinalHop {
|
||||||
hopType = "final"
|
hopType = "final"
|
||||||
}
|
}
|
||||||
violation := "included"
|
|
||||||
if e.Omitted {
|
|
||||||
violation = "omitted"
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprintf("onion payload for %s hop %s record with type %d",
|
return fmt.Sprintf("onion payload for %s hop %v record with type %d",
|
||||||
hopType, violation, e.Type)
|
hopType, e.Violation, e.Type)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Payload encapsulates all information delivered to a hop in an onion payload.
|
// Payload encapsulates all information delivered to a hop in an onion payload.
|
||||||
@ -87,13 +118,34 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
|||||||
|
|
||||||
parsedTypes, err := tlvStream.DecodeWithParsedTypes(r)
|
parsedTypes, err := tlvStream.DecodeWithParsedTypes(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Promote any required type failures into ErrInvalidPayload.
|
||||||
|
if e, required := err.(tlv.ErrUnknownRequiredType); required {
|
||||||
|
// If the parser returned an unknown required type
|
||||||
|
// failure, we'll first check that the payload is
|
||||||
|
// properly formed according to our known set of
|
||||||
|
// constraints. If an error is discovered, this
|
||||||
|
// overrides the required type failure.
|
||||||
|
nextHop := lnwire.NewShortChanIDFromInt(cid)
|
||||||
|
err = ValidateParsedPayloadTypes(parsedTypes, nextHop)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise the known constraints were applied
|
||||||
|
// successfully, report the invalid type failure
|
||||||
|
// returned by the parser.
|
||||||
|
return nil, ErrInvalidPayload{
|
||||||
|
Type: tlv.Type(e),
|
||||||
|
Violation: RequiredViolation,
|
||||||
|
FinalHop: nextHop == Exit,
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
nextHop := lnwire.NewShortChanIDFromInt(cid)
|
|
||||||
|
|
||||||
// Validate whether the sender properly included or omitted tlv records
|
// Validate whether the sender properly included or omitted tlv records
|
||||||
// in accordance with BOLT 04.
|
// in accordance with BOLT 04.
|
||||||
|
nextHop := lnwire.NewShortChanIDFromInt(cid)
|
||||||
err = ValidateParsedPayloadTypes(parsedTypes, nextHop)
|
err = ValidateParsedPayloadTypes(parsedTypes, nextHop)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -133,17 +185,17 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet,
|
|||||||
// All hops must include an amount to forward.
|
// All hops must include an amount to forward.
|
||||||
case !hasAmt:
|
case !hasAmt:
|
||||||
return ErrInvalidPayload{
|
return ErrInvalidPayload{
|
||||||
Type: record.AmtOnionType,
|
Type: record.AmtOnionType,
|
||||||
Omitted: true,
|
Violation: OmittedViolation,
|
||||||
FinalHop: isFinalHop,
|
FinalHop: isFinalHop,
|
||||||
}
|
}
|
||||||
|
|
||||||
// All hops must include a cltv expiry.
|
// All hops must include a cltv expiry.
|
||||||
case !hasLockTime:
|
case !hasLockTime:
|
||||||
return ErrInvalidPayload{
|
return ErrInvalidPayload{
|
||||||
Type: record.LockTimeOnionType,
|
Type: record.LockTimeOnionType,
|
||||||
Omitted: true,
|
Violation: OmittedViolation,
|
||||||
FinalHop: isFinalHop,
|
FinalHop: isFinalHop,
|
||||||
}
|
}
|
||||||
|
|
||||||
// The exit hop should omit the next hop id. If nextHop != Exit, the
|
// The exit hop should omit the next hop id. If nextHop != Exit, the
|
||||||
@ -151,9 +203,9 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet,
|
|||||||
// inclusion at intermediate hops directly.
|
// inclusion at intermediate hops directly.
|
||||||
case isFinalHop && hasNextHop:
|
case isFinalHop && hasNextHop:
|
||||||
return ErrInvalidPayload{
|
return ErrInvalidPayload{
|
||||||
Type: record.NextHopOnionType,
|
Type: record.NextHopOnionType,
|
||||||
Omitted: false,
|
Violation: IncludedViolation,
|
||||||
FinalHop: true,
|
FinalHop: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,13 +16,23 @@ type decodePayloadTest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var decodePayloadTests = []decodePayloadTest{
|
var decodePayloadTests = []decodePayloadTest{
|
||||||
|
{
|
||||||
|
name: "final hop valid",
|
||||||
|
payload: []byte{0x02, 0x00, 0x04, 0x00},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "intermediate hop valid",
|
||||||
|
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "final hop no amount",
|
name: "final hop no amount",
|
||||||
payload: []byte{0x04, 0x00},
|
payload: []byte{0x04, 0x00},
|
||||||
expErr: hop.ErrInvalidPayload{
|
expErr: hop.ErrInvalidPayload{
|
||||||
Type: record.AmtOnionType,
|
Type: record.AmtOnionType,
|
||||||
Omitted: true,
|
Violation: hop.OmittedViolation,
|
||||||
FinalHop: true,
|
FinalHop: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -31,18 +41,18 @@ var decodePayloadTests = []decodePayloadTest{
|
|||||||
0x00, 0x00, 0x00, 0x00,
|
0x00, 0x00, 0x00, 0x00,
|
||||||
},
|
},
|
||||||
expErr: hop.ErrInvalidPayload{
|
expErr: hop.ErrInvalidPayload{
|
||||||
Type: record.AmtOnionType,
|
Type: record.AmtOnionType,
|
||||||
Omitted: true,
|
Violation: hop.OmittedViolation,
|
||||||
FinalHop: false,
|
FinalHop: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "final hop no expiry",
|
name: "final hop no expiry",
|
||||||
payload: []byte{0x02, 0x00},
|
payload: []byte{0x02, 0x00},
|
||||||
expErr: hop.ErrInvalidPayload{
|
expErr: hop.ErrInvalidPayload{
|
||||||
Type: record.LockTimeOnionType,
|
Type: record.LockTimeOnionType,
|
||||||
Omitted: true,
|
Violation: hop.OmittedViolation,
|
||||||
FinalHop: true,
|
FinalHop: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -51,9 +61,9 @@ var decodePayloadTests = []decodePayloadTest{
|
|||||||
0x00, 0x00, 0x00, 0x00,
|
0x00, 0x00, 0x00, 0x00,
|
||||||
},
|
},
|
||||||
expErr: hop.ErrInvalidPayload{
|
expErr: hop.ErrInvalidPayload{
|
||||||
Type: record.LockTimeOnionType,
|
Type: record.LockTimeOnionType,
|
||||||
Omitted: true,
|
Violation: hop.OmittedViolation,
|
||||||
FinalHop: false,
|
FinalHop: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -62,9 +72,60 @@ var decodePayloadTests = []decodePayloadTest{
|
|||||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
},
|
},
|
||||||
expErr: hop.ErrInvalidPayload{
|
expErr: hop.ErrInvalidPayload{
|
||||||
Type: record.NextHopOnionType,
|
Type: record.NextHopOnionType,
|
||||||
Omitted: false,
|
Violation: hop.IncludedViolation,
|
||||||
FinalHop: true,
|
FinalHop: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "required type after omitted hop id",
|
||||||
|
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x08, 0x00},
|
||||||
|
expErr: hop.ErrInvalidPayload{
|
||||||
|
Type: 8,
|
||||||
|
Violation: hop.RequiredViolation,
|
||||||
|
FinalHop: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "required type after included hop id",
|
||||||
|
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00,
|
||||||
|
},
|
||||||
|
expErr: hop.ErrInvalidPayload{
|
||||||
|
Type: 8,
|
||||||
|
Violation: hop.RequiredViolation,
|
||||||
|
FinalHop: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "required type zero final hop",
|
||||||
|
payload: []byte{0x00, 0x00, 0x02, 0x00, 0x04, 0x00},
|
||||||
|
expErr: hop.ErrInvalidPayload{
|
||||||
|
Type: 0,
|
||||||
|
Violation: hop.RequiredViolation,
|
||||||
|
FinalHop: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "required type zero final hop zero sid",
|
||||||
|
payload: []byte{0x00, 0x00, 0x02, 0x00, 0x04, 0x00, 0x06, 0x08,
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
},
|
||||||
|
expErr: hop.ErrInvalidPayload{
|
||||||
|
Type: 6,
|
||||||
|
Violation: hop.IncludedViolation,
|
||||||
|
FinalHop: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "required type zero intermediate hop",
|
||||||
|
payload: []byte{0x00, 0x00, 0x02, 0x00, 0x04, 0x00, 0x06, 0x08,
|
||||||
|
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
},
|
||||||
|
expErr: hop.ErrInvalidPayload{
|
||||||
|
Type: 0,
|
||||||
|
Violation: hop.RequiredViolation,
|
||||||
|
FinalHop: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -2645,12 +2645,23 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
|
|||||||
fwdInfo, err := chanIterator.ForwardingInstructions()
|
fwdInfo, err := chanIterator.ForwardingInstructions()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If we're unable to process the onion payload, or we
|
// If we're unable to process the onion payload, or we
|
||||||
// we received malformed TLV stream, then we should
|
// received invalid onion payload failure, then we
|
||||||
// send an error back to the caller so the HTLC can be
|
// should send an error back to the caller so the HTLC
|
||||||
// canceled.
|
// can be canceled.
|
||||||
|
var failedType uint64
|
||||||
|
if e, ok := err.(hop.ErrInvalidPayload); ok {
|
||||||
|
failedType = uint64(e.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: currently none of the test unit infrastructure
|
||||||
|
// is setup to handle TLV payloads, so testing this
|
||||||
|
// would require implementing a separate mock iterator
|
||||||
|
// for TLV payloads that also supports injecting invalid
|
||||||
|
// payloads. Deferring this non-trival effort till a
|
||||||
|
// later date
|
||||||
l.sendHTLCError(
|
l.sendHTLCError(
|
||||||
pd.HtlcIndex,
|
pd.HtlcIndex,
|
||||||
lnwire.NewInvalidOnionVersion(onionBlob[:]),
|
lnwire.NewInvalidOnionPayload(failedType, 0),
|
||||||
obfuscator, pd.SourceRef,
|
obfuscator, pd.SourceRef,
|
||||||
)
|
)
|
||||||
needUpdate = true
|
needUpdate = true
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
|
|
||||||
"github.com/davecgh/go-spew/spew"
|
"github.com/davecgh/go-spew/spew"
|
||||||
"github.com/go-errors/errors"
|
"github.com/go-errors/errors"
|
||||||
|
"github.com/lightningnetwork/lnd/tlv"
|
||||||
)
|
)
|
||||||
|
|
||||||
// FailureMessage represents the onion failure object identified by its unique
|
// FailureMessage represents the onion failure object identified by its unique
|
||||||
@ -78,6 +79,7 @@ const (
|
|||||||
CodeFinalIncorrectCltvExpiry FailCode = 18
|
CodeFinalIncorrectCltvExpiry FailCode = 18
|
||||||
CodeFinalIncorrectHtlcAmount FailCode = 19
|
CodeFinalIncorrectHtlcAmount FailCode = 19
|
||||||
CodeExpiryTooFar FailCode = 21
|
CodeExpiryTooFar FailCode = 21
|
||||||
|
CodeInvalidOnionPayload = FlagPerm | 22
|
||||||
)
|
)
|
||||||
|
|
||||||
// String returns the string representation of the failure code.
|
// String returns the string representation of the failure code.
|
||||||
@ -149,6 +151,9 @@ func (c FailCode) String() string {
|
|||||||
case CodeExpiryTooFar:
|
case CodeExpiryTooFar:
|
||||||
return "ExpiryTooFar"
|
return "ExpiryTooFar"
|
||||||
|
|
||||||
|
case CodeInvalidOnionPayload:
|
||||||
|
return "InvalidOnionPayload"
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return "<unknown>"
|
return "<unknown>"
|
||||||
}
|
}
|
||||||
@ -1117,6 +1122,66 @@ func (f *FailExpiryTooFar) Error() string {
|
|||||||
return f.Code().String()
|
return f.Code().String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InvalidOnionPayload is returned if the hop could not process the TLV payload
|
||||||
|
// enclosed in the onion.
|
||||||
|
type InvalidOnionPayload struct {
|
||||||
|
// Type is the TLV type that caused the specific failure.
|
||||||
|
Type uint64
|
||||||
|
|
||||||
|
// Offset is the byte offset within the payload where the failure
|
||||||
|
// occurred.
|
||||||
|
Offset uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewInvalidOnionPayload initializes a new InvalidOnionPayload failure.
|
||||||
|
func NewInvalidOnionPayload(typ uint64, offset uint16) *InvalidOnionPayload {
|
||||||
|
return &InvalidOnionPayload{
|
||||||
|
Type: typ,
|
||||||
|
Offset: offset,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Code returns the failure unique code.
|
||||||
|
//
|
||||||
|
// NOTE: Part of the FailureMessage interface.
|
||||||
|
func (f *InvalidOnionPayload) Code() FailCode {
|
||||||
|
return CodeInvalidOnionPayload
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a human readable string describing the target FailureMessage.
|
||||||
|
//
|
||||||
|
// NOTE: Implements the error interface.
|
||||||
|
func (f *InvalidOnionPayload) Error() string {
|
||||||
|
return fmt.Sprintf("%v(type=%v, offset=%d)",
|
||||||
|
f.Code(), f.Type, f.Offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode decodes the failure from bytes stream.
|
||||||
|
//
|
||||||
|
// NOTE: Part of the Serializable interface.
|
||||||
|
func (f *InvalidOnionPayload) Decode(r io.Reader, pver uint32) error {
|
||||||
|
var buf [8]byte
|
||||||
|
typ, err := tlv.ReadVarInt(r, &buf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
f.Type = typ
|
||||||
|
|
||||||
|
return ReadElements(r, &f.Offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode writes the failure in bytes stream.
|
||||||
|
//
|
||||||
|
// NOTE: Part of the Serializable interface.
|
||||||
|
func (f *InvalidOnionPayload) Encode(w io.Writer, pver uint32) error {
|
||||||
|
var buf [8]byte
|
||||||
|
if err := tlv.WriteVarInt(w, f.Type, &buf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return WriteElements(w, f.Offset)
|
||||||
|
}
|
||||||
|
|
||||||
// DecodeFailure decodes, validates, and parses the lnwire onion failure, for
|
// DecodeFailure decodes, validates, and parses the lnwire onion failure, for
|
||||||
// the provided protocol version.
|
// the provided protocol version.
|
||||||
func DecodeFailure(r io.Reader, pver uint32) (FailureMessage, error) {
|
func DecodeFailure(r io.Reader, pver uint32) (FailureMessage, error) {
|
||||||
@ -1298,6 +1363,9 @@ func makeEmptyOnionError(code FailCode) (FailureMessage, error) {
|
|||||||
case CodeExpiryTooFar:
|
case CodeExpiryTooFar:
|
||||||
return &FailExpiryTooFar{}, nil
|
return &FailExpiryTooFar{}, nil
|
||||||
|
|
||||||
|
case CodeInvalidOnionPayload:
|
||||||
|
return &InvalidOnionPayload{}, nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return nil, errors.Errorf("unknown error code: %v", code)
|
return nil, errors.Errorf("unknown error code: %v", code)
|
||||||
}
|
}
|
||||||
|
@ -16,6 +16,8 @@ var (
|
|||||||
testAmount = MilliSatoshi(1)
|
testAmount = MilliSatoshi(1)
|
||||||
testCtlvExpiry = uint32(2)
|
testCtlvExpiry = uint32(2)
|
||||||
testFlags = uint16(2)
|
testFlags = uint16(2)
|
||||||
|
testType = uint64(3)
|
||||||
|
testOffset = uint16(24)
|
||||||
sig, _ = NewSigFromSignature(testSig)
|
sig, _ = NewSigFromSignature(testSig)
|
||||||
testChannelUpdate = ChannelUpdate{
|
testChannelUpdate = ChannelUpdate{
|
||||||
Signature: sig,
|
Signature: sig,
|
||||||
@ -50,6 +52,7 @@ var onionFailures = []FailureMessage{
|
|||||||
NewChannelDisabled(testFlags, testChannelUpdate),
|
NewChannelDisabled(testFlags, testChannelUpdate),
|
||||||
NewFinalIncorrectCltvExpiry(testCtlvExpiry),
|
NewFinalIncorrectCltvExpiry(testCtlvExpiry),
|
||||||
NewFinalIncorrectHtlcAmount(testAmount),
|
NewFinalIncorrectHtlcAmount(testAmount),
|
||||||
|
NewInvalidOnionPayload(testType, testOffset),
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestEncodeDecodeCode tests the ability of onion errors to be properly encoded
|
// TestEncodeDecodeCode tests the ability of onion errors to be properly encoded
|
||||||
|
@ -265,7 +265,17 @@ func (i *interpretedResult) processPaymentOutcomeIntermediate(
|
|||||||
|
|
||||||
// All nodes up to the failing pair must have forwarded
|
// All nodes up to the failing pair must have forwarded
|
||||||
// successfully.
|
// successfully.
|
||||||
if errorSourceIdx > 2 {
|
if errorSourceIdx > 1 {
|
||||||
|
i.successPairRange(route, 0, errorSourceIdx-2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reportNode := func() {
|
||||||
|
// Fail only the node that reported the failure.
|
||||||
|
i.failNode(route, errorSourceIdx)
|
||||||
|
|
||||||
|
// Other preceding channels in the route forwarded correctly.
|
||||||
|
if errorSourceIdx > 1 {
|
||||||
i.successPairRange(route, 0, errorSourceIdx-2)
|
i.successPairRange(route, 0, errorSourceIdx-2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -302,6 +312,14 @@ func (i *interpretedResult) processPaymentOutcomeIntermediate(
|
|||||||
|
|
||||||
reportOutgoing()
|
reportOutgoing()
|
||||||
|
|
||||||
|
// If InvalidOnionPayload is received, we penalize only the reporting
|
||||||
|
// node. We know the preceding hop didn't corrupt the onion, since the
|
||||||
|
// reporting node is able to send the failure. We assume that we
|
||||||
|
// constructed a valid onion payload and that the failure is most likely
|
||||||
|
// an unknown required type or a bug in their implementation.
|
||||||
|
case *lnwire.InvalidOnionPayload:
|
||||||
|
reportNode()
|
||||||
|
|
||||||
// If the next hop in the route wasn't known or offline, we'll only
|
// If the next hop in the route wasn't known or offline, we'll only
|
||||||
// penalize the channel set which we attempted to route over. This is
|
// penalize the channel set which we attempted to route over. This is
|
||||||
// conservative, and it can handle faulty channels between nodes
|
// conservative, and it can handle faulty channels between nodes
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/davecgh/go-spew/spew"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
|
|
||||||
"github.com/lightningnetwork/lnd/routing/route"
|
"github.com/lightningnetwork/lnd/routing/route"
|
||||||
@ -47,6 +48,11 @@ func getTestPair(from, to int) DirectedNodePair {
|
|||||||
return NewDirectedNodePair(hops[from], hops[to])
|
return NewDirectedNodePair(hops[from], hops[to])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getPolicyFailure(from, to int) *DirectedNodePair {
|
||||||
|
pair := getTestPair(from, to)
|
||||||
|
return &pair
|
||||||
|
}
|
||||||
|
|
||||||
type resultTestCase struct {
|
type resultTestCase struct {
|
||||||
name string
|
name string
|
||||||
route *route.Route
|
route *route.Route
|
||||||
@ -169,6 +175,97 @@ var resultTestCases = []resultTestCase{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
|
// Tests that a fee insufficient failure to an intermediate hop with
|
||||||
|
// index 2 results in the first hop marked as success, and then a
|
||||||
|
// bidirectional failure for the incoming channel. It should also result
|
||||||
|
// in a policy failure for the outgoing hop.
|
||||||
|
{
|
||||||
|
name: "fail fee insufficient intermediate",
|
||||||
|
route: &routeFourHop,
|
||||||
|
failureSrcIdx: 2,
|
||||||
|
failure: lnwire.NewFeeInsufficient(0, lnwire.ChannelUpdate{}),
|
||||||
|
|
||||||
|
expectedResult: &interpretedResult{
|
||||||
|
pairResults: map[DirectedNodePair]pairResult{
|
||||||
|
getTestPair(0, 1): {
|
||||||
|
success: true,
|
||||||
|
},
|
||||||
|
getTestPair(1, 2): {},
|
||||||
|
getTestPair(2, 1): {},
|
||||||
|
},
|
||||||
|
policyFailure: getPolicyFailure(2, 3),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
// Tests an invalid onion payload from a final hop. The final hop should
|
||||||
|
// be failed while the proceeding hops are reproed as successes. The
|
||||||
|
// failure is terminal since the receiver can't process our onion.
|
||||||
|
{
|
||||||
|
name: "fail invalid onion payload final hop",
|
||||||
|
route: &routeFourHop,
|
||||||
|
failureSrcIdx: 4,
|
||||||
|
failure: lnwire.NewInvalidOnionPayload(0, 0),
|
||||||
|
|
||||||
|
expectedResult: &interpretedResult{
|
||||||
|
pairResults: map[DirectedNodePair]pairResult{
|
||||||
|
getTestPair(0, 1): {
|
||||||
|
success: true,
|
||||||
|
},
|
||||||
|
getTestPair(1, 2): {
|
||||||
|
success: true,
|
||||||
|
},
|
||||||
|
getTestPair(2, 3): {
|
||||||
|
success: true,
|
||||||
|
},
|
||||||
|
getTestPair(4, 3): {},
|
||||||
|
},
|
||||||
|
finalFailureReason: &reasonError,
|
||||||
|
nodeFailure: &hops[4],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
// Tests an invalid onion payload from an intermediate hop. Only the
|
||||||
|
// reporting node should be failed. The failure is non-terminal since we
|
||||||
|
// can still try other paths.
|
||||||
|
{
|
||||||
|
name: "fail invalid onion payload intermediate",
|
||||||
|
route: &routeFourHop,
|
||||||
|
failureSrcIdx: 3,
|
||||||
|
failure: lnwire.NewInvalidOnionPayload(0, 0),
|
||||||
|
|
||||||
|
expectedResult: &interpretedResult{
|
||||||
|
pairResults: map[DirectedNodePair]pairResult{
|
||||||
|
getTestPair(0, 1): {
|
||||||
|
success: true,
|
||||||
|
},
|
||||||
|
getTestPair(1, 2): {
|
||||||
|
success: true,
|
||||||
|
},
|
||||||
|
getTestPair(3, 2): {},
|
||||||
|
getTestPair(3, 4): {},
|
||||||
|
},
|
||||||
|
nodeFailure: &hops[3],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
// Tests an invalid onion payload in a direct peer that is also the
|
||||||
|
// final hop. The final node should be failed and the error is terminal
|
||||||
|
// since the remote node can't process our onion.
|
||||||
|
{
|
||||||
|
name: "fail invalid onion payload direct",
|
||||||
|
route: &routeOneHop,
|
||||||
|
failureSrcIdx: 1,
|
||||||
|
failure: lnwire.NewInvalidOnionPayload(0, 0),
|
||||||
|
|
||||||
|
expectedResult: &interpretedResult{
|
||||||
|
pairResults: map[DirectedNodePair]pairResult{
|
||||||
|
getTestPair(1, 0): {},
|
||||||
|
},
|
||||||
|
finalFailureReason: &reasonError,
|
||||||
|
nodeFailure: &hops[1],
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestResultInterpretation executes a list of test cases that test the result
|
// TestResultInterpretation executes a list of test cases that test the result
|
||||||
@ -192,7 +289,8 @@ func TestResultInterpretation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(i, expected) {
|
if !reflect.DeepEqual(i, expected) {
|
||||||
t.Fatal("unexpected result")
|
t.Fatalf("unexpected result\nwant: %v\ngot: %v",
|
||||||
|
spew.Sdump(expected), spew.Sdump(i))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -162,6 +162,7 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
|
|||||||
var (
|
var (
|
||||||
typ Type
|
typ Type
|
||||||
min Type
|
min Type
|
||||||
|
firstFail *Type
|
||||||
recordIdx int
|
recordIdx int
|
||||||
overflow bool
|
overflow bool
|
||||||
)
|
)
|
||||||
@ -176,7 +177,10 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
|
|||||||
// We'll silence an EOF when zero bytes remain, meaning the
|
// We'll silence an EOF when zero bytes remain, meaning the
|
||||||
// stream was cleanly encoded.
|
// stream was cleanly encoded.
|
||||||
case err == io.EOF:
|
case err == io.EOF:
|
||||||
return parsedTypes, nil
|
if firstFail == nil {
|
||||||
|
return parsedTypes, nil
|
||||||
|
}
|
||||||
|
return parsedTypes, ErrUnknownRequiredType(*firstFail)
|
||||||
|
|
||||||
// Other unexpected errors.
|
// Other unexpected errors.
|
||||||
case err != nil:
|
case err != nil:
|
||||||
@ -243,7 +247,27 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
|
|||||||
// This record type is unknown to the stream, fail if the type
|
// This record type is unknown to the stream, fail if the type
|
||||||
// is even meaning that we are required to understand it.
|
// is even meaning that we are required to understand it.
|
||||||
case typ%2 == 0:
|
case typ%2 == 0:
|
||||||
return nil, ErrUnknownRequiredType(typ)
|
// We'll fail immediately in the case that we aren't
|
||||||
|
// tracking the set of parsed types.
|
||||||
|
if parsedTypes == nil {
|
||||||
|
return nil, ErrUnknownRequiredType(typ)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, we'll track the first such failure and
|
||||||
|
// allow parsing to continue. If no other types of
|
||||||
|
// errors are encountered, the first failure will be
|
||||||
|
// returned as an ErrUnknownRequiredType so that the
|
||||||
|
// full set of included types can be returned.
|
||||||
|
if firstFail == nil {
|
||||||
|
failTyp := typ
|
||||||
|
firstFail = &failTyp
|
||||||
|
}
|
||||||
|
|
||||||
|
// With the failure type recorded, we'll simply discard
|
||||||
|
// the remainder of the record as if it were optional.
|
||||||
|
// The first failure will be returned after reaching the
|
||||||
|
// stopping condition.
|
||||||
|
fallthrough
|
||||||
|
|
||||||
// Otherwise, the record type is unknown and is odd, discard the
|
// Otherwise, the record type is unknown and is odd, discard the
|
||||||
// number of bytes specified by length.
|
// number of bytes specified by length.
|
||||||
|
@ -2,50 +2,106 @@ package tlv_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/lightningnetwork/lnd/tlv"
|
"github.com/lightningnetwork/lnd/tlv"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type parsedTypeTest struct {
|
||||||
|
name string
|
||||||
|
encode []tlv.Type
|
||||||
|
decode []tlv.Type
|
||||||
|
expErr error
|
||||||
|
}
|
||||||
|
|
||||||
// TestParsedTypes asserts that a Stream will properly return the set of types
|
// TestParsedTypes asserts that a Stream will properly return the set of types
|
||||||
// that it encounters when the type is known-and-decoded or unknown-and-ignored.
|
// that it encounters when the type is known-and-decoded or unknown-and-ignored.
|
||||||
func TestParsedTypes(t *testing.T) {
|
func TestParsedTypes(t *testing.T) {
|
||||||
const (
|
const (
|
||||||
knownType = 1
|
firstReqType = 0
|
||||||
unknownType = 3
|
knownType = 1
|
||||||
|
unknownType = 3
|
||||||
|
secondReqType = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
// Construct a stream that will encode two types, one that will be known
|
tests := []parsedTypeTest{
|
||||||
// to the decoder and another that will be unknown.
|
{
|
||||||
encStream := tlv.MustNewStream(
|
name: "known optional and unknown optional",
|
||||||
tlv.MakePrimitiveRecord(knownType, new(uint64)),
|
encode: []tlv.Type{knownType, unknownType},
|
||||||
tlv.MakePrimitiveRecord(unknownType, new(uint64)),
|
decode: []tlv.Type{knownType},
|
||||||
)
|
},
|
||||||
|
{
|
||||||
|
name: "unknown required and known optional",
|
||||||
|
encode: []tlv.Type{firstReqType, knownType},
|
||||||
|
decode: []tlv.Type{knownType},
|
||||||
|
expErr: tlv.ErrUnknownRequiredType(firstReqType),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown required and unknown optional",
|
||||||
|
encode: []tlv.Type{unknownType, secondReqType},
|
||||||
|
expErr: tlv.ErrUnknownRequiredType(secondReqType),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown required and known required",
|
||||||
|
encode: []tlv.Type{firstReqType, secondReqType},
|
||||||
|
decode: []tlv.Type{secondReqType},
|
||||||
|
expErr: tlv.ErrUnknownRequiredType(firstReqType),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two unknown required",
|
||||||
|
encode: []tlv.Type{firstReqType, secondReqType},
|
||||||
|
expErr: tlv.ErrUnknownRequiredType(firstReqType),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
test := test
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
testParsedTypes(t, test)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testParsedTypes(t *testing.T, test parsedTypeTest) {
|
||||||
|
encRecords := make([]tlv.Record, 0, len(test.encode))
|
||||||
|
for _, typ := range test.encode {
|
||||||
|
encRecords = append(
|
||||||
|
encRecords, tlv.MakePrimitiveRecord(typ, new(uint64)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
decRecords := make([]tlv.Record, 0, len(test.decode))
|
||||||
|
for _, typ := range test.decode {
|
||||||
|
decRecords = append(
|
||||||
|
decRecords, tlv.MakePrimitiveRecord(typ, new(uint64)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct a stream that will encode the test's set of types.
|
||||||
|
encStream := tlv.MustNewStream(encRecords...)
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := encStream.Encode(&b); err != nil {
|
if err := encStream.Encode(&b); err != nil {
|
||||||
t.Fatalf("unable to encode stream: %v", err)
|
t.Fatalf("unable to encode stream: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a stream that will parse only the known type.
|
// Create a stream that will parse a subset of the test's types.
|
||||||
decStream := tlv.MustNewStream(
|
decStream := tlv.MustNewStream(decRecords...)
|
||||||
tlv.MakePrimitiveRecord(knownType, new(uint64)),
|
|
||||||
)
|
|
||||||
|
|
||||||
parsedTypes, err := decStream.DecodeWithParsedTypes(
|
parsedTypes, err := decStream.DecodeWithParsedTypes(
|
||||||
bytes.NewReader(b.Bytes()),
|
bytes.NewReader(b.Bytes()),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if !reflect.DeepEqual(err, test.expErr) {
|
||||||
t.Fatalf("unable to decode stream: %v", err)
|
t.Fatalf("error mismatch, want: %v got: %v", err, test.expErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assert that both the known and unknown types are included in the set
|
// Assert that all encoded types are included in the set of parsed
|
||||||
// of parsed types.
|
// types.
|
||||||
if _, ok := parsedTypes[knownType]; !ok {
|
for _, typ := range test.encode {
|
||||||
t.Fatalf("known type %d should be in parsed types", knownType)
|
if _, ok := parsedTypes[typ]; !ok {
|
||||||
}
|
t.Fatalf("encoded type %d should be in parsed types",
|
||||||
if _, ok := parsedTypes[unknownType]; !ok {
|
typ)
|
||||||
t.Fatalf("unknown type %d should be in parsed types",
|
}
|
||||||
unknownType)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user