htlcswitch/hop/payload: parse option_mpp

This commit is contained in:
Conner Fromknecht 2019-11-04 15:10:00 -08:00
parent 6d971e5113
commit 4a6f5d8d3d
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7
2 changed files with 115 additions and 9 deletions

@ -81,6 +81,10 @@ type Payload struct {
// FwdInfo holds the basic parameters required for HTLC forwarding, e.g. // FwdInfo holds the basic parameters required for HTLC forwarding, e.g.
// amount, cltv, and next hop. // amount, cltv, and next hop.
FwdInfo ForwardingInfo FwdInfo ForwardingInfo
// MPP holds the info provided in an option_mpp record when parsed from
// a TLV onion payload.
MPP *record.MPP
} }
// NewLegacyPayload builds a Payload from the amount, cltv, and next hop // NewLegacyPayload builds a Payload from the amount, cltv, and next hop
@ -105,12 +109,14 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
cid uint64 cid uint64
amt uint64 amt uint64
cltv uint32 cltv uint32
mpp = &record.MPP{}
) )
tlvStream, err := tlv.NewStream( tlvStream, err := tlv.NewStream(
record.NewAmtToFwdRecord(&amt), record.NewAmtToFwdRecord(&amt),
record.NewLockTimeRecord(&cltv), record.NewLockTimeRecord(&cltv),
record.NewNextHopIDRecord(&cid), record.NewNextHopIDRecord(&cid),
mpp.Record(),
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -151,6 +157,12 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
return nil, err return nil, err
} }
// If no MPP field was parsed, set the MPP field on the resulting
// payload to nil.
if _, ok := parsedTypes[record.MPPOnionType]; !ok {
mpp = nil
}
return &Payload{ return &Payload{
FwdInfo: ForwardingInfo{ FwdInfo: ForwardingInfo{
Network: BitcoinNetwork, Network: BitcoinNetwork,
@ -158,6 +170,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
AmountToForward: lnwire.MilliSatoshi(amt), AmountToForward: lnwire.MilliSatoshi(amt),
OutgoingCTLV: cltv, OutgoingCTLV: cltv,
}, },
MPP: mpp,
}, nil }, nil
} }
@ -179,6 +192,7 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet,
_, hasAmt := parsedTypes[record.AmtOnionType] _, hasAmt := parsedTypes[record.AmtOnionType]
_, hasLockTime := parsedTypes[record.LockTimeOnionType] _, hasLockTime := parsedTypes[record.LockTimeOnionType]
_, hasNextHop := parsedTypes[record.NextHopOnionType] _, hasNextHop := parsedTypes[record.NextHopOnionType]
_, hasMPP := parsedTypes[record.MPPOnionType]
switch { switch {
@ -207,6 +221,14 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet,
Violation: IncludedViolation, Violation: IncludedViolation,
FinalHop: true, FinalHop: true,
} }
// Intermediate nodes should never receive MPP fields.
case !isFinalHop && hasMPP:
return ErrInvalidPayload{
Type: record.MPPOnionType,
Violation: IncludedViolation,
FinalHop: isFinalHop,
}
} }
return nil return nil

@ -6,6 +6,7 @@ import (
"testing" "testing"
"github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/record"
) )
@ -13,6 +14,7 @@ type decodePayloadTest struct {
name string name string
payload []byte payload []byte
expErr error expErr error
shouldHaveMPP bool
} }
var decodePayloadTests = []decodePayloadTest{ var decodePayloadTests = []decodePayloadTest{
@ -79,9 +81,9 @@ var decodePayloadTests = []decodePayloadTest{
}, },
{ {
name: "required type after omitted hop id", name: "required type after omitted hop id",
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x08, 0x00}, payload: []byte{0x02, 0x00, 0x04, 0x00, 0x0a, 0x00},
expErr: hop.ErrInvalidPayload{ expErr: hop.ErrInvalidPayload{
Type: 8, Type: 10,
Violation: hop.RequiredViolation, Violation: hop.RequiredViolation,
FinalHop: true, FinalHop: true,
}, },
@ -89,10 +91,10 @@ var decodePayloadTests = []decodePayloadTest{
{ {
name: "required type after included hop id", name: "required type after included hop id",
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00, payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0a, 0x00,
}, },
expErr: hop.ErrInvalidPayload{ expErr: hop.ErrInvalidPayload{
Type: 8, Type: 10,
Violation: hop.RequiredViolation, Violation: hop.RequiredViolation,
FinalHop: false, FinalHop: false,
}, },
@ -112,7 +114,7 @@ var decodePayloadTests = []decodePayloadTest{
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}, },
expErr: hop.ErrInvalidPayload{ expErr: hop.ErrInvalidPayload{
Type: 6, Type: record.NextHopOnionType,
Violation: hop.IncludedViolation, Violation: hop.IncludedViolation,
FinalHop: true, FinalHop: true,
}, },
@ -128,6 +130,60 @@ var decodePayloadTests = []decodePayloadTest{
FinalHop: false, FinalHop: false,
}, },
}, },
{
name: "valid intermediate hop",
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
},
expErr: nil,
},
{
name: "valid final hop",
payload: []byte{0x02, 0x00, 0x04, 0x00},
expErr: nil,
},
{
name: "intermediate hop with mpp",
payload: []byte{
// amount
0x02, 0x00,
// cltv
0x04, 0x00,
// next hop id
0x06, 0x08,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
// mpp
0x08, 0x21,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x08,
},
expErr: hop.ErrInvalidPayload{
Type: record.MPPOnionType,
Violation: hop.IncludedViolation,
FinalHop: false,
},
},
{
name: "final hop with mpp",
payload: []byte{
// amount
0x02, 0x00,
// cltv
0x04, 0x00,
// mpp
0x08, 0x21,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x08,
},
expErr: nil,
shouldHaveMPP: true,
},
} }
// TestDecodeHopPayloadRecordValidation asserts that parsing the payloads in the // TestDecodeHopPayloadRecordValidation asserts that parsing the payloads in the
@ -142,9 +198,37 @@ func TestDecodeHopPayloadRecordValidation(t *testing.T) {
} }
func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) { func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) {
_, err := hop.NewPayloadFromReader(bytes.NewReader(test.payload)) var (
testTotalMsat = lnwire.MilliSatoshi(8)
testAddr = [32]byte{
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
}
)
p, err := hop.NewPayloadFromReader(bytes.NewReader(test.payload))
if !reflect.DeepEqual(test.expErr, err) { if !reflect.DeepEqual(test.expErr, err) {
t.Fatalf("expected error mismatch, want: %v, got: %v", t.Fatalf("expected error mismatch, want: %v, got: %v",
test.expErr, err) test.expErr, err)
} }
if err != nil {
return
}
// Assert MPP fields if we expect them.
if test.shouldHaveMPP {
if p.MPP == nil {
t.Fatalf("payload should have MPP record")
}
if p.MPP.TotalMsat() != testTotalMsat {
t.Fatalf("invalid total msat")
}
if p.MPP.PaymentAddr() != testAddr {
t.Fatalf("invalid payment addr")
}
} else if p.MPP != nil {
t.Fatalf("unexpected MPP payload")
}
} }