diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index 4f3ee5dd..e752c30b 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -314,3 +314,64 @@ func TestInvoiceAddTimeSeries(t *testing.T) { } } } + +// TestDuplicateSettleInvoice tests that if we add a new invoice and settle it +// twice, then the second time we also receive the invoice that we settled as a +// return argument. +func TestDuplicateSettleInvoice(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test db: %v", err) + } + + // We'll start out by creating an invoice and writing it to the DB. + amt := lnwire.NewMSatFromSatoshis(1000) + invoice, err := randInvoice(amt) + if err != nil { + t.Fatalf("unable to create invoice: %v", err) + } + + if _, err := db.AddInvoice(invoice); err != nil { + t.Fatalf("unable to add invoice %v", err) + } + + // With the invoice in the DB, we'll now attempt to settle the invoice. + payHash := sha256.Sum256(invoice.Terms.PaymentPreimage[:]) + dbInvoice, err := db.SettleInvoice(payHash, amt) + if err != nil { + t.Fatalf("unable to settle invoice: %v", err) + } + + // We'll update what we expect the settle invoice to be so that our + // comparison below has the correct assumption. + invoice.SettleIndex = 1 + invoice.Terms.Settled = true + invoice.AmtPaid = amt + invoice.SettleDate = dbInvoice.SettleDate + + // We should get back the exact same invoice that we just inserted. + if !reflect.DeepEqual(dbInvoice, invoice) { + t.Fatalf("wrong invoice after settle, expected %v got %v", + spew.Sdump(invoice), spew.Sdump(dbInvoice)) + } + + // If we try to settle the invoice again, then we should get the very + // same invoice back. + dbInvoice, err = db.SettleInvoice(payHash, amt) + if err != nil { + t.Fatalf("unable to settle invoice: %v", err) + } + + if dbInvoice == nil { + t.Fatalf("invoice from db is nil after settle!") + } + + invoice.SettleDate = dbInvoice.SettleDate + if !reflect.DeepEqual(dbInvoice, invoice) { + t.Fatalf("wrong invoice after second settle, expected %v got %v", + spew.Sdump(invoice), spew.Sdump(dbInvoice)) + } +} diff --git a/channeldb/invoices.go b/channeldb/invoices.go index d315c4a0..feb07442 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -8,9 +8,9 @@ import ( "io" "time" + "github.com/btcsuite/btcd/wire" "github.com/coreos/bbolt" "github.com/lightningnetwork/lnd/lnwire" - "github.com/btcsuite/btcd/wire" ) var ( @@ -704,7 +704,7 @@ func settleInvoice(invoices, settleIndex *bolt.Bucket, invoiceNum []byte, // Add idempotency to duplicate settles, return here to avoid // overwriting the previous info. if invoice.Terms.Settled { - return nil, nil + return &invoice, nil } // Now that we know the invoice hasn't already been settled, we'll