channeldb+invoice: add state filter to HTLCSet

This commit is contained in:
Conner Fromknecht 2021-05-06 09:16:15 -07:00
parent 620e426bc3
commit d93c3298b7
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7
3 changed files with 16 additions and 17 deletions

View File

@ -1294,9 +1294,9 @@ func TestHTLCSet(t *testing.T) {
expSet2 := make(map[CircuitKey]*InvoiceHTLC) expSet2 := make(map[CircuitKey]*InvoiceHTLC)
checkHTLCSets := func() { checkHTLCSets := func() {
require.Equal(t, expSetNil, inv.HTLCSet(nil)) require.Equal(t, expSetNil, inv.HTLCSet(nil, HtlcStateAccepted))
require.Equal(t, expSet1, inv.HTLCSet(setID1)) require.Equal(t, expSet1, inv.HTLCSet(setID1, HtlcStateAccepted))
require.Equal(t, expSet2, inv.HTLCSet(setID2)) require.Equal(t, expSet2, inv.HTLCSet(setID2, HtlcStateAccepted))
} }
// All HTLC sets should be empty initially. // All HTLC sets should be empty initially.

View File

@ -471,17 +471,16 @@ type Invoice struct {
HodlInvoice bool HodlInvoice bool
} }
// HTLCSet returns the set of accepted HTLCs belonging to an invoice. Passing a // HTLCSet returns the set of HTLCs belonging to setID and in the provided
// nil setID will return all accepted HTLCs in the case of legacy or MPP, and no // state. Passing a nil setID will return all HTLCs in the provided state in the
// HTLCs in the case of AMP. Otherwise, the returned set will be filtered by // case of legacy or MPP, and no HTLCs in the case of AMP. Otherwise, the
// the populated setID which is used to retrieve AMP HTLC sets. // returned set will be filtered by the populated setID which is used to
func (i *Invoice) HTLCSet(setID *[32]byte) map[CircuitKey]*InvoiceHTLC { // retrieve AMP HTLC sets.
func (i *Invoice) HTLCSet(setID *[32]byte, state HtlcState) map[CircuitKey]*InvoiceHTLC {
htlcSet := make(map[CircuitKey]*InvoiceHTLC) htlcSet := make(map[CircuitKey]*InvoiceHTLC)
for key, htlc := range i.Htlcs { for key, htlc := range i.Htlcs {
// Only consider accepted mpp htlcs. It is possible that there // Only add HTLCs that are in the requested HtlcState.
// are htlcs registered in the invoice database that previously if htlc.State != state {
// timed out and are in the canceled state now.
if htlc.State != HtlcStateAccepted {
continue continue
} }
@ -2039,7 +2038,7 @@ func updateInvoiceState(invoice *Invoice, hash *lntypes.Hash,
// Sanity check that the user isn't trying to settle or accept a // Sanity check that the user isn't trying to settle or accept a
// non-existent HTLC set. // non-existent HTLC set.
if len(invoice.HTLCSet(update.SetID)) == 0 { if len(invoice.HTLCSet(update.SetID, HtlcStateAccepted)) == 0 {
return ErrEmptyHTLCSet return ErrEmptyHTLCSet
} }
@ -2329,8 +2328,8 @@ func (d *DB) DeleteInvoice(invoicesToDelete []InvoiceDeleteRef) error {
// invoice key. // invoice key.
key := invoiceAddIndex.Get(addIndexKey[:]) key := invoiceAddIndex.Get(addIndexKey[:])
if !bytes.Equal(key, invoiceKey) { if !bytes.Equal(key, invoiceKey) {
return fmt.Errorf("unknown invoice in " + return fmt.Errorf("unknown invoice " +
"add index") "in add index")
} }
// Remove from the add index. // Remove from the add index.

View File

@ -168,7 +168,7 @@ func updateMpp(ctx *invoiceUpdateCtx,
return nil, ctx.failRes(ResultHtlcSetTotalTooLow), nil return nil, ctx.failRes(ResultHtlcSetTotalTooLow), nil
} }
htlcSet := inv.HTLCSet(setID) htlcSet := inv.HTLCSet(setID, channeldb.HtlcStateAccepted)
// Check whether total amt matches other htlcs in the set. // Check whether total amt matches other htlcs in the set.
var newSetTotal lnwire.MilliSatoshi var newSetTotal lnwire.MilliSatoshi
@ -373,7 +373,7 @@ func updateLegacy(ctx *invoiceUpdateCtx,
// Don't allow settling the invoice with an old style // Don't allow settling the invoice with an old style
// htlc if we are already in the process of gathering an // htlc if we are already in the process of gathering an
// mpp set. // mpp set.
for _, htlc := range inv.HTLCSet(nil) { for _, htlc := range inv.HTLCSet(nil, channeldb.HtlcStateAccepted) {
if htlc.MppTotalAmt > 0 { if htlc.MppTotalAmt > 0 {
return nil, ctx.failRes(ResultMppInProgress), nil return nil, ctx.failRes(ResultMppInProgress), nil
} }