htlcswitch/hop/payload: parse option_mpp

This commit is contained in:
Conner Fromknecht 2019-11-04 15:10:00 -08:00
parent 6d971e5113
commit 4a6f5d8d3d
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7
2 changed files with 115 additions and 9 deletions

@ -81,6 +81,10 @@ type Payload struct {
// FwdInfo holds the basic parameters required for HTLC forwarding, e.g.
// amount, cltv, and next hop.
FwdInfo ForwardingInfo
// MPP holds the info provided in an option_mpp record when parsed from
// a TLV onion payload.
MPP *record.MPP
}
// NewLegacyPayload builds a Payload from the amount, cltv, and next hop
@ -105,12 +109,14 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
cid uint64
amt uint64
cltv uint32
mpp = &record.MPP{}
)
tlvStream, err := tlv.NewStream(
record.NewAmtToFwdRecord(&amt),
record.NewLockTimeRecord(&cltv),
record.NewNextHopIDRecord(&cid),
mpp.Record(),
)
if err != nil {
return nil, err
@ -151,6 +157,12 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
return nil, err
}
// If no MPP field was parsed, set the MPP field on the resulting
// payload to nil.
if _, ok := parsedTypes[record.MPPOnionType]; !ok {
mpp = nil
}
return &Payload{
FwdInfo: ForwardingInfo{
Network: BitcoinNetwork,
@ -158,6 +170,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
AmountToForward: lnwire.MilliSatoshi(amt),
OutgoingCTLV: cltv,
},
MPP: mpp,
}, nil
}
@ -179,6 +192,7 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet,
_, hasAmt := parsedTypes[record.AmtOnionType]
_, hasLockTime := parsedTypes[record.LockTimeOnionType]
_, hasNextHop := parsedTypes[record.NextHopOnionType]
_, hasMPP := parsedTypes[record.MPPOnionType]
switch {
@ -207,6 +221,14 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet,
Violation: IncludedViolation,
FinalHop: true,
}
// Intermediate nodes should never receive MPP fields.
case !isFinalHop && hasMPP:
return ErrInvalidPayload{
Type: record.MPPOnionType,
Violation: IncludedViolation,
FinalHop: isFinalHop,
}
}
return nil

@ -6,13 +6,15 @@ import (
"testing"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
)
type decodePayloadTest struct {
name string
payload []byte
expErr error
name string
payload []byte
expErr error
shouldHaveMPP bool
}
var decodePayloadTests = []decodePayloadTest{
@ -79,9 +81,9 @@ var decodePayloadTests = []decodePayloadTest{
},
{
name: "required type after omitted hop id",
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x08, 0x00},
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x0a, 0x00},
expErr: hop.ErrInvalidPayload{
Type: 8,
Type: 10,
Violation: hop.RequiredViolation,
FinalHop: true,
},
@ -89,10 +91,10 @@ var decodePayloadTests = []decodePayloadTest{
{
name: "required type after included hop id",
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0a, 0x00,
},
expErr: hop.ErrInvalidPayload{
Type: 8,
Type: 10,
Violation: hop.RequiredViolation,
FinalHop: false,
},
@ -112,7 +114,7 @@ var decodePayloadTests = []decodePayloadTest{
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
},
expErr: hop.ErrInvalidPayload{
Type: 6,
Type: record.NextHopOnionType,
Violation: hop.IncludedViolation,
FinalHop: true,
},
@ -128,6 +130,60 @@ var decodePayloadTests = []decodePayloadTest{
FinalHop: false,
},
},
{
name: "valid intermediate hop",
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
},
expErr: nil,
},
{
name: "valid final hop",
payload: []byte{0x02, 0x00, 0x04, 0x00},
expErr: nil,
},
{
name: "intermediate hop with mpp",
payload: []byte{
// amount
0x02, 0x00,
// cltv
0x04, 0x00,
// next hop id
0x06, 0x08,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
// mpp
0x08, 0x21,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x08,
},
expErr: hop.ErrInvalidPayload{
Type: record.MPPOnionType,
Violation: hop.IncludedViolation,
FinalHop: false,
},
},
{
name: "final hop with mpp",
payload: []byte{
// amount
0x02, 0x00,
// cltv
0x04, 0x00,
// mpp
0x08, 0x21,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x08,
},
expErr: nil,
shouldHaveMPP: true,
},
}
// TestDecodeHopPayloadRecordValidation asserts that parsing the payloads in the
@ -142,9 +198,37 @@ func TestDecodeHopPayloadRecordValidation(t *testing.T) {
}
func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) {
_, err := hop.NewPayloadFromReader(bytes.NewReader(test.payload))
var (
testTotalMsat = lnwire.MilliSatoshi(8)
testAddr = [32]byte{
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
}
)
p, err := hop.NewPayloadFromReader(bytes.NewReader(test.payload))
if !reflect.DeepEqual(test.expErr, err) {
t.Fatalf("expected error mismatch, want: %v, got: %v",
test.expErr, err)
}
if err != nil {
return
}
// Assert MPP fields if we expect them.
if test.shouldHaveMPP {
if p.MPP == nil {
t.Fatalf("payload should have MPP record")
}
if p.MPP.TotalMsat() != testTotalMsat {
t.Fatalf("invalid total msat")
}
if p.MPP.PaymentAddr() != testAddr {
t.Fatalf("invalid payment addr")
}
} else if p.MPP != nil {
t.Fatalf("unexpected MPP payload")
}
}