multi: add reset closure to kvdb.View

This commit adds a reset() closure to the kvdb.View function which will
be called before each retry (including the first) of the view
transaction. The reset() closure can be used to reset external state
(eg slices or maps) where the view closure puts intermediate results.
This commit is contained in:
Andras Banki-Horvath 2020-10-20 16:18:40 +02:00
parent ffb27284df
commit 2a358327f4
No known key found for this signature in database
GPG Key ID: 80E5375C094198D8
47 changed files with 340 additions and 82 deletions

@ -152,10 +152,12 @@ func (b *breachArbiter) start() error {
brarLog.Tracef("Starting breach arbiter") brarLog.Tracef("Starting breach arbiter")
// Load all retributions currently persisted in the retribution store. // Load all retributions currently persisted in the retribution store.
breachRetInfos := make(map[wire.OutPoint]retributionInfo) var breachRetInfos map[wire.OutPoint]retributionInfo
if err := b.cfg.Store.ForAll(func(ret *retributionInfo) error { if err := b.cfg.Store.ForAll(func(ret *retributionInfo) error {
breachRetInfos[ret.chanPoint] = *ret breachRetInfos[ret.chanPoint] = *ret
return nil return nil
}, func() {
breachRetInfos = make(map[wire.OutPoint]retributionInfo)
}); err != nil { }); err != nil {
return err return err
} }
@ -1223,7 +1225,7 @@ type RetributionStore interface {
// ForAll iterates over the existing on-disk contents and applies a // ForAll iterates over the existing on-disk contents and applies a
// chosen, read-only callback to each. This method should ensure that it // chosen, read-only callback to each. This method should ensure that it
// immediately propagate any errors generated by the callback. // immediately propagate any errors generated by the callback.
ForAll(cb func(*retributionInfo) error) error ForAll(cb func(*retributionInfo) error, reset func()) error
} }
// retributionStore handles persistence of retribution states to disk and is // retributionStore handles persistence of retribution states to disk and is
@ -1312,6 +1314,8 @@ func (rs *retributionStore) GetFinalizedTxn(
finalTxBytes = justiceBkt.Get(chanBuf.Bytes()) finalTxBytes = justiceBkt.Get(chanBuf.Bytes())
return nil return nil
}, func() {
finalTxBytes = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
@ -1349,6 +1353,8 @@ func (rs *retributionStore) IsBreached(chanPoint *wire.OutPoint) (bool, error) {
} }
return nil return nil
}, func() {
found = false
}) })
return found, err return found, err
@ -1395,7 +1401,9 @@ func (rs *retributionStore) Remove(chanPoint *wire.OutPoint) error {
// ForAll iterates through all stored retributions and executes the passed // ForAll iterates through all stored retributions and executes the passed
// callback function on each retribution. // callback function on each retribution.
func (rs *retributionStore) ForAll(cb func(*retributionInfo) error) error { func (rs *retributionStore) ForAll(cb func(*retributionInfo) error,
reset func()) error {
return kvdb.View(rs.db, func(tx kvdb.RTx) error { return kvdb.View(rs.db, func(tx kvdb.RTx) error {
// If the bucket does not exist, then there are no pending // If the bucket does not exist, then there are no pending
// retributions. // retributions.
@ -1416,7 +1424,7 @@ func (rs *retributionStore) ForAll(cb func(*retributionInfo) error) error {
return cb(ret) return cb(ret)
}) })
}) }, reset)
} }
// Encode serializes the retribution into the passed byte stream. // Encode serializes the retribution into the passed byte stream.

