diff --git a/breacharbiter.go b/breacharbiter.go index b504d1b8..cf96f5ad 100644 --- a/breacharbiter.go +++ b/breacharbiter.go @@ -152,10 +152,12 @@ func (b *breachArbiter) start() error { brarLog.Tracef("Starting breach arbiter") // Load all retributions currently persisted in the retribution store. - breachRetInfos := make(map[wire.OutPoint]retributionInfo) + var breachRetInfos map[wire.OutPoint]retributionInfo if err := b.cfg.Store.ForAll(func(ret *retributionInfo) error { breachRetInfos[ret.chanPoint] = *ret return nil + }, func() { + breachRetInfos = make(map[wire.OutPoint]retributionInfo) }); err != nil { return err } @@ -1223,7 +1225,7 @@ type RetributionStore interface { // ForAll iterates over the existing on-disk contents and applies a // chosen, read-only callback to each. This method should ensure that it // immediately propagate any errors generated by the callback. - ForAll(cb func(*retributionInfo) error) error + ForAll(cb func(*retributionInfo) error, reset func()) error } // retributionStore handles persistence of retribution states to disk and is @@ -1312,6 +1314,8 @@ func (rs *retributionStore) GetFinalizedTxn( finalTxBytes = justiceBkt.Get(chanBuf.Bytes()) return nil + }, func() { + finalTxBytes = nil }); err != nil { return nil, err } @@ -1349,6 +1353,8 @@ func (rs *retributionStore) IsBreached(chanPoint *wire.OutPoint) (bool, error) { } return nil + }, func() { + found = false }) return found, err @@ -1395,7 +1401,9 @@ func (rs *retributionStore) Remove(chanPoint *wire.OutPoint) error { // ForAll iterates through all stored retributions and executes the passed // callback function on each retribution. -func (rs *retributionStore) ForAll(cb func(*retributionInfo) error) error { +func (rs *retributionStore) ForAll(cb func(*retributionInfo) error, + reset func()) error { + return kvdb.View(rs.db, func(tx kvdb.RTx) error { // If the bucket does not exist, then there are no pending // retributions. @@ -1416,7 +1424,7 @@ func (rs *retributionStore) ForAll(cb func(*retributionInfo) error) error { return cb(ret) }) - }) + }, reset) } // Encode serializes the retribution into the passed byte stream. diff --git a/breacharbiter_test.go b/breacharbiter_test.go index 03c1cdb6..43831efb 100644 --- a/breacharbiter_test.go +++ b/breacharbiter_test.go @@ -431,11 +431,13 @@ func (frs *failingRetributionStore) Remove(key *wire.OutPoint) error { return frs.rs.Remove(key) } -func (frs *failingRetributionStore) ForAll(cb func(*retributionInfo) error) error { +func (frs *failingRetributionStore) ForAll(cb func(*retributionInfo) error, + reset func()) error { + frs.mu.Lock() defer frs.mu.Unlock() - return frs.rs.ForAll(cb) + return frs.rs.ForAll(cb, reset) } // Parse the pubkeys in the breached outputs. @@ -592,10 +594,13 @@ func (rs *mockRetributionStore) Remove(key *wire.OutPoint) error { return nil } -func (rs *mockRetributionStore) ForAll(cb func(*retributionInfo) error) error { +func (rs *mockRetributionStore) ForAll(cb func(*retributionInfo) error, + reset func()) error { + rs.mu.Lock() defer rs.mu.Unlock() + reset() for _, retInfo := range rs.state { if err := cb(copyRetInfo(retInfo)); err != nil { return err @@ -717,6 +722,8 @@ func countRetributions(t *testing.T, rs RetributionStore) int { err := rs.ForAll(func(_ *retributionInfo) error { count++ return nil + }, func() { + count = 0 }) if err != nil { t.Fatalf("unable to list retributions in db: %v", err) @@ -919,7 +926,7 @@ restartCheck: // Construct a set of all channel points presented by the store. Entries // are only be added to the set if their corresponding retribution // information matches the test vector. - var foundSet = make(map[wire.OutPoint]struct{}) + var foundSet map[wire.OutPoint]struct{} // Iterate through the stored retributions, checking to see if we have // an equivalent retribution in the test vector. This will return an @@ -948,6 +955,8 @@ restartCheck: } return nil + }, func() { + foundSet = make(map[wire.OutPoint]struct{}) }); err != nil { t.Fatalf("failed to iterate over persistent retributions: %v", err) diff --git a/chainntnfs/height_hint_cache.go b/chainntnfs/height_hint_cache.go index 018afbef..6de9bc3d 100644 --- a/chainntnfs/height_hint_cache.go +++ b/chainntnfs/height_hint_cache.go @@ -179,6 +179,8 @@ func (c *HeightHintCache) QuerySpendHint(spendRequest SpendRequest) (uint32, err } return channeldb.ReadElement(bytes.NewReader(spendHint), &hint) + }, func() { + hint = 0 }) if err != nil { return 0, err @@ -278,6 +280,8 @@ func (c *HeightHintCache) QueryConfirmHint(confRequest ConfRequest) (uint32, err } return channeldb.ReadElement(bytes.NewReader(confirmHint), &hint) + }, func() { + hint = 0 }) if err != nil { return 0, err diff --git a/channeldb/channel.go b/channeldb/channel.go index 35a0700d..e30d6eb5 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -756,7 +756,7 @@ func (c *OpenChannel) RefreshShortChanID() error { } return nil - }) + }, func() {}) if err != nil { return err } @@ -950,6 +950,8 @@ func (c *OpenChannel) DataLossCommitPoint() (*btcec.PublicKey, error) { } return nil + }, func() { + commitPoint = nil }) if err != nil { return nil, err @@ -1168,6 +1170,8 @@ func (c *OpenChannel) getClosingTx(key []byte) (*wire.MsgTx, error) { } r := bytes.NewReader(bs) return ReadElement(r, &closeTx) + }, func() { + closeTx = nil }) if err != nil { return nil, err @@ -2062,6 +2066,8 @@ func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) { cd = dcd return nil + }, func() { + cd = nil }) if err != nil { return nil, err @@ -2094,6 +2100,8 @@ func (c *OpenChannel) UnsignedAckedUpdates() ([]LogUpdate, error) { r := bytes.NewReader(updateBytes) updates, err = deserializeLogUpdates(r) return err + }, func() { + updates = nil }) if err != nil { return nil, err @@ -2127,6 +2135,8 @@ func (c *OpenChannel) RemoteUnsignedLocalUpdates() ([]LogUpdate, error) { r := bytes.NewReader(updateBytes) updates, err = deserializeLogUpdates(r) return err + }, func() { + updates = nil }) if err != nil { return nil, err @@ -2365,6 +2375,8 @@ func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) { var err error fwdPkgs, err = c.Packager.LoadFwdPkgs(tx) return err + }, func() { + fwdPkgs = nil }); err != nil { return nil, err } @@ -2475,7 +2487,7 @@ func (c *OpenChannel) RevocationLogTail() (*ChannelCommitment, error) { } return nil - }); err != nil { + }, func() {}); err != nil { return nil, err } @@ -2509,6 +2521,8 @@ func (c *OpenChannel) CommitmentHeight() (uint64, error) { height = commit.CommitHeight return nil + }, func() { + height = 0 }) if err != nil { return 0, err @@ -2547,7 +2561,7 @@ func (c *OpenChannel) FindPreviousState(updateNum uint64) (*ChannelCommitment, e commit = c return nil - }) + }, func() {}) if err != nil { return nil, err } @@ -2870,7 +2884,7 @@ func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitmen } return fetchChanCommitments(chanBucket, c) - }) + }, func() {}) if err != nil { return nil, nil, err } @@ -2892,7 +2906,7 @@ func (c *OpenChannel) RemoteRevocationStore() (shachain.Store, error) { } return fetchChanRevocationState(chanBucket, c) - }) + }, func() {}) if err != nil { return nil, err } diff --git a/channeldb/db.go b/channeldb/db.go index 983b4fbb..02e1c697 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -201,12 +201,16 @@ func (db *DB) Update(f func(tx walletdb.ReadWriteTx) error) error { // View is a wrapper around walletdb.View which calls into the extended // backend when available. This call is needed to be able to cast DB to -// ExtendedBackend. -func (db *DB) View(f func(tx walletdb.ReadTx) error) error { +// ExtendedBackend. The passed reset function is called before the start of the +// transaction and can be used to reset intermediate state. As callers may +// expect retries of the f closure (depending on the database backend used), the +// reset function will be called before each retry respectively. +func (db *DB) View(f func(tx walletdb.ReadTx) error, reset func()) error { if v, ok := db.Backend.(kvdb.ExtendedBackend); ok { - return v.View(f) + return v.View(f, reset) } + reset() return walletdb.View(db, f) } @@ -389,6 +393,8 @@ func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error) var err error channels, err = d.fetchOpenChannels(tx, nodeID) return err + }, func() { + channels = nil }) return channels, err @@ -574,7 +580,7 @@ func (d *DB) FetchChannel(chanPoint wire.OutPoint) (*OpenChannel, error) { }) } - err := kvdb.View(d, chanScan) + err := kvdb.View(d, chanScan, func() {}) if err != nil { return nil, err } @@ -741,6 +747,8 @@ func fetchChannels(d *DB, filters ...fetchChannelsFilter) ([]*OpenChannel, error }) }) + }, func() { + channels = nil }) if err != nil { return nil, err @@ -781,6 +789,8 @@ func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, erro chanSummaries = append(chanSummaries, chanSummary) return nil }) + }, func() { + chanSummaries = nil }); err != nil { return nil, err } @@ -817,6 +827,8 @@ func (d *DB) FetchClosedChannel(chanID *wire.OutPoint) (*ChannelCloseSummary, er chanSummary, err = deserializeCloseChannelSummary(summaryReader) return err + }, func() { + chanSummary = nil }); err != nil { return nil, err } @@ -865,6 +877,8 @@ func (d *DB) FetchClosedChannelForID(cid lnwire.ChannelID) ( return nil } return ErrClosedChannelNotFound + }, func() { + chanSummary = nil }); err != nil { return nil, err } @@ -1052,6 +1066,8 @@ func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) { } return nil + }, func() { + linkNode = nil }) if dbErr != nil { return nil, dbErr @@ -1261,6 +1277,8 @@ func (db *DB) FetchHistoricalChannel(outPoint *wire.OutPoint) (*OpenChannel, err channel, err = fetchOpenChannel(chanBucket, outPoint) return err + }, func() { + channel = nil }) if err != nil { return nil, err diff --git a/channeldb/forwarding_log.go b/channeldb/forwarding_log.go index d1216dc4..57e46f35 100644 --- a/channeldb/forwarding_log.go +++ b/channeldb/forwarding_log.go @@ -228,9 +228,7 @@ type ForwardingLogTimeSlice struct { // // TODO(roasbeef): rename? func (f *ForwardingLog) Query(q ForwardingEventQuery) (ForwardingLogTimeSlice, error) { - resp := ForwardingLogTimeSlice{ - ForwardingEventQuery: q, - } + var resp ForwardingLogTimeSlice // If the user provided an index offset, then we'll not know how many // records we need to skip. We'll also keep track of the record offset @@ -297,6 +295,10 @@ func (f *ForwardingLog) Query(q ForwardingEventQuery) (ForwardingLogTimeSlice, e } return nil + }, func() { + resp = ForwardingLogTimeSlice{ + ForwardingEventQuery: q, + } }) if err != nil && err != ErrNoForwardingEvents { return ForwardingLogTimeSlice{}, err diff --git a/channeldb/forwarding_package_test.go b/channeldb/forwarding_package_test.go index 031a85f2..daeb4621 100644 --- a/channeldb/forwarding_package_test.go +++ b/channeldb/forwarding_package_test.go @@ -786,6 +786,8 @@ func loadFwdPkgs(t *testing.T, db kvdb.Backend, var err error fwdPkgs, err = packager.LoadFwdPkgs(tx) return err + }, func() { + fwdPkgs = nil }); err != nil { t.Fatalf("unable to load fwd pkgs: %v", err) } diff --git a/channeldb/graph.go b/channeldb/graph.go index c55e238c..101f11a2 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -249,7 +249,7 @@ func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, *ChannelEdgePoli // be aborted. return cb(&edgeInfo, edge1, edge2) }) - }) + }, func() {}) } // ForEachNodeChannel iterates through all channels of a given node, executing the @@ -279,7 +279,7 @@ func (c *ChannelGraph) ForEachNodeChannel(tx kvdb.RTx, nodePub []byte, // have their disabled bit on. func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) { var disabledChanIDs []uint64 - chanEdgeFound := make(map[uint64]struct{}) + var chanEdgeFound map[uint64]struct{} err := kvdb.View(c.db, func(tx kvdb.RTx) error { edges := tx.ReadBucket(edgeBucket) @@ -308,6 +308,9 @@ func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) { chanEdgeFound[chanID] = struct{}{} return nil }) + }, func() { + disabledChanIDs = nil + chanEdgeFound = make(map[uint64]struct{}) }) if err != nil { return nil, err @@ -353,7 +356,7 @@ func (c *ChannelGraph) ForEachNode(cb func(kvdb.RTx, *LightningNode) error) erro }) } - return kvdb.View(c.db, traversal) + return kvdb.View(c.db, traversal, func() {}) } // SourceNode returns the source node of the graph. The source node is treated @@ -377,6 +380,8 @@ func (c *ChannelGraph) SourceNode() (*LightningNode, error) { source = node return nil + }, func() { + source = nil }) if err != nil { return nil, err @@ -493,6 +498,8 @@ func (c *ChannelGraph) LookupAlias(pub *btcec.PublicKey) (string, error) { // package... alias = string(a) return nil + }, func() { + alias = "" }) if err != nil { return "", err @@ -774,7 +781,7 @@ func (c *ChannelGraph) HasChannelEdge( } return nil - }); err != nil { + }, func() {}); err != nil { return time.Time{}, time.Time{}, exists, isZombie, err } @@ -1235,7 +1242,7 @@ func (c *ChannelGraph) PruneTip() (*chainhash.Hash, uint32, error) { tipHeight = byteOrder.Uint32(k[:]) return nil - }) + }, func() {}) if err != nil { return nil, 0, err } @@ -1312,6 +1319,8 @@ func (c *ChannelGraph) ChannelID(chanPoint *wire.OutPoint) (uint64, error) { var err error chanID, err = getChanID(tx, chanPoint) return err + }, func() { + chanID = 0 }); err != nil { return 0, err } @@ -1379,6 +1388,8 @@ func (c *ChannelGraph) HighestChanID() (uint64, error) { // to the caller. cid = byteOrder.Uint64(lastChanID) return nil + }, func() { + cid = 0 }) if err != nil && err != ErrGraphNoEdgesFound { return 0, err @@ -1409,8 +1420,8 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]Cha // To ensure we don't return duplicate ChannelEdges, we'll use an // additional map to keep track of the edges already seen to prevent // re-adding it. - edgesSeen := make(map[uint64]struct{}) - edgesToCache := make(map[uint64]ChannelEdge) + var edgesSeen map[uint64]struct{} + var edgesToCache map[uint64]ChannelEdge var edgesInHorizon []ChannelEdge c.cacheMu.Lock() @@ -1507,6 +1518,10 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]Cha } return nil + }, func() { + edgesSeen = make(map[uint64]struct{}) + edgesToCache = make(map[uint64]ChannelEdge) + edgesInHorizon = nil }) switch { case err == ErrGraphNoEdgesFound: @@ -1577,6 +1592,8 @@ func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, endTime time.Time) ([]Lig } return nil + }, func() { + nodesInHorizon = nil }) switch { case err == ErrGraphNoEdgesFound: @@ -1637,6 +1654,8 @@ func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) { } return nil + }, func() { + newChanIDs = nil }) switch { // If we don't know of any edges yet, then we'll return the entire set @@ -1701,7 +1720,10 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, endHeight uint32) ([]uint } return nil + }, func() { + chanIDs = nil }) + switch { // If we don't know of any channels yet, then there's nothing to // filter, so we'll return an empty slice. @@ -1775,6 +1797,8 @@ func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { }) } return nil + }, func() { + chanEdges = nil }) if err != nil { return nil, err @@ -2209,7 +2233,7 @@ func (c *ChannelGraph) FetchLightningNode(tx kvdb.RTx, nodePub route.Vertex) ( var err error if tx == nil { - err = kvdb.View(c.db, fetchNode) + err = kvdb.View(c.db, fetchNode, func() {}) } else { err = fetchNode(tx) } @@ -2259,6 +2283,9 @@ func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, erro exists = true updateTime = node.LastUpdate return nil + }, func() { + updateTime = time.Time{} + exists = false }) if err != nil { return time.Time{}, exists, err @@ -2346,7 +2373,7 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db *DB, // If no transaction was provided, then we'll create a new transaction // to execute the transaction within. if tx == nil { - return kvdb.View(db, traversal) + return kvdb.View(db, traversal, func() {}) } // Otherwise, we re-use the existing transaction to execute the graph @@ -2596,7 +2623,7 @@ func (c *ChannelEdgeInfo) FetchOtherNode(tx kvdb.RTx, thisNodeKey []byte) (*Ligh // otherwise we can use the existing db transaction. var err error if tx == nil { - err = kvdb.View(c.db, fetchNodeFunc) + err = kvdb.View(c.db, fetchNodeFunc, func() { targetNode = nil }) } else { err = fetchNodeFunc(tx) } @@ -2929,6 +2956,10 @@ func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint, policy1 = e1 policy2 = e2 return nil + }, func() { + edgeInfo = nil + policy1 = nil + policy2 = nil }) if err != nil { return nil, nil, nil, err @@ -3030,6 +3061,10 @@ func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64, policy1 = e1 policy2 = e2 return nil + }, func() { + edgeInfo = nil + policy1 = nil + policy2 = nil }) if err == ErrZombieEdge { return edgeInfo, nil, nil, err @@ -3062,6 +3097,8 @@ func (c *ChannelGraph) IsPublicNode(pubKey [33]byte) (bool, error) { nodeIsPublic, err = node.isPublic(tx, ourPubKey) return err + }, func() { + nodeIsPublic = false }) if err != nil { return false, err @@ -3183,6 +3220,8 @@ func (c *ChannelGraph) ChannelView() ([]EdgePoint, error) { return nil }) + }, func() { + edgePoints = nil }); err != nil { return nil, err } @@ -3261,6 +3300,10 @@ func (c *ChannelGraph) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte) { isZombie, pubKey1, pubKey2 = isZombieEdge(zombieIndex, chanID) return nil + }, func() { + isZombie = false + pubKey1 = [33]byte{} + pubKey2 = [33]byte{} }) if err != nil { return false, [33]byte{}, [33]byte{} @@ -3307,6 +3350,8 @@ func (c *ChannelGraph) NumZombies() (uint64, error) { numZombies++ return nil }) + }, func() { + numZombies = 0 }) if err != nil { return 0, err diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 71edc8f8..43d786d9 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -2272,7 +2272,7 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { return nil }) - }) + }, func() {}) if err != nil { t.Fatal(err) } @@ -2858,7 +2858,7 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) { } return nil - }) + }, func() {}) if err != nil { t.Fatalf("error reading db: %v", err) } diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 5f7b6462..c5209f1e 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -624,6 +624,8 @@ func (d *DB) InvoicesAddedSince(sinceAddIndex uint64) ([]Invoice, error) { } return nil + }, func() { + newInvoices = nil }) if err != nil { return nil, err @@ -669,7 +671,7 @@ func (d *DB) LookupInvoice(ref InvoiceRef) (Invoice, error) { invoice = i return nil - }) + }, func() {}) if err != nil { return invoice, err } @@ -731,13 +733,6 @@ func (d *DB) ScanInvoices( scanFunc func(lntypes.Hash, *Invoice) error, reset func()) error { return kvdb.View(d, func(tx kvdb.RTx) error { - // Reset partial results. As transaction commit success is not - // guaranteed when using etcd, we need to be prepared to redo - // the whole view transaction. In order to be able to do that - // we need a way to reset existing results. This is also done - // upon first run for initialization. - reset() - invoices := tx.ReadBucket(invoiceBucket) if invoices == nil { return ErrNoInvoicesCreated @@ -773,7 +768,7 @@ func (d *DB) ScanInvoices( return scanFunc(paymentHash, &invoice) }) - }) + }, reset) } // InvoiceQuery represents a query to the invoice database. The query allows a @@ -825,9 +820,7 @@ type InvoiceSlice struct { // QueryInvoices allows a caller to query the invoice database for invoices // within the specified add index range. func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) { - resp := InvoiceSlice{ - InvoiceQuery: q, - } + var resp InvoiceSlice err := kvdb.View(d, func(tx kvdb.RTx) error { // If the bucket wasn't found, then there aren't any invoices @@ -892,6 +885,10 @@ func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) { } return nil + }, func() { + resp = InvoiceSlice{ + InvoiceQuery: q, + } }) if err != nil && err != ErrNoInvoicesCreated { return resp, err @@ -1011,6 +1008,8 @@ func (d *DB) InvoicesSettledSince(sinceSettleIndex uint64) ([]Invoice, error) { } return nil + }, func() { + settledInvoices = nil }) if err != nil { return nil, err diff --git a/channeldb/kvdb/etcd/db.go b/channeldb/kvdb/etcd/db.go index 9f52ad4e..c63bbdbd 100644 --- a/channeldb/kvdb/etcd/db.go +++ b/channeldb/kvdb/etcd/db.go @@ -218,11 +218,15 @@ func (db *db) getSTMOptions() []STMOptionFunc { } // View opens a database read transaction and executes the function f with the -// transaction passed as a parameter. After f exits, the transaction is rolled -// back. If f errors, its error is returned, not a rollback error (if any -// occur). -func (db *db) View(f func(tx walletdb.ReadTx) error) error { +// transaction passed as a parameter. After f exits, the transaction is rolled +// back. If f errors, its error is returned, not a rollback error (if any +// occur). The passed reset function is called before the start of the +// transaction and can be used to reset intermediate state. As callers may +// expect retries of the f closure (depending on the database backend used), the +// reset function will be called before each retry respectively. +func (db *db) View(f func(tx walletdb.ReadTx) error, reset func()) error { apply := func(stm STM) error { + reset() return f(newReadWriteTx(stm, db.config.Prefix)) } diff --git a/channeldb/kvdb/etcd/readwrite_cursor_test.go b/channeldb/kvdb/etcd/readwrite_cursor_test.go index 216b47c4..6b4f317f 100644 --- a/channeldb/kvdb/etcd/readwrite_cursor_test.go +++ b/channeldb/kvdb/etcd/readwrite_cursor_test.go @@ -49,7 +49,7 @@ func TestReadCursorEmptyInterval(t *testing.T) { require.Nil(t, v) return nil - }) + }, func() {}) require.NoError(t, err) } @@ -125,7 +125,7 @@ func TestReadCursorNonEmptyInterval(t *testing.T) { require.Nil(t, v) return nil - }) + }, func() {}) require.NoError(t, err) } @@ -354,7 +354,7 @@ func TestReadWriteCursorWithBucketAndValue(t *testing.T) { require.Equal(t, []byte("val"), v) return nil - }) + }, func() {}) require.NoError(t, err) diff --git a/channeldb/kvdb/interface.go b/channeldb/kvdb/interface.go index 16e84285..7d44f56c 100644 --- a/channeldb/kvdb/interface.go +++ b/channeldb/kvdb/interface.go @@ -21,12 +21,19 @@ func Update(db Backend, f func(tx RwTx) error) error { // View opens a database read transaction and executes the function f with the // transaction passed as a parameter. After f exits, the transaction is rolled // back. If f errors, its error is returned, not a rollback error (if any -// occur). -func View(db Backend, f func(tx RTx) error) error { +// occur). The passed reset function is called before the start of the +// transaction and can be used to reset intermediate state. As callers may +// expect retries of the f closure (depending on the database backend used), the +// reset function will be called before each retry respectively. +func View(db Backend, f func(tx RTx) error, reset func()) error { if extendedDB, ok := db.(ExtendedBackend); ok { - return extendedDB.View(f) + return extendedDB.View(f, reset) } + // Since we know that walletdb simply calls into bbolt which never + // retries transactions, we'll call the reset function here before View. + reset() + return walletdb.View(db, f) } @@ -55,11 +62,15 @@ type ExtendedBackend interface { // PrintStats returns all collected stats pretty printed into a string. PrintStats() string - // View opens a database read transaction and executes the function f with - // the transaction passed as a parameter. After f exits, the transaction is - // rolled back. If f errors, its error is returned, not a rollback error - // (if any occur). - View(f func(tx walletdb.ReadTx) error) error + // View opens a database read transaction and executes the function f + // with the transaction passed as a parameter. After f exits, the + // transaction is rolled back. If f errors, its error is returned, not a + // rollback error (if any occur). The passed reset function is called + // before the start of the transaction and can be used to reset + // intermediate state. As callers may expect retries of the f closure + // (depending on the database backend used), the reset function will be + //called before each retry respectively. + View(f func(tx walletdb.ReadTx) error, reset func()) error // Update opens a database read/write transaction and executes the function // f with the transaction passed as a parameter. After f exits, if f did not diff --git a/channeldb/meta.go b/channeldb/meta.go index f7f9ae1e..c8ade44b 100644 --- a/channeldb/meta.go +++ b/channeldb/meta.go @@ -23,10 +23,12 @@ type Meta struct { // FetchMeta fetches the meta data from boltdb and returns filled meta // structure. func (d *DB) FetchMeta(tx kvdb.RTx) (*Meta, error) { - meta := &Meta{} + var meta *Meta err := kvdb.View(d, func(tx kvdb.RTx) error { return fetchMeta(meta, tx) + }, func() { + meta = &Meta{} }) if err != nil { return nil, err diff --git a/channeldb/meta_test.go b/channeldb/meta_test.go index 98e9c88a..7eabfc2c 100644 --- a/channeldb/meta_test.go +++ b/channeldb/meta_test.go @@ -492,7 +492,7 @@ func TestMigrationDryRun(t *testing.T) { } return nil - }) + }, func() {}) if err != nil { t.Fatalf("unable to apply after func: %v", err) } diff --git a/channeldb/migration_01_to_11/db.go b/channeldb/migration_01_to_11/db.go index f5890246..71128a11 100644 --- a/channeldb/migration_01_to_11/db.go +++ b/channeldb/migration_01_to_11/db.go @@ -203,6 +203,8 @@ func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, erro chanSummaries = append(chanSummaries, chanSummary) return nil }) + }, func() { + chanSummaries = nil }); err != nil { return nil, err } diff --git a/channeldb/migration_01_to_11/graph.go b/channeldb/migration_01_to_11/graph.go index f3b88539..0e34b405 100644 --- a/channeldb/migration_01_to_11/graph.go +++ b/channeldb/migration_01_to_11/graph.go @@ -190,6 +190,8 @@ func (c *ChannelGraph) SourceNode() (*LightningNode, error) { source = node return nil + }, func() { + source = nil }) if err != nil { return nil, err diff --git a/channeldb/migration_01_to_11/invoices.go b/channeldb/migration_01_to_11/invoices.go index dcba1d54..ceb21a33 100644 --- a/channeldb/migration_01_to_11/invoices.go +++ b/channeldb/migration_01_to_11/invoices.go @@ -282,6 +282,8 @@ func (d *DB) FetchAllInvoices(pendingOnly bool) ([]Invoice, error) { return nil }) + }, func() { + invoices = nil }) if err != nil { return nil, err diff --git a/channeldb/migration_01_to_11/migration_09_legacy_serialization.go b/channeldb/migration_01_to_11/migration_09_legacy_serialization.go index acd61b0a..b8f86d38 100644 --- a/channeldb/migration_01_to_11/migration_09_legacy_serialization.go +++ b/channeldb/migration_01_to_11/migration_09_legacy_serialization.go @@ -126,6 +126,8 @@ func (db *DB) fetchAllPayments() ([]*outgoingPayment, error) { payments = append(payments, payment) return nil }) + }, func() { + payments = nil }) if err != nil { return nil, err @@ -144,6 +146,8 @@ func (db *DB) fetchPaymentStatus(paymentHash [32]byte) (PaymentStatus, error) { var err error paymentStatus, err = fetchPaymentStatusTx(tx, paymentHash) return err + }, func() { + paymentStatus = StatusUnknown }) if err != nil { return StatusUnknown, err @@ -424,6 +428,8 @@ func (db *DB) fetchPaymentsMigration9() ([]*Payment, error) { return nil }) }) + }, func() { + payments = nil }) if err != nil { return nil, err diff --git a/channeldb/migration_01_to_11/migrations_test.go b/channeldb/migration_01_to_11/migrations_test.go index 6cd855e8..010c10a1 100644 --- a/channeldb/migration_01_to_11/migrations_test.go +++ b/channeldb/migration_01_to_11/migrations_test.go @@ -418,6 +418,8 @@ func TestMigrateOptionalChannelCloseSummaryFields(t *testing.T) { "serialization") } return nil + }, func() { + dbSummary = nil }) if err != nil { t.Fatalf("unable to view DB: %v", err) @@ -521,6 +523,8 @@ func TestMigrateGossipMessageStoreKeys(t *testing.T) { } return nil + }, func() { + rawMsg = nil }) if err != nil { t.Fatal(err) diff --git a/channeldb/migration_01_to_11/payments.go b/channeldb/migration_01_to_11/payments.go index 697da0e0..e44be003 100644 --- a/channeldb/migration_01_to_11/payments.go +++ b/channeldb/migration_01_to_11/payments.go @@ -303,6 +303,8 @@ func (db *DB) FetchPayments() ([]*Payment, error) { return nil }) }) + }, func() { + payments = nil }) if err != nil { return nil, err diff --git a/channeldb/nodes.go b/channeldb/nodes.go index dcbdf63a..03871264 100644 --- a/channeldb/nodes.go +++ b/channeldb/nodes.go @@ -158,6 +158,8 @@ func (db *DB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) { linkNode = node return nil + }, func() { + linkNode = nil }) return linkNode, err @@ -199,6 +201,8 @@ func (db *DB) FetchAllLinkNodes() ([]*LinkNode, error) { linkNodes = nodes return nil + }, func() { + linkNodes = nil }) if err != nil { return nil, err diff --git a/channeldb/payment_control.go b/channeldb/payment_control.go index d67b0a8e..9ed5e639 100644 --- a/channeldb/payment_control.go +++ b/channeldb/payment_control.go @@ -550,6 +550,8 @@ func (p *PaymentControl) FetchPayment(paymentHash lntypes.Hash) ( payment, err = fetchPayment(bucket) return err + }, func() { + payment = nil }) if err != nil { return nil, err @@ -716,6 +718,8 @@ func (p *PaymentControl) FetchInFlightPayments() ([]*InFlightPayment, error) { inFlights = append(inFlights, inFlight) return nil }) + }, func() { + inFlights = nil }) if err != nil { return nil, err diff --git a/channeldb/payment_control_test.go b/channeldb/payment_control_test.go index 4f901462..e395a93f 100644 --- a/channeldb/payment_control_test.go +++ b/channeldb/payment_control_test.go @@ -502,7 +502,7 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) { indexCount++ return nil }) - }) + }, func() { indexCount = 0 }) require.NoError(t, err) require.Equal(t, 1, indexCount) @@ -989,7 +989,8 @@ func fetchPaymentIndexEntry(_ *testing.T, p *PaymentControl, var err error hash, err = deserializePaymentIndex(r) return err - + }, func() { + hash = lntypes.Hash{} }); err != nil { return nil, err } diff --git a/channeldb/payments.go b/channeldb/payments.go index 5c2475bd..3344451c 100644 --- a/channeldb/payments.go +++ b/channeldb/payments.go @@ -269,6 +269,8 @@ func (db *DB) FetchPayments() ([]*MPPayment, error) { payments = append(payments, duplicatePayments...) return nil }) + }, func() { + payments = nil }) if err != nil { return nil, err @@ -572,6 +574,8 @@ func (db *DB) QueryPayments(query PaymentsQuery) (PaymentsResponse, error) { } return nil + }, func() { + resp = PaymentsResponse{} }); err != nil { return resp, err } diff --git a/channeldb/peers.go b/channeldb/peers.go index 4dea9f2f..55920b0d 100644 --- a/channeldb/peers.go +++ b/channeldb/peers.go @@ -113,6 +113,8 @@ func (d *DB) ReadFlapCount(pubkey route.Vertex) (*FlapCount, error) { } return ReadElements(r, &flapCount.Count) + }, func() { + flapCount = FlapCount{} }); err != nil { return nil, err } diff --git a/channeldb/reports.go b/channeldb/reports.go index f71a1c4c..111e787e 100644 --- a/channeldb/reports.go +++ b/channeldb/reports.go @@ -250,6 +250,8 @@ func (d DB) FetchChannelReports(chainHash chainhash.Hash, return nil }) + }, func() { + reports = nil }); err != nil { return nil, err } diff --git a/channeldb/waitingproof.go b/channeldb/waitingproof.go index 2ea706c8..43a4fd8a 100644 --- a/channeldb/waitingproof.go +++ b/channeldb/waitingproof.go @@ -42,13 +42,14 @@ type WaitingProofStore struct { // NewWaitingProofStore creates new instance of proofs storage. func NewWaitingProofStore(db *DB) (*WaitingProofStore, error) { s := &WaitingProofStore{ - db: db, - cache: make(map[WaitingProofKey]struct{}), + db: db, } if err := s.ForAll(func(proof *WaitingProof) error { s.cache[proof.Key()] = struct{}{} return nil + }, func() { + s.cache = make(map[WaitingProofKey]struct{}) }); err != nil && err != ErrWaitingProofNotFound { return nil, err } @@ -122,7 +123,9 @@ func (s *WaitingProofStore) Remove(key WaitingProofKey) error { // ForAll iterates thought all waiting proofs and passing the waiting proof // in the given callback. -func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error) error { +func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error, + reset func()) error { + return kvdb.View(s.db, func(tx kvdb.RTx) error { bucket := tx.ReadBucket(waitingProofsBucketKey) if bucket == nil { @@ -144,12 +147,12 @@ func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error) error { return cb(proof) }) - }) + }, reset) } // Get returns the object which corresponds to the given index. func (s *WaitingProofStore) Get(key WaitingProofKey) (*WaitingProof, error) { - proof := &WaitingProof{} + var proof *WaitingProof s.mu.RLock() defer s.mu.RUnlock() @@ -172,6 +175,8 @@ func (s *WaitingProofStore) Get(key WaitingProofKey) (*WaitingProof, error) { r := bytes.NewReader(v) return proof.Decode(r) + }, func() { + proof = &WaitingProof{} }) return proof, err diff --git a/channeldb/waitingproof_test.go b/channeldb/waitingproof_test.go index 12679b69..ed1bf050 100644 --- a/channeldb/waitingproof_test.go +++ b/channeldb/waitingproof_test.go @@ -53,7 +53,7 @@ func TestWaitingProofStore(t *testing.T) { if err := store.ForAll(func(proof *WaitingProof) error { return errors.New("storage should be empty") - }); err != nil && err != ErrWaitingProofNotFound { + }, func() {}); err != nil && err != ErrWaitingProofNotFound { t.Fatal(err) } } diff --git a/channeldb/witness_cache.go b/channeldb/witness_cache.go index 3a91fc44..d4abfeed 100644 --- a/channeldb/witness_cache.go +++ b/channeldb/witness_cache.go @@ -174,6 +174,8 @@ func (w *WitnessCache) lookupWitness(wType WitnessType, witnessKey []byte) ([]by copy(witness[:], dbWitness) return nil + }, func() { + witness = nil }) if err != nil { return nil, err diff --git a/contractcourt/briefcase.go b/contractcourt/briefcase.go index fbf2e96a..d1dc17dd 100644 --- a/contractcourt/briefcase.go +++ b/contractcourt/briefcase.go @@ -430,6 +430,8 @@ func (b *boltArbitratorLog) CurrentState() (ArbitratorState, error) { s = ArbitratorState(stateBytes[0]) return nil + }, func() { + s = 0 }) if err != nil && err != errScopeBucketNoExist { return s, err @@ -521,6 +523,8 @@ func (b *boltArbitratorLog) FetchUnresolvedContracts() ([]ContractResolver, erro contracts = append(contracts, res) return nil }) + }, func() { + contracts = nil }) if err != nil && err != errScopeBucketNoExist && err != errNoContracts { return nil, err @@ -685,7 +689,7 @@ func (b *boltArbitratorLog) LogContractResolutions(c *ContractResolutions) error // // NOTE: Part of the ContractResolver interface. func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, error) { - c := &ContractResolutions{} + var c *ContractResolutions err := kvdb.View(b.db, func(tx kvdb.RTx) error { scopeBucket := tx.ReadBucket(b.scopeKey[:]) if scopeBucket == nil { @@ -769,6 +773,8 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er } return nil + }, func() { + c = &ContractResolutions{} }) if err != nil { return nil, err @@ -783,7 +789,7 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er // // NOTE: Part of the ContractResolver interface. func (b *boltArbitratorLog) FetchChainActions() (ChainActionMap, error) { - actionsMap := make(ChainActionMap) + var actionsMap ChainActionMap err := kvdb.View(b.db, func(tx kvdb.RTx) error { scopeBucket := tx.ReadBucket(b.scopeKey[:]) @@ -813,6 +819,8 @@ func (b *boltArbitratorLog) FetchChainActions() (ChainActionMap, error) { return nil }) + }, func() { + actionsMap = make(ChainActionMap) }) if err != nil { return nil, err @@ -866,6 +874,8 @@ func (b *boltArbitratorLog) FetchConfirmedCommitSet() (*CommitSet, error) { c = commitSet return nil + }, func() { + c = nil }) if err != nil { return nil, err diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 9dfad6ba..f37caa8d 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -1117,6 +1117,9 @@ func TestSignatureAnnouncementLocalFirst(t *testing.T) { number++ return nil }, + func() { + number = 0 + }, ); err != nil { t.Fatalf("unable to retrieve objects from store: %v", err) } @@ -1150,6 +1153,9 @@ func TestSignatureAnnouncementLocalFirst(t *testing.T) { number++ return nil }, + func() { + number = 0 + }, ); err != nil && err != channeldb.ErrWaitingProofNotFound { t.Fatalf("unable to retrieve objects from store: %v", err) } @@ -1219,6 +1225,9 @@ func TestOrphanSignatureAnnouncement(t *testing.T) { number++ return nil }, + func() { + number = 0 + }, ); err != nil { t.Fatalf("unable to retrieve objects from store: %v", err) } @@ -1355,6 +1364,9 @@ func TestOrphanSignatureAnnouncement(t *testing.T) { number++ return nil }, + func() { + number = 0 + }, ); err != nil { t.Fatalf("unable to retrieve objects from store: %v", err) } @@ -1466,6 +1478,9 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { number++ return nil }, + func() { + number = 0 + }, ); err != nil { t.Fatalf("unable to retrieve objects from store: %v", err) } @@ -1570,6 +1585,9 @@ out: number++ return nil }, + func() { + number = 0 + }, ); err != nil && err != channeldb.ErrWaitingProofNotFound { t.Fatalf("unable to retrieve objects from store: %v", err) } @@ -1754,6 +1772,9 @@ func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) { number++ return nil }, + func() { + number = 0 + }, ); err != nil && err != channeldb.ErrWaitingProofNotFound { t.Fatalf("unable to retrieve objects from store: %v", err) } @@ -2583,6 +2604,9 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { number++ return nil }, + func() { + number = 0 + }, ); err != nil { t.Fatalf("unable to retrieve objects from store: %v", err) } @@ -2612,6 +2636,9 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { number++ return nil }, + func() { + number = 0 + }, ); err != nil && err != channeldb.ErrWaitingProofNotFound { t.Fatalf("unable to retrieve objects from store: %v", err) } diff --git a/discovery/message_store.go b/discovery/message_store.go index f86ede20..6cce5494 100644 --- a/discovery/message_store.go +++ b/discovery/message_store.go @@ -199,7 +199,7 @@ func readMessage(msgBytes []byte) (lnwire.Message, error) { // Messages returns the total set of messages that exist within the store for // all peers. func (s *MessageStore) Messages() (map[[33]byte][]lnwire.Message, error) { - msgs := make(map[[33]byte][]lnwire.Message) + var msgs map[[33]byte][]lnwire.Message err := kvdb.View(s.db, func(tx kvdb.RTx) error { messageStore := tx.ReadBucket(messageStoreBucket) if messageStore == nil { @@ -224,6 +224,8 @@ func (s *MessageStore) Messages() (map[[33]byte][]lnwire.Message, error) { msgs[pubKey] = append(msgs[pubKey], msg) return nil }) + }, func() { + msgs = make(map[[33]byte][]lnwire.Message) }) if err != nil { return nil, err @@ -262,6 +264,8 @@ func (s *MessageStore) MessagesForPeer( } return nil + }, func() { + msgs = nil }) if err != nil { return nil, err @@ -272,7 +276,7 @@ func (s *MessageStore) MessagesForPeer( // Peers returns the public key of all peers with messages within the store. func (s *MessageStore) Peers() (map[[33]byte]struct{}, error) { - peers := make(map[[33]byte]struct{}) + var peers map[[33]byte]struct{} err := kvdb.View(s.db, func(tx kvdb.RTx) error { messageStore := tx.ReadBucket(messageStoreBucket) if messageStore == nil { @@ -285,6 +289,8 @@ func (s *MessageStore) Peers() (map[[33]byte]struct{}, error) { peers[pubKey] = struct{}{} return nil }) + }, func() { + peers = make(map[[33]byte]struct{}) }) if err != nil { return nil, err diff --git a/fundingmanager.go b/fundingmanager.go index ebe95d42..1bd018b1 100644 --- a/fundingmanager.go +++ b/fundingmanager.go @@ -3538,7 +3538,7 @@ func (f *fundingManager) getChannelOpeningState(chanPoint *wire.OutPoint) ( state = channelOpeningState(byteOrder.Uint16(value[:2])) shortChanID = lnwire.NewShortChanIDFromInt(byteOrder.Uint64(value[2:])) return nil - }) + }, func() {}) if err != nil { return 0, nil, err } diff --git a/htlcswitch/decayedlog.go b/htlcswitch/decayedlog.go index db27ed39..8b8a78ef 100644 --- a/htlcswitch/decayedlog.go +++ b/htlcswitch/decayedlog.go @@ -280,6 +280,8 @@ func (d *DecayedLog) Get(hash *sphinx.HashPrefix) (uint32, error) { value = uint32(binary.BigEndian.Uint32(valueBytes)) return nil + }, func() { + value = 0 }) if err != nil { return value, err diff --git a/htlcswitch/payment_result.go b/htlcswitch/payment_result.go index e6a1e59f..d341ba12 100644 --- a/htlcswitch/payment_result.go +++ b/htlcswitch/payment_result.go @@ -197,6 +197,8 @@ func (store *networkResultStore) subscribeResult(paymentID uint64) ( default: return nil } + }, func() { + result = nil }) if err != nil { return nil, err @@ -230,6 +232,8 @@ func (store *networkResultStore) getResult(pid uint64) ( var err error result, err = fetchResult(tx, pid) return err + }, func() { + result = nil }) if err != nil { return nil, err diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index d03d7496..1f755bb0 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -1833,6 +1833,8 @@ func (s *Switch) loadChannelFwdPkgs(source lnwire.ShortChannelID) ([]*channeldb. tx, source, ) return err + }, func() { + fwdPkgs = nil }); err != nil { return nil, err } diff --git a/macaroons/store.go b/macaroons/store.go index 97baae49..1e4447a8 100644 --- a/macaroons/store.go +++ b/macaroons/store.go @@ -150,6 +150,8 @@ func (r *RootKeyStorage) Get(_ context.Context, id []byte) ([]byte, error) { rootKey = make([]byte, len(decKey)) copy(rootKey[:], decKey) return nil + }, func() { + rootKey = nil }) if err != nil { return nil, err @@ -257,6 +259,8 @@ func (r *RootKeyStorage) ListMacaroonIDs(_ context.Context) ([][]byte, error) { } return tx.ReadBucket(rootKeyBucketName).ForEach(appendRootKey) + }, func() { + rootKeySlice = nil }) if err != nil { return nil, err diff --git a/nursery_store.go b/nursery_store.go index 42a13607..97fe08cc 100644 --- a/nursery_store.go +++ b/nursery_store.go @@ -129,7 +129,7 @@ type NurseryStore interface { // the caller to process each key-value pair. The key will be a prefixed // outpoint, and the value will be the serialized bytes for an output, // whose type should be inferred from the key's prefix. - ForChanOutputs(*wire.OutPoint, func([]byte, []byte) error) error + ForChanOutputs(*wire.OutPoint, func([]byte, []byte) error, func()) error // ListChannels returns all channels the nursery is currently tracking. ListChannels() ([]wire.OutPoint, error) @@ -582,6 +582,9 @@ func (ns *nurseryStore) FetchClass( }) + }, func() { + kids = nil + babies = nil }); err != nil { return nil, nil, err } @@ -655,6 +658,8 @@ func (ns *nurseryStore) FetchPreschools() ([]kidOutput, error) { } return nil + }, func() { + kids = nil }); err != nil { return nil, err } @@ -693,6 +698,8 @@ func (ns *nurseryStore) HeightsBelowOrEqual(height uint32) ([]uint32, error) { } return nil + }, func() { + activeHeights = nil }) if err != nil { return nil, err @@ -709,11 +716,11 @@ func (ns *nurseryStore) HeightsBelowOrEqual(height uint32) ([]uint32, error) { // NOTE: The callback should not modify the provided byte slices and is // preferably non-blocking. func (ns *nurseryStore) ForChanOutputs(chanPoint *wire.OutPoint, - callback func([]byte, []byte) error) error { + callback func([]byte, []byte) error, reset func()) error { return kvdb.View(ns.db, func(tx kvdb.RTx) error { return ns.forChanOutputs(tx, chanPoint, callback) - }) + }, reset) } // ListChannels returns all channels the nursery is currently tracking. @@ -743,6 +750,8 @@ func (ns *nurseryStore) ListChannels() ([]wire.OutPoint, error) { return nil }) + }, func() { + activeChannels = nil }); err != nil { return nil, err } @@ -765,7 +774,7 @@ func (ns *nurseryStore) IsMatureChannel(chanPoint *wire.OutPoint) (bool, error) return nil }) - }) + }, func() {}) if err != nil && err != ErrImmatureChannel { return false, err } diff --git a/nursery_store_test.go b/nursery_store_test.go index 052b2304..1d4caf51 100644 --- a/nursery_store_test.go +++ b/nursery_store_test.go @@ -370,6 +370,8 @@ func assertNumChanOutputs(t *testing.T, ns NurseryStore, err := ns.ForChanOutputs(chanPoint, func([]byte, []byte) error { count++ return nil + }, func() { + count = 0 }) if count == 0 && err == ErrContractNotFound { diff --git a/routing/missioncontrol_store.go b/routing/missioncontrol_store.go index e0b2e82a..5113b0b0 100644 --- a/routing/missioncontrol_store.go +++ b/routing/missioncontrol_store.go @@ -103,6 +103,8 @@ func (b *missionControlStore) fetchAll() ([]*paymentResult, error) { return nil }) + }, func() { + results = nil }) if err != nil { return nil, err diff --git a/sweep/store.go b/sweep/store.go index 7809db96..c78759c4 100644 --- a/sweep/store.go +++ b/sweep/store.go @@ -219,6 +219,8 @@ func (s *sweeperStore) GetLastPublishedTx() (*wire.MsgTx, error) { } return nil + }, func() { + sweepTx = nil }) if err != nil { return nil, err @@ -241,6 +243,8 @@ func (s *sweeperStore) IsOurTx(hash chainhash.Hash) (bool, error) { ours = txHashesBucket.Get(hash[:]) != nil return nil + }, func() { + ours = false }) if err != nil { return false, err @@ -269,6 +273,8 @@ func (s *sweeperStore) ListSweeps() ([]chainhash.Hash, error) { return nil }) + }, func() { + sweepTxns = nil }); err != nil { return nil, err } diff --git a/utxonursery.go b/utxonursery.go index 217aa2da..d73ceba5 100644 --- a/utxonursery.go +++ b/utxonursery.go @@ -477,7 +477,7 @@ func (u *utxoNursery) NurseryReport( utxnLog.Debugf("NurseryReport: building nursery report for channel %v", chanPoint) - report := &contractMaturityReport{} + var report *contractMaturityReport if err := u.cfg.Store.ForChanOutputs(chanPoint, func(k, v []byte) error { switch { @@ -576,6 +576,8 @@ func (u *utxoNursery) NurseryReport( } return nil + }, func() { + report = &contractMaturityReport{} }); err != nil { return nil, err } diff --git a/utxonursery_test.go b/utxonursery_test.go index 5675c9ea..4e02b24d 100644 --- a/utxonursery_test.go +++ b/utxonursery_test.go @@ -931,9 +931,9 @@ func (i *nurseryStoreInterceptor) HeightsBelowOrEqual(height uint32) ( } func (i *nurseryStoreInterceptor) ForChanOutputs(chanPoint *wire.OutPoint, - callback func([]byte, []byte) error) error { + callback func([]byte, []byte) error, reset func()) error { - return i.ns.ForChanOutputs(chanPoint, callback) + return i.ns.ForChanOutputs(chanPoint, callback, reset) } func (i *nurseryStoreInterceptor) ListChannels() ([]wire.OutPoint, error) { diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 66aeb345..dc88d45c 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -192,6 +192,8 @@ func (c *ClientDB) Version() (uint32, error) { var err error version, err = getDBVersion(tx) return err + }, func() { + version = 0 }) if err != nil { return 0, err @@ -392,6 +394,8 @@ func (c *ClientDB) LoadTowerByID(towerID TowerID) (*Tower, error) { var err error tower, err = getTower(towers, towerID.Bytes()) return err + }, func() { + tower = nil }) if err != nil { return nil, err @@ -421,6 +425,8 @@ func (c *ClientDB) LoadTower(pubKey *btcec.PublicKey) (*Tower, error) { var err error tower, err = getTower(towers, towerIDBytes) return err + }, func() { + tower = nil }) if err != nil { return nil, err @@ -446,6 +452,8 @@ func (c *ClientDB) ListTowers() ([]*Tower, error) { towers = append(towers, tower) return nil }) + }, func() { + towers = nil }) if err != nil { return nil, err @@ -566,6 +574,8 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (map[SessionID]*ClientSession var err error clientSessions, err = listClientSessions(sessions, id) return err + }, func() { + clientSessions = nil }) if err != nil { return nil, err @@ -611,7 +621,7 @@ func listClientSessions(sessions kvdb.RBucket, // FetchChanSummaries loads a mapping from all registered channels to their // channel summaries. func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) { - summaries := make(map[lnwire.ChannelID]ClientChanSummary) + var summaries map[lnwire.ChannelID]ClientChanSummary err := kvdb.View(c.db, func(tx kvdb.RTx) error { chanSummaries := tx.ReadBucket(cChanSummaryBkt) if chanSummaries == nil { @@ -632,6 +642,8 @@ func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) { return nil }) + }, func() { + summaries = make(map[lnwire.ChannelID]ClientChanSummary) }) if err != nil { return nil, err diff --git a/watchtower/wtdb/db_common.go b/watchtower/wtdb/db_common.go index d769f13e..c4f8bb17 100644 --- a/watchtower/wtdb/db_common.go +++ b/watchtower/wtdb/db_common.go @@ -80,6 +80,8 @@ func createDBIfNotExist(dbPath, name string) (kvdb.Backend, bool, error) { err = kvdb.View(bdb, func(tx kvdb.RTx) error { metadataExists = tx.ReadBucket(metadataBkt) != nil return nil + }, func() { + metadataExists = false }) if err != nil { return nil, false, err diff --git a/watchtower/wtdb/tower_db.go b/watchtower/wtdb/tower_db.go index d98fdad5..a788a100 100644 --- a/watchtower/wtdb/tower_db.go +++ b/watchtower/wtdb/tower_db.go @@ -133,6 +133,8 @@ func (t *TowerDB) Version() (uint32, error) { var err error version, err = getDBVersion(tx) return err + }, func() { + version = 0 }) if err != nil { return 0, err @@ -159,6 +161,8 @@ func (t *TowerDB) GetSessionInfo(id *SessionID) (*SessionInfo, error) { var err error session, err = getSession(sessions, id[:]) return err + }, func() { + session = nil }) if err != nil { return nil, err @@ -460,6 +464,8 @@ func (t *TowerDB) QueryMatches(breachHints []blob.BreachHint) ([]Match, error) { } return nil + }, func() { + matches = nil }) if err != nil { return nil, err @@ -494,6 +500,8 @@ func (t *TowerDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) { epoch = getLookoutEpoch(lookoutTip) return nil + }, func() { + epoch = nil }) if err != nil { return nil, err