htlcswitch/hop/payload: parse option_mpp
This commit is contained in:
parent
6d971e5113
commit
4a6f5d8d3d
@ -81,6 +81,10 @@ type Payload struct {
|
|||||||
// FwdInfo holds the basic parameters required for HTLC forwarding, e.g.
|
// FwdInfo holds the basic parameters required for HTLC forwarding, e.g.
|
||||||
// amount, cltv, and next hop.
|
// amount, cltv, and next hop.
|
||||||
FwdInfo ForwardingInfo
|
FwdInfo ForwardingInfo
|
||||||
|
|
||||||
|
// MPP holds the info provided in an option_mpp record when parsed from
|
||||||
|
// a TLV onion payload.
|
||||||
|
MPP *record.MPP
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewLegacyPayload builds a Payload from the amount, cltv, and next hop
|
// NewLegacyPayload builds a Payload from the amount, cltv, and next hop
|
||||||
@ -105,12 +109,14 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
|||||||
cid uint64
|
cid uint64
|
||||||
amt uint64
|
amt uint64
|
||||||
cltv uint32
|
cltv uint32
|
||||||
|
mpp = &record.MPP{}
|
||||||
)
|
)
|
||||||
|
|
||||||
tlvStream, err := tlv.NewStream(
|
tlvStream, err := tlv.NewStream(
|
||||||
record.NewAmtToFwdRecord(&amt),
|
record.NewAmtToFwdRecord(&amt),
|
||||||
record.NewLockTimeRecord(&cltv),
|
record.NewLockTimeRecord(&cltv),
|
||||||
record.NewNextHopIDRecord(&cid),
|
record.NewNextHopIDRecord(&cid),
|
||||||
|
mpp.Record(),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -151,6 +157,12 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If no MPP field was parsed, set the MPP field on the resulting
|
||||||
|
// payload to nil.
|
||||||
|
if _, ok := parsedTypes[record.MPPOnionType]; !ok {
|
||||||
|
mpp = nil
|
||||||
|
}
|
||||||
|
|
||||||
return &Payload{
|
return &Payload{
|
||||||
FwdInfo: ForwardingInfo{
|
FwdInfo: ForwardingInfo{
|
||||||
Network: BitcoinNetwork,
|
Network: BitcoinNetwork,
|
||||||
@ -158,6 +170,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
|||||||
AmountToForward: lnwire.MilliSatoshi(amt),
|
AmountToForward: lnwire.MilliSatoshi(amt),
|
||||||
OutgoingCTLV: cltv,
|
OutgoingCTLV: cltv,
|
||||||
},
|
},
|
||||||
|
MPP: mpp,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -179,6 +192,7 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet,
|
|||||||
_, hasAmt := parsedTypes[record.AmtOnionType]
|
_, hasAmt := parsedTypes[record.AmtOnionType]
|
||||||
_, hasLockTime := parsedTypes[record.LockTimeOnionType]
|
_, hasLockTime := parsedTypes[record.LockTimeOnionType]
|
||||||
_, hasNextHop := parsedTypes[record.NextHopOnionType]
|
_, hasNextHop := parsedTypes[record.NextHopOnionType]
|
||||||
|
_, hasMPP := parsedTypes[record.MPPOnionType]
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
|
|
||||||
@ -207,6 +221,14 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet,
|
|||||||
Violation: IncludedViolation,
|
Violation: IncludedViolation,
|
||||||
FinalHop: true,
|
FinalHop: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Intermediate nodes should never receive MPP fields.
|
||||||
|
case !isFinalHop && hasMPP:
|
||||||
|
return ErrInvalidPayload{
|
||||||
|
Type: record.MPPOnionType,
|
||||||
|
Violation: IncludedViolation,
|
||||||
|
FinalHop: isFinalHop,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/lightningnetwork/lnd/htlcswitch/hop"
|
"github.com/lightningnetwork/lnd/htlcswitch/hop"
|
||||||
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
"github.com/lightningnetwork/lnd/record"
|
"github.com/lightningnetwork/lnd/record"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -13,6 +14,7 @@ type decodePayloadTest struct {
|
|||||||
name string
|
name string
|
||||||
payload []byte
|
payload []byte
|
||||||
expErr error
|
expErr error
|
||||||
|
shouldHaveMPP bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var decodePayloadTests = []decodePayloadTest{
|
var decodePayloadTests = []decodePayloadTest{
|
||||||
@ -79,9 +81,9 @@ var decodePayloadTests = []decodePayloadTest{
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "required type after omitted hop id",
|
name: "required type after omitted hop id",
|
||||||
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x08, 0x00},
|
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x0a, 0x00},
|
||||||
expErr: hop.ErrInvalidPayload{
|
expErr: hop.ErrInvalidPayload{
|
||||||
Type: 8,
|
Type: 10,
|
||||||
Violation: hop.RequiredViolation,
|
Violation: hop.RequiredViolation,
|
||||||
FinalHop: true,
|
FinalHop: true,
|
||||||
},
|
},
|
||||||
@ -89,10 +91,10 @@ var decodePayloadTests = []decodePayloadTest{
|
|||||||
{
|
{
|
||||||
name: "required type after included hop id",
|
name: "required type after included hop id",
|
||||||
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00,
|
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00,
|
||||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00,
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0a, 0x00,
|
||||||
},
|
},
|
||||||
expErr: hop.ErrInvalidPayload{
|
expErr: hop.ErrInvalidPayload{
|
||||||
Type: 8,
|
Type: 10,
|
||||||
Violation: hop.RequiredViolation,
|
Violation: hop.RequiredViolation,
|
||||||
FinalHop: false,
|
FinalHop: false,
|
||||||
},
|
},
|
||||||
@ -112,7 +114,7 @@ var decodePayloadTests = []decodePayloadTest{
|
|||||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
},
|
},
|
||||||
expErr: hop.ErrInvalidPayload{
|
expErr: hop.ErrInvalidPayload{
|
||||||
Type: 6,
|
Type: record.NextHopOnionType,
|
||||||
Violation: hop.IncludedViolation,
|
Violation: hop.IncludedViolation,
|
||||||
FinalHop: true,
|
FinalHop: true,
|
||||||
},
|
},
|
||||||
@ -128,6 +130,60 @@ var decodePayloadTests = []decodePayloadTest{
|
|||||||
FinalHop: false,
|
FinalHop: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "valid intermediate hop",
|
||||||
|
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
},
|
||||||
|
expErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid final hop",
|
||||||
|
payload: []byte{0x02, 0x00, 0x04, 0x00},
|
||||||
|
expErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "intermediate hop with mpp",
|
||||||
|
payload: []byte{
|
||||||
|
// amount
|
||||||
|
0x02, 0x00,
|
||||||
|
// cltv
|
||||||
|
0x04, 0x00,
|
||||||
|
// next hop id
|
||||||
|
0x06, 0x08,
|
||||||
|
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
// mpp
|
||||||
|
0x08, 0x21,
|
||||||
|
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||||
|
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||||
|
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||||
|
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||||
|
0x08,
|
||||||
|
},
|
||||||
|
expErr: hop.ErrInvalidPayload{
|
||||||
|
Type: record.MPPOnionType,
|
||||||
|
Violation: hop.IncludedViolation,
|
||||||
|
FinalHop: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "final hop with mpp",
|
||||||
|
payload: []byte{
|
||||||
|
// amount
|
||||||
|
0x02, 0x00,
|
||||||
|
// cltv
|
||||||
|
0x04, 0x00,
|
||||||
|
// mpp
|
||||||
|
0x08, 0x21,
|
||||||
|
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||||
|
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||||
|
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||||
|
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||||
|
0x08,
|
||||||
|
},
|
||||||
|
expErr: nil,
|
||||||
|
shouldHaveMPP: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestDecodeHopPayloadRecordValidation asserts that parsing the payloads in the
|
// TestDecodeHopPayloadRecordValidation asserts that parsing the payloads in the
|
||||||
@ -142,9 +198,37 @@ func TestDecodeHopPayloadRecordValidation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) {
|
func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) {
|
||||||
_, err := hop.NewPayloadFromReader(bytes.NewReader(test.payload))
|
var (
|
||||||
|
testTotalMsat = lnwire.MilliSatoshi(8)
|
||||||
|
testAddr = [32]byte{
|
||||||
|
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||||
|
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||||
|
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||||
|
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
p, err := hop.NewPayloadFromReader(bytes.NewReader(test.payload))
|
||||||
if !reflect.DeepEqual(test.expErr, err) {
|
if !reflect.DeepEqual(test.expErr, err) {
|
||||||
t.Fatalf("expected error mismatch, want: %v, got: %v",
|
t.Fatalf("expected error mismatch, want: %v, got: %v",
|
||||||
test.expErr, err)
|
test.expErr, err)
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert MPP fields if we expect them.
|
||||||
|
if test.shouldHaveMPP {
|
||||||
|
if p.MPP == nil {
|
||||||
|
t.Fatalf("payload should have MPP record")
|
||||||
|
}
|
||||||
|
if p.MPP.TotalMsat() != testTotalMsat {
|
||||||
|
t.Fatalf("invalid total msat")
|
||||||
|
}
|
||||||
|
if p.MPP.PaymentAddr() != testAddr {
|
||||||
|
t.Fatalf("invalid payment addr")
|
||||||
|
}
|
||||||
|
} else if p.MPP != nil {
|
||||||
|
t.Fatalf("unexpected MPP payload")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user