@ -431,11 +431,13 @@ func (frs *failingRetributionStore) Remove(key *wire.OutPoint) error {
return frs.rs.Remove(key) return frs.rs.Remove(key)
} }
func (frs *failingRetributionStore) ForAll(cb func(*retributionInfo) error) error { func (frs *failingRetributionStore) ForAll(cb func(*retributionInfo) error,
reset func()) error {
frs.mu.Lock() frs.mu.Lock()
defer frs.mu.Unlock() defer frs.mu.Unlock()
return frs.rs.ForAll(cb) return frs.rs.ForAll(cb, reset)
} }
// Parse the pubkeys in the breached outputs. // Parse the pubkeys in the breached outputs.
@ -592,10 +594,13 @@ func (rs *mockRetributionStore) Remove(key *wire.OutPoint) error {
return nil return nil
} }
func (rs *mockRetributionStore) ForAll(cb func(*retributionInfo) error) error { func (rs *mockRetributionStore) ForAll(cb func(*retributionInfo) error,
reset func()) error {
rs.mu.Lock() rs.mu.Lock()
defer rs.mu.Unlock() defer rs.mu.Unlock()
reset()
for _, retInfo := range rs.state { for _, retInfo := range rs.state {
if err := cb(copyRetInfo(retInfo)); err != nil { if err := cb(copyRetInfo(retInfo)); err != nil {
return err return err
@ -717,6 +722,8 @@ func countRetributions(t *testing.T, rs RetributionStore) int {
err := rs.ForAll(func(_ *retributionInfo) error { err := rs.ForAll(func(_ *retributionInfo) error {
count++ count++
return nil return nil
}, func() {
count = 0
}) })
if err != nil { if err != nil {
t.Fatalf("unable to list retributions in db: %v", err) t.Fatalf("unable to list retributions in db: %v", err)
@ -919,7 +926,7 @@ restartCheck:
// Construct a set of all channel points presented by the store. Entries // Construct a set of all channel points presented by the store. Entries
// are only be added to the set if their corresponding retribution // are only be added to the set if their corresponding retribution
// information matches the test vector. // information matches the test vector.
var foundSet = make(map[wire.OutPoint]struct{}) var foundSet map[wire.OutPoint]struct{}
// Iterate through the stored retributions, checking to see if we have // Iterate through the stored retributions, checking to see if we have
// an equivalent retribution in the test vector. This will return an // an equivalent retribution in the test vector. This will return an
@ -948,6 +955,8 @@ restartCheck:
} }
return nil return nil
}, func() {
foundSet = make(map[wire.OutPoint]struct{})
}); err != nil { }); err != nil {
t.Fatalf("failed to iterate over persistent retributions: %v", t.Fatalf("failed to iterate over persistent retributions: %v",
err) err)

@ -179,6 +179,8 @@ func (c *HeightHintCache) QuerySpendHint(spendRequest SpendRequest) (uint32, err
} }
return channeldb.ReadElement(bytes.NewReader(spendHint), &hint) return channeldb.ReadElement(bytes.NewReader(spendHint), &hint)
}, func() {
hint = 0
}) })
if err != nil { if err != nil {
return 0, err return 0, err
@ -278,6 +280,8 @@ func (c *HeightHintCache) QueryConfirmHint(confRequest ConfRequest) (uint32, err
} }
return channeldb.ReadElement(bytes.NewReader(confirmHint), &hint) return channeldb.ReadElement(bytes.NewReader(confirmHint), &hint)
}, func() {
hint = 0
}) })
if err != nil { if err != nil {
return 0, err return 0, err

@ -756,7 +756,7 @@ func (c *OpenChannel) RefreshShortChanID() error {
} }
return nil return nil
}) }, func() {})
if err != nil { if err != nil {
return err return err
} }
@ -950,6 +950,8 @@ func (c *OpenChannel) DataLossCommitPoint() (*btcec.PublicKey, error) {
} }
return nil return nil
}, func() {
commitPoint = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -1168,6 +1170,8 @@ func (c *OpenChannel) getClosingTx(key []byte) (*wire.MsgTx, error) {
} }
r := bytes.NewReader(bs) r := bytes.NewReader(bs)
return ReadElement(r, &closeTx) return ReadElement(r, &closeTx)
}, func() {
closeTx = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -2062,6 +2066,8 @@ func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) {
cd = dcd cd = dcd
return nil return nil
}, func() {
cd = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -2094,6 +2100,8 @@ func (c *OpenChannel) UnsignedAckedUpdates() ([]LogUpdate, error) {
r := bytes.NewReader(updateBytes) r := bytes.NewReader(updateBytes)
updates, err = deserializeLogUpdates(r) updates, err = deserializeLogUpdates(r)
return err return err
}, func() {
updates = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -2127,6 +2135,8 @@ func (c *OpenChannel) RemoteUnsignedLocalUpdates() ([]LogUpdate, error) {
r := bytes.NewReader(updateBytes) r := bytes.NewReader(updateBytes)
updates, err = deserializeLogUpdates(r) updates, err = deserializeLogUpdates(r)
return err return err
}, func() {
updates = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -2365,6 +2375,8 @@ func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) {
var err error var err error
fwdPkgs, err = c.Packager.LoadFwdPkgs(tx) fwdPkgs, err = c.Packager.LoadFwdPkgs(tx)
return err return err
}, func() {
fwdPkgs = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
@ -2475,7 +2487,7 @@ func (c *OpenChannel) RevocationLogTail() (*ChannelCommitment, error) {
} }
return nil return nil
}); err != nil { }, func() {}); err != nil {
return nil, err return nil, err
} }
@ -2509,6 +2521,8 @@ func (c *OpenChannel) CommitmentHeight() (uint64, error) {
height = commit.CommitHeight height = commit.CommitHeight
return nil return nil
}, func() {
height = 0
}) })
if err != nil { if err != nil {
return 0, err return 0, err
@ -2547,7 +2561,7 @@ func (c *OpenChannel) FindPreviousState(updateNum uint64) (*ChannelCommitment, e
commit = c commit = c
return nil return nil
}) }, func() {})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -2870,7 +2884,7 @@ func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitmen
} }
return fetchChanCommitments(chanBucket, c) return fetchChanCommitments(chanBucket, c)
}) }, func() {})
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -2892,7 +2906,7 @@ func (c *OpenChannel) RemoteRevocationStore() (shachain.Store, error) {
} }
return fetchChanRevocationState(chanBucket, c) return fetchChanRevocationState(chanBucket, c)
}) }, func() {})
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -201,12 +201,16 @@ func (db *DB) Update(f func(tx walletdb.ReadWriteTx) error) error {
// View is a wrapper around walletdb.View which calls into the extended // View is a wrapper around walletdb.View which calls into the extended
// backend when available. This call is needed to be able to cast DB to // backend when available. This call is needed to be able to cast DB to
// ExtendedBackend. // ExtendedBackend. The passed reset function is called before the start of the
func (db *DB) View(f func(tx walletdb.ReadTx) error) error { // transaction and can be used to reset intermediate state. As callers may
// expect retries of the f closure (depending on the database backend used), the
// reset function will be called before each retry respectively.
func (db *DB) View(f func(tx walletdb.ReadTx) error, reset func()) error {
if v, ok := db.Backend.(kvdb.ExtendedBackend); ok { if v, ok := db.Backend.(kvdb.ExtendedBackend); ok {
return v.View(f) return v.View(f, reset)
} }
reset()
return walletdb.View(db, f) return walletdb.View(db, f)
} }
@ -389,6 +393,8 @@ func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error)
var err error var err error
channels, err = d.fetchOpenChannels(tx, nodeID) channels, err = d.fetchOpenChannels(tx, nodeID)
return err return err
}, func() {
channels = nil
}) })
return channels, err return channels, err
@ -574,7 +580,7 @@ func (d *DB) FetchChannel(chanPoint wire.OutPoint) (*OpenChannel, error) {
}) })
} }
err := kvdb.View(d, chanScan) err := kvdb.View(d, chanScan, func() {})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -741,6 +747,8 @@ func fetchChannels(d *DB, filters ...fetchChannelsFilter) ([]*OpenChannel, error
}) })
}) })
}, func() {
channels = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -781,6 +789,8 @@ func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, erro
chanSummaries = append(chanSummaries, chanSummary) chanSummaries = append(chanSummaries, chanSummary)
return nil return nil
}) })
}, func() {
chanSummaries = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
@ -817,6 +827,8 @@ func (d *DB) FetchClosedChannel(chanID *wire.OutPoint) (*ChannelCloseSummary, er
chanSummary, err = deserializeCloseChannelSummary(summaryReader) chanSummary, err = deserializeCloseChannelSummary(summaryReader)
return err return err
}, func() {
chanSummary = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
@ -865,6 +877,8 @@ func (d *DB) FetchClosedChannelForID(cid lnwire.ChannelID) (
return nil return nil
} }
return ErrClosedChannelNotFound return ErrClosedChannelNotFound
}, func() {
chanSummary = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
@ -1052,6 +1066,8 @@ func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) {
} }
return nil return nil
}, func() {
linkNode = nil
}) })
if dbErr != nil { if dbErr != nil {
return nil, dbErr return nil, dbErr
@ -1261,6 +1277,8 @@ func (db *DB) FetchHistoricalChannel(outPoint *wire.OutPoint) (*OpenChannel, err
channel, err = fetchOpenChannel(chanBucket, outPoint) channel, err = fetchOpenChannel(chanBucket, outPoint)
return err return err
}, func() {
channel = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -228,9 +228,7 @@ type ForwardingLogTimeSlice struct {
// //
// TODO(roasbeef): rename? // TODO(roasbeef): rename?
func (f *ForwardingLog) Query(q ForwardingEventQuery) (ForwardingLogTimeSlice, error) { func (f *ForwardingLog) Query(q ForwardingEventQuery) (ForwardingLogTimeSlice, error) {
resp := ForwardingLogTimeSlice{ var resp ForwardingLogTimeSlice
ForwardingEventQuery: q,
}
// If the user provided an index offset, then we'll not know how many // If the user provided an index offset, then we'll not know how many
// records we need to skip. We'll also keep track of the record offset // records we need to skip. We'll also keep track of the record offset
@ -297,6 +295,10 @@ func (f *ForwardingLog) Query(q ForwardingEventQuery) (ForwardingLogTimeSlice, e
} }
return nil return nil
}, func() {
resp = ForwardingLogTimeSlice{
ForwardingEventQuery: q,
}
}) })
if err != nil && err != ErrNoForwardingEvents { if err != nil && err != ErrNoForwardingEvents {
return ForwardingLogTimeSlice{}, err return ForwardingLogTimeSlice{}, err

@ -786,6 +786,8 @@ func loadFwdPkgs(t *testing.T, db kvdb.Backend,
var err error var err error
fwdPkgs, err = packager.LoadFwdPkgs(tx) fwdPkgs, err = packager.LoadFwdPkgs(tx)
return err return err
}, func() {
fwdPkgs = nil
}); err != nil { }); err != nil {
t.Fatalf("unable to load fwd pkgs: %v", err) t.Fatalf("unable to load fwd pkgs: %v", err)
} }

