From db1d671b1af6b1d1145d5287d1065e5ee5169b11 Mon Sep 17 00:00:00 2001 From: carla Date: Wed, 14 Apr 2021 09:19:23 +0200 Subject: [PATCH] multi: terminate SubscribeSingleInvoice once completed --- channeldb/invoices.go | 5 +++++ lnrpc/invoicesrpc/invoices_server.go | 6 ++++++ lntest/itest/lnd_hold_persistence_test.go | 8 ++++++++ 3 files changed, 19 insertions(+) diff --git a/channeldb/invoices.go b/channeldb/invoices.go index fbec1a39..c4ee0a47 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -362,6 +362,11 @@ func (c ContractState) String() string { return "Unknown" } +// IsFinal returns a boolean indicating whether an invoice state is final +func (c ContractState) IsFinal() bool { + return c == ContractSettled || c == ContractCanceled +} + // ContractTerm is a companion struct to the Invoice struct. This struct houses // the necessary conditions required before the invoice can be considered fully // settled by the payee. diff --git a/lnrpc/invoicesrpc/invoices_server.go b/lnrpc/invoicesrpc/invoices_server.go index 9c092e57..2fa881db 100644 --- a/lnrpc/invoicesrpc/invoices_server.go +++ b/lnrpc/invoicesrpc/invoices_server.go @@ -247,6 +247,12 @@ func (s *Server) SubscribeSingleInvoice(req *SubscribeSingleInvoiceRequest, return err } + // If we have reached a terminal state, close the + // stream with no error. + if newInvoice.State.IsFinal() { + return nil + } + case <-s.quit: return nil } diff --git a/lntest/itest/lnd_hold_persistence_test.go b/lntest/itest/lnd_hold_persistence_test.go index 387b363f..06c4daaf 100644 --- a/lntest/itest/lnd_hold_persistence_test.go +++ b/lntest/itest/lnd_hold_persistence_test.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "fmt" + "io" "sync" "time" @@ -403,4 +404,11 @@ func testHoldInvoicePersistence(net *lntest.NetworkHarness, t *harnessTest) { } } } + + // Check that all of our invoice streams are terminated by the server + // since the invoices have completed. + for _, stream := range invoiceStreams { + _, err = stream.Recv() + require.Equal(t.t, io.EOF, err) + } }