diff --git a/invoices/update.go b/invoices/update.go index 41bd690b..45435ca8 100644 --- a/invoices/update.go +++ b/invoices/update.go @@ -33,6 +33,17 @@ func (i *invoiceUpdateCtx) invoiceRef() channeldb.InvoiceRef { return channeldb.InvoiceRefByHash(i.hash) } +// setID returns an identifier that identifies other possible HTLCs that this +// particular one is related to. If nil is returned this means the HTLC is an +// MPP or legacy payment, otherwise the HTLC belongs AMP payment. +func (i invoiceUpdateCtx) setID() *[32]byte { + if i.amp != nil { + setID := i.amp.SetID() + return &setID + } + return nil +} + // log logs a message specific to this update context. func (i *invoiceUpdateCtx) log(s string) { log.Debugf("Invoice%v: %v, amt=%v, expiry=%v, circuit=%v, mpp=%v, "+ @@ -108,6 +119,8 @@ func updateMpp(ctx *invoiceUpdateCtx, inv *channeldb.Invoice) (*channeldb.InvoiceUpdateDesc, HtlcResolution, error) { + setID := ctx.setID() + // Start building the accept descriptor. acceptDesc := &channeldb.HtlcAcceptDesc{ Amt: ctx.amtPaid, @@ -143,7 +156,7 @@ func updateMpp(ctx *invoiceUpdateCtx, // Check whether total amt matches other htlcs in the set. var newSetTotal lnwire.MilliSatoshi - for _, htlc := range inv.HTLCSet(nil) { + for _, htlc := range inv.HTLCSet(setID) { if ctx.mpp.TotalMsat() != htlc.MppTotalAmt { return nil, ctx.failRes(ResultHtlcSetTotalMismatch), nil } @@ -188,6 +201,7 @@ func updateMpp(ctx *invoiceUpdateCtx, if inv.HodlInvoice { update.State = &channeldb.InvoiceStateUpdateDesc{ NewState: channeldb.ContractAccepted, + SetID: setID, } return &update, ctx.acceptRes(resultAccepted), nil } @@ -195,6 +209,7 @@ func updateMpp(ctx *invoiceUpdateCtx, update.State = &channeldb.InvoiceStateUpdateDesc{ NewState: channeldb.ContractSettled, Preimage: inv.Terms.PaymentPreimage, + SetID: setID, } return &update, ctx.settleRes(