lnrpc: add UnmarshalAMP decoding
This commit is contained in:
parent
00581efec6
commit
352ce10658
@ -129,7 +129,7 @@ func CreateRPCInvoice(invoice *channeldb.Invoice,
|
||||
rpcHtlc.Amp = &lnrpc.AMP{
|
||||
RootShare: rootShare[:],
|
||||
SetId: setID[:],
|
||||
ChildIndex: uint32(htlc.AMP.Record.ChildIndex()),
|
||||
ChildIndex: htlc.AMP.Record.ChildIndex(),
|
||||
Hash: htlc.AMP.Hash[:],
|
||||
Preimage: preimage,
|
||||
}
|
||||
|
@ -455,6 +455,11 @@ func UnmarshallHopWithPubkey(rpcHop *lnrpc.Hop, pubkey route.Vertex) (*route.Hop
|
||||
return nil, err
|
||||
}
|
||||
|
||||
amp, err := UnmarshalAMP(rpcHop.AmpRecord)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &route.Hop{
|
||||
OutgoingTimeLock: rpcHop.Expiry,
|
||||
AmtToForward: lnwire.MilliSatoshi(rpcHop.AmtToForwardMsat),
|
||||
@ -463,6 +468,7 @@ func UnmarshallHopWithPubkey(rpcHop *lnrpc.Hop, pubkey route.Vertex) (*route.Hop
|
||||
CustomRecords: customRecords,
|
||||
LegacyPayload: !rpcHop.TlvPayload,
|
||||
MPP: mpp,
|
||||
AMP: amp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -895,6 +901,32 @@ func UnmarshalMPP(reqMPP *lnrpc.MPPRecord) (*record.MPP, error) {
|
||||
return record.NewMPP(total, addr), nil
|
||||
}
|
||||
|
||||
func UnmarshalAMP(reqAMP *lnrpc.AMPRecord) (*record.AMP, error) {
|
||||
if reqAMP == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
reqRootShare := reqAMP.RootShare
|
||||
reqSetID := reqAMP.SetId
|
||||
|
||||
switch {
|
||||
case len(reqRootShare) != 32:
|
||||
return nil, errors.New("AMP root_share must be 32 bytes")
|
||||
|
||||
case len(reqSetID) != 32:
|
||||
return nil, errors.New("AMP set_id must be 32 bytes")
|
||||
}
|
||||
|
||||
var (
|
||||
rootShare [32]byte
|
||||
setID [32]byte
|
||||
)
|
||||
copy(rootShare[:], reqRootShare)
|
||||
copy(setID[:], reqSetID)
|
||||
|
||||
return record.NewAMP(rootShare, setID, reqAMP.ChildIndex), nil
|
||||
}
|
||||
|
||||
// MarshalHTLCAttempt constructs an RPC HTLCAttempt from the db representation.
|
||||
func (r *RouterBackend) MarshalHTLCAttempt(
|
||||
htlc channeldb.HTLCAttempt) (*lnrpc.HTLCAttempt, error) {
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
"github.com/lightningnetwork/lnd/routing"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/lightningnetwork/lnd/lnrpc"
|
||||
)
|
||||
@ -239,18 +240,18 @@ func (m *mockMissionControl) GetPairHistorySnapshot(fromNode,
|
||||
return routing.TimedPairResult{}
|
||||
}
|
||||
|
||||
type mppOutcome byte
|
||||
type recordParseOutcome byte
|
||||
|
||||
const (
|
||||
valid mppOutcome = iota
|
||||
valid recordParseOutcome = iota
|
||||
invalid
|
||||
nompp
|
||||
norecord
|
||||
)
|
||||
|
||||
type unmarshalMPPTest struct {
|
||||
name string
|
||||
mpp *lnrpc.MPPRecord
|
||||
outcome mppOutcome
|
||||
outcome recordParseOutcome
|
||||
}
|
||||
|
||||
// TestUnmarshalMPP checks both positive and negative cases of UnmarshalMPP to
|
||||
@ -262,7 +263,7 @@ func TestUnmarshalMPP(t *testing.T) {
|
||||
{
|
||||
name: "nil record",
|
||||
mpp: nil,
|
||||
outcome: nompp,
|
||||
outcome: norecord,
|
||||
},
|
||||
{
|
||||
name: "invalid total or addr",
|
||||
@ -346,7 +347,7 @@ func testUnmarshalMPP(t *testing.T, test unmarshalMPPTest) {
|
||||
|
||||
// Arguments that produce no MPP field should return no error and no MPP
|
||||
// record.
|
||||
case nompp:
|
||||
case norecord:
|
||||
if err != nil {
|
||||
t.Fatalf("failure for args resulting for no-mpp")
|
||||
}
|
||||
@ -358,3 +359,95 @@ func testUnmarshalMPP(t *testing.T, test unmarshalMPPTest) {
|
||||
t.Fatalf("test case has non-standard outcome")
|
||||
}
|
||||
}
|
||||
|
||||
type unmarshalAMPTest struct {
|
||||
name string
|
||||
amp *lnrpc.AMPRecord
|
||||
outcome recordParseOutcome
|
||||
}
|
||||
|
||||
// TestUnmarshalAMP asserts the behavior of decoding an RPC AMPRecord.
|
||||
func TestUnmarshalAMP(t *testing.T) {
|
||||
rootShare := bytes.Repeat([]byte{0x01}, 32)
|
||||
setID := bytes.Repeat([]byte{0x02}, 32)
|
||||
|
||||
// All child indexes are valid.
|
||||
childIndex := uint32(3)
|
||||
|
||||
tests := []unmarshalAMPTest{
|
||||
{
|
||||
name: "nil record",
|
||||
amp: nil,
|
||||
outcome: norecord,
|
||||
},
|
||||
{
|
||||
name: "invalid root share invalid set id",
|
||||
amp: &lnrpc.AMPRecord{
|
||||
RootShare: []byte{0x01},
|
||||
SetId: []byte{0x02},
|
||||
ChildIndex: childIndex,
|
||||
},
|
||||
outcome: invalid,
|
||||
},
|
||||
{
|
||||
name: "valid root share invalid set id",
|
||||
amp: &lnrpc.AMPRecord{
|
||||
RootShare: rootShare,
|
||||
SetId: []byte{0x02},
|
||||
ChildIndex: childIndex,
|
||||
},
|
||||
outcome: invalid,
|
||||
},
|
||||
{
|
||||
name: "invalid root share valid set id",
|
||||
amp: &lnrpc.AMPRecord{
|
||||
RootShare: []byte{0x01},
|
||||
SetId: setID,
|
||||
ChildIndex: childIndex,
|
||||
},
|
||||
outcome: invalid,
|
||||
},
|
||||
{
|
||||
name: "valid root share valid set id",
|
||||
amp: &lnrpc.AMPRecord{
|
||||
RootShare: rootShare,
|
||||
SetId: setID,
|
||||
ChildIndex: childIndex,
|
||||
},
|
||||
outcome: valid,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
testUnmarshalAMP(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testUnmarshalAMP(t *testing.T, test unmarshalAMPTest) {
|
||||
amp, err := UnmarshalAMP(test.amp)
|
||||
switch test.outcome {
|
||||
case valid:
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, amp)
|
||||
|
||||
rootShare := amp.RootShare()
|
||||
setID := amp.SetID()
|
||||
require.Equal(t, test.amp.RootShare, rootShare[:])
|
||||
require.Equal(t, test.amp.SetId, setID[:])
|
||||
require.Equal(t, test.amp.ChildIndex, amp.ChildIndex())
|
||||
|
||||
case invalid:
|
||||
require.Error(t, err)
|
||||
require.Nil(t, amp)
|
||||
|
||||
case norecord:
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, amp)
|
||||
|
||||
default:
|
||||
t.Fatalf("test case has non-standard outcome")
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user