diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index a34f7877..16c28f2d 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -37,13 +37,15 @@ type InvoiceRegistry struct { cdb *channeldb.DB - clientMtx sync.Mutex - nextClientID uint32 - notificationClients map[uint32]*InvoiceSubscription + clientMtx sync.Mutex + nextClientID uint32 + notificationClients map[uint32]*InvoiceSubscription + singleNotificationClients map[uint32]*SingleInvoiceSubscription - newSubscriptions chan *InvoiceSubscription - subscriptionCancels chan uint32 - invoiceEvents chan *invoiceEvent + newSubscriptions chan *InvoiceSubscription + newSingleSubscriptions chan *SingleInvoiceSubscription + subscriptionCancels chan uint32 + invoiceEvents chan *invoiceEvent // debugInvoices is a map which stores special "debug" invoices which // should be only created/used when manual tests require an invoice @@ -64,14 +66,16 @@ func NewRegistry(cdb *channeldb.DB, activeNetParams *chaincfg.Params) *InvoiceRegistry { return &InvoiceRegistry{ - cdb: cdb, - debugInvoices: make(map[lntypes.Hash]*channeldb.Invoice), - notificationClients: make(map[uint32]*InvoiceSubscription), - newSubscriptions: make(chan *InvoiceSubscription), - subscriptionCancels: make(chan uint32), - invoiceEvents: make(chan *invoiceEvent, 100), - activeNetParams: activeNetParams, - quit: make(chan struct{}), + cdb: cdb, + debugInvoices: make(map[lntypes.Hash]*channeldb.Invoice), + notificationClients: make(map[uint32]*InvoiceSubscription), + singleNotificationClients: make(map[uint32]*SingleInvoiceSubscription), + newSubscriptions: make(chan *InvoiceSubscription), + newSingleSubscriptions: make(chan *SingleInvoiceSubscription), + subscriptionCancels: make(chan uint32), + invoiceEvents: make(chan *invoiceEvent, 100), + activeNetParams: activeNetParams, + quit: make(chan struct{}), } } @@ -96,6 +100,7 @@ func (i *InvoiceRegistry) Stop() { // instance where invoices are settled. type invoiceEvent struct { state channeldb.ContractState + hash lntypes.Hash invoice *channeldb.Invoice } @@ -107,9 +112,9 @@ func (i *InvoiceRegistry) invoiceEventNotifier() { for { select { - // A new invoice subscription has just arrived! We'll query for - // any backlog notifications, then add it to the set of - // clients. + // A new invoice subscription for all invoices has just arrived! + // We'll query for any backlog notifications, then add it to the + // set of clients. case newClient := <-i.newSubscriptions: // Before we add the client to our set of active // clients, we'll first attempt to deliver any backlog @@ -128,6 +133,23 @@ func (i *InvoiceRegistry) invoiceEventNotifier() { // continue. i.notificationClients[newClient.id] = newClient + // A new single invoice subscription has arrived. We'll query + // for any backlog notifications, then add it to the set of + // clients. + case newClient := <-i.newSingleSubscriptions: + err := i.deliverSingleBacklogEvents(newClient) + if err != nil { + log.Errorf("Unable to deliver backlog invoice "+ + "notifications: %v", err) + } + + log.Infof("New single invoice subscription "+ + "client: id=%v, hash=%v", + newClient.id, newClient.hash, + ) + + i.singleNotificationClients[newClient.id] = newClient + // A client no longer wishes to receive invoice notifications. // So we'll remove them from the set of active clients. case clientID := <-i.subscriptionCancels: @@ -135,11 +157,13 @@ func (i *InvoiceRegistry) invoiceEventNotifier() { "client=%v", clientID) delete(i.notificationClients, clientID) + delete(i.singleNotificationClients, clientID) // A sub-systems has just modified the invoice state, so we'll // dispatch notifications to all registered clients. case event := <-i.invoiceEvents: i.dispatchToClients(event) + i.dispatchToSingleClients(event) case <-i.quit: return @@ -147,6 +171,26 @@ func (i *InvoiceRegistry) invoiceEventNotifier() { } } +// dispatchToSingleClients passes the supplied event to all notification clients +// that subscribed to all the invoice this event applies to. +func (i *InvoiceRegistry) dispatchToSingleClients(event *invoiceEvent) { + // Dispatch to single invoice subscribers. + for _, client := range i.singleNotificationClients { + if client.hash != event.hash { + continue + } + + select { + case client.ntfnQueue.ChanIn() <- &invoiceEvent{ + state: event.state, + invoice: event.invoice, + }: + case <-i.quit: + return + } + } +} + // dispatchToClients passes the supplied event to all notification clients that // subscribed to all invoices. Add and settle indices are used to make sure that // clients don't receive duplicate or unwanted events. @@ -227,6 +271,7 @@ func (i *InvoiceRegistry) deliverBacklogEvents(client *InvoiceSubscription) erro if err != nil { return err } + settleEvents, err := i.cdb.InvoicesSettledSince(client.settleIndex) if err != nil { return err @@ -249,6 +294,7 @@ func (i *InvoiceRegistry) deliverBacklogEvents(client *InvoiceSubscription) erro return fmt.Errorf("registry shutting down") } } + for _, settleEvent := range settleEvents { // We re-bind the loop variable to ensure we don't hold onto // the loop reference causing is to point to the same item. @@ -267,6 +313,37 @@ func (i *InvoiceRegistry) deliverBacklogEvents(client *InvoiceSubscription) erro return nil } +// deliverSingleBacklogEvents will attempt to query the invoice database to +// retrieve the current invoice state and deliver this to the subscriber. Single +// invoice subscribers will always receive the current state right after +// subscribing. Only in case the invoice does not yet exist, nothing is sent +// yet. +func (i *InvoiceRegistry) deliverSingleBacklogEvents( + client *SingleInvoiceSubscription) error { + + invoice, err := i.cdb.LookupInvoice(client.hash) + + // It is possible that the invoice does not exist yet, but the client is + // already watching it in anticipation. + if err == channeldb.ErrInvoiceNotFound { + return nil + } + if err != nil { + return err + } + + err = client.notify(&invoiceEvent{ + hash: client.hash, + invoice: &invoice, + state: invoice.Terms.State, + }) + if err != nil { + return err + } + + return nil +} + // AddDebugInvoice adds a debug invoice for the specified amount, identified // by the passed preimage. Once this invoice is added, subsystems within the // daemon add/forward HTLCs that are able to obtain the proper preimage @@ -300,7 +377,9 @@ func (i *InvoiceRegistry) AddDebugInvoice(amt btcutil.Amount, // redemption in the case that we're the final destination. We also return the // addIndex of the newly created invoice which monotonically increases for each // new invoice added. -func (i *InvoiceRegistry) AddInvoice(invoice *channeldb.Invoice) (uint64, error) { +func (i *InvoiceRegistry) AddInvoice(invoice *channeldb.Invoice, + paymentHash lntypes.Hash) (uint64, error) { + i.Lock() defer i.Unlock() @@ -315,7 +394,7 @@ func (i *InvoiceRegistry) AddInvoice(invoice *channeldb.Invoice) (uint64, error) // Now that we've added the invoice, we'll send dispatch a message to // notify the clients of this new invoice. - i.notifyClients(invoice, channeldb.ContractOpen) + i.notifyClients(paymentHash, invoice, channeldb.ContractOpen) return addIndex, nil } @@ -392,19 +471,21 @@ func (i *InvoiceRegistry) SettleInvoice(rHash lntypes.Hash, log.Infof("Payment received: %v", spew.Sdump(invoice)) - i.notifyClients(invoice, channeldb.ContractSettled) + i.notifyClients(rHash, invoice, channeldb.ContractSettled) return nil } // notifyClients notifies all currently registered invoice notification clients // of a newly added/settled invoice. -func (i *InvoiceRegistry) notifyClients(invoice *channeldb.Invoice, +func (i *InvoiceRegistry) notifyClients(hash lntypes.Hash, + invoice *channeldb.Invoice, state channeldb.ContractState) { event := &invoiceEvent{ state: state, invoice: invoice, + hash: hash, } select { @@ -413,13 +494,25 @@ func (i *InvoiceRegistry) notifyClients(invoice *channeldb.Invoice, } } +// invoiceSubscriptionKit defines that are common to both all invoice +// subscribers and single invoice subscribers. +type invoiceSubscriptionKit struct { + id uint32 + inv *InvoiceRegistry + ntfnQueue *queue.ConcurrentQueue + + cancelled uint32 // To be used atomically. + cancelChan chan struct{} + wg sync.WaitGroup +} + // InvoiceSubscription represents an intent to receive updates for newly added // or settled invoices. For each newly added invoice, a copy of the invoice // will be sent over the NewInvoices channel. Similarly, for each newly settled // invoice, a copy of the invoice will be sent over the SettledInvoices // channel. type InvoiceSubscription struct { - cancelled uint32 // To be used atomically. + invoiceSubscriptionKit // NewInvoices is a channel that we'll use to send all newly created // invoices with an invoice index greater than the specified @@ -443,21 +536,23 @@ type InvoiceSubscription struct { // greater than this will be dispatched before any new notifications // are sent out. settleIndex uint64 +} - ntfnQueue *queue.ConcurrentQueue +// SingleInvoiceSubscription represents an intent to receive updates for a +// specific invoice. +type SingleInvoiceSubscription struct { + invoiceSubscriptionKit - id uint32 + hash lntypes.Hash - inv *InvoiceRegistry - - cancelChan chan struct{} - - wg sync.WaitGroup + // Updates is a channel that we'll use to send all invoice events for + // the invoice that is subscribed to. + Updates chan *channeldb.Invoice } // Cancel unregisters the InvoiceSubscription, freeing any previously allocated // resources. -func (i *InvoiceSubscription) Cancel() { +func (i *invoiceSubscriptionKit) Cancel() { if !atomic.CompareAndSwapUint32(&i.cancelled, 0, 1) { return } @@ -473,6 +568,16 @@ func (i *InvoiceSubscription) Cancel() { i.wg.Wait() } +func (i *invoiceSubscriptionKit) notify(event *invoiceEvent) error { + select { + case i.ntfnQueue.ChanIn() <- event: + case <-i.inv.quit: + return fmt.Errorf("registry shutting down") + } + + return nil +} + // SubscribeNotifications returns an InvoiceSubscription which allows the // caller to receive async notifications when any invoices are settled or // added. The invoiceIndex parameter is a streaming "checkpoint". We'll start @@ -484,9 +589,11 @@ func (i *InvoiceRegistry) SubscribeNotifications(addIndex, settleIndex uint64) * SettledInvoices: make(chan *channeldb.Invoice), addIndex: addIndex, settleIndex: settleIndex, - inv: i, - ntfnQueue: queue.NewConcurrentQueue(20), - cancelChan: make(chan struct{}), + invoiceSubscriptionKit: invoiceSubscriptionKit{ + inv: i, + ntfnQueue: queue.NewConcurrentQueue(20), + cancelChan: make(chan struct{}), + }, } client.ntfnQueue.Start() @@ -551,3 +658,67 @@ func (i *InvoiceRegistry) SubscribeNotifications(addIndex, settleIndex uint64) * return client } + +// SubscribeSingleInvoice returns an SingleInvoiceSubscription which allows the +// caller to receive async notifications for a specific invoice. +func (i *InvoiceRegistry) SubscribeSingleInvoice( + hash lntypes.Hash) *SingleInvoiceSubscription { + + client := &SingleInvoiceSubscription{ + Updates: make(chan *channeldb.Invoice), + invoiceSubscriptionKit: invoiceSubscriptionKit{ + inv: i, + ntfnQueue: queue.NewConcurrentQueue(20), + cancelChan: make(chan struct{}), + }, + hash: hash, + } + client.ntfnQueue.Start() + + i.clientMtx.Lock() + client.id = i.nextClientID + i.nextClientID++ + i.clientMtx.Unlock() + + // Before we register this new invoice subscription, we'll launch a new + // goroutine that will proxy all notifications appended to the end of + // the concurrent queue to the two client-side channels the caller will + // feed off of. + i.wg.Add(1) + go func() { + defer i.wg.Done() + + for { + select { + // A new invoice event has been sent by the + // invoiceRegistry. We will dispatch the event to the + // client. + case ntfn := <-client.ntfnQueue.ChanOut(): + invoiceEvent := ntfn.(*invoiceEvent) + + select { + case client.Updates <- invoiceEvent.invoice: + + case <-client.cancelChan: + return + + case <-i.quit: + return + } + + case <-client.cancelChan: + return + + case <-i.quit: + return + } + } + }() + + select { + case i.newSingleSubscriptions <- client: + case <-i.quit: + } + + return client +} diff --git a/rpcserver.go b/rpcserver.go index 813675c3..934c54e4 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -3320,7 +3320,7 @@ func (r *rpcServer) AddInvoice(ctx context.Context, ) // With all sanity checks passed, write the invoice to the database. - addIndex, err := r.server.invoices.AddInvoice(newInvoice) + addIndex, err := r.server.invoices.AddInvoice(newInvoice, rHash) if err != nil { return nil, err }