@ -249,7 +249,7 @@ func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, *ChannelEdgePoli
// be aborted. // be aborted.
return cb(&edgeInfo, edge1, edge2) return cb(&edgeInfo, edge1, edge2)
}) })
}) }, func() {})
} }
// ForEachNodeChannel iterates through all channels of a given node, executing the // ForEachNodeChannel iterates through all channels of a given node, executing the
@ -279,7 +279,7 @@ func (c *ChannelGraph) ForEachNodeChannel(tx kvdb.RTx, nodePub []byte,
// have their disabled bit on. // have their disabled bit on.
func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) { func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) {
var disabledChanIDs []uint64 var disabledChanIDs []uint64
chanEdgeFound := make(map[uint64]struct{}) var chanEdgeFound map[uint64]struct{}
err := kvdb.View(c.db, func(tx kvdb.RTx) error { err := kvdb.View(c.db, func(tx kvdb.RTx) error {
edges := tx.ReadBucket(edgeBucket) edges := tx.ReadBucket(edgeBucket)
@ -308,6 +308,9 @@ func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) {
chanEdgeFound[chanID] = struct{}{} chanEdgeFound[chanID] = struct{}{}
return nil return nil
}) })
}, func() {
disabledChanIDs = nil
chanEdgeFound = make(map[uint64]struct{})
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -353,7 +356,7 @@ func (c *ChannelGraph) ForEachNode(cb func(kvdb.RTx, *LightningNode) error) erro
}) })
} }
return kvdb.View(c.db, traversal) return kvdb.View(c.db, traversal, func() {})
} }
// SourceNode returns the source node of the graph. The source node is treated // SourceNode returns the source node of the graph. The source node is treated
@ -377,6 +380,8 @@ func (c *ChannelGraph) SourceNode() (*LightningNode, error) {
source = node source = node
return nil return nil
}, func() {
source = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -493,6 +498,8 @@ func (c *ChannelGraph) LookupAlias(pub *btcec.PublicKey) (string, error) {
// package... // package...
alias = string(a) alias = string(a)
return nil return nil
}, func() {
alias = ""
}) })
if err != nil { if err != nil {
return "", err return "", err
@ -774,7 +781,7 @@ func (c *ChannelGraph) HasChannelEdge(
} }
return nil return nil
}); err != nil { }, func() {}); err != nil {
return time.Time{}, time.Time{}, exists, isZombie, err return time.Time{}, time.Time{}, exists, isZombie, err
} }
@ -1235,7 +1242,7 @@ func (c *ChannelGraph) PruneTip() (*chainhash.Hash, uint32, error) {
tipHeight = byteOrder.Uint32(k[:]) tipHeight = byteOrder.Uint32(k[:])
return nil return nil
}) }, func() {})
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
@ -1312,6 +1319,8 @@ func (c *ChannelGraph) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
var err error var err error
chanID, err = getChanID(tx, chanPoint) chanID, err = getChanID(tx, chanPoint)
return err return err
}, func() {
chanID = 0
}); err != nil { }); err != nil {
return 0, err return 0, err
} }
@ -1379,6 +1388,8 @@ func (c *ChannelGraph) HighestChanID() (uint64, error) {
// to the caller. // to the caller.
cid = byteOrder.Uint64(lastChanID) cid = byteOrder.Uint64(lastChanID)
return nil return nil
}, func() {
cid = 0
}) })
if err != nil && err != ErrGraphNoEdgesFound { if err != nil && err != ErrGraphNoEdgesFound {
return 0, err return 0, err
@ -1409,8 +1420,8 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]Cha
// To ensure we don't return duplicate ChannelEdges, we'll use an // To ensure we don't return duplicate ChannelEdges, we'll use an
// additional map to keep track of the edges already seen to prevent // additional map to keep track of the edges already seen to prevent
// re-adding it. // re-adding it.
edgesSeen := make(map[uint64]struct{}) var edgesSeen map[uint64]struct{}
edgesToCache := make(map[uint64]ChannelEdge) var edgesToCache map[uint64]ChannelEdge
var edgesInHorizon []ChannelEdge var edgesInHorizon []ChannelEdge
c.cacheMu.Lock() c.cacheMu.Lock()
@ -1507,6 +1518,10 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]Cha
} }
return nil return nil
}, func() {
edgesSeen = make(map[uint64]struct{})
edgesToCache = make(map[uint64]ChannelEdge)
edgesInHorizon = nil
}) })
switch { switch {
case err == ErrGraphNoEdgesFound: case err == ErrGraphNoEdgesFound:
@ -1577,6 +1592,8 @@ func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, endTime time.Time) ([]Lig
} }
return nil return nil
}, func() {
nodesInHorizon = nil
}) })
switch { switch {
case err == ErrGraphNoEdgesFound: case err == ErrGraphNoEdgesFound:
@ -1637,6 +1654,8 @@ func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) {
} }
return nil return nil
}, func() {
newChanIDs = nil
}) })
switch { switch {
// If we don't know of any edges yet, then we'll return the entire set // If we don't know of any edges yet, then we'll return the entire set
@ -1701,7 +1720,10 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, endHeight uint32) ([]uint
} }
return nil return nil
}, func() {
chanIDs = nil
}) })
switch { switch {
// If we don't know of any channels yet, then there's nothing to // If we don't know of any channels yet, then there's nothing to
// filter, so we'll return an empty slice. // filter, so we'll return an empty slice.
@ -1775,6 +1797,8 @@ func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
}) })
} }
return nil return nil
}, func() {
chanEdges = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -2209,7 +2233,7 @@ func (c *ChannelGraph) FetchLightningNode(tx kvdb.RTx, nodePub route.Vertex) (
var err error var err error
if tx == nil { if tx == nil {
err = kvdb.View(c.db, fetchNode) err = kvdb.View(c.db, fetchNode, func() {})
} else { } else {
err = fetchNode(tx) err = fetchNode(tx)
} }
@ -2259,6 +2283,9 @@ func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, erro
exists = true exists = true
updateTime = node.LastUpdate updateTime = node.LastUpdate
return nil return nil
}, func() {
updateTime = time.Time{}
exists = false
}) })
if err != nil { if err != nil {
return time.Time{}, exists, err return time.Time{}, exists, err
@ -2346,7 +2373,7 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db *DB,
// If no transaction was provided, then we'll create a new transaction // If no transaction was provided, then we'll create a new transaction
// to execute the transaction within. // to execute the transaction within.
if tx == nil { if tx == nil {
return kvdb.View(db, traversal) return kvdb.View(db, traversal, func() {})
} }
// Otherwise, we re-use the existing transaction to execute the graph // Otherwise, we re-use the existing transaction to execute the graph
@ -2596,7 +2623,7 @@ func (c *ChannelEdgeInfo) FetchOtherNode(tx kvdb.RTx, thisNodeKey []byte) (*Ligh
// otherwise we can use the existing db transaction. // otherwise we can use the existing db transaction.
var err error var err error
if tx == nil { if tx == nil {
err = kvdb.View(c.db, fetchNodeFunc) err = kvdb.View(c.db, fetchNodeFunc, func() { targetNode = nil })
} else { } else {
err = fetchNodeFunc(tx) err = fetchNodeFunc(tx)
} }
@ -2929,6 +2956,10 @@ func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint,
policy1 = e1 policy1 = e1
policy2 = e2 policy2 = e2
return nil return nil
}, func() {
edgeInfo = nil
policy1 = nil
policy2 = nil
}) })
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
@ -3030,6 +3061,10 @@ func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64,
policy1 = e1 policy1 = e1
policy2 = e2 policy2 = e2
return nil return nil
}, func() {
edgeInfo = nil
policy1 = nil
policy2 = nil
}) })
if err == ErrZombieEdge { if err == ErrZombieEdge {
return edgeInfo, nil, nil, err return edgeInfo, nil, nil, err
@ -3062,6 +3097,8 @@ func (c *ChannelGraph) IsPublicNode(pubKey [33]byte) (bool, error) {
nodeIsPublic, err = node.isPublic(tx, ourPubKey) nodeIsPublic, err = node.isPublic(tx, ourPubKey)
return err return err
}, func() {
nodeIsPublic = false
}) })
if err != nil { if err != nil {
return false, err return false, err
@ -3183,6 +3220,8 @@ func (c *ChannelGraph) ChannelView() ([]EdgePoint, error) {
return nil return nil
}) })
}, func() {
edgePoints = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
@ -3261,6 +3300,10 @@ func (c *ChannelGraph) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte) {
isZombie, pubKey1, pubKey2 = isZombieEdge(zombieIndex, chanID) isZombie, pubKey1, pubKey2 = isZombieEdge(zombieIndex, chanID)
return nil return nil
}, func() {
isZombie = false
pubKey1 = [33]byte{}
pubKey2 = [33]byte{}
}) })
if err != nil { if err != nil {
return false, [33]byte{}, [33]byte{} return false, [33]byte{}, [33]byte{}
@ -3307,6 +3350,8 @@ func (c *ChannelGraph) NumZombies() (uint64, error) {
numZombies++ numZombies++
return nil return nil
}) })
}, func() {
numZombies = 0
}) })
if err != nil { if err != nil {
return 0, err return 0, err

@ -2272,7 +2272,7 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) {
return nil return nil
}) })
}) }, func() {})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -2858,7 +2858,7 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) {
} }
return nil return nil
}) }, func() {})
if err != nil { if err != nil {
t.Fatalf("error reading db: %v", err) t.Fatalf("error reading db: %v", err)
} }

