diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index ec54353d..fc61736a 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -1294,9 +1294,9 @@ func TestHTLCSet(t *testing.T) { expSet2 := make(map[CircuitKey]*InvoiceHTLC) checkHTLCSets := func() { - require.Equal(t, expSetNil, inv.HTLCSet(nil)) - require.Equal(t, expSet1, inv.HTLCSet(setID1)) - require.Equal(t, expSet2, inv.HTLCSet(setID2)) + require.Equal(t, expSetNil, inv.HTLCSet(nil, HtlcStateAccepted)) + require.Equal(t, expSet1, inv.HTLCSet(setID1, HtlcStateAccepted)) + require.Equal(t, expSet2, inv.HTLCSet(setID2, HtlcStateAccepted)) } // All HTLC sets should be empty initially. diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 146e8d88..292a972e 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -471,17 +471,16 @@ type Invoice struct { HodlInvoice bool } -// HTLCSet returns the set of accepted HTLCs belonging to an invoice. Passing a -// nil setID will return all accepted HTLCs in the case of legacy or MPP, and no -// HTLCs in the case of AMP. Otherwise, the returned set will be filtered by -// the populated setID which is used to retrieve AMP HTLC sets. -func (i *Invoice) HTLCSet(setID *[32]byte) map[CircuitKey]*InvoiceHTLC { +// HTLCSet returns the set of HTLCs belonging to setID and in the provided +// state. Passing a nil setID will return all HTLCs in the provided state in the +// case of legacy or MPP, and no HTLCs in the case of AMP. Otherwise, the +// returned set will be filtered by the populated setID which is used to +// retrieve AMP HTLC sets. +func (i *Invoice) HTLCSet(setID *[32]byte, state HtlcState) map[CircuitKey]*InvoiceHTLC { htlcSet := make(map[CircuitKey]*InvoiceHTLC) for key, htlc := range i.Htlcs { - // Only consider accepted mpp htlcs. It is possible that there - // are htlcs registered in the invoice database that previously - // timed out and are in the canceled state now. - if htlc.State != HtlcStateAccepted { + // Only add HTLCs that are in the requested HtlcState. + if htlc.State != state { continue } @@ -2039,7 +2038,7 @@ func updateInvoiceState(invoice *Invoice, hash *lntypes.Hash, // Sanity check that the user isn't trying to settle or accept a // non-existent HTLC set. - if len(invoice.HTLCSet(update.SetID)) == 0 { + if len(invoice.HTLCSet(update.SetID, HtlcStateAccepted)) == 0 { return ErrEmptyHTLCSet } @@ -2329,8 +2328,8 @@ func (d *DB) DeleteInvoice(invoicesToDelete []InvoiceDeleteRef) error { // invoice key. key := invoiceAddIndex.Get(addIndexKey[:]) if !bytes.Equal(key, invoiceKey) { - return fmt.Errorf("unknown invoice in " + - "add index") + return fmt.Errorf("unknown invoice " + + "in add index") } // Remove from the add index. diff --git a/invoices/update.go b/invoices/update.go index e10b92c6..b41bd1a5 100644 --- a/invoices/update.go +++ b/invoices/update.go @@ -168,7 +168,7 @@ func updateMpp(ctx *invoiceUpdateCtx, return nil, ctx.failRes(ResultHtlcSetTotalTooLow), nil } - htlcSet := inv.HTLCSet(setID) + htlcSet := inv.HTLCSet(setID, channeldb.HtlcStateAccepted) // Check whether total amt matches other htlcs in the set. var newSetTotal lnwire.MilliSatoshi @@ -373,7 +373,7 @@ func updateLegacy(ctx *invoiceUpdateCtx, // Don't allow settling the invoice with an old style // htlc if we are already in the process of gathering an // mpp set. - for _, htlc := range inv.HTLCSet(nil) { + for _, htlc := range inv.HTLCSet(nil, channeldb.HtlcStateAccepted) { if htlc.MppTotalAmt > 0 { return nil, ctx.failRes(ResultMppInProgress), nil }