diff --git a/htlcswitch/payment_result.go b/htlcswitch/payment_result.go index 54130755..e6a1e59f 100644 --- a/htlcswitch/payment_result.go +++ b/htlcswitch/payment_result.go @@ -258,3 +258,48 @@ func fetchResult(tx kvdb.RTx, pid uint64) (*networkResult, error) { return deserializeNetworkResult(r) } + +// cleanStore removes all entries from the store, except the payment IDs given. +// NOTE: Since every result not listed in the keep map will be deleted, care +// should be taken to ensure no new payment attempts are being made +// concurrently while this process is ongoing, as its result might end up being +// deleted. +func (store *networkResultStore) cleanStore(keep map[uint64]struct{}) error { + return kvdb.Update(store.db.Backend, func(tx kvdb.RwTx) error { + networkResults, err := tx.CreateTopLevelBucket( + networkResultStoreBucketKey, + ) + if err != nil { + return err + } + + // Iterate through the bucket, deleting all items not in the + // keep map. + var toClean [][]byte + if err := networkResults.ForEach(func(k, _ []byte) error { + pid := binary.BigEndian.Uint64(k) + if _, ok := keep[pid]; ok { + return nil + } + + toClean = append(toClean, k) + return nil + }); err != nil { + return err + } + + for _, k := range toClean { + err := networkResults.Delete(k) + if err != nil { + return err + } + } + + if len(toClean) > 0 { + log.Infof("Removed %d stale entries from network "+ + "result store", len(toClean)) + } + + return nil + }) +} diff --git a/htlcswitch/payment_result_test.go b/htlcswitch/payment_result_test.go index d0c48c4d..04ff57d8 100644 --- a/htlcswitch/payment_result_test.go +++ b/htlcswitch/payment_result_test.go @@ -12,6 +12,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" ) // TestNetworkResultSerialization checks that NetworkResults are properly @@ -180,7 +181,6 @@ func TestNetworkResultStore(t *testing.T) { // Since we don't delete results from the store (yet), make sure we // will get subscriptions for all of them. - // TODO(halseth): check deletion when we have reliable handoff. for i := uint64(0); i < numResults; i++ { sub, err := store.subscribeResult(i) if err != nil { @@ -193,4 +193,25 @@ func TestNetworkResultStore(t *testing.T) { t.Fatalf("no result received") } } + + // Clean the store keeping the first two results. + toKeep := map[uint64]struct{}{ + 0: {}, + 1: {}, + } + // Finally, delete the result. + err = store.cleanStore(toKeep) + require.NoError(t, err) + + // Payment IDs 0 and 1 should be found, 2 and 3 should be deleted. + for i := uint64(0); i < numResults; i++ { + _, err = store.getResult(i) + if i <= 1 { + require.NoError(t, err, "unable to get result") + } + if i >= 2 && err != ErrPaymentIDNotFound { + t.Fatalf("expected ErrPaymentIDNotFound, got %v", err) + } + + } }