diff --git a/channeldb/control_tower_test.go b/channeldb/control_tower_test.go index 102b83e1..370300d7 100644 --- a/channeldb/control_tower_test.go +++ b/channeldb/control_tower_test.go @@ -275,6 +275,107 @@ func TestPaymentControlFailsWithoutInFlight(t *testing.T) { ) } +// TestPaymentControlDeleteNonInFlight checks that calling DeletaPayments only +// deletes payments from the database that are not in-flight. +func TestPaymentControlDeleteNonInFligt(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(db) + + payments := []struct { + failed bool + success bool + }{ + { + failed: true, + success: false, + }, + { + failed: false, + success: true, + }, + { + failed: false, + success: false, + }, + } + + for _, p := range payments { + info, attempt, preimg, err := genInfo() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // Sends base htlc message which initiate StatusInFlight. + err = pControl.InitPayment(info.PaymentHash, info) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + err = pControl.RegisterAttempt(info.PaymentHash, attempt) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + if p.failed { + // Fail the payment, which should moved it to Failed. + failReason := FailureReasonNoRoute + err = pControl.Fail(info.PaymentHash, failReason) + if err != nil { + t.Fatalf("unable to fail payment hash: %v", err) + } + + // Verify the status is indeed Failed. + assertPaymentStatus(t, db, info.PaymentHash, StatusFailed) + assertPaymentInfo( + t, db, info.PaymentHash, info, attempt, + lntypes.Preimage{}, &failReason, + ) + } else if p.success { + // Verifies that status was changed to StatusSucceeded. + err := pControl.Success(info.PaymentHash, preimg) + if err != nil { + t.Fatalf("error shouldn't have been received, got: %v", err) + } + + assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded) + assertPaymentInfo( + t, db, info.PaymentHash, info, attempt, preimg, nil, + ) + } else { + assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, db, info.PaymentHash, info, attempt, + lntypes.Preimage{}, nil, + ) + } + } + + // Delete payments. + if err := db.DeletePayments(); err != nil { + t.Fatal(err) + } + + // This should leave the in-flight payment. + dbPayments, err := db.FetchPayments() + if err != nil { + t.Fatal(err) + } + + if len(dbPayments) != 1 { + t.Fatalf("expected one payment, got %d", len(dbPayments)) + } + + status := dbPayments[0].Status + if status != StatusInFlight { + t.Fatalf("expected in-fligth status, got %v", status) + } +} + func assertPaymentStatus(t *testing.T, db *DB, hash [32]byte, expStatus PaymentStatus) {