lnrpc: add UnmarshalAMP decoding
This commit is contained in:
parent
00581efec6
commit
352ce10658
@ -129,7 +129,7 @@ func CreateRPCInvoice(invoice *channeldb.Invoice,
|
|||||||
rpcHtlc.Amp = &lnrpc.AMP{
|
rpcHtlc.Amp = &lnrpc.AMP{
|
||||||
RootShare: rootShare[:],
|
RootShare: rootShare[:],
|
||||||
SetId: setID[:],
|
SetId: setID[:],
|
||||||
ChildIndex: uint32(htlc.AMP.Record.ChildIndex()),
|
ChildIndex: htlc.AMP.Record.ChildIndex(),
|
||||||
Hash: htlc.AMP.Hash[:],
|
Hash: htlc.AMP.Hash[:],
|
||||||
Preimage: preimage,
|
Preimage: preimage,
|
||||||
}
|
}
|
||||||
|
@ -455,6 +455,11 @@ func UnmarshallHopWithPubkey(rpcHop *lnrpc.Hop, pubkey route.Vertex) (*route.Hop
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
amp, err := UnmarshalAMP(rpcHop.AmpRecord)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return &route.Hop{
|
return &route.Hop{
|
||||||
OutgoingTimeLock: rpcHop.Expiry,
|
OutgoingTimeLock: rpcHop.Expiry,
|
||||||
AmtToForward: lnwire.MilliSatoshi(rpcHop.AmtToForwardMsat),
|
AmtToForward: lnwire.MilliSatoshi(rpcHop.AmtToForwardMsat),
|
||||||
@ -463,6 +468,7 @@ func UnmarshallHopWithPubkey(rpcHop *lnrpc.Hop, pubkey route.Vertex) (*route.Hop
|
|||||||
CustomRecords: customRecords,
|
CustomRecords: customRecords,
|
||||||
LegacyPayload: !rpcHop.TlvPayload,
|
LegacyPayload: !rpcHop.TlvPayload,
|
||||||
MPP: mpp,
|
MPP: mpp,
|
||||||
|
AMP: amp,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -895,6 +901,32 @@ func UnmarshalMPP(reqMPP *lnrpc.MPPRecord) (*record.MPP, error) {
|
|||||||
return record.NewMPP(total, addr), nil
|
return record.NewMPP(total, addr), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func UnmarshalAMP(reqAMP *lnrpc.AMPRecord) (*record.AMP, error) {
|
||||||
|
if reqAMP == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
reqRootShare := reqAMP.RootShare
|
||||||
|
reqSetID := reqAMP.SetId
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case len(reqRootShare) != 32:
|
||||||
|
return nil, errors.New("AMP root_share must be 32 bytes")
|
||||||
|
|
||||||
|
case len(reqSetID) != 32:
|
||||||
|
return nil, errors.New("AMP set_id must be 32 bytes")
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
rootShare [32]byte
|
||||||
|
setID [32]byte
|
||||||
|
)
|
||||||
|
copy(rootShare[:], reqRootShare)
|
||||||
|
copy(setID[:], reqSetID)
|
||||||
|
|
||||||
|
return record.NewAMP(rootShare, setID, reqAMP.ChildIndex), nil
|
||||||
|
}
|
||||||
|
|
||||||
// MarshalHTLCAttempt constructs an RPC HTLCAttempt from the db representation.
|
// MarshalHTLCAttempt constructs an RPC HTLCAttempt from the db representation.
|
||||||
func (r *RouterBackend) MarshalHTLCAttempt(
|
func (r *RouterBackend) MarshalHTLCAttempt(
|
||||||
htlc channeldb.HTLCAttempt) (*lnrpc.HTLCAttempt, error) {
|
htlc channeldb.HTLCAttempt) (*lnrpc.HTLCAttempt, error) {
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/lightningnetwork/lnd/record"
|
"github.com/lightningnetwork/lnd/record"
|
||||||
"github.com/lightningnetwork/lnd/routing"
|
"github.com/lightningnetwork/lnd/routing"
|
||||||
"github.com/lightningnetwork/lnd/routing/route"
|
"github.com/lightningnetwork/lnd/routing/route"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/lightningnetwork/lnd/lnrpc"
|
"github.com/lightningnetwork/lnd/lnrpc"
|
||||||
)
|
)
|
||||||
@ -239,18 +240,18 @@ func (m *mockMissionControl) GetPairHistorySnapshot(fromNode,
|
|||||||
return routing.TimedPairResult{}
|
return routing.TimedPairResult{}
|
||||||
}
|
}
|
||||||
|
|
||||||
type mppOutcome byte
|
type recordParseOutcome byte
|
||||||
|
|
||||||
const (
|
const (
|
||||||
valid mppOutcome = iota
|
valid recordParseOutcome = iota
|
||||||
invalid
|
invalid
|
||||||
nompp
|
norecord
|
||||||
)
|
)
|
||||||
|
|
||||||
type unmarshalMPPTest struct {
|
type unmarshalMPPTest struct {
|
||||||
name string
|
name string
|
||||||
mpp *lnrpc.MPPRecord
|
mpp *lnrpc.MPPRecord
|
||||||
outcome mppOutcome
|
outcome recordParseOutcome
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestUnmarshalMPP checks both positive and negative cases of UnmarshalMPP to
|
// TestUnmarshalMPP checks both positive and negative cases of UnmarshalMPP to
|
||||||
@ -262,7 +263,7 @@ func TestUnmarshalMPP(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "nil record",
|
name: "nil record",
|
||||||
mpp: nil,
|
mpp: nil,
|
||||||
outcome: nompp,
|
outcome: norecord,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid total or addr",
|
name: "invalid total or addr",
|
||||||
@ -346,7 +347,7 @@ func testUnmarshalMPP(t *testing.T, test unmarshalMPPTest) {
|
|||||||
|
|
||||||
// Arguments that produce no MPP field should return no error and no MPP
|
// Arguments that produce no MPP field should return no error and no MPP
|
||||||
// record.
|
// record.
|
||||||
case nompp:
|
case norecord:
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failure for args resulting for no-mpp")
|
t.Fatalf("failure for args resulting for no-mpp")
|
||||||
}
|
}
|
||||||
@ -358,3 +359,95 @@ func testUnmarshalMPP(t *testing.T, test unmarshalMPPTest) {
|
|||||||
t.Fatalf("test case has non-standard outcome")
|
t.Fatalf("test case has non-standard outcome")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type unmarshalAMPTest struct {
|
||||||
|
name string
|
||||||
|
amp *lnrpc.AMPRecord
|
||||||
|
outcome recordParseOutcome
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUnmarshalAMP asserts the behavior of decoding an RPC AMPRecord.
|
||||||
|
func TestUnmarshalAMP(t *testing.T) {
|
||||||
|
rootShare := bytes.Repeat([]byte{0x01}, 32)
|
||||||
|
setID := bytes.Repeat([]byte{0x02}, 32)
|
||||||
|
|
||||||
|
// All child indexes are valid.
|
||||||
|
childIndex := uint32(3)
|
||||||
|
|
||||||
|
tests := []unmarshalAMPTest{
|
||||||
|
{
|
||||||
|
name: "nil record",
|
||||||
|
amp: nil,
|
||||||
|
outcome: norecord,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid root share invalid set id",
|
||||||
|
amp: &lnrpc.AMPRecord{
|
||||||
|
RootShare: []byte{0x01},
|
||||||
|
SetId: []byte{0x02},
|
||||||
|
ChildIndex: childIndex,
|
||||||
|
},
|
||||||
|
outcome: invalid,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid root share invalid set id",
|
||||||
|
amp: &lnrpc.AMPRecord{
|
||||||
|
RootShare: rootShare,
|
||||||
|
SetId: []byte{0x02},
|
||||||
|
ChildIndex: childIndex,
|
||||||
|
},
|
||||||
|
outcome: invalid,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid root share valid set id",
|
||||||
|
amp: &lnrpc.AMPRecord{
|
||||||
|
RootShare: []byte{0x01},
|
||||||
|
SetId: setID,
|
||||||
|
ChildIndex: childIndex,
|
||||||
|
},
|
||||||
|
outcome: invalid,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid root share valid set id",
|
||||||
|
amp: &lnrpc.AMPRecord{
|
||||||
|
RootShare: rootShare,
|
||||||
|
SetId: setID,
|
||||||
|
ChildIndex: childIndex,
|
||||||
|
},
|
||||||
|
outcome: valid,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
test := test
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
testUnmarshalAMP(t, test)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testUnmarshalAMP(t *testing.T, test unmarshalAMPTest) {
|
||||||
|
amp, err := UnmarshalAMP(test.amp)
|
||||||
|
switch test.outcome {
|
||||||
|
case valid:
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, amp)
|
||||||
|
|
||||||
|
rootShare := amp.RootShare()
|
||||||
|
setID := amp.SetID()
|
||||||
|
require.Equal(t, test.amp.RootShare, rootShare[:])
|
||||||
|
require.Equal(t, test.amp.SetId, setID[:])
|
||||||
|
require.Equal(t, test.amp.ChildIndex, amp.ChildIndex())
|
||||||
|
|
||||||
|
case invalid:
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, amp)
|
||||||
|
|
||||||
|
case norecord:
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, amp)
|
||||||
|
|
||||||
|
default:
|
||||||
|
t.Fatalf("test case has non-standard outcome")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user