htlcswitch/hop: parse and validate AMP records

This commit is contained in:
Conner Fromknecht 2021-03-24 19:47:58 -07:00
parent 135a0a9f7f
commit c2729cbbbd
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7
2 changed files with 100 additions and 0 deletions

@ -119,6 +119,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
amt uint64 amt uint64
cltv uint32 cltv uint32
mpp = &record.MPP{} mpp = &record.MPP{}
amp = &record.AMP{}
) )
tlvStream, err := tlv.NewStream( tlvStream, err := tlv.NewStream(
@ -126,6 +127,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
record.NewLockTimeRecord(&cltv), record.NewLockTimeRecord(&cltv),
record.NewNextHopIDRecord(&cid), record.NewNextHopIDRecord(&cid),
mpp.Record(), mpp.Record(),
amp.Record(),
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -160,6 +162,12 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
mpp = nil mpp = nil
} }
// If no AMP field was parsed, set the MPP field on the resulting
// payload to nil.
if _, ok := parsedTypes[record.AMPOnionType]; !ok {
amp = nil
}
// Filter out the custom records. // Filter out the custom records.
customRecords := NewCustomRecords(parsedTypes) customRecords := NewCustomRecords(parsedTypes)
@ -171,6 +179,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
OutgoingCTLV: cltv, OutgoingCTLV: cltv,
}, },
MPP: mpp, MPP: mpp,
AMP: amp,
customRecords: customRecords, customRecords: customRecords,
}, nil }, nil
} }
@ -207,6 +216,7 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap,
_, hasLockTime := parsedTypes[record.LockTimeOnionType] _, hasLockTime := parsedTypes[record.LockTimeOnionType]
_, hasNextHop := parsedTypes[record.NextHopOnionType] _, hasNextHop := parsedTypes[record.NextHopOnionType]
_, hasMPP := parsedTypes[record.MPPOnionType] _, hasMPP := parsedTypes[record.MPPOnionType]
_, hasAMP := parsedTypes[record.AMPOnionType]
switch { switch {
@ -243,6 +253,14 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap,
Violation: IncludedViolation, Violation: IncludedViolation,
FinalHop: isFinalHop, FinalHop: isFinalHop,
} }
// Intermediate nodes should never receive AMP fields.
case !isFinalHop && hasAMP:
return ErrInvalidPayload{
Type: record.AMPOnionType,
Violation: IncludedViolation,
FinalHop: isFinalHop,
}
} }
return nil return nil

@ -8,6 +8,7 @@ import (
"github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/record"
"github.com/stretchr/testify/require"
) )
const testUnknownRequiredType = 0x10 const testUnknownRequiredType = 0x10
@ -18,6 +19,7 @@ type decodePayloadTest struct {
expErr error expErr error
expCustomRecords map[uint64][]byte expCustomRecords map[uint64][]byte
shouldHaveMPP bool shouldHaveMPP bool
shouldHaveAMP bool
} }
var decodePayloadTests = []decodePayloadTest{ var decodePayloadTests = []decodePayloadTest{
@ -183,6 +185,37 @@ var decodePayloadTests = []decodePayloadTest{
FinalHop: false, FinalHop: false,
}, },
}, },
{
name: "intermediate hop with amp",
payload: []byte{
// amount
0x02, 0x00,
// cltv
0x04, 0x00,
// next hop id
0x06, 0x08,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
// amp
0x0e, 0x41,
// amp.root_share
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
// amp.set_id
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
// amp.child_index
0x09,
},
expErr: hop.ErrInvalidPayload{
Type: record.AMPOnionType,
Violation: hop.IncludedViolation,
FinalHop: false,
},
},
{ {
name: "final hop with mpp", name: "final hop with mpp",
payload: []byte{ payload: []byte{
@ -201,6 +234,30 @@ var decodePayloadTests = []decodePayloadTest{
expErr: nil, expErr: nil,
shouldHaveMPP: true, shouldHaveMPP: true,
}, },
{
name: "final hop with amp",
payload: []byte{
// amount
0x02, 0x00,
// cltv
0x04, 0x00,
// amp
0x0e, 0x41,
// amp.root_share
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
// amp.set_id
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
// amp.child_index
0x09,
},
shouldHaveAMP: true,
},
} }
// TestDecodeHopPayloadRecordValidation asserts that parsing the payloads in the // TestDecodeHopPayloadRecordValidation asserts that parsing the payloads in the
@ -223,6 +280,20 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) {
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
} }
testRootShare = [32]byte{
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
}
testSetID = [32]byte{
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13, 0x13,
}
testChildIndex = uint32(9)
) )
p, err := hop.NewPayloadFromReader(bytes.NewReader(test.payload)) p, err := hop.NewPayloadFromReader(bytes.NewReader(test.payload))
@ -249,6 +320,17 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) {
t.Fatalf("unexpected MPP payload") t.Fatalf("unexpected MPP payload")
} }
if test.shouldHaveAMP {
if p.AMP == nil {
t.Fatalf("payload should have AMP record")
}
require.Equal(t, testRootShare, p.AMP.RootShare())
require.Equal(t, testSetID, p.AMP.SetID())
require.Equal(t, testChildIndex, p.AMP.ChildIndex())
} else if p.AMP != nil {
t.Fatalf("unexpected AMP payload")
}
// Convert expected nil map to empty map, because we always expect an // Convert expected nil map to empty map, because we always expect an
// initiated map from the payload. // initiated map from the payload.
expCustomRecords := make(record.CustomSet) expCustomRecords := make(record.CustomSet)