diff --git a/channeldb/db.go b/channeldb/db.go index 76d0c37a..9f593513 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -189,6 +189,12 @@ var ( number: 21, migration: migration21.MigrateDatabaseWireMessages, }, + { + // Initialize set id index so that invoices can be + // queried by individual htlc sets. + number: 22, + migration: mig.CreateTLB(setIDIndexBucket), + }, } // Big endian is the preferred byte order, due to cursor scans over @@ -319,6 +325,7 @@ var topLevelBuckets = [][]byte{ fwdPackagesKey, invoiceBucket, payAddrIndexBucket, + setIDIndexBucket, paymentsIndexBucket, peersBucket, nodeInfoBucket, diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index a66d9f25..52b39742 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" @@ -1202,18 +1203,28 @@ func testInvoiceHtlcAMPFields(t *testing.T, isAMP bool) { func TestInvoiceRef(t *testing.T) { payHash := lntypes.Hash{0x01} payAddr := [32]byte{0x02} + setID := [32]byte{0x03} // An InvoiceRef by hash should return the provided hash and a nil // payment addr. refByHash := InvoiceRefByHash(payHash) require.Equal(t, payHash, refByHash.PayHash()) require.Equal(t, (*[32]byte)(nil), refByHash.PayAddr()) + require.Equal(t, (*[32]byte)(nil), refByHash.SetID()) // An InvoiceRef by hash and addr should return the payment hash and // payment addr passed to the constructor. refByHashAndAddr := InvoiceRefByHashAndAddr(payHash, payAddr) require.Equal(t, payHash, refByHashAndAddr.PayHash()) require.Equal(t, &payAddr, refByHashAndAddr.PayAddr()) + require.Equal(t, (*[32]byte)(nil), refByHashAndAddr.SetID()) + + // An InvoiceRef by set id should return an empty pay hash, a nil pay + // addr, and a reference to the given set id. + refBySetID := InvoiceRefBySetID(setID) + require.Equal(t, lntypes.Hash{}, refBySetID.PayHash()) + require.Equal(t, (*[32]byte)(nil), refBySetID.PayAddr()) + require.Equal(t, &setID, refBySetID.SetID()) } // TestHTLCSet asserts that HTLCSet returns the proper set of accepted HTLCs @@ -1322,6 +1333,157 @@ func TestAddInvoiceWithHTLCs(t *testing.T) { require.Equal(t, ErrInvoiceHasHtlcs, err) } +// TestSetIDIndex asserts that the set id index properly adds new invoices as we +// accept HTLCs, that they can be queried by their set id after accepting, and +// that invoices with duplicate set ids are disallowed. +func TestSetIDIndex(t *testing.T) { + testClock := clock.NewTestClock(testNow) + db, cleanUp, err := MakeTestDB(OptionClock(testClock)) + defer cleanUp() + require.Nil(t, err) + + // We'll start out by creating an invoice and writing it to the DB. + amt := lnwire.NewMSatFromSatoshis(1000) + invoice, err := randInvoice(amt) + require.Nil(t, err) + + preimage := *invoice.Terms.PaymentPreimage + payHash := preimage.Hash() + _, err = db.AddInvoice(invoice, payHash) + require.Nil(t, err) + + setID := &[32]byte{1} + + // Update the invoice with an accepted HTLC that also accepts the + // invoice. + ref := InvoiceRefByHashAndAddr(payHash, invoice.Terms.PaymentAddr) + dbInvoice, err := db.UpdateInvoice(ref, updateAcceptAMPHtlc(0, amt, setID, true)) + require.Nil(t, err) + + // We'll update what we expect the accepted invoice to be so that our + // comparison below has the correct assumption. + invoice.State = ContractAccepted + invoice.AmtPaid = amt + invoice.SettleDate = dbInvoice.SettleDate + invoice.Htlcs = map[CircuitKey]*InvoiceHTLC{ + {HtlcID: 0}: makeAMPInvoiceHTLC(amt, *setID, preimage), + } + + // We should get back the exact same invoice that we just inserted. + require.Equal(t, invoice, dbInvoice) + + // Now lookup the invoice by set id and see that we get the same one. + refBySetID := InvoiceRefBySetID(*setID) + dbInvoiceBySetID, err := db.LookupInvoice(refBySetID) + require.Nil(t, err) + require.Equal(t, invoice, &dbInvoiceBySetID) + + // Trying to accept an HTLC to a different invoice, but using the same + // set id should fail. + invoice2, err := randInvoice(amt) + require.Nil(t, err) + + payHash2 := invoice2.Terms.PaymentPreimage.Hash() + _, err = db.AddInvoice(invoice2, payHash2) + require.Nil(t, err) + + ref2 := InvoiceRefByHashAndAddr(payHash2, invoice2.Terms.PaymentAddr) + _, err = db.UpdateInvoice(ref2, updateAcceptAMPHtlc(0, amt, setID, true)) + require.Equal(t, ErrDuplicateSetID{setID: *setID}, err) + + // Now, begin constructing a second htlc set under a different set id. + // This set will contain two distinct HTLCs. + setID2 := &[32]byte{2} + + _, err = db.UpdateInvoice(ref, updateAcceptAMPHtlc(1, amt, setID2, false)) + require.Nil(t, err) + dbInvoice, err = db.UpdateInvoice(ref, updateAcceptAMPHtlc(2, amt, setID2, false)) + require.Nil(t, err) + + // We'll update what we expect the settle invoice to be so that our + // comparison below has the correct assumption. + invoice.State = ContractAccepted + invoice.AmtPaid += 2 * amt + invoice.SettleDate = dbInvoice.SettleDate + invoice.Htlcs = map[CircuitKey]*InvoiceHTLC{ + {HtlcID: 0}: makeAMPInvoiceHTLC(amt, *setID, preimage), + {HtlcID: 1}: makeAMPInvoiceHTLC(amt, *setID2, preimage), + {HtlcID: 2}: makeAMPInvoiceHTLC(amt, *setID2, preimage), + } + + // We should get back the exact same invoice that we just inserted. + require.Equal(t, invoice, dbInvoice) + + // Now lookup the invoice by second set id and see that we get the same + // index, including the htlcs under the first set id. + refBySetID = InvoiceRefBySetID(*setID2) + dbInvoiceBySetID, err = db.LookupInvoice(refBySetID) + require.Nil(t, err) + require.Equal(t, invoice, &dbInvoiceBySetID) + + // Lastly, querying for an unknown set id should fail. + refUnknownSetID := InvoiceRefBySetID([32]byte{}) + _, err = db.LookupInvoice(refUnknownSetID) + require.Equal(t, ErrInvoiceNotFound, err) +} + +func makeAMPInvoiceHTLC(amt lnwire.MilliSatoshi, setID [32]byte, + preimage lntypes.Preimage) *InvoiceHTLC { + + return &InvoiceHTLC{ + Amt: amt, + AcceptTime: testNow, + ResolveTime: time.Time{}, + State: HtlcStateAccepted, + CustomRecords: make(record.CustomSet), + AMP: &InvoiceHtlcAMPData{ + Record: *record.NewAMP([32]byte{}, setID, 0), + Hash: preimage.Hash(), + Preimage: &preimage, + }, + } +} + +// updateAcceptAMPHtlc returns an invoice update callback that, when called, +// settles the invoice with the given amount. +func updateAcceptAMPHtlc(id uint64, amt lnwire.MilliSatoshi, + setID *[32]byte, accept bool) InvoiceUpdateCallback { + + return func(invoice *Invoice) (*InvoiceUpdateDesc, error) { + if invoice.State == ContractSettled { + return nil, ErrInvoiceAlreadySettled + } + + noRecords := make(record.CustomSet) + + var state *InvoiceStateUpdateDesc + if accept { + state = &InvoiceStateUpdateDesc{ + NewState: ContractAccepted, + SetID: setID, + } + } + + ampData := &InvoiceHtlcAMPData{ + Record: *record.NewAMP([32]byte{}, *setID, 0), + Hash: invoice.Terms.PaymentPreimage.Hash(), + Preimage: invoice.Terms.PaymentPreimage, + } + update := &InvoiceUpdateDesc{ + State: state, + AddHtlcs: map[CircuitKey]*HtlcAcceptDesc{ + {HtlcID: id}: { + Amt: amt, + CustomRecords: noRecords, + AMP: ampData, + }, + }, + } + + return update, nil + } +} + // TestDeleteInvoices tests that deleting a list of invoices will succeed // if all delete references are valid, or will fail otherwise. func TestDeleteInvoices(t *testing.T) { @@ -1413,4 +1575,5 @@ func TestDeleteInvoices(t *testing.T) { // Delete should succeed with all the valid references. require.NoError(t, db.DeleteInvoice(invoicesToDelete)) assertInvoiceCount(0) + } diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 48f01b83..e81a8fa8 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -53,6 +53,14 @@ var ( // maps: payAddr => invoiceKey payAddrIndexBucket = []byte("pay-addr-index") + // setIDIndexBucket is the name of the top-level bucket that maps set + // ids to their invoice number. This can be used to efficiently query or + // update AMP invoice. Note that legacy or MPP invoices will not be + // included in this index, since their HTLCs do not have a set id. + // + // maps: setID => invoiceKey + setIDIndexBucket = []byte("set-id-index") + // numInvoicesKey is the name of key which houses the auto-incrementing // invoice ID which is essentially used as a primary key. With each // invoice inserted, the primary key is incremented by one. This key is @@ -112,6 +120,17 @@ var ( ErrInvoiceHasHtlcs = errors.New("cannot add invoice with htlcs") ) +// ErrDuplicateSetID is an error returned when attempting to adding an AMP HTLC +// to an invoice, but another invoice is already indexed by the same set id. +type ErrDuplicateSetID struct { + setID [32]byte +} + +// Error returns a human-readable description of ErrDuplicateSetID. +func (e ErrDuplicateSetID) Error() string { + return fmt.Sprintf("invoice with set_id=%x already exists", e.setID) +} + const ( // MaxMemoSize is maximum size of the memo field within invoices stored // in the database. @@ -183,6 +202,14 @@ type InvoiceRef struct { // known it will be used as the primary identifier, falling back to // payHash if no value is known. payAddr *[32]byte + + // setID is the optional set id for an AMP payment. This can be used to + // lookup or update the invoice knowing only this value. Queries by set + // id are only used to facilitate user-facing requests, e.g. lookup, + // settle or cancel an AMP invoice. The regular update flow from the + // invoice registry will always query for the invoice by + // payHash+payAddr. + setID *[32]byte } // InvoiceRefByHash creates an InvoiceRef that queries for an invoice only by @@ -205,6 +232,15 @@ func InvoiceRefByHashAndAddr(payHash lntypes.Hash, } } +// InvoiceRefBySetID creates an InvoiceRef that queries the set id index for an +// invoice with the provided setID. If the invoice is not found, the query will +// not fallback to payHash or payAddr. +func InvoiceRefBySetID(setID [32]byte) InvoiceRef { + return InvoiceRef{ + setID: &setID, + } +} + // PayHash returns the target invoice's payment hash. func (r InvoiceRef) PayHash() lntypes.Hash { return r.payHash @@ -221,6 +257,17 @@ func (r InvoiceRef) PayAddr() *[32]byte { return nil } +// SetID returns the optional set id of the target invoice. +// +// NOTE: This value may be nil. +func (r InvoiceRef) SetID() *[32]byte { + if r.setID != nil { + id := *r.setID + return &id + } + return nil +} + // String returns a human-readable representation of an InvoiceRef. func (r InvoiceRef) String() string { if r.payAddr != nil { @@ -564,6 +611,11 @@ type InvoiceStateUpdateDesc struct { // Preimage must be set to the preimage when NewState is settled. Preimage *lntypes.Preimage + + // SetID identifies a specific set of HTLCs destined for the same + // invoice as part of a larger AMP payment. This value will be nil for + // legacy or MPP payments. + SetID *[32]byte } // InvoiceUpdateCallback is a callback used in the db transaction to update the @@ -772,11 +824,12 @@ func (d *DB) LookupInvoice(ref InvoiceRef) (Invoice, error) { return ErrNoInvoicesCreated } payAddrIndex := tx.ReadBucket(payAddrIndexBucket) + setIDIndex := tx.ReadBucket(setIDIndexBucket) - // Retrieve the invoice number for this invoice using the - // provided invoice reference. + // Retrieve the invoice number for this invoice using + // the provided invoice reference. invoiceNum, err := fetchInvoiceNumByRef( - invoiceIndex, payAddrIndex, ref, + invoiceIndex, payAddrIndex, setIDIndex, ref, ) if err != nil { return err @@ -803,9 +856,22 @@ func (d *DB) LookupInvoice(ref InvoiceRef) (Invoice, error) { // reference. The payment address will be treated as the primary key, falling // back to the payment hash if nothing is found for the payment address. An // error is returned if the invoice is not found. -func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex kvdb.RBucket, +func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex, setIDIndex kvdb.RBucket, ref InvoiceRef) ([]byte, error) { + // If the set id is present, we only consult the set id index for this + // invoice. This type of query is only used to facilitate user-facing + // requests to lookup, settle or cancel an AMP invoice. + setID := ref.SetID() + if setID != nil { + invoiceNumBySetID := setIDIndex.Get(setID[:]) + if invoiceNumBySetID == nil { + return nil, ErrInvoiceNotFound + } + + return invoiceNumBySetID, nil + } + payHash := ref.PayHash() payAddr := ref.PayAddr() @@ -1053,20 +1119,21 @@ func (d *DB) UpdateInvoice(ref InvoiceRef, return err } payAddrIndex := tx.ReadBucket(payAddrIndexBucket) + setIDIndex := tx.ReadWriteBucket(setIDIndexBucket) // Retrieve the invoice number for this invoice using the // provided invoice reference. invoiceNum, err := fetchInvoiceNumByRef( - invoiceIndex, payAddrIndex, ref, + invoiceIndex, payAddrIndex, setIDIndex, ref, ) if err != nil { return err - } + payHash := ref.PayHash() updatedInvoice, err = d.updateInvoice( - payHash, invoices, settleIndex, invoiceNum, - callback, + payHash, invoices, settleIndex, setIDIndex, + invoiceNum, callback, ) return err @@ -1662,8 +1729,9 @@ func copyInvoice(src *Invoice) *Invoice { // updateInvoice fetches the invoice, obtains the update descriptor from the // callback and applies the updates in a single db transaction. -func (d *DB) updateInvoice(hash lntypes.Hash, invoices, settleIndex kvdb.RwBucket, - invoiceNum []byte, callback InvoiceUpdateCallback) (*Invoice, error) { +func (d *DB) updateInvoice(hash lntypes.Hash, invoices, + settleIndex, setIDIndex kvdb.RwBucket, invoiceNum []byte, + callback InvoiceUpdateCallback) (*Invoice, error) { invoice, err := fetchInvoice(invoiceNum, invoices) if err != nil { @@ -1717,6 +1785,22 @@ func (d *DB) updateInvoice(hash lntypes.Hash, invoices, settleIndex kvdb.RwBucke return nil, errors.New("nil custom records map") } + // If a newly added HTLC has an associated set id, use it to + // index this invoice in the set id index. An error is returned + // if we find the index already points to a different invoice. + if htlcUpdate.AMP != nil { + setID := htlcUpdate.AMP.Record.SetID() + setIDInvNum := setIDIndex.Get(setID[:]) + if setIDInvNum == nil { + err = setIDIndex.Put(setID[:], invoiceNum) + if err != nil { + return nil, err + } + } else if !bytes.Equal(setIDInvNum, invoiceNum) { + return nil, ErrDuplicateSetID{setID: setID} + } + } + htlc := &InvoiceHTLC{ Amt: htlcUpdate.Amt, MppTotalAmt: htlcUpdate.MppTotalAmt,