@ -624,6 +624,8 @@ func (d *DB) InvoicesAddedSince(sinceAddIndex uint64) ([]Invoice, error) {
} }
return nil return nil
}, func() {
newInvoices = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -669,7 +671,7 @@ func (d *DB) LookupInvoice(ref InvoiceRef) (Invoice, error) {
invoice = i invoice = i
return nil return nil
}) }, func() {})
if err != nil { if err != nil {
return invoice, err return invoice, err
} }
@ -731,13 +733,6 @@ func (d *DB) ScanInvoices(
scanFunc func(lntypes.Hash, *Invoice) error, reset func()) error { scanFunc func(lntypes.Hash, *Invoice) error, reset func()) error {
return kvdb.View(d, func(tx kvdb.RTx) error { return kvdb.View(d, func(tx kvdb.RTx) error {
// Reset partial results. As transaction commit success is not
// guaranteed when using etcd, we need to be prepared to redo
// the whole view transaction. In order to be able to do that
// we need a way to reset existing results. This is also done
// upon first run for initialization.
reset()
invoices := tx.ReadBucket(invoiceBucket) invoices := tx.ReadBucket(invoiceBucket)
if invoices == nil { if invoices == nil {
return ErrNoInvoicesCreated return ErrNoInvoicesCreated
@ -773,7 +768,7 @@ func (d *DB) ScanInvoices(
return scanFunc(paymentHash, &invoice) return scanFunc(paymentHash, &invoice)
}) })
}) }, reset)
} }
// InvoiceQuery represents a query to the invoice database. The query allows a // InvoiceQuery represents a query to the invoice database. The query allows a
@ -825,9 +820,7 @@ type InvoiceSlice struct {
// QueryInvoices allows a caller to query the invoice database for invoices // QueryInvoices allows a caller to query the invoice database for invoices
// within the specified add index range. // within the specified add index range.
func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) { func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) {
resp := InvoiceSlice{ var resp InvoiceSlice
InvoiceQuery: q,
}
err := kvdb.View(d, func(tx kvdb.RTx) error { err := kvdb.View(d, func(tx kvdb.RTx) error {
// If the bucket wasn't found, then there aren't any invoices // If the bucket wasn't found, then there aren't any invoices
@ -892,6 +885,10 @@ func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) {
} }
return nil return nil
}, func() {
resp = InvoiceSlice{
InvoiceQuery: q,
}
}) })
if err != nil && err != ErrNoInvoicesCreated { if err != nil && err != ErrNoInvoicesCreated {
return resp, err return resp, err
@ -1011,6 +1008,8 @@ func (d *DB) InvoicesSettledSince(sinceSettleIndex uint64) ([]Invoice, error) {
} }
return nil return nil
}, func() {
settledInvoices = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -220,9 +220,13 @@ func (db *db) getSTMOptions() []STMOptionFunc {
// View opens a database read transaction and executes the function f with the // View opens a database read transaction and executes the function f with the
// transaction passed as a parameter. After f exits, the transaction is rolled // transaction passed as a parameter. After f exits, the transaction is rolled
// back. If f errors, its error is returned, not a rollback error (if any // back. If f errors, its error is returned, not a rollback error (if any
// occur). // occur). The passed reset function is called before the start of the
func (db *db) View(f func(tx walletdb.ReadTx) error) error { // transaction and can be used to reset intermediate state. As callers may
// expect retries of the f closure (depending on the database backend used), the
// reset function will be called before each retry respectively.
func (db *db) View(f func(tx walletdb.ReadTx) error, reset func()) error {
apply := func(stm STM) error { apply := func(stm STM) error {
reset()
return f(newReadWriteTx(stm, db.config.Prefix)) return f(newReadWriteTx(stm, db.config.Prefix))
} }

@ -49,7 +49,7 @@ func TestReadCursorEmptyInterval(t *testing.T) {
require.Nil(t, v) require.Nil(t, v)
return nil return nil
}) }, func() {})
require.NoError(t, err) require.NoError(t, err)
} }
@ -125,7 +125,7 @@ func TestReadCursorNonEmptyInterval(t *testing.T) {
require.Nil(t, v) require.Nil(t, v)
return nil return nil
}) }, func() {})
require.NoError(t, err) require.NoError(t, err)
} }
@ -354,7 +354,7 @@ func TestReadWriteCursorWithBucketAndValue(t *testing.T) {
require.Equal(t, []byte("val"), v) require.Equal(t, []byte("val"), v)
return nil return nil
}) }, func() {})
require.NoError(t, err) require.NoError(t, err)

