diff --git a/channeldb/payment_control_test.go b/channeldb/payment_control_test.go index 63499d71..e13f9d72 100644 --- a/channeldb/payment_control_test.go +++ b/channeldb/payment_control_test.go @@ -480,7 +480,7 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) { } // Delete all failed payments. - if err := db.DeletePayments(true); err != nil { + if err := db.DeletePayments(true, false); err != nil { t.Fatal(err) } @@ -516,7 +516,7 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) { } // Now delete all payments except in-flight. - if err := db.DeletePayments(false); err != nil { + if err := db.DeletePayments(false, false); err != nil { t.Fatal(err) } @@ -553,6 +553,223 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) { require.Equal(t, 1, indexCount) } +// TestPaymentControlDeletePayments tests that DeletePayments correcly deletes +// information about completed payments from the database. +func TestPaymentControlDeletePayments(t *testing.T) { + t.Parallel() + + db, cleanup, err := MakeTestDB() + defer cleanup() + + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(db) + + // Register three payments: + // 1. A payment with two failed attempts. + // 2. A Payment with one failed and one settled attempt. + // 3. A payment with one failed and one in-flight attempt. + attemptID := uint64(0) + for i := 0; i < 3; i++ { + info, attempt, preimg, err := genInfo() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + attempt.AttemptID = attemptID + attemptID++ + + // Init the payment. + err = pControl.InitPayment(info.PaymentHash, info) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + // Register and fail the first attempt for all three payments. + _, err = pControl.RegisterAttempt(info.PaymentHash, attempt) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + htlcFailure := HTLCFailUnreadable + _, err = pControl.FailAttempt( + info.PaymentHash, attempt.AttemptID, + &HTLCFailInfo{ + Reason: htlcFailure, + }, + ) + if err != nil { + t.Fatalf("unable to fail htlc: %v", err) + } + + // Depending on the test case, fail or succeed the next + // attempt. + attempt.AttemptID = attemptID + attemptID++ + + _, err = pControl.RegisterAttempt(info.PaymentHash, attempt) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + switch i { + + // Fail the attempt and the payment overall. + case 0: + htlcFailure := HTLCFailUnreadable + _, err = pControl.FailAttempt( + info.PaymentHash, attempt.AttemptID, + &HTLCFailInfo{ + Reason: htlcFailure, + }, + ) + if err != nil { + t.Fatalf("unable to fail htlc: %v", err) + } + + failReason := FailureReasonNoRoute + _, err = pControl.Fail(info.PaymentHash, failReason) + if err != nil { + t.Fatalf("unable to fail payment hash: %v", err) + } + + // Settle the attempt + case 1: + _, err := pControl.SettleAttempt( + info.PaymentHash, attempt.AttemptID, + &HTLCSettleInfo{ + Preimage: preimg, + }, + ) + if err != nil { + t.Fatalf("error shouldn't have been received, got: %v", err) + } + + // We leave the attmpet in-flight by doing nothing. + case 2: + } + } + + type fetchedPayment struct { + status PaymentStatus + htlcs int + } + + assertPayments := func(expPayments []fetchedPayment) { + t.Helper() + + dbPayments, err := db.FetchPayments() + if err != nil { + t.Fatal(err) + } + + if len(dbPayments) != len(expPayments) { + t.Fatalf("expected %d payments, got %d", + len(expPayments), len(dbPayments)) + } + + for i := range dbPayments { + if dbPayments[i].Status != expPayments[i].status { + t.Fatalf("unexpected payment status") + } + + if len(dbPayments[i].HTLCs) != expPayments[i].htlcs { + t.Fatalf("unexpected number of htlcs") + } + + } + } + + // Check that all payments are there as we added them. + assertPayments([]fetchedPayment{ + { + status: StatusFailed, + htlcs: 2, + }, + { + status: StatusSucceeded, + htlcs: 2, + }, + { + status: StatusInFlight, + htlcs: 2, + }, + }) + + // Delete HTLC attempts for failed payments only. + if err := db.DeletePayments(true, true); err != nil { + t.Fatal(err) + } + + // The failed payment is the only altered one. + assertPayments([]fetchedPayment{ + { + status: StatusFailed, + htlcs: 0, + }, + { + status: StatusSucceeded, + htlcs: 2, + }, + { + status: StatusInFlight, + htlcs: 2, + }, + }) + + // Delete failed attempts for all payments. + if err := db.DeletePayments(false, true); err != nil { + t.Fatal(err) + } + + // The failed attempts should be deleted, except for the in-flight + // payment, that shouldn't be altered until it has completed. + assertPayments([]fetchedPayment{ + { + status: StatusFailed, + htlcs: 0, + }, + { + status: StatusSucceeded, + htlcs: 1, + }, + { + status: StatusInFlight, + htlcs: 2, + }, + }) + + // Now delete all failed payments. + if err := db.DeletePayments(true, false); err != nil { + t.Fatal(err) + } + + assertPayments([]fetchedPayment{ + { + status: StatusSucceeded, + htlcs: 1, + }, + { + status: StatusInFlight, + htlcs: 2, + }, + }) + + // Finally delete all completed payments. + if err := db.DeletePayments(false, false); err != nil { + t.Fatal(err) + } + + assertPayments([]fetchedPayment{ + { + status: StatusInFlight, + htlcs: 2, + }, + }) +} + // TestPaymentControlMultiShard checks the ability of payment control to // have multiple in-flight HTLCs for a single payment. func TestPaymentControlMultiShard(t *testing.T) { diff --git a/channeldb/payments.go b/channeldb/payments.go index 93556438..7e659d05 100644 --- a/channeldb/payments.go +++ b/channeldb/payments.go @@ -676,8 +676,11 @@ func fetchPaymentWithSequenceNumber(tx kvdb.RTx, paymentHash lntypes.Hash, return duplicatePayment, nil } -// DeletePayments deletes all completed and failed payments from the DB. -func (db *DB) DeletePayments(failedOnly bool) error { +// DeletePayments deletes all completed and failed payments from the DB. If +// failedOnly is set, only failed payments will be considered for deletion. If +// failedHtlsOnly is set, the payment itself won't be deleted, only failed HTLC +// attempts. +func (db *DB) DeletePayments(failedOnly, failedHtlcsOnly bool) error { return kvdb.Update(db, func(tx kvdb.RwTx) error { payments := tx.ReadWriteBucket(paymentsRootBucket) if payments == nil { @@ -692,6 +695,10 @@ func (db *DB) DeletePayments(failedOnly bool) error { // deleteIndexes is the set of indexes pointing to these // payments that need to be deleted. deleteIndexes [][]byte + + // deleteHtlcs maps a payment hash to the HTLC IDs we + // want to delete for that payment. + deleteHtlcs = make(map[lntypes.Hash][][]byte) ) err := payments.ForEach(func(k, _ []byte) error { bucket := payments.NestedReadBucket(k) @@ -721,6 +728,49 @@ func (db *DB) DeletePayments(failedOnly bool) error { return nil } + // If we are only deleting failed HTLCs, fetch them. + if failedHtlcsOnly { + htlcsBucket := bucket.NestedReadBucket( + paymentHtlcsBucket, + ) + + var htlcs []HTLCAttempt + if htlcsBucket != nil { + htlcs, err = fetchHtlcAttempts( + htlcsBucket, + ) + if err != nil { + return err + } + } + + // Now iterate though them and save the bucket + // keys for the failed HTLCs. + var toDelete [][]byte + for _, h := range htlcs { + if h.Failure == nil { + continue + } + + htlcIDBytes := make([]byte, 8) + binary.BigEndian.PutUint64( + htlcIDBytes, h.AttemptID, + ) + + toDelete = append(toDelete, htlcIDBytes) + } + + hash, err := lntypes.MakeHash(k) + if err != nil { + return err + } + + deleteHtlcs[hash] = toDelete + + // We return, we are only deleting attempts. + return nil + } + // Add the bucket to the set of buckets we can delete. deleteBuckets = append(deleteBuckets, k) @@ -732,13 +782,27 @@ func (db *DB) DeletePayments(failedOnly bool) error { } deleteIndexes = append(deleteIndexes, seqNrs...) - return nil }) if err != nil { return err } + // Delete the failed HTLC attempts we found. + for hash, htlcIDs := range deleteHtlcs { + bucket := payments.NestedReadWriteBucket(hash[:]) + htlcsBucket := bucket.NestedReadWriteBucket( + paymentHtlcsBucket, + ) + + for _, aid := range htlcIDs { + err := htlcsBucket.DeleteNestedBucket(aid) + if err != nil { + return err + } + } + } + for _, k := range deleteBuckets { if err := payments.DeleteNestedBucket(k); err != nil { return err diff --git a/rpcserver.go b/rpcserver.go index 7a09f3d7..d8883351 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -5713,11 +5713,12 @@ func (r *rpcServer) DeleteAllPayments(ctx context.Context, req *lnrpc.DeleteAllPaymentsRequest) ( *lnrpc.DeleteAllPaymentsResponse, error) { - rpcsLog.Debugf("[DeleteAllPayments] failed_payments_only=%v", - req.FailedPaymentsOnly) + rpcsLog.Infof("[DeleteAllPayments] failed_payments_only=%v, "+ + "failed_htlcs_only=%v", req.FailedPaymentsOnly, + req.FailedHtlcsOnly) err := r.server.remoteChanDB.DeletePayments( - req.FailedPaymentsOnly, + req.FailedPaymentsOnly, req.FailedHtlcsOnly, ) if err != nil { return nil, err