diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index 014f8230..68512607 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -106,7 +106,7 @@ func TestInvoiceWorkflow(t *testing.T) { if err != nil { t.Fatalf("unable to fetch invoice: %v", err) } - if !dbInvoice2.Terms.Settled { + if dbInvoice2.Terms.State != ContractSettled { t.Fatalf("invoice should now be settled but isn't") } if dbInvoice2.SettleDate.IsZero() { @@ -348,7 +348,7 @@ func TestDuplicateSettleInvoice(t *testing.T) { // We'll update what we expect the settle invoice to be so that our // comparison below has the correct assumption. invoice.SettleIndex = 1 - invoice.Terms.Settled = true + invoice.Terms.State = ContractSettled invoice.AmtPaid = amt invoice.SettleDate = dbInvoice.SettleDate diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 428b6c11..5915cb4c 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -74,6 +74,30 @@ const ( MaxPaymentRequestSize = 4096 ) +// ContractState describes the state the invoice is in. +type ContractState uint8 + +const ( + // ContractOpen means the invoice has only been created. + ContractOpen ContractState = 0 + + // ContractSettled means the htlc is settled and the invoice has been + // paid. + ContractSettled ContractState = 1 +) + +// String returns a human readable identifier for the ContractState type. +func (c ContractState) String() string { + switch c { + case ContractOpen: + return "Open" + case ContractSettled: + return "Settled" + } + + return "Unknown" +} + // ContractTerm is a companion struct to the Invoice struct. This struct houses // the necessary conditions required before the invoice can be considered fully // settled by the payee. @@ -87,9 +111,8 @@ type ContractTerm struct { // which can be satisfied by the above preimage. Value lnwire.MilliSatoshi - // Settled indicates if this particular contract term has been fully - // settled by the payer. - Settled bool + // State describes the state the invoice is in. + State ContractState } // Invoice is a payment invoice generated by a payee in order to request @@ -380,7 +403,9 @@ func (d *DB) FetchAllInvoices(pendingOnly bool) ([]Invoice, error) { return err } - if pendingOnly && invoice.Terms.Settled { + if pendingOnly && + invoice.Terms.State == ContractSettled { + return nil } @@ -528,7 +553,9 @@ func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) { // Skip any settled invoices if the caller is only // interested in unsettled. - if q.PendingOnly && invoice.Terms.Settled { + if q.PendingOnly && + invoice.Terms.State == ContractSettled { + continue } @@ -773,7 +800,7 @@ func serializeInvoice(w io.Writer, i *Invoice) error { return err } - if err := binary.Write(w, byteOrder, i.Terms.Settled); err != nil { + if err := binary.Write(w, byteOrder, i.Terms.State); err != nil { return err } @@ -845,7 +872,7 @@ func deserializeInvoice(r io.Reader) (Invoice, error) { } invoice.Terms.Value = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:])) - if err := binary.Read(r, byteOrder, &invoice.Terms.Settled); err != nil { + if err := binary.Read(r, byteOrder, &invoice.Terms.State); err != nil { return invoice, err } @@ -872,7 +899,7 @@ func settleInvoice(invoices, settleIndex *bbolt.Bucket, invoiceNum []byte, // Add idempotency to duplicate settles, return here to avoid // overwriting the previous info. - if invoice.Terms.Settled { + if invoice.Terms.State == ContractSettled { return &invoice, nil } @@ -891,7 +918,7 @@ func settleInvoice(invoices, settleIndex *bbolt.Bucket, invoiceNum []byte, } invoice.AmtPaid = amtPaid - invoice.Terms.Settled = true + invoice.Terms.State = ContractSettled invoice.SettleDate = time.Now() invoice.SettleIndex = nextSettleSeqNo diff --git a/channeldb/migrations.go b/channeldb/migrations.go index ffac8a51..647502c1 100644 --- a/channeldb/migrations.go +++ b/channeldb/migrations.go @@ -189,7 +189,7 @@ func migrateInvoiceTimeSeries(tx *bbolt.Tx) error { // Next, we'll check if the invoice has been settled or not. If // so, then we'll also add it to the settle index. var nextSettleSeqNo uint64 - if invoice.Terms.Settled { + if invoice.Terms.State == ContractSettled { nextSettleSeqNo, err = settleIndex.NextSequence() if err != nil { return err diff --git a/htlcswitch/link.go b/htlcswitch/link.go index fe2e71a5..3f59b89b 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -2334,7 +2334,7 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, // TODO(conner): track ownership of settlements to // properly recover from failures? or add batch invoice // settlement - if invoice.Terms.Settled { + if invoice.Terms.State != channeldb.ContractOpen { log.Warnf("Accepting duplicate payment for "+ "hash=%x", pd.RHash[:]) } diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 5fbaa765..d4fe2a30 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -236,7 +236,7 @@ func TestChannelLinkSingleHopPayment(t *testing.T) { if err != nil { t.Fatalf("unable to get invoice: %v", err) } - if !invoice.Terms.Settled { + if invoice.Terms.State != channeldb.ContractSettled { t.Fatal("alice invoice wasn't settled") } @@ -467,7 +467,7 @@ func TestChannelLinkMultiHopPayment(t *testing.T) { if err != nil { t.Fatalf("unable to get invoice: %v", err) } - if !invoice.Terms.Settled { + if invoice.Terms.State != channeldb.ContractSettled { t.Fatal("carol invoice haven't been settled") } @@ -818,7 +818,7 @@ func TestUpdateForwardingPolicy(t *testing.T) { if err != nil { t.Fatalf("unable to get invoice: %v", err) } - if !invoice.Terms.Settled { + if invoice.Terms.State != channeldb.ContractSettled { t.Fatal("carol invoice haven't been settled") } @@ -937,7 +937,7 @@ func TestChannelLinkMultiHopInsufficientPayment(t *testing.T) { if err != nil { t.Fatalf("unable to get invoice: %v", err) } - if invoice.Terms.Settled { + if invoice.Terms.State == channeldb.ContractSettled { t.Fatal("carol invoice have been settled") } @@ -1026,7 +1026,7 @@ func TestChannelLinkMultiHopUnknownPaymentHash(t *testing.T) { // Check that alice invoice wasn't settled and bandwidth of htlc // links hasn't been changed. - if invoice.Terms.Settled { + if invoice.Terms.State == channeldb.ContractSettled { t.Fatal("alice invoice was settled") } @@ -1112,7 +1112,7 @@ func TestChannelLinkMultiHopUnknownNextHop(t *testing.T) { if err != nil { t.Fatalf("unable to get invoice: %v", err) } - if invoice.Terms.Settled { + if invoice.Terms.State == channeldb.ContractSettled { t.Fatal("carol invoice have been settled") } @@ -1227,7 +1227,7 @@ func TestChannelLinkMultiHopDecodeError(t *testing.T) { if err != nil { t.Fatalf("unable to get invoice: %v", err) } - if invoice.Terms.Settled { + if invoice.Terms.State == channeldb.ContractSettled { t.Fatal("carol invoice have been settled") } @@ -3332,7 +3332,7 @@ func TestChannelRetransmission(t *testing.T) { err = errors.Errorf("unable to get invoice: %v", err) continue } - if !invoice.Terms.Settled { + if invoice.Terms.State != channeldb.ContractSettled { err = errors.Errorf("alice invoice haven't been settled") continue } @@ -3828,7 +3828,7 @@ func TestChannelLinkAcceptOverpay(t *testing.T) { if err != nil { t.Fatalf("unable to get invoice: %v", err) } - if !invoice.Terms.Settled { + if invoice.Terms.State != channeldb.ContractSettled { t.Fatal("carol invoice haven't been settled") } diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index bb260156..d6c1d820 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -720,11 +720,11 @@ func (i *mockInvoiceRegistry) SettleInvoice(rhash chainhash.Hash, return fmt.Errorf("can't find mock invoice: %x", rhash[:]) } - if invoice.Terms.Settled { + if invoice.Terms.State == channeldb.ContractSettled { return nil } - invoice.Terms.Settled = true + invoice.Terms.State = channeldb.ContractSettled invoice.AmtPaid = amt i.invoices[rhash] = invoice diff --git a/invoiceregistry.go b/invoiceregistry.go index 235460f3..d8a0fd5f 100644 --- a/invoiceregistry.go +++ b/invoiceregistry.go @@ -89,8 +89,7 @@ func (i *invoiceRegistry) Stop() { // Only two event types are currently supported: newly created invoices, and // instance where invoices are settled. type invoiceEvent struct { - isSettle bool - + state channeldb.ContractState invoice *channeldb.Invoice } @@ -143,27 +142,27 @@ func (i *invoiceRegistry) invoiceEventNotifier() { switch { // If we've already sent this settle event to // the client, then we can skip this. - case event.isSettle && + case event.state == channeldb.ContractSettled && client.settleIndex >= invoice.SettleIndex: continue // Similarly, if we've already sent this add to // the client then we can skip this one. - case !event.isSettle && + case event.state == channeldb.ContractOpen && client.addIndex >= invoice.AddIndex: continue // These two states should never happen, but we // log them just in case so we can detect this // instance. - case !event.isSettle && + case event.state == channeldb.ContractOpen && client.addIndex+1 != invoice.AddIndex: ltndLog.Warnf("client=%v for invoice "+ "notifications missed an update, "+ "add_index=%v, new add event index=%v", clientID, client.addIndex, invoice.AddIndex) - case event.isSettle && + case event.state == channeldb.ContractSettled && client.settleIndex+1 != invoice.SettleIndex: ltndLog.Warnf("client=%v for invoice "+ "notifications missed an update, "+ @@ -174,8 +173,8 @@ func (i *invoiceRegistry) invoiceEventNotifier() { select { case client.ntfnQueue.ChanIn() <- &invoiceEvent{ - isSettle: event.isSettle, - invoice: invoice, + state: event.state, + invoice: invoice, }: case <-i.quit: return @@ -187,10 +186,14 @@ func (i *invoiceRegistry) invoiceEventNotifier() { // don't send a notification twice, which can // happen if a new event is added while we're // catching up a new client. - if event.isSettle { + switch event.state { + case channeldb.ContractSettled: client.settleIndex = invoice.SettleIndex - } else { + case channeldb.ContractOpen: client.addIndex = invoice.AddIndex + default: + ltndLog.Errorf("unknown invoice "+ + "state: %v", event.state) } } @@ -225,8 +228,8 @@ func (i *invoiceRegistry) deliverBacklogEvents(client *invoiceSubscription) erro select { case client.ntfnQueue.ChanIn() <- &invoiceEvent{ - isSettle: false, - invoice: &addEvent, + state: channeldb.ContractOpen, + invoice: &addEvent, }: case <-i.quit: return fmt.Errorf("registry shutting down") @@ -239,8 +242,8 @@ func (i *invoiceRegistry) deliverBacklogEvents(client *invoiceSubscription) erro select { case client.ntfnQueue.ChanIn() <- &invoiceEvent{ - isSettle: true, - invoice: &settleEvent, + state: channeldb.ContractSettled, + invoice: &settleEvent, }: case <-i.quit: return fmt.Errorf("registry shutting down") @@ -296,7 +299,7 @@ func (i *invoiceRegistry) AddInvoice(invoice *channeldb.Invoice) (uint64, error) // Now that we've added the invoice, we'll send dispatch a message to // notify the clients of this new invoice. - i.notifyClients(invoice, false) + i.notifyClients(invoice, channeldb.ContractOpen) return addIndex, nil } @@ -365,17 +368,19 @@ func (i *invoiceRegistry) SettleInvoice(rHash chainhash.Hash, ltndLog.Infof("Payment received: %v", spew.Sdump(invoice)) - i.notifyClients(invoice, true) + i.notifyClients(invoice, channeldb.ContractSettled) return nil } // notifyClients notifies all currently registered invoice notification clients // of a newly added/settled invoice. -func (i *invoiceRegistry) notifyClients(invoice *channeldb.Invoice, settle bool) { +func (i *invoiceRegistry) notifyClients(invoice *channeldb.Invoice, + state channeldb.ContractState) { + event := &invoiceEvent{ - isSettle: settle, - invoice: invoice, + state: state, + invoice: invoice, } select { @@ -483,9 +488,17 @@ func (i *invoiceRegistry) SubscribeNotifications(addIndex, settleIndex uint64) * case ntfn := <-client.ntfnQueue.ChanOut(): invoiceEvent := ntfn.(*invoiceEvent) - targetChan := client.NewInvoices - if invoiceEvent.isSettle { + var targetChan chan *channeldb.Invoice + switch invoiceEvent.state { + case channeldb.ContractOpen: + targetChan = client.NewInvoices + case channeldb.ContractSettled: targetChan = client.SettledInvoices + default: + ltndLog.Errorf("unknown invoice "+ + "state: %v", invoiceEvent.state) + + continue } select { diff --git a/rpcserver.go b/rpcserver.go index ca22c939..1f1a2de6 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -3303,6 +3303,8 @@ func createRPCInvoice(invoice *channeldb.Invoice) (*lnrpc.Invoice, error) { satAmt := invoice.Terms.Value.ToSatoshis() satAmtPaid := invoice.AmtPaid.ToSatoshis() + isSettled := invoice.Terms.State == channeldb.ContractSettled + return &lnrpc.Invoice{ Memo: string(invoice.Memo[:]), Receipt: invoice.Receipt[:], @@ -3311,7 +3313,7 @@ func createRPCInvoice(invoice *channeldb.Invoice) (*lnrpc.Invoice, error) { Value: int64(satAmt), CreationDate: invoice.CreationDate.Unix(), SettleDate: settleDate, - Settled: invoice.Terms.Settled, + Settled: isSettled, PaymentRequest: paymentRequest, DescriptionHash: descHash, Expiry: expiry,