@ -21,12 +21,19 @@ func Update(db Backend, f func(tx RwTx) error) error {
// View opens a database read transaction and executes the function f with the // View opens a database read transaction and executes the function f with the
// transaction passed as a parameter. After f exits, the transaction is rolled // transaction passed as a parameter. After f exits, the transaction is rolled
// back. If f errors, its error is returned, not a rollback error (if any // back. If f errors, its error is returned, not a rollback error (if any
// occur). // occur). The passed reset function is called before the start of the
func View(db Backend, f func(tx RTx) error) error { // transaction and can be used to reset intermediate state. As callers may
// expect retries of the f closure (depending on the database backend used), the
// reset function will be called before each retry respectively.
func View(db Backend, f func(tx RTx) error, reset func()) error {
if extendedDB, ok := db.(ExtendedBackend); ok { if extendedDB, ok := db.(ExtendedBackend); ok {
return extendedDB.View(f) return extendedDB.View(f, reset)
} }
// Since we know that walletdb simply calls into bbolt which never
// retries transactions, we'll call the reset function here before View.
reset()
return walletdb.View(db, f) return walletdb.View(db, f)
} }
@ -55,11 +62,15 @@ type ExtendedBackend interface {
// PrintStats returns all collected stats pretty printed into a string. // PrintStats returns all collected stats pretty printed into a string.
PrintStats() string PrintStats() string
// View opens a database read transaction and executes the function f with // View opens a database read transaction and executes the function f
// the transaction passed as a parameter. After f exits, the transaction is // with the transaction passed as a parameter. After f exits, the
// rolled back. If f errors, its error is returned, not a rollback error // transaction is rolled back. If f errors, its error is returned, not a
// (if any occur). // rollback error (if any occur). The passed reset function is called
View(f func(tx walletdb.ReadTx) error) error // before the start of the transaction and can be used to reset
// intermediate state. As callers may expect retries of the f closure
// (depending on the database backend used), the reset function will be
//called before each retry respectively.
View(f func(tx walletdb.ReadTx) error, reset func()) error
// Update opens a database read/write transaction and executes the function // Update opens a database read/write transaction and executes the function
// f with the transaction passed as a parameter. After f exits, if f did not // f with the transaction passed as a parameter. After f exits, if f did not

@ -23,10 +23,12 @@ type Meta struct {
// FetchMeta fetches the meta data from boltdb and returns filled meta // FetchMeta fetches the meta data from boltdb and returns filled meta
// structure. // structure.
func (d *DB) FetchMeta(tx kvdb.RTx) (*Meta, error) { func (d *DB) FetchMeta(tx kvdb.RTx) (*Meta, error) {
meta := &Meta{} var meta *Meta
err := kvdb.View(d, func(tx kvdb.RTx) error { err := kvdb.View(d, func(tx kvdb.RTx) error {
return fetchMeta(meta, tx) return fetchMeta(meta, tx)
}, func() {
meta = &Meta{}
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -492,7 +492,7 @@ func TestMigrationDryRun(t *testing.T) {
} }
return nil return nil
}) }, func() {})
if err != nil { if err != nil {
t.Fatalf("unable to apply after func: %v", err) t.Fatalf("unable to apply after func: %v", err)
} }

@ -203,6 +203,8 @@ func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, erro
chanSummaries = append(chanSummaries, chanSummary) chanSummaries = append(chanSummaries, chanSummary)
return nil return nil
}) })
}, func() {
chanSummaries = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }

@ -190,6 +190,8 @@ func (c *ChannelGraph) SourceNode() (*LightningNode, error) {
source = node source = node
return nil return nil
}, func() {
source = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -282,6 +282,8 @@ func (d *DB) FetchAllInvoices(pendingOnly bool) ([]Invoice, error) {
return nil return nil
}) })
}, func() {
invoices = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -126,6 +126,8 @@ func (db *DB) fetchAllPayments() ([]*outgoingPayment, error) {
payments = append(payments, payment) payments = append(payments, payment)
return nil return nil
}) })
}, func() {
payments = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -144,6 +146,8 @@ func (db *DB) fetchPaymentStatus(paymentHash [32]byte) (PaymentStatus, error) {
var err error var err error
paymentStatus, err = fetchPaymentStatusTx(tx, paymentHash) paymentStatus, err = fetchPaymentStatusTx(tx, paymentHash)
return err return err
}, func() {
paymentStatus = StatusUnknown
}) })
if err != nil { if err != nil {
return StatusUnknown, err return StatusUnknown, err
@ -424,6 +428,8 @@ func (db *DB) fetchPaymentsMigration9() ([]*Payment, error) {
return nil return nil
}) })
}) })
}, func() {
payments = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -418,6 +418,8 @@ func TestMigrateOptionalChannelCloseSummaryFields(t *testing.T) {
"serialization") "serialization")
} }
return nil return nil
}, func() {
dbSummary = nil
}) })
if err != nil { if err != nil {
t.Fatalf("unable to view DB: %v", err) t.Fatalf("unable to view DB: %v", err)
@ -521,6 +523,8 @@ func TestMigrateGossipMessageStoreKeys(t *testing.T) {
} }
return nil return nil
}, func() {
rawMsg = nil
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

@ -303,6 +303,8 @@ func (db *DB) FetchPayments() ([]*Payment, error) {
return nil return nil
}) })
}) })
}, func() {
payments = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -158,6 +158,8 @@ func (db *DB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) {
linkNode = node linkNode = node
return nil return nil
}, func() {
linkNode = nil
}) })
return linkNode, err return linkNode, err
@ -199,6 +201,8 @@ func (db *DB) FetchAllLinkNodes() ([]*LinkNode, error) {
linkNodes = nodes linkNodes = nodes
return nil return nil
}, func() {
linkNodes = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -550,6 +550,8 @@ func (p *PaymentControl) FetchPayment(paymentHash lntypes.Hash) (
payment, err = fetchPayment(bucket) payment, err = fetchPayment(bucket)
return err return err
}, func() {
payment = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -716,6 +718,8 @@ func (p *PaymentControl) FetchInFlightPayments() ([]*InFlightPayment, error) {
inFlights = append(inFlights, inFlight) inFlights = append(inFlights, inFlight)
return nil return nil
}) })
}, func() {
inFlights = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -502,7 +502,7 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) {
indexCount++ indexCount++
return nil return nil
}) })
}) }, func() { indexCount = 0 })
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, indexCount) require.Equal(t, 1, indexCount)
@ -989,7 +989,8 @@ func fetchPaymentIndexEntry(_ *testing.T, p *PaymentControl,
var err error var err error
hash, err = deserializePaymentIndex(r) hash, err = deserializePaymentIndex(r)
return err return err
}, func() {
hash = lntypes.Hash{}
}); err != nil { }); err != nil {
return nil, err return nil, err
} }

@ -269,6 +269,8 @@ func (db *DB) FetchPayments() ([]*MPPayment, error) {
payments = append(payments, duplicatePayments...) payments = append(payments, duplicatePayments...)
return nil return nil
}) })
}, func() {
payments = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -572,6 +574,8 @@ func (db *DB) QueryPayments(query PaymentsQuery) (PaymentsResponse, error) {
} }
return nil return nil
}, func() {
resp = PaymentsResponse{}
}); err != nil { }); err != nil {
return resp, err return resp, err
} }

@ -113,6 +113,8 @@ func (d *DB) ReadFlapCount(pubkey route.Vertex) (*FlapCount, error) {
} }
return ReadElements(r, &flapCount.Count) return ReadElements(r, &flapCount.Count)
}, func() {
flapCount = FlapCount{}
}); err != nil { }); err != nil {
return nil, err return nil, err
} }

@ -250,6 +250,8 @@ func (d DB) FetchChannelReports(chainHash chainhash.Hash,
return nil return nil
}) })
}, func() {
reports = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }

