diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index 45d072b5..4e01f52e 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -1216,6 +1216,94 @@ func TestInvoiceRef(t *testing.T) { require.Equal(t, &payAddr, refByHashAndAddr.PayAddr()) } +// TestHTLCSet asserts that HTLCSet returns the proper set of accepted HTLCs +// that can be considered for settlement. It asserts that MPP and AMP HTLCs do +// not comingle, and also that HTLCs with disjoint set ids appear in different +// sets. +func TestHTLCSet(t *testing.T) { + inv := &Invoice{ + Htlcs: make(map[CircuitKey]*InvoiceHTLC), + } + + // Construct two distinct set id's, in this test we'll also track the + // nil set id as a third group. + setID1 := &[32]byte{1} + setID2 := &[32]byte{2} + + // Create the expected htlc sets for each group, these will be updated + // as the invoice is modified. + expSetNil := make(map[CircuitKey]*InvoiceHTLC) + expSet1 := make(map[CircuitKey]*InvoiceHTLC) + expSet2 := make(map[CircuitKey]*InvoiceHTLC) + + checkHTLCSets := func() { + require.Equal(t, expSetNil, inv.HTLCSet(nil)) + require.Equal(t, expSet1, inv.HTLCSet(setID1)) + require.Equal(t, expSet2, inv.HTLCSet(setID2)) + } + + // All HTLC sets should be empty initially. + checkHTLCSets() + + // Add the following sequence of HTLCs to the invoice, sanity checking + // all three HTLC sets after each transition. This sequence asserts: + // - both nil and non-nil set ids can have multiple htlcs. + // - there may be distinct htlc sets with non-nil set ids. + // - only accepted htlcs are returned as part of the set. + htlcs := []struct { + setID *[32]byte + state HtlcState + }{ + {nil, HtlcStateAccepted}, + {nil, HtlcStateAccepted}, + {setID1, HtlcStateAccepted}, + {setID1, HtlcStateAccepted}, + {setID2, HtlcStateAccepted}, + {setID2, HtlcStateAccepted}, + {nil, HtlcStateCanceled}, + {setID1, HtlcStateCanceled}, + {setID2, HtlcStateCanceled}, + {nil, HtlcStateSettled}, + {setID1, HtlcStateSettled}, + {setID2, HtlcStateSettled}, + } + + for i, h := range htlcs { + var ampData *InvoiceHtlcAMPData + if h.setID != nil { + ampData = &InvoiceHtlcAMPData{ + Record: *record.NewAMP([32]byte{0}, *h.setID, 0), + } + + } + + // Add the HTLC to the invoice's set of HTLCs. + key := CircuitKey{HtlcID: uint64(i)} + htlc := &InvoiceHTLC{ + AMP: ampData, + State: h.state, + } + inv.Htlcs[key] = htlc + + // Update our expected htlc set if the htlc is accepted, + // otherwise it shouldn't be reflected. + if h.state == HtlcStateAccepted { + switch h.setID { + case nil: + expSetNil[key] = htlc + case setID1: + expSet1[key] = htlc + case setID2: + expSet2[key] = htlc + default: + t.Fatalf("unexpected set id") + } + } + + checkHTLCSets() + } +} + // TestDeleteInvoices tests that deleting a list of invoices will succeed // if all delete references are valid, or will fail otherwise. func TestDeleteInvoices(t *testing.T) { diff --git a/channeldb/invoices.go b/channeldb/invoices.go index cd8ffc83..bcff0629 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -362,6 +362,30 @@ type Invoice struct { HodlInvoice bool } +// HTLCSet returns the set of accepted HTLCs belonging to an invoice. Passing a +// nil setID will return all accepted HTLCs in the case of legacy or MPP, and no +// HTLCs in the case of AMP. Otherwise, the returned set will be filtered by +// the populated setID which is used to retrieve AMP HTLC sets. +func (i *Invoice) HTLCSet(setID *[32]byte) map[CircuitKey]*InvoiceHTLC { + htlcSet := make(map[CircuitKey]*InvoiceHTLC) + for key, htlc := range i.Htlcs { + // Only consider accepted mpp htlcs. It is possible that there + // are htlcs registered in the invoice database that previously + // timed out and are in the canceled state now. + if htlc.State != HtlcStateAccepted { + continue + } + + if !htlc.IsInHTLCSet(setID) { + continue + } + + htlcSet[key] = htlc + } + + return htlcSet +} + // HtlcState defines the states an htlc paying to an invoice can be in. type HtlcState uint8 @@ -420,6 +444,26 @@ type InvoiceHTLC struct { AMP *InvoiceHtlcAMPData } +// IsInHTLCSet returns true if this HTLC is part an HTLC set. If nil is passed, +// this method returns true if this is an MPP HTLC. Otherwise, it only returns +// true if the AMP HTLC's set id matches the populated setID. +func (h *InvoiceHTLC) IsInHTLCSet(setID *[32]byte) bool { + wantAMPSet := setID != nil + isAMPHtlc := h.AMP != nil + + // Non-AMP HTLCs cannot be part of AMP HTLC sets, and vice versa. + if wantAMPSet != isAMPHtlc { + return false + } + + // Skip AMP HTLCs that have differing set ids. + if isAMPHtlc && *setID != h.AMP.Record.SetID() { + return false + } + + return true +} + // InvoiceHtlcAMPData is a struct hodling the additional metadata stored for // each received AMP HTLC. This includes the AMP onion record, in addition to // the HTLC's payment hash and preimage. diff --git a/invoices/update.go b/invoices/update.go index f7eaf4e4..41bd690b 100644 --- a/invoices/update.go +++ b/invoices/update.go @@ -143,14 +143,7 @@ func updateMpp(ctx *invoiceUpdateCtx, // Check whether total amt matches other htlcs in the set. var newSetTotal lnwire.MilliSatoshi - for _, htlc := range inv.Htlcs { - // Only consider accepted mpp htlcs. It is possible that there - // are htlcs registered in the invoice database that previously - // timed out and are in the canceled state now. - if htlc.State != channeldb.HtlcStateAccepted { - continue - } - + for _, htlc := range inv.HTLCSet(nil) { if ctx.mpp.TotalMsat() != htlc.MppTotalAmt { return nil, ctx.failRes(ResultHtlcSetTotalMismatch), nil } @@ -250,10 +243,8 @@ func updateLegacy(ctx *invoiceUpdateCtx, // Don't allow settling the invoice with an old style // htlc if we are already in the process of gathering an // mpp set. - for _, htlc := range inv.Htlcs { - if htlc.State == channeldb.HtlcStateAccepted && - htlc.MppTotalAmt > 0 { - + for _, htlc := range inv.HTLCSet(nil) { + if htlc.MppTotalAmt > 0 { return nil, ctx.failRes(ResultMppInProgress), nil } }