diff --git a/lntest/itest/lnd_amp_test.go b/lntest/itest/lnd_amp_test.go index 66308bb0..092629a2 100644 --- a/lntest/itest/lnd_amp_test.go +++ b/lntest/itest/lnd_amp_test.go @@ -3,6 +3,8 @@ package itest import ( "context" "crypto/rand" + "sort" + "testing" "time" "github.com/btcsuite/btcutil" @@ -172,6 +174,13 @@ func testSendToRouteAMP(net *lntest.NetworkHarness, t *harnessTest) { ctx.waitForChannels() + // Subscribe to bob's invoices. + req := &lnrpc.InvoiceSubscription{} + ctxc, cancelSubscription := context.WithCancel(ctxb) + bobInvoiceSubscription, err := ctx.bob.SubscribeInvoices(ctxc, req) + require.NoError(t.t, err) + defer cancelSubscription() + // We'll send shards along three routes from Alice. sendRoutes := [numShards][]*lntest.HarnessNode{ {ctx.carol, ctx.bob}, @@ -180,7 +189,7 @@ func testSendToRouteAMP(net *lntest.NetworkHarness, t *harnessTest) { } payAddr := make([]byte, 32) - _, err := rand.Read(payAddr) + _, err = rand.Read(payAddr) require.NoError(t.t, err) setID := make([]byte, 32) @@ -193,7 +202,9 @@ func testSendToRouteAMP(net *lntest.NetworkHarness, t *harnessTest) { childPreimages := make(map[lntypes.Preimage]uint32) responses := make(chan *lnrpc.HTLCAttempt, len(sendRoutes)) - for i, hops := range sendRoutes { + + // Define a closure for sending each of the three shards. + sendShard := func(i int, hops []*lntest.HarnessNode) { // Build a route for the specified hops. r, err := ctx.buildRoute(ctxb, shardAmt, net.Alice, hops) if err != nil { @@ -245,6 +256,24 @@ func testSendToRouteAMP(net *lntest.NetworkHarness, t *harnessTest) { }() } + // Send the first shard, this cause Bob to JIT add an invoice. + sendShard(0, sendRoutes[0]) + + // Ensure we get a notification of the invoice being added by Bob. + rpcInvoice, err := bobInvoiceSubscription.Recv() + require.NoError(t.t, err) + + require.False(t.t, rpcInvoice.Settled) // nolint:staticcheck + require.Equal(t.t, lnrpc.Invoice_OPEN, rpcInvoice.State) + require.Equal(t.t, int64(0), rpcInvoice.AmtPaidSat) + require.Equal(t.t, int64(0), rpcInvoice.AmtPaidMsat) + require.Equal(t.t, payAddr, rpcInvoice.PaymentAddr) + + require.Equal(t.t, 0, len(rpcInvoice.Htlcs)) + + sendShard(1, sendRoutes[1]) + sendShard(2, sendRoutes[2]) + // Assert that all of the child preimages are unique. require.Equal(t.t, len(sendRoutes), len(childPreimages)) @@ -282,15 +311,18 @@ func testSendToRouteAMP(net *lntest.NetworkHarness, t *harnessTest) { } childPreimages = childPreimagesCopy - // Fetch Bob's invoices. + // There should now be a settle event for the invoice. + rpcInvoice, err = bobInvoiceSubscription.Recv() + require.NoError(t.t, err) + + // Also fetch Bob's invoice from ListInvoices and assert it is equal to + // the one recevied via the subscription. invoiceResp, err := net.Bob.ListInvoices( ctxb, &lnrpc.ListInvoiceRequest{}, ) require.NoError(t.t, err) - - // There should only be one invoice. require.Equal(t.t, 1, len(invoiceResp.Invoices)) - rpcInvoice := invoiceResp.Invoices[0] + assertInvoiceEqual(t.t, rpcInvoice, invoiceResp.Invoices[0]) // Assert that the invoice is settled for the total payment amount and // has the correct payment address. @@ -330,3 +362,60 @@ func testSendToRouteAMP(net *lntest.NetworkHarness, t *harnessTest) { delete(childPreimages, childPreimage) } } + +// assertInvoiceEqual asserts that two lnrpc.Invoices are equivalent. A custom +// comparison function is defined for these tests, since proto message returned +// from unary and streaming RPCs (as of protobuf 1.23.0 and grpc 1.29.1) aren't +// consistent with the private fields set on the messages. As a result, we avoid +// using require.Equal and test only the actual data members. +func assertInvoiceEqual(t *testing.T, a, b *lnrpc.Invoice) { + t.Helper() + + // Ensure the HTLCs are sorted properly before attempting to compare. + sort.Slice(a.Htlcs, func(i, j int) bool { + return a.Htlcs[i].ChanId < a.Htlcs[j].ChanId + }) + sort.Slice(b.Htlcs, func(i, j int) bool { + return b.Htlcs[i].ChanId < b.Htlcs[j].ChanId + }) + + require.Equal(t, a.Memo, b.Memo) + require.Equal(t, a.RPreimage, b.RPreimage) + require.Equal(t, a.RHash, b.RHash) + require.Equal(t, a.Value, b.Value) + require.Equal(t, a.ValueMsat, b.ValueMsat) + require.Equal(t, a.CreationDate, b.CreationDate) + require.Equal(t, a.SettleDate, b.SettleDate) + require.Equal(t, a.PaymentRequest, b.PaymentRequest) + require.Equal(t, a.DescriptionHash, b.DescriptionHash) + require.Equal(t, a.Expiry, b.Expiry) + require.Equal(t, a.FallbackAddr, b.FallbackAddr) + require.Equal(t, a.CltvExpiry, b.CltvExpiry) + require.Equal(t, a.RouteHints, b.RouteHints) + require.Equal(t, a.Private, b.Private) + require.Equal(t, a.AddIndex, b.AddIndex) + require.Equal(t, a.SettleIndex, b.SettleIndex) + require.Equal(t, a.AmtPaidSat, b.AmtPaidSat) + require.Equal(t, a.AmtPaidMsat, b.AmtPaidMsat) + require.Equal(t, a.State, b.State) + require.Equal(t, a.Features, b.Features) + require.Equal(t, a.IsKeysend, b.IsKeysend) + require.Equal(t, a.PaymentAddr, b.PaymentAddr) + require.Equal(t, a.IsAmp, b.IsAmp) + + require.Equal(t, len(a.Htlcs), len(b.Htlcs)) + for i := range a.Htlcs { + htlcA, htlcB := a.Htlcs[i], b.Htlcs[i] + require.Equal(t, htlcA.ChanId, htlcB.ChanId) + require.Equal(t, htlcA.HtlcIndex, htlcB.HtlcIndex) + require.Equal(t, htlcA.AmtMsat, htlcB.AmtMsat) + require.Equal(t, htlcA.AcceptHeight, htlcB.AcceptHeight) + require.Equal(t, htlcA.AcceptTime, htlcB.AcceptTime) + require.Equal(t, htlcA.ResolveTime, htlcB.ResolveTime) + require.Equal(t, htlcA.ExpiryHeight, htlcB.ExpiryHeight) + require.Equal(t, htlcA.State, htlcB.State) + require.Equal(t, htlcA.CustomRecords, htlcB.CustomRecords) + require.Equal(t, htlcA.MppTotalAmtMsat, htlcB.MppTotalAmtMsat) + require.Equal(t, htlcA.Amp, htlcB.Amp) + } +}