@ -43,12 +43,13 @@ type WaitingProofStore struct {
func NewWaitingProofStore(db *DB) (*WaitingProofStore, error) { func NewWaitingProofStore(db *DB) (*WaitingProofStore, error) {
s := &WaitingProofStore{ s := &WaitingProofStore{
db: db, db: db,
cache: make(map[WaitingProofKey]struct{}),
} }
if err := s.ForAll(func(proof *WaitingProof) error { if err := s.ForAll(func(proof *WaitingProof) error {
s.cache[proof.Key()] = struct{}{} s.cache[proof.Key()] = struct{}{}
return nil return nil
}, func() {
s.cache = make(map[WaitingProofKey]struct{})
}); err != nil && err != ErrWaitingProofNotFound { }); err != nil && err != ErrWaitingProofNotFound {
return nil, err return nil, err
} }
@ -122,7 +123,9 @@ func (s *WaitingProofStore) Remove(key WaitingProofKey) error {
// ForAll iterates thought all waiting proofs and passing the waiting proof // ForAll iterates thought all waiting proofs and passing the waiting proof
// in the given callback. // in the given callback.
func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error) error { func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error,
reset func()) error {
return kvdb.View(s.db, func(tx kvdb.RTx) error { return kvdb.View(s.db, func(tx kvdb.RTx) error {
bucket := tx.ReadBucket(waitingProofsBucketKey) bucket := tx.ReadBucket(waitingProofsBucketKey)
if bucket == nil { if bucket == nil {
@ -144,12 +147,12 @@ func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error) error {
return cb(proof) return cb(proof)
}) })
}) }, reset)
} }
// Get returns the object which corresponds to the given index. // Get returns the object which corresponds to the given index.
func (s *WaitingProofStore) Get(key WaitingProofKey) (*WaitingProof, error) { func (s *WaitingProofStore) Get(key WaitingProofKey) (*WaitingProof, error) {
proof := &WaitingProof{} var proof *WaitingProof
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
@ -172,6 +175,8 @@ func (s *WaitingProofStore) Get(key WaitingProofKey) (*WaitingProof, error) {
r := bytes.NewReader(v) r := bytes.NewReader(v)
return proof.Decode(r) return proof.Decode(r)
}, func() {
proof = &WaitingProof{}
}) })
return proof, err return proof, err

@ -53,7 +53,7 @@ func TestWaitingProofStore(t *testing.T) {
if err := store.ForAll(func(proof *WaitingProof) error { if err := store.ForAll(func(proof *WaitingProof) error {
return errors.New("storage should be empty") return errors.New("storage should be empty")
}); err != nil && err != ErrWaitingProofNotFound { }, func() {}); err != nil && err != ErrWaitingProofNotFound {
t.Fatal(err) t.Fatal(err)
} }
} }

@ -174,6 +174,8 @@ func (w *WitnessCache) lookupWitness(wType WitnessType, witnessKey []byte) ([]by
copy(witness[:], dbWitness) copy(witness[:], dbWitness)
return nil return nil
}, func() {
witness = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -430,6 +430,8 @@ func (b *boltArbitratorLog) CurrentState() (ArbitratorState, error) {
s = ArbitratorState(stateBytes[0]) s = ArbitratorState(stateBytes[0])
return nil return nil
}, func() {
s = 0
}) })
if err != nil && err != errScopeBucketNoExist { if err != nil && err != errScopeBucketNoExist {
return s, err return s, err
@ -521,6 +523,8 @@ func (b *boltArbitratorLog) FetchUnresolvedContracts() ([]ContractResolver, erro
contracts = append(contracts, res) contracts = append(contracts, res)
return nil return nil
}) })
}, func() {
contracts = nil
}) })
if err != nil && err != errScopeBucketNoExist && err != errNoContracts { if err != nil && err != errScopeBucketNoExist && err != errNoContracts {
return nil, err return nil, err
@ -685,7 +689,7 @@ func (b *boltArbitratorLog) LogContractResolutions(c *ContractResolutions) error
// //
// NOTE: Part of the ContractResolver interface. // NOTE: Part of the ContractResolver interface.
func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, error) { func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, error) {
c := &ContractResolutions{} var c *ContractResolutions
err := kvdb.View(b.db, func(tx kvdb.RTx) error { err := kvdb.View(b.db, func(tx kvdb.RTx) error {
scopeBucket := tx.ReadBucket(b.scopeKey[:]) scopeBucket := tx.ReadBucket(b.scopeKey[:])
if scopeBucket == nil { if scopeBucket == nil {
@ -769,6 +773,8 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er
} }
return nil return nil
}, func() {
c = &ContractResolutions{}
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -783,7 +789,7 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er
// //
// NOTE: Part of the ContractResolver interface. // NOTE: Part of the ContractResolver interface.
func (b *boltArbitratorLog) FetchChainActions() (ChainActionMap, error) { func (b *boltArbitratorLog) FetchChainActions() (ChainActionMap, error) {
actionsMap := make(ChainActionMap) var actionsMap ChainActionMap
err := kvdb.View(b.db, func(tx kvdb.RTx) error { err := kvdb.View(b.db, func(tx kvdb.RTx) error {
scopeBucket := tx.ReadBucket(b.scopeKey[:]) scopeBucket := tx.ReadBucket(b.scopeKey[:])
@ -813,6 +819,8 @@ func (b *boltArbitratorLog) FetchChainActions() (ChainActionMap, error) {
return nil return nil
}) })
}, func() {
actionsMap = make(ChainActionMap)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -866,6 +874,8 @@ func (b *boltArbitratorLog) FetchConfirmedCommitSet() (*CommitSet, error) {
c = commitSet c = commitSet
return nil return nil
}, func() {
c = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -1117,6 +1117,9 @@ func TestSignatureAnnouncementLocalFirst(t *testing.T) {
number++ number++
return nil return nil
}, },
func() {
number = 0
},
); err != nil { ); err != nil {
t.Fatalf("unable to retrieve objects from store: %v", err) t.Fatalf("unable to retrieve objects from store: %v", err)
} }
@ -1150,6 +1153,9 @@ func TestSignatureAnnouncementLocalFirst(t *testing.T) {
number++ number++
return nil return nil
}, },
func() {
number = 0
},
); err != nil && err != channeldb.ErrWaitingProofNotFound { ); err != nil && err != channeldb.ErrWaitingProofNotFound {
t.Fatalf("unable to retrieve objects from store: %v", err) t.Fatalf("unable to retrieve objects from store: %v", err)
} }
@ -1219,6 +1225,9 @@ func TestOrphanSignatureAnnouncement(t *testing.T) {
number++ number++
return nil return nil
}, },
func() {
number = 0
},
); err != nil { ); err != nil {
t.Fatalf("unable to retrieve objects from store: %v", err) t.Fatalf("unable to retrieve objects from store: %v", err)
} }
@ -1355,6 +1364,9 @@ func TestOrphanSignatureAnnouncement(t *testing.T) {
number++ number++
return nil return nil
}, },
func() {
number = 0
},
); err != nil { ); err != nil {
t.Fatalf("unable to retrieve objects from store: %v", err) t.Fatalf("unable to retrieve objects from store: %v", err)
} }
@ -1466,6 +1478,9 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) {
number++ number++
return nil return nil
}, },
func() {
number = 0
},
); err != nil { ); err != nil {
t.Fatalf("unable to retrieve objects from store: %v", err) t.Fatalf("unable to retrieve objects from store: %v", err)
} }
@ -1570,6 +1585,9 @@ out:
number++ number++
return nil return nil
}, },
func() {
number = 0
},
); err != nil && err != channeldb.ErrWaitingProofNotFound { ); err != nil && err != channeldb.ErrWaitingProofNotFound {
t.Fatalf("unable to retrieve objects from store: %v", err) t.Fatalf("unable to retrieve objects from store: %v", err)
} }
@ -1754,6 +1772,9 @@ func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) {
number++ number++
return nil return nil
}, },
func() {
number = 0
},
); err != nil && err != channeldb.ErrWaitingProofNotFound { ); err != nil && err != channeldb.ErrWaitingProofNotFound {
t.Fatalf("unable to retrieve objects from store: %v", err) t.Fatalf("unable to retrieve objects from store: %v", err)
} }
@ -2583,6 +2604,9 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) {
number++ number++
return nil return nil
}, },
func() {
number = 0
},
); err != nil { ); err != nil {
t.Fatalf("unable to retrieve objects from store: %v", err) t.Fatalf("unable to retrieve objects from store: %v", err)
} }
@ -2612,6 +2636,9 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) {
number++ number++
return nil return nil
}, },
func() {
number = 0
},
); err != nil && err != channeldb.ErrWaitingProofNotFound { ); err != nil && err != channeldb.ErrWaitingProofNotFound {
t.Fatalf("unable to retrieve objects from store: %v", err) t.Fatalf("unable to retrieve objects from store: %v", err)
} }

