htlcswitch/hop: parse and validate AMP records
This commit is contained in:
parent
135a0a9f7f
commit
c2729cbbbd
@ -119,6 +119,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
|||||||
amt uint64
|
amt uint64
|
||||||
cltv uint32
|
cltv uint32
|
||||||
mpp = &record.MPP{}
|
mpp = &record.MPP{}
|
||||||
|
amp = &record.AMP{}
|
||||||
)
|
)
|
||||||
|
|
||||||
tlvStream, err := tlv.NewStream(
|
tlvStream, err := tlv.NewStream(
|
||||||
@ -126,6 +127,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
|||||||
record.NewLockTimeRecord(&cltv),
|
record.NewLockTimeRecord(&cltv),
|
||||||
record.NewNextHopIDRecord(&cid),
|
record.NewNextHopIDRecord(&cid),
|
||||||
mpp.Record(),
|
mpp.Record(),
|
||||||
|
amp.Record(),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -160,6 +162,12 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
|||||||
mpp = nil
|
mpp = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If no AMP field was parsed, set the MPP field on the resulting
|
||||||
|
// payload to nil.
|
||||||
|
if _, ok := parsedTypes[record.AMPOnionType]; !ok {
|
||||||
|
amp = nil
|
||||||
|
}
|
||||||
|
|
||||||
// Filter out the custom records.
|
// Filter out the custom records.
|
||||||
customRecords := NewCustomRecords(parsedTypes)
|
customRecords := NewCustomRecords(parsedTypes)
|
||||||
|
|
||||||
@ -171,6 +179,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
|||||||
OutgoingCTLV: cltv,
|
OutgoingCTLV: cltv,
|
||||||
},
|
},
|
||||||
MPP: mpp,
|
MPP: mpp,
|
||||||
|
AMP: amp,
|
||||||
customRecords: customRecords,
|
customRecords: customRecords,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@ -207,6 +216,7 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap,
|
|||||||
_, hasLockTime := parsedTypes[record.LockTimeOnionType]
|
_, hasLockTime := parsedTypes[record.LockTimeOnionType]
|
||||||
_, hasNextHop := parsedTypes[record.NextHopOnionType]
|
_, hasNextHop := parsedTypes[record.NextHopOnionType]
|
||||||
_, hasMPP := parsedTypes[record.MPPOnionType]
|
_, hasMPP := parsedTypes[record.MPPOnionType]
|
||||||
|
_, hasAMP := parsedTypes[record.AMPOnionType]
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
|
|
||||||
@ -243,6 +253,14 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap,
|
|||||||
Violation: IncludedViolation,
|
Violation: IncludedViolation,
|
||||||
FinalHop: isFinalHop,
|
FinalHop: isFinalHop,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Intermediate nodes should never receive AMP fields.
|
||||||
|
case !isFinalHop && hasAMP:
|
||||||
|
return ErrInvalidPayload{
|
||||||
|
Type: record.AMPOnionType,
|
||||||
|
Violation: IncludedViolation,
|
||||||
|
FinalHop: isFinalHop,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/lightningnetwork/lnd/htlcswitch/hop"
|
"github.com/lightningnetwork/lnd/htlcswitch/hop"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
"github.com/lightningnetwork/lnd/record"
|
"github.com/lightningnetwork/lnd/record"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
const testUnknownRequiredType = 0x10
|
const testUnknownRequiredType = 0x10
|
||||||
@ -18,6 +19,7 @@ type decodePayloadTest struct {
|
|||||||
expErr error
|
expErr error
|
||||||
expCustomRecords map[uint64][]byte
|
expCustomRecords map[uint64][]byte
|
||||||
shouldHaveMPP bool
|
shouldHaveMPP bool
|
||||||
|
shouldHaveAMP bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var decodePayloadTests = []decodePayloadTest{
|
var decodePayloadTests = []decodePayloadTest{
|
||||||
@ -183,6 +185,37 @@ var decodePayloadTests = []decodePayloadTest{
|
|||||||
FinalHop: false,
|
FinalHop: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "intermediate hop with amp",
|
||||||
|
payload: []byte{
|
||||||
|
// amount
|
||||||
|
0x02, 0x00,
|
||||||
|
// cltv
|
||||||
|
0x04, 0x00,
|
||||||
|
// next hop id
|
||||||
|
0x06, 0x08,
|
||||||
|
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
// amp
|
||||||
|
0x0e, 0x41,
|
||||||
|
// amp.root_share
|
||||||
|
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
|
||||||
|
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
|
||||||
|
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
|
||||||
|
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
|
||||||
|
// amp.set_id
|
||||||
|
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
|
||||||
|
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
|
||||||
|
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
|
||||||
|
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
|
||||||
|
// amp.child_index
|
||||||
|
0x09,
|
||||||
|
},
|
||||||
|
expErr: hop.ErrInvalidPayload{
|
||||||
|
Type: record.AMPOnionType,
|
||||||
|
Violation: hop.IncludedViolation,
|
||||||
|
FinalHop: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "final hop with mpp",
|
name: "final hop with mpp",
|
||||||
payload: []byte{
|
payload: []byte{
|
||||||
@ -201,6 +234,30 @@ var decodePayloadTests = []decodePayloadTest{
|
|||||||
expErr: nil,
|
expErr: nil,
|
||||||
shouldHaveMPP: true,
|
shouldHaveMPP: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "final hop with amp",
|
||||||
|
payload: []byte{
|
||||||
|
// amount
|
||||||
|
0x02, 0x00,
|
||||||
|
// cltv
|
||||||
|
0x04, 0x00,
|
||||||
|
// amp
|
||||||
|
0x0e, 0x41,
|
||||||
|
// amp.root_share
|
||||||
|
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
|
||||||
|
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
|
||||||
|
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
|
||||||
|
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
|
||||||
|
// amp.set_id
|
||||||
|
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
|
||||||
|
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
|
||||||
|
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
|
||||||
|
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
|
||||||
|
// amp.child_index
|
||||||
|
0x09,
|
||||||
|
},
|
||||||
|
shouldHaveAMP: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestDecodeHopPayloadRecordValidation asserts that parsing the payloads in the
|
// TestDecodeHopPayloadRecordValidation asserts that parsing the payloads in the
|
||||||
@ -223,6 +280,20 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) {
|
|||||||
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||||
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
testRootShare = [32]byte{
|
||||||
|
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
|
||||||
|
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
|
||||||
|
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
|
||||||
|
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
|
||||||
|
}
|
||||||
|
testSetID = [32]byte{
|
||||||
|
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
|
||||||
|
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
|
||||||
|
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
|
||||||
|
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
|
||||||
|
}
|
||||||
|
testChildIndex = uint32(9)
|
||||||
)
|
)
|
||||||
|
|
||||||
p, err := hop.NewPayloadFromReader(bytes.NewReader(test.payload))
|
p, err := hop.NewPayloadFromReader(bytes.NewReader(test.payload))
|
||||||
@ -249,6 +320,17 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) {
|
|||||||
t.Fatalf("unexpected MPP payload")
|
t.Fatalf("unexpected MPP payload")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if test.shouldHaveAMP {
|
||||||
|
if p.AMP == nil {
|
||||||
|
t.Fatalf("payload should have AMP record")
|
||||||
|
}
|
||||||
|
require.Equal(t, testRootShare, p.AMP.RootShare())
|
||||||
|
require.Equal(t, testSetID, p.AMP.SetID())
|
||||||
|
require.Equal(t, testChildIndex, p.AMP.ChildIndex())
|
||||||
|
} else if p.AMP != nil {
|
||||||
|
t.Fatalf("unexpected AMP payload")
|
||||||
|
}
|
||||||
|
|
||||||
// Convert expected nil map to empty map, because we always expect an
|
// Convert expected nil map to empty map, because we always expect an
|
||||||
// initiated map from the payload.
|
// initiated map from the payload.
|
||||||
expCustomRecords := make(record.CustomSet)
|
expCustomRecords := make(record.CustomSet)
|
||||||
|
Loading…
Reference in New Issue
Block a user