From 2a358327f4e7ed26961018c6fa81349f5c30b9f9 Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Tue, 20 Oct 2020 16:18:40 +0200 Subject: [PATCH] multi: add reset closure to kvdb.View This commit adds a reset() closure to the kvdb.View function which will be called before each retry (including the first) of the view transaction. The reset() closure can be used to reset external state (eg slices or maps) where the view closure puts intermediate results. --- breacharbiter.go | 16 +++-- breacharbiter_test.go | 17 +++-- chainntnfs/height_hint_cache.go | 4 ++ channeldb/channel.go | 24 +++++-- channeldb/db.go | 26 ++++++-- channeldb/forwarding_log.go | 8 ++- channeldb/forwarding_package_test.go | 2 + channeldb/graph.go | 65 ++++++++++++++++--- channeldb/graph_test.go | 4 +- channeldb/invoices.go | 23 ++++--- channeldb/kvdb/etcd/db.go | 12 ++-- channeldb/kvdb/etcd/readwrite_cursor_test.go | 6 +- channeldb/kvdb/interface.go | 27 +++++--- channeldb/meta.go | 4 +- channeldb/meta_test.go | 2 +- channeldb/migration_01_to_11/db.go | 2 + channeldb/migration_01_to_11/graph.go | 2 + channeldb/migration_01_to_11/invoices.go | 2 + .../migration_09_legacy_serialization.go | 6 ++ .../migration_01_to_11/migrations_test.go | 4 ++ channeldb/migration_01_to_11/payments.go | 2 + channeldb/nodes.go | 4 ++ channeldb/payment_control.go | 4 ++ channeldb/payment_control_test.go | 5 +- channeldb/payments.go | 4 ++ channeldb/peers.go | 2 + channeldb/reports.go | 2 + channeldb/waitingproof.go | 15 +++-- channeldb/waitingproof_test.go | 2 +- channeldb/witness_cache.go | 2 + contractcourt/briefcase.go | 14 +++- discovery/gossiper_test.go | 27 ++++++++ discovery/message_store.go | 10 ++- fundingmanager.go | 2 +- htlcswitch/decayedlog.go | 2 + htlcswitch/payment_result.go | 4 ++ htlcswitch/switch.go | 2 + macaroons/store.go | 4 ++ nursery_store.go | 17 +++-- nursery_store_test.go | 2 + routing/missioncontrol_store.go | 2 + sweep/store.go | 6 ++ utxonursery.go | 4 +- utxonursery_test.go | 4 +- watchtower/wtdb/client_db.go | 14 +++- watchtower/wtdb/db_common.go | 2 + watchtower/wtdb/tower_db.go | 8 +++ 47 files changed, 340 insertions(+), 82 deletions(-) diff --git a/breacharbiter.go b/breacharbiter.go index b504d1b8..cf96f5ad 100644 --- a/breacharbiter.go +++ b/breacharbiter.go @@ -152,10 +152,12 @@ func (b *breachArbiter) start() error { brarLog.Tracef("Starting breach arbiter") // Load all retributions currently persisted in the retribution store. - breachRetInfos := make(map[wire.OutPoint]retributionInfo) + var breachRetInfos map[wire.OutPoint]retributionInfo if err := b.cfg.Store.ForAll(func(ret *retributionInfo) error { breachRetInfos[ret.chanPoint] = *ret return nil + }, func() { + breachRetInfos = make(map[wire.OutPoint]retributionInfo) }); err != nil { return err } @@ -1223,7 +1225,7 @@ type RetributionStore interface { // ForAll iterates over the existing on-disk contents and applies a // chosen, read-only callback to each. This method should ensure that it // immediately propagate any errors generated by the callback. - ForAll(cb func(*retributionInfo) error) error + ForAll(cb func(*retributionInfo) error, reset func()) error } // retributionStore handles persistence of retribution states to disk and is @@ -1312,6 +1314,8 @@ func (rs *retributionStore) GetFinalizedTxn( finalTxBytes = justiceBkt.Get(chanBuf.Bytes()) return nil + }, func() { + finalTxBytes = nil }); err != nil { return nil, err } @@ -1349,6 +1353,8 @@ func (rs *retributionStore) IsBreached(chanPoint *wire.OutPoint) (bool, error) { } return nil + }, func() { + found = false }) return found, err @@ -1395,7 +1401,9 @@ func (rs *retributionStore) Remove(chanPoint *wire.OutPoint) error { // ForAll iterates through all stored retributions and executes the passed // callback function on each retribution. -func (rs *retributionStore) ForAll(cb func(*retributionInfo) error) error { +func (rs *retributionStore) ForAll(cb func(*retributionInfo) error, + reset func()) error { + return kvdb.View(rs.db, func(tx kvdb.RTx) error { // If the bucket does not exist, then there are no pending // retributions. @@ -1416,7 +1424,7 @@ func (rs *retributionStore) ForAll(cb func(*retributionInfo) error) error { return cb(ret) }) - }) + }, reset) } // Encode serializes the retribution into the passed byte stream. diff --git a/breacharbiter_test.go b/breacharbiter_test.go index 03c1cdb6..43831efb 100644 --- a/breacharbiter_test.go +++ b/breacharbiter_test.go @@ -431,11 +431,13 @@ func (frs *failingRetributionStore) Remove(key *wire.OutPoint) error { return frs.rs.Remove(key) } -func (frs *failingRetributionStore) ForAll(cb func(*retributionInfo) error) error { +func (frs *failingRetributionStore) ForAll(cb func(*retributionInfo) error, + reset func()) error { + frs.mu.Lock() defer frs.mu.Unlock() - return frs.rs.ForAll(cb) + return frs.rs.ForAll(cb, reset) } // Parse the pubkeys in the breached outputs. @@ -592,10 +594,13 @@ func (rs *mockRetributionStore) Remove(key *wire.OutPoint) error { return nil } -func (rs *mockRetributionStore) ForAll(cb func(*retributionInfo) error) error { +func (rs *mockRetributionStore) ForAll(cb func(*retributionInfo) error, + reset func()) error { + rs.mu.Lock() defer rs.mu.Unlock() + reset() for _, retInfo := range rs.state { if err := cb(copyRetInfo(retInfo)); err != nil { return err @@ -717,6 +722,8 @@ func countRetributions(t *testing.T, rs RetributionStore) int { err := rs.ForAll(func(_ *retributionInfo) error { count++ return nil + }, func() { + count = 0 }) if err != nil { t.Fatalf("unable to list retributions in db: %v", err) @@ -919,7 +926,7 @@ restartCheck: // Construct a set of all channel points presented by the store. Entries // are only be added to the set if their corresponding retribution // information matches the test vector. - var foundSet = make(map[wire.OutPoint]struct{}) + var foundSet map[wire.OutPoint]struct{} // Iterate through the stored retributions, checking to see if we have // an equivalent retribution in the test vector. This will return an @@ -948,6 +955,8 @@ restartCheck: } return nil + }, func() { + foundSet = make(map[wire.OutPoint]struct{}) }); err != nil { t.Fatalf("failed to iterate over persistent retributions: %v", err) diff --git a/chainntnfs/height_hint_cache.go b/chainntnfs/height_hint_cache.go index 018afbef..6de9bc3d 100644 --- a/chainntnfs/height_hint_cache.go +++ b/chainntnfs/height_hint_cache.go @@ -179,6 +179,8 @@ func (c *HeightHintCache) QuerySpendHint(spendRequest SpendRequest) (uint32, err } return channeldb.ReadElement(bytes.NewReader(spendHint), &hint) + }, func() { + hint = 0 }) if err != nil { return 0, err @@ -278,6 +280,8 @@ func (c *HeightHintCache) QueryConfirmHint(confRequest ConfRequest) (uint32, err } return channeldb.ReadElement(bytes.NewReader(confirmHint), &hint) + }, func() { + hint = 0 }) if err != nil { return 0, err diff --git a/channeldb/channel.go b/channeldb/channel.go index 35a0700d..e30d6eb5 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -756,7 +756,7 @@ func (c *OpenChannel) RefreshShortChanID() error { } return nil - }) + }, func() {}) if err != nil { return err } @@ -950,6 +950,8 @@ func (c *OpenChannel) DataLossCommitPoint() (*btcec.PublicKey, error) { } return nil + }, func() { + commitPoint = nil }) if err != nil { return nil, err @@ -1168,6 +1170,8 @@ func (c *OpenChannel) getClosingTx(key []byte) (*wire.MsgTx, error) { } r := bytes.NewReader(bs) return ReadElement(r, &closeTx) + }, func() { + closeTx = nil }) if err != nil { return nil, err @@ -2062,6 +2066,8 @@ func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) { cd = dcd return nil + }, func() { + cd = nil }) if err != nil { return nil, err @@ -2094,6 +2100,8 @@ func (c *OpenChannel) UnsignedAckedUpdates() ([]LogUpdate, error) { r := bytes.NewReader(updateBytes) updates, err = deserializeLogUpdates(r) return err + }, func() { + updates = nil }) if err != nil { return nil, err @@ -2127,6 +2135,8 @@ func (c *OpenChannel) RemoteUnsignedLocalUpdates() ([]LogUpdate, error) { r := bytes.NewReader(updateBytes) updates, err = deserializeLogUpdates(r) return err + }, func() { + updates = nil }) if err != nil { return nil, err @@ -2365,6 +2375,8 @@ func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) { var err error fwdPkgs, err = c.Packager.LoadFwdPkgs(tx) return err + }, func() { + fwdPkgs = nil }); err != nil { return nil, err } @@ -2475,7 +2487,7 @@ func (c *OpenChannel) RevocationLogTail() (*ChannelCommitment, error) { } return nil - }); err != nil { + }, func() {}); err != nil { return nil, err } @@ -2509,6 +2521,8 @@ func (c *OpenChannel) CommitmentHeight() (uint64, error) { height = commit.CommitHeight return nil + }, func() { + height = 0 }) if err != nil { return 0, err @@ -2547,7 +2561,7 @@ func (c *OpenChannel) FindPreviousState(updateNum uint64) (*ChannelCommitment, e commit = c return nil - }) + }, func() {}) if err != nil { return nil, err } @@ -2870,7 +2884,7 @@ func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitmen } return fetchChanCommitments(chanBucket, c) - }) + }, func() {}) if err != nil { return nil, nil, err } @@ -2892,7 +2906,7 @@ func (c *OpenChannel) RemoteRevocationStore() (shachain.Store, error) { } return fetchChanRevocationState(chanBucket, c) - }) + }, func() {}) if err != nil { return nil, err } diff --git a/channeldb/db.go b/channeldb/db.go index 983b4fbb..02e1c697 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -201,12 +201,16 @@ func (db *DB) Update(f func(tx walletdb.ReadWriteTx) error) error { // View is a wrapper around walletdb.View which calls into the extended // backend when available. This call is needed to be able to cast DB to -// ExtendedBackend. -func (db *DB) View(f func(tx walletdb.ReadTx) error) error { +// ExtendedBackend. The passed reset function is called before the start of the +// transaction and can be used to reset intermediate state. As callers may +// expect retries of the f closure (depending on the database backend used), the +// reset function will be called before each retry respectively. +func (db *DB) View(f func(tx walletdb.ReadTx) error, reset func()) error { if v, ok := db.Backend.(kvdb.ExtendedBackend); ok { - return v.View(f) + return v.View(f, reset) } + reset() return walletdb.View(db, f) } @@ -389,6 +393,8 @@ func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error) var err error channels, err = d.fetchOpenChannels(tx, nodeID) return err + }, func() { + channels = nil }) return channels, err @@ -574,7 +580,7 @@ func (d *DB) FetchChannel(chanPoint wire.OutPoint) (*OpenChannel, error) { }) } - err := kvdb.View(d, chanScan) + err := kvdb.View(d, chanScan, func() {}) if err != nil { return nil, err } @@ -741,6 +747,8 @@ func fetchChannels(d *DB, filters ...fetchChannelsFilter) ([]*OpenChannel, error }) }) + }, func() { + channels = nil }) if err != nil { return nil, err @@ -781,6 +789,8 @@ func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, erro chanSummaries = append(chanSummaries, chanSummary) return nil }) + }, func() { + chanSummaries = nil }); err != nil { return nil, err } @@ -817,6 +827,8 @@ func (d *DB) FetchClosedChannel(chanID *wire.OutPoint) (*ChannelCloseSummary, er chanSummary, err = deserializeCloseChannelSummary(summaryReader) return err + }, func() { + chanSummary = nil }); err != nil { return nil, err } @@ -865,6 +877,8 @@ func (d *DB) FetchClosedChannelForID(cid lnwire.ChannelID) ( return nil } return ErrClosedChannelNotFound + }, func() { + chanSummary = nil }); err != nil { return nil, err } @@ -1052,6 +1066,8 @@ func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) { } return nil + }, func() { + linkNode = nil }) if dbErr != nil { return nil, dbErr @@ -1261,6 +1277,8 @@ func (db *DB) FetchHistoricalChannel(outPoint *wire.OutPoint) (*OpenChannel, err channel, err = fetchOpenChannel(chanBucket, outPoint) return err + }, func() { + channel = nil }) if err != nil { return nil, err diff --git a/channeldb/forwarding_log.go b/channeldb/forwarding_log.go index d1216dc4..57e46f35 100644 --- a/channeldb/forwarding_log.go +++ b/channeldb/forwarding_log.go @@ -228,9 +228,7 @@ type ForwardingLogTimeSlice struct { // // TODO(roasbeef): rename? func (f *ForwardingLog) Query(q ForwardingEventQuery) (ForwardingLogTimeSlice, error) { - resp := ForwardingLogTimeSlice{ - ForwardingEventQuery: q, - } + var resp ForwardingLogTimeSlice // If the user provided an index offset, then we'll not know how many // records we need to skip. We'll also keep track of the record offset @@ -297,6 +295,10 @@ func (f *ForwardingLog) Query(q ForwardingEventQuery) (ForwardingLogTimeSlice, e } return nil + }, func() { + resp = ForwardingLogTimeSlice{ + ForwardingEventQuery: q, + } }) if err != nil && err != ErrNoForwardingEvents { return ForwardingLogTimeSlice{}, err diff --git a/channeldb/forwarding_package_test.go b/channeldb/forwarding_package_test.go index 031a85f2..daeb4621 100644 --- a/channeldb/forwarding_package_test.go +++ b/channeldb/forwarding_package_test.go @@ -786,6 +786,8 @@ func loadFwdPkgs(t *testing.T, db kvdb.Backend, var err error fwdPkgs, err = packager.LoadFwdPkgs(tx) return err + }, func() { + fwdPkgs = nil }); err != nil { t.Fatalf("unable to load fwd pkgs: %v", err) } diff --git a/channeldb/graph.go b/channeldb/graph.go index c55e238c..101f11a2 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -249,7 +249,7 @@ func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, *ChannelEdgePoli // be aborted. return cb(&edgeInfo, edge1, edge2) }) - }) + }, func() {}) } // ForEachNodeChannel iterates through all channels of a given node, executing the @@ -279,7 +279,7 @@ func (c *ChannelGraph) ForEachNodeChannel(tx kvdb.RTx, nodePub []byte, // have their disabled bit on. func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) { var disabledChanIDs []uint64 - chanEdgeFound := make(map[uint64]struct{}) + var chanEdgeFound map[uint64]struct{} err := kvdb.View(c.db, func(tx kvdb.RTx) error { edges := tx.ReadBucket(edgeBucket) @@ -308,6 +308,9 @@ func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) { chanEdgeFound[chanID] = struct{}{} return nil }) + }, func() { + disabledChanIDs = nil + chanEdgeFound = make(map[uint64]struct{}) }) if err != nil { return nil, err @@ -353,7 +356,7 @@ func (c *ChannelGraph) ForEachNode(cb func(kvdb.RTx, *LightningNode) error) erro }) } - return kvdb.View(c.db, traversal) + return kvdb.View(c.db, traversal, func() {}) } // SourceNode returns the source node of the graph. The source node is treated @@ -377,6 +380,8 @@ func (c *ChannelGraph) SourceNode() (*LightningNode, error) { source = node return nil + }, func() { + source = nil }) if err != nil { return nil, err @@ -493,6 +498,8 @@ func (c *ChannelGraph) LookupAlias(pub *btcec.PublicKey) (string, error) { // package... alias = string(a) return nil + }, func() { + alias = "" }) if err != nil { return "", err @@ -774,7 +781,7 @@ func (c *ChannelGraph) HasChannelEdge( } return nil - }); err != nil { + }, func() {}); err != nil { return time.Time{}, time.Time{}, exists, isZombie, err } @@ -1235,7 +1242,7 @@ func (c *ChannelGraph) PruneTip() (*chainhash.Hash, uint32, error) { tipHeight = byteOrder.Uint32(k[:]) return nil - }) + }, func() {}) if err != nil { return nil, 0, err } @@ -1312,6 +1319,8 @@ func (c *ChannelGraph) ChannelID(chanPoint *wire.OutPoint) (uint64, error) { var err error chanID, err = getChanID(tx, chanPoint) return err + }, func() { + chanID = 0 }); err != nil { return 0, err } @@ -1379,6 +1388,8 @@ func (c *ChannelGraph) HighestChanID() (uint64, error) { // to the caller. cid = byteOrder.Uint64(lastChanID) return nil + }, func() { + cid = 0 }) if err != nil && err != ErrGraphNoEdgesFound { return 0, err @@ -1409,8 +1420,8 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]Cha // To ensure we don't return duplicate ChannelEdges, we'll use an // additional map to keep track of the edges already seen to prevent // re-adding it. - edgesSeen := make(map[uint64]struct{}) - edgesToCache := make(map[uint64]ChannelEdge) + var edgesSeen map[uint64]struct{} + var edgesToCache map[uint64]ChannelEdge var edgesInHorizon []ChannelEdge c.cacheMu.Lock() @@ -1507,6 +1518,10 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]Cha } return nil + }, func() { + edgesSeen = make(map[uint64]struct{}) + edgesToCache = make(map[uint64]ChannelEdge) + edgesInHorizon = nil }) switch { case err == ErrGraphNoEdgesFound: @@ -1577,6 +1592,8 @@ func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, endTime time.Time) ([]Lig } return nil + }, func() { + nodesInHorizon = nil }) switch { case err == ErrGraphNoEdgesFound: @@ -1637,6 +1654,8 @@ func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) { } return nil + }, func() { + newChanIDs = nil }) switch { // If we don't know of any edges yet, then we'll return the entire set @@ -1701,7 +1720,10 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, endHeight uint32) ([]uint } return nil + }, func() { + chanIDs = nil }) + switch { // If we don't know of any channels yet, then there's nothing to // filter, so we'll return an empty slice. @@ -1775,6 +1797,8 @@ func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { }) } return nil + }, func() { + chanEdges = nil }) if err != nil { return nil, err @@ -2209,7 +2233,7 @@ func (c *ChannelGraph) FetchLightningNode(tx kvdb.RTx, nodePub route.Vertex) ( var err error if tx == nil { - err = kvdb.View(c.db, fetchNode) + err = kvdb.View(c.db, fetchNode, func() {}) } else { err = fetchNode(tx) } @@ -2259,6 +2283,9 @@ func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, erro exists = true updateTime = node.LastUpdate return nil + }, func() { + updateTime = time.Time{} + exists = false }) if err != nil { return time.Time{}, exists, err @@ -2346,7 +2373,7 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db *DB, // If no transaction was provided, then we'll create a new transaction // to execute the transaction within. if tx == nil { - return kvdb.View(db, traversal) + return kvdb.View(db, traversal, func() {}) } // Otherwise, we re-use the existing transaction to execute the graph @@ -2596,7 +2623,7 @@ func (c *ChannelEdgeInfo) FetchOtherNode(tx kvdb.RTx, thisNodeKey []byte) (*Ligh // otherwise we can use the existing db transaction. var err error if tx == nil { - err = kvdb.View(c.db, fetchNodeFunc) + err = kvdb.View(c.db, fetchNodeFunc, func() { targetNode = nil }) } else { err = fetchNodeFunc(tx) } @@ -2929,6 +2956,10 @@ func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint, policy1 = e1 policy2 = e2 return nil + }, func() { + edgeInfo = nil + policy1 = nil + policy2 = nil }) if err != nil { return nil, nil, nil, err @@ -3030,6 +3061,10 @@ func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64, policy1 = e1 policy2 = e2 return nil + }, func() { + edgeInfo = nil + policy1 = nil + policy2 = nil }) if err == ErrZombieEdge { return edgeInfo, nil, nil, err @@ -3062,6 +3097,8 @@ func (c *ChannelGraph) IsPublicNode(pubKey [33]byte) (bool, error) { nodeIsPublic, err = node.isPublic(tx, ourPubKey) return err + }, func() { + nodeIsPublic = false }) if err != nil { return false, err @@ -3183,6 +3220,8 @@ func (c *ChannelGraph) ChannelView() ([]EdgePoint, error) { return nil }) + }, func() { + edgePoints = nil }); err != nil { return nil, err } @@ -3261,6 +3300,10 @@ func (c *ChannelGraph) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte) { isZombie, pubKey1, pubKey2 = isZombieEdge(zombieIndex, chanID) return nil + }, func() { + isZombie = false + pubKey1 = [33]byte{} + pubKey2 = [33]byte{} }) if err != nil { return false, [33]byte{}, [33]byte{} @@ -3307,6 +3350,8 @@ func (c *ChannelGraph) NumZombies() (uint64, error) { numZombies++ return nil }) + }, func() { + numZombies = 0 }) if err != nil { return 0, err diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 71edc8f8..43d786d9 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -2272,7 +2272,7 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { return nil }) - }) + }, func() {}) if err != nil { t.Fatal(err) } @@ -2858,7 +2858,7 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) { } return nil - }) + }, func() {}) if err != nil { t.Fatalf("error reading db: %v", err) } diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 5f7b6462..c5209f1e 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -624,6 +624,8 @@ func (d *DB) InvoicesAddedSince(sinceAddIndex uint64) ([]Invoice, error) { } return nil + }, func() { + newInvoices = nil }) if err != nil { return nil, err @@ -669,7 +671,7 @@ func (d *DB) LookupInvoice(ref InvoiceRef) (Invoice, error) { invoice = i return nil - }) + }, func() {}) if err != nil { return invoice, err } @@ -731,13 +733,6 @@ func (d *DB) ScanInvoices( scanFunc func(lntypes.Hash, *Invoice) error, reset func()) error { return kvdb.View(d, func(tx kvdb.RTx) error { - // Reset partial results. As transaction commit success is not - // guaranteed when using etcd, we need to be prepared to redo - // the whole view transaction. In order to be able to do that - // we need a way to reset existing results. This is also done - // upon first run for initialization. - reset() - invoices := tx.ReadBucket(invoiceBucket) if invoices == nil { return ErrNoInvoicesCreated @@ -773,7 +768,7 @@ func (d *DB) ScanInvoices( return scanFunc(paymentHash, &invoice) }) - }) + }, reset) } // InvoiceQuery represents a query to the invoice database. The query allows a @@ -825,9 +820,7 @@ type InvoiceSlice struct { // QueryInvoices allows a caller to query the invoice database for invoices // within the specified add index range. func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) { - resp := InvoiceSlice{ - InvoiceQuery: q, - } + var resp InvoiceSlice err := kvdb.View(d, func(tx kvdb.RTx) error { // If the bucket wasn't found, then there aren't any invoices @@ -892,6 +885,10 @@ func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) { } return nil + }, func() { + resp = InvoiceSlice{ + InvoiceQuery: q, + } }) if err != nil && err != ErrNoInvoicesCreated { return resp, err @@ -1011,6 +1008,8 @@ func (d *DB) InvoicesSettledSince(sinceSettleIndex uint64) ([]Invoice, error) { } return nil + }, func() { + settledInvoices = nil }) if err != nil { return nil, err diff --git a/channeldb/kvdb/etcd/db.go b/channeldb/kvdb/etcd/db.go index 9f52ad4e..c63bbdbd 100644 --- a/channeldb/kvdb/etcd/db.go +++ b/channeldb/kvdb/etcd/db.go @@ -218,11 +218,15 @@ func (db *db) getSTMOptions() []STMOptionFunc { } // View opens a database read transaction and executes the function f with the -// transaction passed as a parameter. After f exits, the transaction is rolled -// back. If f errors, its error is returned, not a rollback error (if any -// occur). -func (db *db) View(f func(tx walletdb.ReadTx) error) error { +// transaction passed as a parameter. After f exits, the transaction is rolled +// back. If f errors, its error is returned, not a rollback error (if any +// occur). The passed reset function is called before the start of the +// transaction and can be used to reset intermediate state. As callers may +// expect retries of the f closure (depending on the database backend used), the +// reset function will be called before each retry respectively. +func (db *db) View(f func(tx walletdb.ReadTx) error, reset func()) error { apply := func(stm STM) error { + reset() return f(newReadWriteTx(stm, db.config.Prefix)) } diff --git a/channeldb/kvdb/etcd/readwrite_cursor_test.go b/channeldb/kvdb/etcd/readwrite_cursor_test.go index 216b47c4..6b4f317f 100644 --- a/channeldb/kvdb/etcd/readwrite_cursor_test.go +++ b/channeldb/kvdb/etcd/readwrite_cursor_test.go @@ -49,7 +49,7 @@ func TestReadCursorEmptyInterval(t *testing.T) { require.Nil(t, v) return nil - }) + }, func() {}) require.NoError(t, err) } @@ -125,7 +125,7 @@ func TestReadCursorNonEmptyInterval(t *testing.T) { require.Nil(t, v) return nil - }) + }, func() {}) require.NoError(t, err) } @@ -354,7 +354,7 @@ func TestReadWriteCursorWithBucketAndValue(t *testing.T) { require.Equal(t, []byte("val"), v) return nil - }) + }, func() {}) require.NoError(t, err) diff --git a/channeldb/kvdb/interface.go b/channeldb/kvdb/interface.go index 16e84285..7d44f56c 100644 --- a/channeldb/kvdb/interface.go +++ b/channeldb/kvdb/interface.go @@ -21,12 +21,19 @@ func Update(db Backend, f func(tx RwTx) error) error { // View opens a database read transaction and executes the function f with the // transaction passed as a parameter. After f exits, the transaction is rolled // back. If f errors, its error is returned, not a rollback error (if any -// occur). -func View(db Backend, f func(tx RTx) error) error { +// occur). The passed reset function is called before the start of the +// transaction and can be used to reset intermediate state. As callers may +// expect retries of the f closure (depending on the database backend used), the +// reset function will be called before each retry respectively. +func View(db Backend, f func(tx RTx) error, reset func()) error { if extendedDB, ok := db.(ExtendedBackend); ok { - return extendedDB.View(f) + return extendedDB.View(f, reset) } + // Since we know that walletdb simply calls into bbolt which never + // retries transactions, we'll call the reset function here before View. + reset() + return walletdb.View(db, f) } @@ -55,11 +62,15 @@ type ExtendedBackend interface { // PrintStats returns all collected stats pretty printed into a string. PrintStats() string - // View opens a database read transaction and executes the function f with - // the transaction passed as a parameter. After f exits, the transaction is - // rolled back. If f errors, its error is returned, not a rollback error - // (if any occur). - View(f func(tx walletdb.ReadTx) error) error + // View opens a database read transaction and executes the function f + // with the transaction passed as a parameter. After f exits, the + // transaction is rolled back. If f errors, its error is returned, not a + // rollback error (if any occur). The passed reset function is called + // before the start of the transaction and can be used to reset + // intermediate state. As callers may expect retries of the f closure + // (depending on the database backend used), the reset function will be + //called before each retry respectively. + View(f func(tx walletdb.ReadTx) error, reset func()) error // Update opens a database read/write transaction and executes the function // f with the transaction passed as a parameter. After f exits, if f did not diff --git a/channeldb/meta.go b/channeldb/meta.go index f7f9ae1e..c8ade44b 100644 --- a/channeldb/meta.go +++ b/channeldb/meta.go @@ -23,10 +23,12 @@ type Meta struct { // FetchMeta fetches the meta data from boltdb and returns filled meta // structure. func (d *DB) FetchMeta(tx kvdb.RTx) (*Meta, error) { - meta := &Meta{} + var meta *Meta err := kvdb.View(d, func(tx kvdb.RTx) error { return fetchMeta(meta, tx) + }, func() { + meta = &Meta{} }) if err != nil { return nil, err diff --git a/channeldb/meta_test.go b/channeldb/meta_test.go index 98e9c88a..7eabfc2c 100644 --- a/channeldb/meta_test.go +++ b/channeldb/meta_test.go @@ -492,7 +492,7 @@ func TestMigrationDryRun(t *testing.T) { } return nil - }) + }, func() {}) if err != nil { t.Fatalf("unable to apply after func: %v", err) } diff --git a/channeldb/migration_01_to_11/db.go b/channeldb/migration_01_to_11/db.go index f5890246..71128a11 100644 --- a/channeldb/migration_01_to_11/db.go +++ b/channeldb/migration_01_to_11/db.go @@ -203,6 +203,8 @@ func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, erro chanSummaries = append(chanSummaries, chanSummary) return nil }) + }, func() { + chanSummaries = nil }); err != nil { return nil, err } diff --git a/channeldb/migration_01_to_11/graph.go b/channeldb/migration_01_to_11/graph.go index f3b88539..0e34b405 100644 --- a/channeldb/migration_01_to_11/graph.go +++ b/channeldb/migration_01_to_11/graph.go @@ -190,6 +190,8 @@ func (c *ChannelGraph) SourceNode() (*LightningNode, error) { source = node return nil + }, func() { + source = nil }) if err != nil { return nil, err diff --git a/channeldb/migration_01_to_11/invoices.go b/channeldb/migration_01_to_11/invoices.go index dcba1d54..ceb21a33 100644 --- a/channeldb/migration_01_to_11/invoices.go +++ b/channeldb/migration_01_to_11/invoices.go @@ -282,6 +282,8 @@ func (d *DB) FetchAllInvoices(pendingOnly bool) ([]Invoice, error) { return nil }) + }, func() { + invoices = nil }) if err != nil { return nil, err diff --git a/channeldb/migration_01_to_11/migration_09_legacy_serialization.go b/channeldb/migration_01_to_11/migration_09_legacy_serialization.go index acd61b0a..b8f86d38 100644 --- a/channeldb/migration_01_to_11/migration_09_legacy_serialization.go +++ b/channeldb/migration_01_to_11/migration_09_legacy_serialization.go @@ -126,6 +126,8 @@ func (db *DB) fetchAllPayments() ([]*outgoingPayment, error) { payments = append(payments, payment) return nil }) + }, func() { + payments = nil }) if err != nil { return nil, err @@ -144,6 +146,8 @@ func (db *DB) fetchPaymentStatus(paymentHash [32]byte) (PaymentStatus, error) { var err error paymentStatus, err = fetchPaymentStatusTx(tx, paymentHash) return err + }, func() { + paymentStatus = StatusUnknown }) if err != nil { return StatusUnknown, err @@ -424,6 +428,8 @@ func (db *DB) fetchPaymentsMigration9() ([]*Payment, error) { return nil }) }) + }, func() { + payments = nil }) if err != nil { return nil, err diff --git a/channeldb/migration_01_to_11/migrations_test.go b/channeldb/migration_01_to_11/migrations_test.go index 6cd855e8..010c10a1 100644 --- a/channeldb/migration_01_to_11/migrations_test.go +++ b/channeldb/migration_01_to_11/migrations_test.go @@ -418,6 +418,8 @@ func TestMigrateOptionalChannelCloseSummaryFields(t *testing.T) { "serialization") } return nil + }, func() { + dbSummary = nil }) if err != nil { t.Fatalf("unable to view DB: %v", err) @@ -521,6 +523,8 @@ func TestMigrateGossipMessageStoreKeys(t *testing.T) { } return nil + }, func() { + rawMsg = nil }) if err != nil { t.Fatal(err) diff --git a/channeldb/migration_01_to_11/payments.go b/channeldb/migration_01_to_11/payments.go index 697da0e0..e44be003 100644 --- a/channeldb/migration_01_to_11/payments.go +++ b/channeldb/migration_01_to_11/payments.go @@ -303,6 +303,8 @@ func (db *DB) FetchPayments() ([]*Payment, error) { return nil }) }) + }, func() { + payments = nil }) if err != nil { return nil, err diff --git a/channeldb/nodes.go b/channeldb/nodes.go index dcbdf63a..03871264 100644 --- a/channeldb/nodes.go +++ b/channeldb/nodes.go @@ -158,6 +158,8 @@ func (db *DB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) { linkNode = node return nil + }, func() { + linkNode = nil }) return linkNode, err @@ -199,6 +201,8 @@ func (db *DB) FetchAllLinkNodes() ([]*LinkNode, error) { linkNodes = nodes return nil + }, func() { + linkNodes = nil }) if err != nil { return nil, err diff --git a/channeldb/payment_control.go b/channeldb/payment_control.go index d67b0a8e..9ed5e639 100644 --- a/channeldb/payment_control.go +++ b/channeldb/payment_control.go @@ -550,6 +550,8 @@ func (p *PaymentControl) FetchPayment(paymentHash lntypes.Hash) ( payment, err = fetchPayment(bucket) return err + }, func() { + payment = nil }) if err != nil { return nil, err @@ -716,6 +718,8 @@ func (p *PaymentControl) FetchInFlightPayments() ([]*InFlightPayment, error) { inFlights = append(inFlights, inFlight) return nil }) + }, func() { + inFlights = nil }) if err != nil { return nil, err diff --git a/channeldb/payment_control_test.go b/channeldb/payment_control_test.go index 4f901462..e395a93f 100644 --- a/channeldb/payment_control_test.go +++ b/channeldb/payment_control_test.go @@ -502,7 +502,7 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) { indexCount++ return nil }) - }) + }, func() { indexCount = 0 }) require.NoError(t, err) require.Equal(t, 1, indexCount) @@ -989,7 +989,8 @@ func fetchPaymentIndexEntry(_ *testing.T, p *PaymentControl, var err error hash, err = deserializePaymentIndex(r) return err - + }, func() { + hash = lntypes.Hash{} }); err != nil { return nil, err } diff --git a/channeldb/payments.go b/channeldb/payments.go index 5c2475bd..3344451c 100644 --- a/channeldb/payments.go +++ b/channeldb/payments.go @@ -269,6 +269,8 @@ func (db *DB) FetchPayments() ([]*MPPayment, error) { payments = append(payments, duplicatePayments...) return nil }) + }, func() { + payments = nil }) if err != nil { return nil, err @@ -572,6 +574,8 @@ func (db *DB) QueryPayments(query PaymentsQuery) (PaymentsResponse, error) { } return nil + }, func() { + resp = PaymentsResponse{} }); err != nil { return resp, err } diff --git a/channeldb/peers.go b/channeldb/peers.go index 4dea9f2f..55920b0d 100644 --- a/channeldb/peers.go +++ b/channeldb/peers.go @@ -113,6 +113,8 @@ func (d *DB) ReadFlapCount(pubkey route.Vertex) (*FlapCount, error) { } return ReadElements(r, &flapCount.Count) + }, func() { + flapCount = FlapCount{} }); err != nil { return nil, err } diff --git a/channeldb/reports.go b/channeldb/reports.go index f71a1c4c..111e787e 100644 --- a/channeldb/reports.go +++ b/channeldb/reports.go @@ -250,6 +250,8 @@ func (d DB) FetchChannelReports(chainHash chainhash.Hash, return nil }) + }, func() { + reports = nil }); err != nil { return nil, err } diff --git a/channeldb/waitingproof.go b/channeldb/waitingproof.go index 2ea706c8..43a4fd8a 100644 --- a/channeldb/waitingproof.go +++ b/channeldb/waitingproof.go @@ -42,13 +42,14 @@ type WaitingProofStore struct { // NewWaitingProofStore creates new instance of proofs storage. func NewWaitingProofStore(db *DB) (*WaitingProofStore, error) { s := &WaitingProofStore{ - db: db, - cache: make(map[WaitingProofKey]struct{}), + db: db, } if err := s.ForAll(func(proof *WaitingProof) error { s.cache[proof.Key()] = struct{}{} return nil + }, func() { + s.cache = make(map[WaitingProofKey]struct{}) }); err != nil && err != ErrWaitingProofNotFound { return nil, err } @@ -122,7 +123,9 @@ func (s *WaitingProofStore) Remove(key WaitingProofKey) error { // ForAll iterates thought all waiting proofs and passing the waiting proof // in the given callback. -func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error) error { +func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error, + reset func()) error { + return kvdb.View(s.db, func(tx kvdb.RTx) error { bucket := tx.ReadBucket(waitingProofsBucketKey) if bucket == nil { @@ -144,12 +147,12 @@ func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error) error { return cb(proof) }) - }) + }, reset) } // Get returns the object which corresponds to the given index. func (s *WaitingProofStore) Get(key WaitingProofKey) (*WaitingProof, error) { - proof := &WaitingProof{} + var proof *WaitingProof s.mu.RLock() defer s.mu.RUnlock() @@ -172,6 +175,8 @@ func (s *WaitingProofStore) Get(key WaitingProofKey) (*WaitingProof, error) { r := bytes.NewReader(v) return proof.Decode(r) + }, func() { + proof = &WaitingProof{} }) return proof, err diff --git a/channeldb/waitingproof_test.go b/channeldb/waitingproof_test.go index 12679b69..ed1bf050 100644 --- a/channeldb/waitingproof_test.go +++ b/channeldb/waitingproof_test.go @@ -53,7 +53,7 @@ func TestWaitingProofStore(t *testing.T) { if err := store.ForAll(func(proof *WaitingProof) error { return errors.New("storage should be empty") - }); err != nil && err != ErrWaitingProofNotFound { + }, func() {}); err != nil && err != ErrWaitingProofNotFound { t.Fatal(err) } } diff --git a/channeldb/witness_cache.go b/channeldb/witness_cache.go index 3a91fc44..d4abfeed 100644 --- a/channeldb/witness_cache.go +++ b/channeldb/witness_cache.go @@ -174,6 +174,8 @@ func (w *WitnessCache) lookupWitness(wType WitnessType, witnessKey []byte) ([]by copy(witness[:], dbWitness) return nil + }, func() { + witness = nil }) if err != nil { return nil, err diff --git a/contractcourt/briefcase.go b/contractcourt/briefcase.go index fbf2e96a..d1dc17dd 100644 --- a/contractcourt/briefcase.go +++ b/contractcourt/briefcase.go @@ -430,6 +430,8 @@ func (b *boltArbitratorLog) CurrentState() (ArbitratorState, error) { s = ArbitratorState(stateBytes[0]) return nil + }, func() { + s = 0 }) if err != nil && err != errScopeBucketNoExist { return s, err @@ -521,6 +523,8 @@ func (b *boltArbitratorLog) FetchUnresolvedContracts() ([]ContractResolver, erro contracts = append(contracts, res) return nil }) + }, func() { + contracts = nil }) if err != nil && err != errScopeBucketNoExist && err != errNoContracts { return nil, err @@ -685,7 +689,7 @@ func (b *boltArbitratorLog) LogContractResolutions(c *ContractResolutions) error // // NOTE: Part of the ContractResolver interface. func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, error) { - c := &ContractResolutions{} + var c *ContractResolutions err := kvdb.View(b.db, func(tx kvdb.RTx) error { scopeBucket := tx.ReadBucket(b.scopeKey[:]) if scopeBucket == nil { @@ -769,6 +773,8 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er } return nil + }, func() { + c = &ContractResolutions{} }) if err != nil { return nil, err @@ -783,7 +789,7 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er // // NOTE: Part of the ContractResolver interface. func (b *boltArbitratorLog) FetchChainActions() (ChainActionMap, error) { - actionsMap := make(ChainActionMap) + var actionsMap ChainActionMap err := kvdb.View(b.db, func(tx kvdb.RTx) error { scopeBucket := tx.ReadBucket(b.scopeKey[:]) @@ -813,6 +819,8 @@ func (b *boltArbitratorLog) FetchChainActions() (ChainActionMap, error) { return nil }) + }, func() { + actionsMap = make(ChainActionMap) }) if err != nil { return nil, err @@ -866,6 +874,8 @@ func (b *boltArbitratorLog) FetchConfirmedCommitSet() (*CommitSet, error) { c = commitSet return nil + }, func() { + c = nil }) if err != nil { return nil, err diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 9dfad6ba..f37caa8d 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -1117,6 +1117,9 @@ func TestSignatureAnnouncementLocalFirst(t *testing.T) { number++ return nil }, + func() { + number = 0 + }, ); err != nil { t.Fatalf("unable to retrieve objects from store: %v", err) } @@ -1150,6 +1153,9 @@ func TestSignatureAnnouncementLocalFirst(t *testing.T) { number++ return nil }, + func() { + number = 0 + }, ); err != nil && err != channeldb.ErrWaitingProofNotFound { t.Fatalf("unable to retrieve objects from store: %v", err) } @@ -1219,6 +1225,9 @@ func TestOrphanSignatureAnnouncement(t *testing.T) { number++ return nil }, + func() { + number = 0 + }, ); err != nil { t.Fatalf("unable to retrieve objects from store: %v", err) } @@ -1355,6 +1364,9 @@ func TestOrphanSignatureAnnouncement(t *testing.T) { number++ return nil }, + func() { + number = 0 + }, ); err != nil { t.Fatalf("unable to retrieve objects from store: %v", err) } @@ -1466,6 +1478,9 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { number++ return nil }, + func() { + number = 0 + }, ); err != nil { t.Fatalf("unable to retrieve objects from store: %v", err) } @@ -1570,6 +1585,9 @@ out: number++ return nil }, + func() { + number = 0 + }, ); err != nil && err != channeldb.ErrWaitingProofNotFound { t.Fatalf("unable to retrieve objects from store: %v", err) } @@ -1754,6 +1772,9 @@ func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) { number++ return nil }, + func() { + number = 0 + }, ); err != nil && err != channeldb.ErrWaitingProofNotFound { t.Fatalf("unable to retrieve objects from store: %v", err) } @@ -2583,6 +2604,9 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { number++ return nil }, + func() { + number = 0 + }, ); err != nil { t.Fatalf("unable to retrieve objects from store: %v", err) } @@ -2612,6 +2636,9 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { number++ return nil }, + func() { + number = 0 + }, ); err != nil && err != channeldb.ErrWaitingProofNotFound { t.Fatalf("unable to retrieve objects from store: %v", err) } diff --git a/discovery/message_store.go b/discovery/message_store.go index f86ede20..6cce5494 100644 --- a/discovery/message_store.go +++ b/discovery/message_store.go @@ -199,7 +199,7 @@ func readMessage(msgBytes []byte) (lnwire.Message, error) { // Messages returns the total set of messages that exist within the store for // all peers. func (s *MessageStore) Messages() (map[[33]byte][]lnwire.Message, error) { - msgs := make(map[[33]byte][]lnwire.Message) + var msgs map[[33]byte][]lnwire.Message err := kvdb.View(s.db, func(tx kvdb.RTx) error { messageStore := tx.ReadBucket(messageStoreBucket) if messageStore == nil { @@ -224,6 +224,8 @@ func (s *MessageStore) Messages() (map[[33]byte][]lnwire.Message, error) { msgs[pubKey] = append(msgs[pubKey], msg) return nil }) + }, func() { + msgs = make(map[[33]byte][]lnwire.Message) }) if err != nil { return nil, err @@ -262,6 +264,8 @@ func (s *MessageStore) MessagesForPeer( } return nil + }, func() { + msgs = nil }) if err != nil { return nil, err @@ -272,7 +276,7 @@ func (s *MessageStore) MessagesForPeer( // Peers returns the public key of all peers with messages within the store. func (s *MessageStore) Peers() (map[[33]byte]struct{}, error) { - peers := make(map[[33]byte]struct{}) + var peers map[[33]byte]struct{} err := kvdb.View(s.db, func(tx kvdb.RTx) error { messageStore := tx.ReadBucket(messageStoreBucket) if messageStore == nil { @@ -285,6 +289,8 @@ func (s *MessageStore) Peers() (map[[33]byte]struct{}, error) { peers[pubKey] = struct{}{} return nil }) + }, func() { + peers = make(map[[33]byte]struct{}) }) if err != nil { return nil, err diff --git a/fundingmanager.go b/fundingmanager.go index ebe95d42..1bd018b1 100644 --- a/fundingmanager.go +++ b/fundingmanager.go @@ -3538,7 +3538,7 @@ func (f *fundingManager) getChannelOpeningState(chanPoint *wire.OutPoint) ( state = channelOpeningState(byteOrder.Uint16(value[:2])) shortChanID = lnwire.NewShortChanIDFromInt(byteOrder.Uint64(value[2:])) return nil - }) + }, func() {}) if err != nil { return 0, nil, err } diff --git a/htlcswitch/decayedlog.go b/htlcswitch/decayedlog.go index db27ed39..8b8a78ef 100644 --- a/htlcswitch/decayedlog.go +++ b/htlcswitch/decayedlog.go @@ -280,6 +280,8 @@ func (d *DecayedLog) Get(hash *sphinx.HashPrefix) (uint32, error) { value = uint32(binary.BigEndian.Uint32(valueBytes)) return nil + }, func() { + value = 0 }) if err != nil { return value, err diff --git a/htlcswitch/payment_result.go b/htlcswitch/payment_result.go index e6a1e59f..d341ba12 100644 --- a/htlcswitch/payment_result.go +++ b/htlcswitch/payment_result.go @@ -197,6 +197,8 @@ func (store *networkResultStore) subscribeResult(paymentID uint64) ( default: return nil } + }, func() { + result = nil }) if err != nil { return nil, err @@ -230,6 +232,8 @@ func (store *networkResultStore) getResult(pid uint64) ( var err error result, err = fetchResult(tx, pid) return err + }, func() { + result = nil }) if err != nil { return nil, err diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index d03d7496..1f755bb0 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -1833,6 +1833,8 @@ func (s *Switch) loadChannelFwdPkgs(source lnwire.ShortChannelID) ([]*channeldb. tx, source, ) return err + }, func() { + fwdPkgs = nil }); err != nil { return nil, err } diff --git a/macaroons/store.go b/macaroons/store.go index 97baae49..1e4447a8 100644 --- a/macaroons/store.go +++ b/macaroons/store.go @@ -150,6 +150,8 @@ func (r *RootKeyStorage) Get(_ context.Context, id []byte) ([]byte, error) { rootKey = make([]byte, len(decKey)) copy(rootKey[:], decKey) return nil + }, func() { + rootKey = nil }) if err != nil { return nil, err @@ -257,6 +259,8 @@ func (r *RootKeyStorage) ListMacaroonIDs(_ context.Context) ([][]byte, error) { } return tx.ReadBucket(rootKeyBucketName).ForEach(appendRootKey) + }, func() { + rootKeySlice = nil }) if err != nil { return nil, err diff --git a/nursery_store.go b/nursery_store.go index 42a13607..97fe08cc 100644 --- a/nursery_store.go +++ b/nursery_store.go @@ -129,7 +129,7 @@ type NurseryStore interface { // the caller to process each key-value pair. The key will be a prefixed // outpoint, and the value will be the serialized bytes for an output, // whose type should be inferred from the key's prefix. - ForChanOutputs(*wire.OutPoint, func([]byte, []byte) error) error + ForChanOutputs(*wire.OutPoint, func([]byte, []byte) error, func()) error // ListChannels returns all channels the nursery is currently tracking. ListChannels() ([]wire.OutPoint, error) @@ -582,6 +582,9 @@ func (ns *nurseryStore) FetchClass( }) + }, func() { + kids = nil + babies = nil }); err != nil { return nil, nil, err } @@ -655,6 +658,8 @@ func (ns *nurseryStore) FetchPreschools() ([]kidOutput, error) { } return nil + }, func() { + kids = nil }); err != nil { return nil, err } @@ -693,6 +698,8 @@ func (ns *nurseryStore) HeightsBelowOrEqual(height uint32) ([]uint32, error) { } return nil + }, func() { + activeHeights = nil }) if err != nil { return nil, err @@ -709,11 +716,11 @@ func (ns *nurseryStore) HeightsBelowOrEqual(height uint32) ([]uint32, error) { // NOTE: The callback should not modify the provided byte slices and is // preferably non-blocking. func (ns *nurseryStore) ForChanOutputs(chanPoint *wire.OutPoint, - callback func([]byte, []byte) error) error { + callback func([]byte, []byte) error, reset func()) error { return kvdb.View(ns.db, func(tx kvdb.RTx) error { return ns.forChanOutputs(tx, chanPoint, callback) - }) + }, reset) } // ListChannels returns all channels the nursery is currently tracking. @@ -743,6 +750,8 @@ func (ns *nurseryStore) ListChannels() ([]wire.OutPoint, error) { return nil }) + }, func() { + activeChannels = nil }); err != nil { return nil, err } @@ -765,7 +774,7 @@ func (ns *nurseryStore) IsMatureChannel(chanPoint *wire.OutPoint) (bool, error) return nil }) - }) + }, func() {}) if err != nil && err != ErrImmatureChannel { return false, err } diff --git a/nursery_store_test.go b/nursery_store_test.go index 052b2304..1d4caf51 100644 --- a/nursery_store_test.go +++ b/nursery_store_test.go @@ -370,6 +370,8 @@ func assertNumChanOutputs(t *testing.T, ns NurseryStore, err := ns.ForChanOutputs(chanPoint, func([]byte, []byte) error { count++ return nil + }, func() { + count = 0 }) if count == 0 && err == ErrContractNotFound { diff --git a/routing/missioncontrol_store.go b/routing/missioncontrol_store.go index e0b2e82a..5113b0b0 100644 --- a/routing/missioncontrol_store.go +++ b/routing/missioncontrol_store.go @@ -103,6 +103,8 @@ func (b *missionControlStore) fetchAll() ([]*paymentResult, error) { return nil }) + }, func() { + results = nil }) if err != nil { return nil, err diff --git a/sweep/store.go b/sweep/store.go index 7809db96..c78759c4 100644 --- a/sweep/store.go +++ b/sweep/store.go @@ -219,6 +219,8 @@ func (s *sweeperStore) GetLastPublishedTx() (*wire.MsgTx, error) { } return nil + }, func() { + sweepTx = nil }) if err != nil { return nil, err @@ -241,6 +243,8 @@ func (s *sweeperStore) IsOurTx(hash chainhash.Hash) (bool, error) { ours = txHashesBucket.Get(hash[:]) != nil return nil + }, func() { + ours = false }) if err != nil { return false, err @@ -269,6 +273,8 @@ func (s *sweeperStore) ListSweeps() ([]chainhash.Hash, error) { return nil }) + }, func() { + sweepTxns = nil }); err != nil { return nil, err } diff --git a/utxonursery.go b/utxonursery.go index 217aa2da..d73ceba5 100644 --- a/utxonursery.go +++ b/utxonursery.go @@ -477,7 +477,7 @@ func (u *utxoNursery) NurseryReport( utxnLog.Debugf("NurseryReport: building nursery report for channel %v", chanPoint) - report := &contractMaturityReport{} + var report *contractMaturityReport if err := u.cfg.Store.ForChanOutputs(chanPoint, func(k, v []byte) error { switch { @@ -576,6 +576,8 @@ func (u *utxoNursery) NurseryReport( } return nil + }, func() { + report = &contractMaturityReport{} }); err != nil { return nil, err } diff --git a/utxonursery_test.go b/utxonursery_test.go index 5675c9ea..4e02b24d 100644 --- a/utxonursery_test.go +++ b/utxonursery_test.go @@ -931,9 +931,9 @@ func (i *nurseryStoreInterceptor) HeightsBelowOrEqual(height uint32) ( } func (i *nurseryStoreInterceptor) ForChanOutputs(chanPoint *wire.OutPoint, - callback func([]byte, []byte) error) error { + callback func([]byte, []byte) error, reset func()) error { - return i.ns.ForChanOutputs(chanPoint, callback) + return i.ns.ForChanOutputs(chanPoint, callback, reset) } func (i *nurseryStoreInterceptor) ListChannels() ([]wire.OutPoint, error) { diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 66aeb345..dc88d45c 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -192,6 +192,8 @@ func (c *ClientDB) Version() (uint32, error) { var err error version, err = getDBVersion(tx) return err + }, func() { + version = 0 }) if err != nil { return 0, err @@ -392,6 +394,8 @@ func (c *ClientDB) LoadTowerByID(towerID TowerID) (*Tower, error) { var err error tower, err = getTower(towers, towerID.Bytes()) return err + }, func() { + tower = nil }) if err != nil { return nil, err @@ -421,6 +425,8 @@ func (c *ClientDB) LoadTower(pubKey *btcec.PublicKey) (*Tower, error) { var err error tower, err = getTower(towers, towerIDBytes) return err + }, func() { + tower = nil }) if err != nil { return nil, err @@ -446,6 +452,8 @@ func (c *ClientDB) ListTowers() ([]*Tower, error) { towers = append(towers, tower) return nil }) + }, func() { + towers = nil }) if err != nil { return nil, err @@ -566,6 +574,8 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (map[SessionID]*ClientSession var err error clientSessions, err = listClientSessions(sessions, id) return err + }, func() { + clientSessions = nil }) if err != nil { return nil, err @@ -611,7 +621,7 @@ func listClientSessions(sessions kvdb.RBucket, // FetchChanSummaries loads a mapping from all registered channels to their // channel summaries. func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) { - summaries := make(map[lnwire.ChannelID]ClientChanSummary) + var summaries map[lnwire.ChannelID]ClientChanSummary err := kvdb.View(c.db, func(tx kvdb.RTx) error { chanSummaries := tx.ReadBucket(cChanSummaryBkt) if chanSummaries == nil { @@ -632,6 +642,8 @@ func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) { return nil }) + }, func() { + summaries = make(map[lnwire.ChannelID]ClientChanSummary) }) if err != nil { return nil, err diff --git a/watchtower/wtdb/db_common.go b/watchtower/wtdb/db_common.go index d769f13e..c4f8bb17 100644 --- a/watchtower/wtdb/db_common.go +++ b/watchtower/wtdb/db_common.go @@ -80,6 +80,8 @@ func createDBIfNotExist(dbPath, name string) (kvdb.Backend, bool, error) { err = kvdb.View(bdb, func(tx kvdb.RTx) error { metadataExists = tx.ReadBucket(metadataBkt) != nil return nil + }, func() { + metadataExists = false }) if err != nil { return nil, false, err diff --git a/watchtower/wtdb/tower_db.go b/watchtower/wtdb/tower_db.go index d98fdad5..a788a100 100644 --- a/watchtower/wtdb/tower_db.go +++ b/watchtower/wtdb/tower_db.go @@ -133,6 +133,8 @@ func (t *TowerDB) Version() (uint32, error) { var err error version, err = getDBVersion(tx) return err + }, func() { + version = 0 }) if err != nil { return 0, err @@ -159,6 +161,8 @@ func (t *TowerDB) GetSessionInfo(id *SessionID) (*SessionInfo, error) { var err error session, err = getSession(sessions, id[:]) return err + }, func() { + session = nil }) if err != nil { return nil, err @@ -460,6 +464,8 @@ func (t *TowerDB) QueryMatches(breachHints []blob.BreachHint) ([]Match, error) { } return nil + }, func() { + matches = nil }) if err != nil { return nil, err @@ -494,6 +500,8 @@ func (t *TowerDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) { epoch = getLookoutEpoch(lookoutTip) return nil + }, func() { + epoch = nil }) if err != nil { return nil, err