@ -199,7 +199,7 @@ func readMessage(msgBytes []byte) (lnwire.Message, error) {
// Messages returns the total set of messages that exist within the store for // Messages returns the total set of messages that exist within the store for
// all peers. // all peers.
func (s *MessageStore) Messages() (map[[33]byte][]lnwire.Message, error) { func (s *MessageStore) Messages() (map[[33]byte][]lnwire.Message, error) {
msgs := make(map[[33]byte][]lnwire.Message) var msgs map[[33]byte][]lnwire.Message
err := kvdb.View(s.db, func(tx kvdb.RTx) error { err := kvdb.View(s.db, func(tx kvdb.RTx) error {
messageStore := tx.ReadBucket(messageStoreBucket) messageStore := tx.ReadBucket(messageStoreBucket)
if messageStore == nil { if messageStore == nil {
@ -224,6 +224,8 @@ func (s *MessageStore) Messages() (map[[33]byte][]lnwire.Message, error) {
msgs[pubKey] = append(msgs[pubKey], msg) msgs[pubKey] = append(msgs[pubKey], msg)
return nil return nil
}) })
}, func() {
msgs = make(map[[33]byte][]lnwire.Message)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -262,6 +264,8 @@ func (s *MessageStore) MessagesForPeer(
} }
return nil return nil
}, func() {
msgs = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -272,7 +276,7 @@ func (s *MessageStore) MessagesForPeer(
// Peers returns the public key of all peers with messages within the store. // Peers returns the public key of all peers with messages within the store.
func (s *MessageStore) Peers() (map[[33]byte]struct{}, error) { func (s *MessageStore) Peers() (map[[33]byte]struct{}, error) {
peers := make(map[[33]byte]struct{}) var peers map[[33]byte]struct{}
err := kvdb.View(s.db, func(tx kvdb.RTx) error { err := kvdb.View(s.db, func(tx kvdb.RTx) error {
messageStore := tx.ReadBucket(messageStoreBucket) messageStore := tx.ReadBucket(messageStoreBucket)
if messageStore == nil { if messageStore == nil {
@ -285,6 +289,8 @@ func (s *MessageStore) Peers() (map[[33]byte]struct{}, error) {
peers[pubKey] = struct{}{} peers[pubKey] = struct{}{}
return nil return nil
}) })
}, func() {
peers = make(map[[33]byte]struct{})
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -3538,7 +3538,7 @@ func (f *fundingManager) getChannelOpeningState(chanPoint *wire.OutPoint) (
state = channelOpeningState(byteOrder.Uint16(value[:2])) state = channelOpeningState(byteOrder.Uint16(value[:2]))
shortChanID = lnwire.NewShortChanIDFromInt(byteOrder.Uint64(value[2:])) shortChanID = lnwire.NewShortChanIDFromInt(byteOrder.Uint64(value[2:]))
return nil return nil
}) }, func() {})
if err != nil { if err != nil {
return 0, nil, err return 0, nil, err
} }

@ -280,6 +280,8 @@ func (d *DecayedLog) Get(hash *sphinx.HashPrefix) (uint32, error) {
value = uint32(binary.BigEndian.Uint32(valueBytes)) value = uint32(binary.BigEndian.Uint32(valueBytes))
return nil return nil
}, func() {
value = 0
}) })
if err != nil { if err != nil {
return value, err return value, err

@ -197,6 +197,8 @@ func (store *networkResultStore) subscribeResult(paymentID uint64) (
default: default:
return nil return nil
} }
}, func() {
result = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -230,6 +232,8 @@ func (store *networkResultStore) getResult(pid uint64) (
var err error var err error
result, err = fetchResult(tx, pid) result, err = fetchResult(tx, pid)
return err return err
}, func() {
result = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -1833,6 +1833,8 @@ func (s *Switch) loadChannelFwdPkgs(source lnwire.ShortChannelID) ([]*channeldb.
tx, source, tx, source,
) )
return err return err
}, func() {
fwdPkgs = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }

@ -150,6 +150,8 @@ func (r *RootKeyStorage) Get(_ context.Context, id []byte) ([]byte, error) {
rootKey = make([]byte, len(decKey)) rootKey = make([]byte, len(decKey))
copy(rootKey[:], decKey) copy(rootKey[:], decKey)
return nil return nil
}, func() {
rootKey = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -257,6 +259,8 @@ func (r *RootKeyStorage) ListMacaroonIDs(_ context.Context) ([][]byte, error) {
} }
return tx.ReadBucket(rootKeyBucketName).ForEach(appendRootKey) return tx.ReadBucket(rootKeyBucketName).ForEach(appendRootKey)
}, func() {
rootKeySlice = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -129,7 +129,7 @@ type NurseryStore interface {
// the caller to process each key-value pair. The key will be a prefixed // the caller to process each key-value pair. The key will be a prefixed
// outpoint, and the value will be the serialized bytes for an output, // outpoint, and the value will be the serialized bytes for an output,
// whose type should be inferred from the key's prefix. // whose type should be inferred from the key's prefix.
ForChanOutputs(*wire.OutPoint, func([]byte, []byte) error) error ForChanOutputs(*wire.OutPoint, func([]byte, []byte) error, func()) error
// ListChannels returns all channels the nursery is currently tracking. // ListChannels returns all channels the nursery is currently tracking.
ListChannels() ([]wire.OutPoint, error) ListChannels() ([]wire.OutPoint, error)
@ -582,6 +582,9 @@ func (ns *nurseryStore) FetchClass(
}) })
}, func() {
kids = nil
babies = nil
}); err != nil { }); err != nil {
return nil, nil, err return nil, nil, err
} }
@ -655,6 +658,8 @@ func (ns *nurseryStore) FetchPreschools() ([]kidOutput, error) {
} }
return nil return nil
}, func() {
kids = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
@ -693,6 +698,8 @@ func (ns *nurseryStore) HeightsBelowOrEqual(height uint32) ([]uint32, error) {
} }
return nil return nil
}, func() {
activeHeights = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -709,11 +716,11 @@ func (ns *nurseryStore) HeightsBelowOrEqual(height uint32) ([]uint32, error) {
// NOTE: The callback should not modify the provided byte slices and is // NOTE: The callback should not modify the provided byte slices and is
// preferably non-blocking. // preferably non-blocking.
func (ns *nurseryStore) ForChanOutputs(chanPoint *wire.OutPoint, func (ns *nurseryStore) ForChanOutputs(chanPoint *wire.OutPoint,
callback func([]byte, []byte) error) error { callback func([]byte, []byte) error, reset func()) error {
return kvdb.View(ns.db, func(tx kvdb.RTx) error { return kvdb.View(ns.db, func(tx kvdb.RTx) error {
return ns.forChanOutputs(tx, chanPoint, callback) return ns.forChanOutputs(tx, chanPoint, callback)
}) }, reset)
} }
// ListChannels returns all channels the nursery is currently tracking. // ListChannels returns all channels the nursery is currently tracking.
@ -743,6 +750,8 @@ func (ns *nurseryStore) ListChannels() ([]wire.OutPoint, error) {
return nil return nil
}) })
}, func() {
activeChannels = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
@ -765,7 +774,7 @@ func (ns *nurseryStore) IsMatureChannel(chanPoint *wire.OutPoint) (bool, error)
return nil return nil
}) })
}) }, func() {})
if err != nil && err != ErrImmatureChannel { if err != nil && err != ErrImmatureChannel {
return false, err return false, err
} }

