lnrpc: add UnmarshalAMP decoding

This commit is contained in:
Conner Fromknecht 2021-03-24 19:53:13 -07:00
parent 00581efec6
commit 352ce10658
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7
3 changed files with 132 additions and 7 deletions

@ -129,7 +129,7 @@ func CreateRPCInvoice(invoice *channeldb.Invoice,
rpcHtlc.Amp = &lnrpc.AMP{
RootShare: rootShare[:],
SetId: setID[:],
ChildIndex: uint32(htlc.AMP.Record.ChildIndex()),
ChildIndex: htlc.AMP.Record.ChildIndex(),
Hash: htlc.AMP.Hash[:],
Preimage: preimage,
}

@ -455,6 +455,11 @@ func UnmarshallHopWithPubkey(rpcHop *lnrpc.Hop, pubkey route.Vertex) (*route.Hop
return nil, err
}
amp, err := UnmarshalAMP(rpcHop.AmpRecord)
if err != nil {
return nil, err
}
return &route.Hop{
OutgoingTimeLock: rpcHop.Expiry,
AmtToForward: lnwire.MilliSatoshi(rpcHop.AmtToForwardMsat),
@ -463,6 +468,7 @@ func UnmarshallHopWithPubkey(rpcHop *lnrpc.Hop, pubkey route.Vertex) (*route.Hop
CustomRecords: customRecords,
LegacyPayload: !rpcHop.TlvPayload,
MPP: mpp,
AMP: amp,
}, nil
}
@ -895,6 +901,32 @@ func UnmarshalMPP(reqMPP *lnrpc.MPPRecord) (*record.MPP, error) {
return record.NewMPP(total, addr), nil
}
func UnmarshalAMP(reqAMP *lnrpc.AMPRecord) (*record.AMP, error) {
if reqAMP == nil {
return nil, nil
}
reqRootShare := reqAMP.RootShare
reqSetID := reqAMP.SetId
switch {
case len(reqRootShare) != 32:
return nil, errors.New("AMP root_share must be 32 bytes")
case len(reqSetID) != 32:
return nil, errors.New("AMP set_id must be 32 bytes")
}
var (
rootShare [32]byte
setID [32]byte
)
copy(rootShare[:], reqRootShare)
copy(setID[:], reqSetID)
return record.NewAMP(rootShare, setID, reqAMP.ChildIndex), nil
}
// MarshalHTLCAttempt constructs an RPC HTLCAttempt from the db representation.
func (r *RouterBackend) MarshalHTLCAttempt(
htlc channeldb.HTLCAttempt) (*lnrpc.HTLCAttempt, error) {

@ -12,6 +12,7 @@ import (
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/stretchr/testify/require"
"github.com/lightningnetwork/lnd/lnrpc"
)
@ -239,18 +240,18 @@ func (m *mockMissionControl) GetPairHistorySnapshot(fromNode,
return routing.TimedPairResult{}
}
type mppOutcome byte
type recordParseOutcome byte
const (
valid mppOutcome = iota
valid recordParseOutcome = iota
invalid
nompp
norecord
)
type unmarshalMPPTest struct {
name string
mpp *lnrpc.MPPRecord
outcome mppOutcome
outcome recordParseOutcome
}
// TestUnmarshalMPP checks both positive and negative cases of UnmarshalMPP to
@ -262,7 +263,7 @@ func TestUnmarshalMPP(t *testing.T) {
{
name: "nil record",
mpp: nil,
outcome: nompp,
outcome: norecord,
},
{
name: "invalid total or addr",
@ -346,7 +347,7 @@ func testUnmarshalMPP(t *testing.T, test unmarshalMPPTest) {
// Arguments that produce no MPP field should return no error and no MPP
// record.
case nompp:
case norecord:
if err != nil {
t.Fatalf("failure for args resulting for no-mpp")
}
@ -358,3 +359,95 @@ func testUnmarshalMPP(t *testing.T, test unmarshalMPPTest) {
t.Fatalf("test case has non-standard outcome")
}
}
type unmarshalAMPTest struct {
name string
amp *lnrpc.AMPRecord
outcome recordParseOutcome
}
// TestUnmarshalAMP asserts the behavior of decoding an RPC AMPRecord.
func TestUnmarshalAMP(t *testing.T) {
rootShare := bytes.Repeat([]byte{0x01}, 32)
setID := bytes.Repeat([]byte{0x02}, 32)
// All child indexes are valid.
childIndex := uint32(3)
tests := []unmarshalAMPTest{
{
name: "nil record",
amp: nil,
outcome: norecord,
},
{
name: "invalid root share invalid set id",
amp: &lnrpc.AMPRecord{
RootShare: []byte{0x01},
SetId: []byte{0x02},
ChildIndex: childIndex,
},
outcome: invalid,
},
{
name: "valid root share invalid set id",
amp: &lnrpc.AMPRecord{
RootShare: rootShare,
SetId: []byte{0x02},
ChildIndex: childIndex,
},
outcome: invalid,
},
{
name: "invalid root share valid set id",
amp: &lnrpc.AMPRecord{
RootShare: []byte{0x01},
SetId: setID,
ChildIndex: childIndex,
},
outcome: invalid,
},
{
name: "valid root share valid set id",
amp: &lnrpc.AMPRecord{
RootShare: rootShare,
SetId: setID,
ChildIndex: childIndex,
},
outcome: valid,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
testUnmarshalAMP(t, test)
})
}
}
func testUnmarshalAMP(t *testing.T, test unmarshalAMPTest) {
amp, err := UnmarshalAMP(test.amp)
switch test.outcome {
case valid:
require.NoError(t, err)
require.NotNil(t, amp)
rootShare := amp.RootShare()
setID := amp.SetID()
require.Equal(t, test.amp.RootShare, rootShare[:])
require.Equal(t, test.amp.SetId, setID[:])
require.Equal(t, test.amp.ChildIndex, amp.ChildIndex())
case invalid:
require.Error(t, err)
require.Nil(t, amp)
case norecord:
require.NoError(t, err)
require.Nil(t, amp)
default:
t.Fatalf("test case has non-standard outcome")
}
}