diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index 4f594ee4..a9cf82a8 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -781,6 +781,67 @@ func (i *InvoiceRegistry) processKeySend(ctx invoiceUpdateCtx) error { return nil } +// processAMP just-in-time inserts an invoice if this htlc is a keysend +// htlc. +func (i *InvoiceRegistry) processAMP(ctx invoiceUpdateCtx) error { + // AMP payments MUST also include an MPP record. + if ctx.mpp == nil { + return errors.New("no MPP record for AMP") + } + + // Create an invoice for the total amount expected, provided in the MPP + // record. + amt := ctx.mpp.TotalMsat() + + // Set the TLV and MPP optional features on the invoice. We'll also make + // the AMP features required so that it can't be paid by legacy or MPP + // htlcs. + rawFeatures := lnwire.NewRawFeatureVector( + lnwire.TLVOnionPayloadOptional, + lnwire.PaymentAddrOptional, + lnwire.AMPRequired, + ) + features := lnwire.NewFeatureVector(rawFeatures, lnwire.Features) + + // Use the minimum block delta that we require for settling htlcs. + finalCltvDelta := i.cfg.FinalCltvRejectDelta + + // Pre-check expiry here to prevent inserting an invoice that will not + // be settled. + if ctx.expiry < uint32(ctx.currentHeight+finalCltvDelta) { + return errors.New("final expiry too soon") + } + + // We'll use the sender-generated payment address provided in the HTLC + // to create our AMP invoice. + payAddr := ctx.mpp.PaymentAddr() + + // Create placeholder invoice. + invoice := &channeldb.Invoice{ + CreationDate: i.cfg.Clock.Now(), + Terms: channeldb.ContractTerm{ + FinalCltvDelta: finalCltvDelta, + Value: amt, + PaymentPreimage: nil, + PaymentAddr: payAddr, + Features: features, + }, + } + + // Insert invoice into database. Ignore duplicates payment hashes and + // payment addrs, this may be a replay or a different HTLC for the AMP + // invoice. + _, err := i.AddInvoice(invoice, ctx.hash) + switch { + case err == channeldb.ErrDuplicateInvoice: + return nil + case err == channeldb.ErrDuplicatePayAddr: + return nil + default: + return err + } +} + // NotifyExitHopHtlc attempts to mark an invoice as settled. The return value // describes how the htlc should be resolved. // @@ -819,13 +880,24 @@ func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash, // AddInvoice obtains its own lock. This is no problem, because the // operation is idempotent. if i.cfg.AcceptKeySend { - err := i.processKeySend(ctx) - if err != nil { - ctx.log(fmt.Sprintf("keysend error: %v", err)) + if ctx.amp != nil { + err := i.processAMP(ctx) + if err != nil { + ctx.log(fmt.Sprintf("amp error: %v", err)) - return NewFailResolution( - circuitKey, currentHeight, ResultKeySendError, - ), nil + return NewFailResolution( + circuitKey, currentHeight, ResultAmpError, + ), nil + } + } else { + err := i.processKeySend(ctx) + if err != nil { + ctx.log(fmt.Sprintf("keysend error: %v", err)) + + return NewFailResolution( + circuitKey, currentHeight, ResultKeySendError, + ), nil + } } } diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index 2f4e811a..e27a60c7 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -1273,3 +1273,36 @@ func TestSettleInvoicePaymentAddrRequiredOptionalGrace(t *testing.T) { t.Fatal("no update received") } } + +// TestAMPWithoutMPPPayload asserts that we correctly reject an AMP HTLC that +// does not include an MPP record. +func TestAMPWithoutMPPPayload(t *testing.T) { + defer timeout()() + + ctx := newTestContext(t) + defer ctx.cleanup() + + ctx.registry.cfg.AcceptKeySend = true + + const ( + shardAmt = lnwire.MilliSatoshi(10) + expiry = uint32(testCurrentHeight + 20) + ) + + // Create payload with missing MPP field. + payload := &mockPayload{ + amp: record.NewAMP([32]byte{}, [32]byte{}, 0), + } + + hodlChan := make(chan interface{}, 1) + resolution, err := ctx.registry.NotifyExitHopHtlc( + lntypes.Hash{}, shardAmt, expiry, + testCurrentHeight, getCircuitKey(uint64(10)), hodlChan, + payload, + ) + require.NoError(t, err) + + // We should receive the ResultAmpError failure. + require.NotNil(t, resolution) + checkFailResolution(t, resolution, ResultAmpError) +} diff --git a/invoices/resolution_result.go b/invoices/resolution_result.go index b979d3ac..0aa2f646 100644 --- a/invoices/resolution_result.go +++ b/invoices/resolution_result.go @@ -105,6 +105,9 @@ const ( // ResultMppInProgress is returned when we are busy receiving a mpp // payment. ResultMppInProgress + + // ResultAmpError is returned when we receive invalid AMP parameters. + ResultAmpError ) // String returns a string representation of the result. @@ -162,6 +165,9 @@ func (f FailResolutionResult) FailureString() string { case ResultMppInProgress: return "mpp reception in progress" + case ResultAmpError: + return "invalid amp parameters" + default: return "unknown failure resolution result" } diff --git a/invoices/test_utils_test.go b/invoices/test_utils_test.go index c51d3e7b..6a454a9e 100644 --- a/invoices/test_utils_test.go +++ b/invoices/test_utils_test.go @@ -46,6 +46,16 @@ func (p *mockPayload) CustomRecords() record.CustomSet { return p.customRecords } +const ( + testHtlcExpiry = uint32(5) + + testInvoiceCltvDelta = uint32(4) + + testFinalCltvRejectDelta = int32(4) + + testCurrentHeight = int32(1) +) + var ( testTimeout = 5 * time.Second @@ -55,14 +65,6 @@ var ( testInvoicePaymentHash = testInvoicePreimage.Hash() - testHtlcExpiry = uint32(5) - - testInvoiceCltvDelta = uint32(4) - - testFinalCltvRejectDelta = int32(4) - - testCurrentHeight = int32(1) - testPrivKeyBytes, _ = hex.DecodeString( "e126f68f7eafcc8b74f54d269fe206be715000f94dac067d1c04a8ca3b2db734")