@ -370,6 +370,8 @@ func assertNumChanOutputs(t *testing.T, ns NurseryStore,
err := ns.ForChanOutputs(chanPoint, func([]byte, []byte) error { err := ns.ForChanOutputs(chanPoint, func([]byte, []byte) error {
count++ count++
return nil return nil
}, func() {
count = 0
}) })
if count == 0 && err == ErrContractNotFound { if count == 0 && err == ErrContractNotFound {

@ -103,6 +103,8 @@ func (b *missionControlStore) fetchAll() ([]*paymentResult, error) {
return nil return nil
}) })
}, func() {
results = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -219,6 +219,8 @@ func (s *sweeperStore) GetLastPublishedTx() (*wire.MsgTx, error) {
} }
return nil return nil
}, func() {
sweepTx = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -241,6 +243,8 @@ func (s *sweeperStore) IsOurTx(hash chainhash.Hash) (bool, error) {
ours = txHashesBucket.Get(hash[:]) != nil ours = txHashesBucket.Get(hash[:]) != nil
return nil return nil
}, func() {
ours = false
}) })
if err != nil { if err != nil {
return false, err return false, err
@ -269,6 +273,8 @@ func (s *sweeperStore) ListSweeps() ([]chainhash.Hash, error) {
return nil return nil
}) })
}, func() {
sweepTxns = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }

@ -477,7 +477,7 @@ func (u *utxoNursery) NurseryReport(
utxnLog.Debugf("NurseryReport: building nursery report for channel %v", utxnLog.Debugf("NurseryReport: building nursery report for channel %v",
chanPoint) chanPoint)
report := &contractMaturityReport{} var report *contractMaturityReport
if err := u.cfg.Store.ForChanOutputs(chanPoint, func(k, v []byte) error { if err := u.cfg.Store.ForChanOutputs(chanPoint, func(k, v []byte) error {
switch { switch {
@ -576,6 +576,8 @@ func (u *utxoNursery) NurseryReport(
} }
return nil return nil
}, func() {
report = &contractMaturityReport{}
}); err != nil { }); err != nil {
return nil, err return nil, err
} }

@ -931,9 +931,9 @@ func (i *nurseryStoreInterceptor) HeightsBelowOrEqual(height uint32) (
} }
func (i *nurseryStoreInterceptor) ForChanOutputs(chanPoint *wire.OutPoint, func (i *nurseryStoreInterceptor) ForChanOutputs(chanPoint *wire.OutPoint,
callback func([]byte, []byte) error) error { callback func([]byte, []byte) error, reset func()) error {
return i.ns.ForChanOutputs(chanPoint, callback) return i.ns.ForChanOutputs(chanPoint, callback, reset)
} }
func (i *nurseryStoreInterceptor) ListChannels() ([]wire.OutPoint, error) { func (i *nurseryStoreInterceptor) ListChannels() ([]wire.OutPoint, error) {

@ -192,6 +192,8 @@ func (c *ClientDB) Version() (uint32, error) {
var err error var err error
version, err = getDBVersion(tx) version, err = getDBVersion(tx)
return err return err
}, func() {
version = 0
}) })
if err != nil { if err != nil {
return 0, err return 0, err
@ -392,6 +394,8 @@ func (c *ClientDB) LoadTowerByID(towerID TowerID) (*Tower, error) {
var err error var err error
tower, err = getTower(towers, towerID.Bytes()) tower, err = getTower(towers, towerID.Bytes())
return err return err
}, func() {
tower = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -421,6 +425,8 @@ func (c *ClientDB) LoadTower(pubKey *btcec.PublicKey) (*Tower, error) {
var err error var err error
tower, err = getTower(towers, towerIDBytes) tower, err = getTower(towers, towerIDBytes)
return err return err
}, func() {
tower = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -446,6 +452,8 @@ func (c *ClientDB) ListTowers() ([]*Tower, error) {
towers = append(towers, tower) towers = append(towers, tower)
return nil return nil
}) })
}, func() {
towers = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -566,6 +574,8 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (map[SessionID]*ClientSession
var err error var err error
clientSessions, err = listClientSessions(sessions, id) clientSessions, err = listClientSessions(sessions, id)
return err return err
}, func() {
clientSessions = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -611,7 +621,7 @@ func listClientSessions(sessions kvdb.RBucket,
// FetchChanSummaries loads a mapping from all registered channels to their // FetchChanSummaries loads a mapping from all registered channels to their
// channel summaries. // channel summaries.
func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) { func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
summaries := make(map[lnwire.ChannelID]ClientChanSummary) var summaries map[lnwire.ChannelID]ClientChanSummary
err := kvdb.View(c.db, func(tx kvdb.RTx) error { err := kvdb.View(c.db, func(tx kvdb.RTx) error {
chanSummaries := tx.ReadBucket(cChanSummaryBkt) chanSummaries := tx.ReadBucket(cChanSummaryBkt)
if chanSummaries == nil { if chanSummaries == nil {
@ -632,6 +642,8 @@ func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
return nil return nil
}) })
}, func() {
summaries = make(map[lnwire.ChannelID]ClientChanSummary)
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -80,6 +80,8 @@ func createDBIfNotExist(dbPath, name string) (kvdb.Backend, bool, error) {
err = kvdb.View(bdb, func(tx kvdb.RTx) error { err = kvdb.View(bdb, func(tx kvdb.RTx) error {
metadataExists = tx.ReadBucket(metadataBkt) != nil metadataExists = tx.ReadBucket(metadataBkt) != nil
return nil return nil
}, func() {
metadataExists = false
}) })
if err != nil { if err != nil {
return nil, false, err return nil, false, err

@ -133,6 +133,8 @@ func (t *TowerDB) Version() (uint32, error) {
var err error var err error
version, err = getDBVersion(tx) version, err = getDBVersion(tx)
return err return err
}, func() {
version = 0
}) })
if err != nil { if err != nil {
return 0, err return 0, err
@ -159,6 +161,8 @@ func (t *TowerDB) GetSessionInfo(id *SessionID) (*SessionInfo, error) {
var err error var err error
session, err = getSession(sessions, id[:]) session, err = getSession(sessions, id[:])
return err return err
}, func() {
session = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -460,6 +464,8 @@ func (t *TowerDB) QueryMatches(breachHints []blob.BreachHint) ([]Match, error) {
} }
return nil return nil
}, func() {
matches = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -494,6 +500,8 @@ func (t *TowerDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) {
epoch = getLookoutEpoch(lookoutTip) epoch = getLookoutEpoch(lookoutTip)
return nil return nil
}, func() {
epoch = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err