multi: add reset closure to kvdb.View

This commit adds a reset() closure to the kvdb.View function which will
be called before each retry (including the first) of the view
transaction. The reset() closure can be used to reset external state
(eg slices or maps) where the view closure puts intermediate results.
This commit is contained in:
Andras Banki-Horvath 2020-10-20 16:18:40 +02:00
parent ffb27284df
commit 2a358327f4
No known key found for this signature in database
GPG Key ID: 80E5375C094198D8
47 changed files with 340 additions and 82 deletions

@ -152,10 +152,12 @@ func (b *breachArbiter) start() error {
brarLog.Tracef("Starting breach arbiter")
// Load all retributions currently persisted in the retribution store.
breachRetInfos := make(map[wire.OutPoint]retributionInfo)
var breachRetInfos map[wire.OutPoint]retributionInfo
if err := b.cfg.Store.ForAll(func(ret *retributionInfo) error {
breachRetInfos[ret.chanPoint] = *ret
return nil
}, func() {
breachRetInfos = make(map[wire.OutPoint]retributionInfo)
}); err != nil {
return err
}
@ -1223,7 +1225,7 @@ type RetributionStore interface {
// ForAll iterates over the existing on-disk contents and applies a
// chosen, read-only callback to each. This method should ensure that it
// immediately propagate any errors generated by the callback.
ForAll(cb func(*retributionInfo) error) error
ForAll(cb func(*retributionInfo) error, reset func()) error
}
// retributionStore handles persistence of retribution states to disk and is
@ -1312,6 +1314,8 @@ func (rs *retributionStore) GetFinalizedTxn(
finalTxBytes = justiceBkt.Get(chanBuf.Bytes())
return nil
}, func() {
finalTxBytes = nil
}); err != nil {
return nil, err
}
@ -1349,6 +1353,8 @@ func (rs *retributionStore) IsBreached(chanPoint *wire.OutPoint) (bool, error) {
}
return nil
}, func() {
found = false
})
return found, err
@ -1395,7 +1401,9 @@ func (rs *retributionStore) Remove(chanPoint *wire.OutPoint) error {
// ForAll iterates through all stored retributions and executes the passed
// callback function on each retribution.
func (rs *retributionStore) ForAll(cb func(*retributionInfo) error) error {
func (rs *retributionStore) ForAll(cb func(*retributionInfo) error,
reset func()) error {
return kvdb.View(rs.db, func(tx kvdb.RTx) error {
// If the bucket does not exist, then there are no pending
// retributions.
@ -1416,7 +1424,7 @@ func (rs *retributionStore) ForAll(cb func(*retributionInfo) error) error {
return cb(ret)
})
})
}, reset)
}
// Encode serializes the retribution into the passed byte stream.

@ -431,11 +431,13 @@ func (frs *failingRetributionStore) Remove(key *wire.OutPoint) error {
return frs.rs.Remove(key)
}
func (frs *failingRetributionStore) ForAll(cb func(*retributionInfo) error) error {
func (frs *failingRetributionStore) ForAll(cb func(*retributionInfo) error,
reset func()) error {
frs.mu.Lock()
defer frs.mu.Unlock()
return frs.rs.ForAll(cb)
return frs.rs.ForAll(cb, reset)
}
// Parse the pubkeys in the breached outputs.
@ -592,10 +594,13 @@ func (rs *mockRetributionStore) Remove(key *wire.OutPoint) error {
return nil
}
func (rs *mockRetributionStore) ForAll(cb func(*retributionInfo) error) error {
func (rs *mockRetributionStore) ForAll(cb func(*retributionInfo) error,
reset func()) error {
rs.mu.Lock()
defer rs.mu.Unlock()
reset()
for _, retInfo := range rs.state {
if err := cb(copyRetInfo(retInfo)); err != nil {
return err
@ -717,6 +722,8 @@ func countRetributions(t *testing.T, rs RetributionStore) int {
err := rs.ForAll(func(_ *retributionInfo) error {
count++
return nil
}, func() {
count = 0
})
if err != nil {
t.Fatalf("unable to list retributions in db: %v", err)
@ -919,7 +926,7 @@ restartCheck:
// Construct a set of all channel points presented by the store. Entries
// are only be added to the set if their corresponding retribution
// information matches the test vector.
var foundSet = make(map[wire.OutPoint]struct{})
var foundSet map[wire.OutPoint]struct{}
// Iterate through the stored retributions, checking to see if we have
// an equivalent retribution in the test vector. This will return an
@ -948,6 +955,8 @@ restartCheck:
}
return nil
}, func() {
foundSet = make(map[wire.OutPoint]struct{})
}); err != nil {
t.Fatalf("failed to iterate over persistent retributions: %v",
err)

@ -179,6 +179,8 @@ func (c *HeightHintCache) QuerySpendHint(spendRequest SpendRequest) (uint32, err
}
return channeldb.ReadElement(bytes.NewReader(spendHint), &hint)
}, func() {
hint = 0
})
if err != nil {
return 0, err
@ -278,6 +280,8 @@ func (c *HeightHintCache) QueryConfirmHint(confRequest ConfRequest) (uint32, err
}
return channeldb.ReadElement(bytes.NewReader(confirmHint), &hint)
}, func() {
hint = 0
})
if err != nil {
return 0, err

@ -756,7 +756,7 @@ func (c *OpenChannel) RefreshShortChanID() error {
}
return nil
})
}, func() {})
if err != nil {
return err
}
@ -950,6 +950,8 @@ func (c *OpenChannel) DataLossCommitPoint() (*btcec.PublicKey, error) {
}
return nil
}, func() {
commitPoint = nil
})
if err != nil {
return nil, err
@ -1168,6 +1170,8 @@ func (c *OpenChannel) getClosingTx(key []byte) (*wire.MsgTx, error) {
}
r := bytes.NewReader(bs)
return ReadElement(r, &closeTx)
}, func() {
closeTx = nil
})
if err != nil {
return nil, err
@ -2062,6 +2066,8 @@ func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) {
cd = dcd
return nil
}, func() {
cd = nil
})
if err != nil {
return nil, err
@ -2094,6 +2100,8 @@ func (c *OpenChannel) UnsignedAckedUpdates() ([]LogUpdate, error) {
r := bytes.NewReader(updateBytes)
updates, err = deserializeLogUpdates(r)
return err
}, func() {
updates = nil
})
if err != nil {
return nil, err
@ -2127,6 +2135,8 @@ func (c *OpenChannel) RemoteUnsignedLocalUpdates() ([]LogUpdate, error) {
r := bytes.NewReader(updateBytes)
updates, err = deserializeLogUpdates(r)
return err
}, func() {
updates = nil
})
if err != nil {
return nil, err
@ -2365,6 +2375,8 @@ func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) {
var err error
fwdPkgs, err = c.Packager.LoadFwdPkgs(tx)
return err
}, func() {
fwdPkgs = nil
}); err != nil {
return nil, err
}
@ -2475,7 +2487,7 @@ func (c *OpenChannel) RevocationLogTail() (*ChannelCommitment, error) {
}
return nil
}); err != nil {
}, func() {}); err != nil {
return nil, err
}
@ -2509,6 +2521,8 @@ func (c *OpenChannel) CommitmentHeight() (uint64, error) {
height = commit.CommitHeight
return nil
}, func() {
height = 0
})
if err != nil {
return 0, err
@ -2547,7 +2561,7 @@ func (c *OpenChannel) FindPreviousState(updateNum uint64) (*ChannelCommitment, e
commit = c
return nil
})
}, func() {})
if err != nil {
return nil, err
}
@ -2870,7 +2884,7 @@ func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitmen
}
return fetchChanCommitments(chanBucket, c)
})
}, func() {})
if err != nil {
return nil, nil, err
}
@ -2892,7 +2906,7 @@ func (c *OpenChannel) RemoteRevocationStore() (shachain.Store, error) {
}
return fetchChanRevocationState(chanBucket, c)
})
}, func() {})
if err != nil {
return nil, err
}

@ -201,12 +201,16 @@ func (db *DB) Update(f func(tx walletdb.ReadWriteTx) error) error {
// View is a wrapper around walletdb.View which calls into the extended
// backend when available. This call is needed to be able to cast DB to
// ExtendedBackend.
func (db *DB) View(f func(tx walletdb.ReadTx) error) error {
// ExtendedBackend. The passed reset function is called before the start of the
// transaction and can be used to reset intermediate state. As callers may
// expect retries of the f closure (depending on the database backend used), the
// reset function will be called before each retry respectively.
func (db *DB) View(f func(tx walletdb.ReadTx) error, reset func()) error {
if v, ok := db.Backend.(kvdb.ExtendedBackend); ok {
return v.View(f)
return v.View(f, reset)
}
reset()
return walletdb.View(db, f)
}
@ -389,6 +393,8 @@ func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error)
var err error
channels, err = d.fetchOpenChannels(tx, nodeID)
return err
}, func() {
channels = nil
})
return channels, err
@ -574,7 +580,7 @@ func (d *DB) FetchChannel(chanPoint wire.OutPoint) (*OpenChannel, error) {
})
}
err := kvdb.View(d, chanScan)
err := kvdb.View(d, chanScan, func() {})
if err != nil {
return nil, err
}
@ -741,6 +747,8 @@ func fetchChannels(d *DB, filters ...fetchChannelsFilter) ([]*OpenChannel, error
})
})
}, func() {
channels = nil
})
if err != nil {
return nil, err
@ -781,6 +789,8 @@ func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, erro
chanSummaries = append(chanSummaries, chanSummary)
return nil
})
}, func() {
chanSummaries = nil
}); err != nil {
return nil, err
}
@ -817,6 +827,8 @@ func (d *DB) FetchClosedChannel(chanID *wire.OutPoint) (*ChannelCloseSummary, er
chanSummary, err = deserializeCloseChannelSummary(summaryReader)
return err
}, func() {
chanSummary = nil
}); err != nil {
return nil, err
}
@ -865,6 +877,8 @@ func (d *DB) FetchClosedChannelForID(cid lnwire.ChannelID) (
return nil
}
return ErrClosedChannelNotFound
}, func() {
chanSummary = nil
}); err != nil {
return nil, err
}
@ -1052,6 +1066,8 @@ func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) {
}
return nil
}, func() {
linkNode = nil
})
if dbErr != nil {
return nil, dbErr
@ -1261,6 +1277,8 @@ func (db *DB) FetchHistoricalChannel(outPoint *wire.OutPoint) (*OpenChannel, err
channel, err = fetchOpenChannel(chanBucket, outPoint)
return err
}, func() {
channel = nil
})
if err != nil {
return nil, err

@ -228,9 +228,7 @@ type ForwardingLogTimeSlice struct {
//
// TODO(roasbeef): rename?
func (f *ForwardingLog) Query(q ForwardingEventQuery) (ForwardingLogTimeSlice, error) {
resp := ForwardingLogTimeSlice{
ForwardingEventQuery: q,
}
var resp ForwardingLogTimeSlice
// If the user provided an index offset, then we'll not know how many
// records we need to skip. We'll also keep track of the record offset
@ -297,6 +295,10 @@ func (f *ForwardingLog) Query(q ForwardingEventQuery) (ForwardingLogTimeSlice, e
}
return nil
}, func() {
resp = ForwardingLogTimeSlice{
ForwardingEventQuery: q,
}
})
if err != nil && err != ErrNoForwardingEvents {
return ForwardingLogTimeSlice{}, err

@ -786,6 +786,8 @@ func loadFwdPkgs(t *testing.T, db kvdb.Backend,
var err error
fwdPkgs, err = packager.LoadFwdPkgs(tx)
return err
}, func() {
fwdPkgs = nil
}); err != nil {
t.Fatalf("unable to load fwd pkgs: %v", err)
}

@ -249,7 +249,7 @@ func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, *ChannelEdgePoli
// be aborted.
return cb(&edgeInfo, edge1, edge2)
})
})
}, func() {})
}
// ForEachNodeChannel iterates through all channels of a given node, executing the
@ -279,7 +279,7 @@ func (c *ChannelGraph) ForEachNodeChannel(tx kvdb.RTx, nodePub []byte,
// have their disabled bit on.
func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) {
var disabledChanIDs []uint64
chanEdgeFound := make(map[uint64]struct{})
var chanEdgeFound map[uint64]struct{}
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
edges := tx.ReadBucket(edgeBucket)
@ -308,6 +308,9 @@ func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) {
chanEdgeFound[chanID] = struct{}{}
return nil
})
}, func() {
disabledChanIDs = nil
chanEdgeFound = make(map[uint64]struct{})
})
if err != nil {
return nil, err
@ -353,7 +356,7 @@ func (c *ChannelGraph) ForEachNode(cb func(kvdb.RTx, *LightningNode) error) erro
})
}
return kvdb.View(c.db, traversal)
return kvdb.View(c.db, traversal, func() {})
}
// SourceNode returns the source node of the graph. The source node is treated
@ -377,6 +380,8 @@ func (c *ChannelGraph) SourceNode() (*LightningNode, error) {
source = node
return nil
}, func() {
source = nil
})
if err != nil {
return nil, err
@ -493,6 +498,8 @@ func (c *ChannelGraph) LookupAlias(pub *btcec.PublicKey) (string, error) {
// package...
alias = string(a)
return nil
}, func() {
alias = ""
})
if err != nil {
return "", err
@ -774,7 +781,7 @@ func (c *ChannelGraph) HasChannelEdge(
}
return nil
}); err != nil {
}, func() {}); err != nil {
return time.Time{}, time.Time{}, exists, isZombie, err
}
@ -1235,7 +1242,7 @@ func (c *ChannelGraph) PruneTip() (*chainhash.Hash, uint32, error) {
tipHeight = byteOrder.Uint32(k[:])
return nil
})
}, func() {})
if err != nil {
return nil, 0, err
}
@ -1312,6 +1319,8 @@ func (c *ChannelGraph) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
var err error
chanID, err = getChanID(tx, chanPoint)
return err
}, func() {
chanID = 0
}); err != nil {
return 0, err
}
@ -1379,6 +1388,8 @@ func (c *ChannelGraph) HighestChanID() (uint64, error) {
// to the caller.
cid = byteOrder.Uint64(lastChanID)
return nil
}, func() {
cid = 0
})
if err != nil && err != ErrGraphNoEdgesFound {
return 0, err
@ -1409,8 +1420,8 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]Cha
// To ensure we don't return duplicate ChannelEdges, we'll use an
// additional map to keep track of the edges already seen to prevent
// re-adding it.
edgesSeen := make(map[uint64]struct{})
edgesToCache := make(map[uint64]ChannelEdge)
var edgesSeen map[uint64]struct{}
var edgesToCache map[uint64]ChannelEdge
var edgesInHorizon []ChannelEdge
c.cacheMu.Lock()
@ -1507,6 +1518,10 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]Cha
}
return nil
}, func() {
edgesSeen = make(map[uint64]struct{})
edgesToCache = make(map[uint64]ChannelEdge)
edgesInHorizon = nil
})
switch {
case err == ErrGraphNoEdgesFound:
@ -1577,6 +1592,8 @@ func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, endTime time.Time) ([]Lig
}
return nil
}, func() {
nodesInHorizon = nil
})
switch {
case err == ErrGraphNoEdgesFound:
@ -1637,6 +1654,8 @@ func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) {
}
return nil
}, func() {
newChanIDs = nil
})
switch {
// If we don't know of any edges yet, then we'll return the entire set
@ -1701,7 +1720,10 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, endHeight uint32) ([]uint
}
return nil
}, func() {
chanIDs = nil
})
switch {
// If we don't know of any channels yet, then there's nothing to
// filter, so we'll return an empty slice.
@ -1775,6 +1797,8 @@ func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
})
}
return nil
}, func() {
chanEdges = nil
})
if err != nil {
return nil, err
@ -2209,7 +2233,7 @@ func (c *ChannelGraph) FetchLightningNode(tx kvdb.RTx, nodePub route.Vertex) (
var err error
if tx == nil {
err = kvdb.View(c.db, fetchNode)
err = kvdb.View(c.db, fetchNode, func() {})
} else {
err = fetchNode(tx)
}
@ -2259,6 +2283,9 @@ func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, erro
exists = true
updateTime = node.LastUpdate
return nil
}, func() {
updateTime = time.Time{}
exists = false
})
if err != nil {
return time.Time{}, exists, err
@ -2346,7 +2373,7 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db *DB,
// If no transaction was provided, then we'll create a new transaction
// to execute the transaction within.
if tx == nil {
return kvdb.View(db, traversal)
return kvdb.View(db, traversal, func() {})
}
// Otherwise, we re-use the existing transaction to execute the graph
@ -2596,7 +2623,7 @@ func (c *ChannelEdgeInfo) FetchOtherNode(tx kvdb.RTx, thisNodeKey []byte) (*Ligh
// otherwise we can use the existing db transaction.
var err error
if tx == nil {
err = kvdb.View(c.db, fetchNodeFunc)
err = kvdb.View(c.db, fetchNodeFunc, func() { targetNode = nil })
} else {
err = fetchNodeFunc(tx)
}
@ -2929,6 +2956,10 @@ func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint,
policy1 = e1
policy2 = e2
return nil
}, func() {
edgeInfo = nil
policy1 = nil
policy2 = nil
})
if err != nil {
return nil, nil, nil, err
@ -3030,6 +3061,10 @@ func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64,
policy1 = e1
policy2 = e2
return nil
}, func() {
edgeInfo = nil
policy1 = nil
policy2 = nil
})
if err == ErrZombieEdge {
return edgeInfo, nil, nil, err
@ -3062,6 +3097,8 @@ func (c *ChannelGraph) IsPublicNode(pubKey [33]byte) (bool, error) {
nodeIsPublic, err = node.isPublic(tx, ourPubKey)
return err
}, func() {
nodeIsPublic = false
})
if err != nil {
return false, err
@ -3183,6 +3220,8 @@ func (c *ChannelGraph) ChannelView() ([]EdgePoint, error) {
return nil
})
}, func() {
edgePoints = nil
}); err != nil {
return nil, err
}
@ -3261,6 +3300,10 @@ func (c *ChannelGraph) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte) {
isZombie, pubKey1, pubKey2 = isZombieEdge(zombieIndex, chanID)
return nil
}, func() {
isZombie = false
pubKey1 = [33]byte{}
pubKey2 = [33]byte{}
})
if err != nil {
return false, [33]byte{}, [33]byte{}
@ -3307,6 +3350,8 @@ func (c *ChannelGraph) NumZombies() (uint64, error) {
numZombies++
return nil
})
}, func() {
numZombies = 0
})
if err != nil {
return 0, err

@ -2272,7 +2272,7 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) {
return nil
})
})
}, func() {})
if err != nil {
t.Fatal(err)
}
@ -2858,7 +2858,7 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) {
}
return nil
})
}, func() {})
if err != nil {
t.Fatalf("error reading db: %v", err)
}

@ -624,6 +624,8 @@ func (d *DB) InvoicesAddedSince(sinceAddIndex uint64) ([]Invoice, error) {
}
return nil
}, func() {
newInvoices = nil
})
if err != nil {
return nil, err
@ -669,7 +671,7 @@ func (d *DB) LookupInvoice(ref InvoiceRef) (Invoice, error) {
invoice = i
return nil
})
}, func() {})
if err != nil {
return invoice, err
}
@ -731,13 +733,6 @@ func (d *DB) ScanInvoices(
scanFunc func(lntypes.Hash, *Invoice) error, reset func()) error {
return kvdb.View(d, func(tx kvdb.RTx) error {
// Reset partial results. As transaction commit success is not
// guaranteed when using etcd, we need to be prepared to redo
// the whole view transaction. In order to be able to do that
// we need a way to reset existing results. This is also done
// upon first run for initialization.
reset()
invoices := tx.ReadBucket(invoiceBucket)
if invoices == nil {
return ErrNoInvoicesCreated
@ -773,7 +768,7 @@ func (d *DB) ScanInvoices(
return scanFunc(paymentHash, &invoice)
})
})
}, reset)
}
// InvoiceQuery represents a query to the invoice database. The query allows a
@ -825,9 +820,7 @@ type InvoiceSlice struct {
// QueryInvoices allows a caller to query the invoice database for invoices
// within the specified add index range.
func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) {
resp := InvoiceSlice{
InvoiceQuery: q,
}
var resp InvoiceSlice
err := kvdb.View(d, func(tx kvdb.RTx) error {
// If the bucket wasn't found, then there aren't any invoices
@ -892,6 +885,10 @@ func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) {
}
return nil
}, func() {
resp = InvoiceSlice{
InvoiceQuery: q,
}
})
if err != nil && err != ErrNoInvoicesCreated {
return resp, err
@ -1011,6 +1008,8 @@ func (d *DB) InvoicesSettledSince(sinceSettleIndex uint64) ([]Invoice, error) {
}
return nil
}, func() {
settledInvoices = nil
})
if err != nil {
return nil, err

@ -218,11 +218,15 @@ func (db *db) getSTMOptions() []STMOptionFunc {
}
// View opens a database read transaction and executes the function f with the
// transaction passed as a parameter. After f exits, the transaction is rolled
// back. If f errors, its error is returned, not a rollback error (if any
// occur).
func (db *db) View(f func(tx walletdb.ReadTx) error) error {
// transaction passed as a parameter. After f exits, the transaction is rolled
// back. If f errors, its error is returned, not a rollback error (if any
// occur). The passed reset function is called before the start of the
// transaction and can be used to reset intermediate state. As callers may
// expect retries of the f closure (depending on the database backend used), the
// reset function will be called before each retry respectively.
func (db *db) View(f func(tx walletdb.ReadTx) error, reset func()) error {
apply := func(stm STM) error {
reset()
return f(newReadWriteTx(stm, db.config.Prefix))
}

@ -49,7 +49,7 @@ func TestReadCursorEmptyInterval(t *testing.T) {
require.Nil(t, v)
return nil
})
}, func() {})
require.NoError(t, err)
}
@ -125,7 +125,7 @@ func TestReadCursorNonEmptyInterval(t *testing.T) {
require.Nil(t, v)
return nil
})
}, func() {})
require.NoError(t, err)
}
@ -354,7 +354,7 @@ func TestReadWriteCursorWithBucketAndValue(t *testing.T) {
require.Equal(t, []byte("val"), v)
return nil
})
}, func() {})
require.NoError(t, err)

@ -21,12 +21,19 @@ func Update(db Backend, f func(tx RwTx) error) error {
// View opens a database read transaction and executes the function f with the
// transaction passed as a parameter. After f exits, the transaction is rolled
// back. If f errors, its error is returned, not a rollback error (if any
// occur).
func View(db Backend, f func(tx RTx) error) error {
// occur). The passed reset function is called before the start of the
// transaction and can be used to reset intermediate state. As callers may
// expect retries of the f closure (depending on the database backend used), the
// reset function will be called before each retry respectively.
func View(db Backend, f func(tx RTx) error, reset func()) error {
if extendedDB, ok := db.(ExtendedBackend); ok {
return extendedDB.View(f)
return extendedDB.View(f, reset)
}
// Since we know that walletdb simply calls into bbolt which never
// retries transactions, we'll call the reset function here before View.
reset()
return walletdb.View(db, f)
}
@ -55,11 +62,15 @@ type ExtendedBackend interface {
// PrintStats returns all collected stats pretty printed into a string.
PrintStats() string
// View opens a database read transaction and executes the function f with
// the transaction passed as a parameter. After f exits, the transaction is
// rolled back. If f errors, its error is returned, not a rollback error
// (if any occur).
View(f func(tx walletdb.ReadTx) error) error
// View opens a database read transaction and executes the function f
// with the transaction passed as a parameter. After f exits, the
// transaction is rolled back. If f errors, its error is returned, not a
// rollback error (if any occur). The passed reset function is called
// before the start of the transaction and can be used to reset
// intermediate state. As callers may expect retries of the f closure
// (depending on the database backend used), the reset function will be
//called before each retry respectively.
View(f func(tx walletdb.ReadTx) error, reset func()) error
// Update opens a database read/write transaction and executes the function
// f with the transaction passed as a parameter. After f exits, if f did not

@ -23,10 +23,12 @@ type Meta struct {
// FetchMeta fetches the meta data from boltdb and returns filled meta
// structure.
func (d *DB) FetchMeta(tx kvdb.RTx) (*Meta, error) {
meta := &Meta{}
var meta *Meta
err := kvdb.View(d, func(tx kvdb.RTx) error {
return fetchMeta(meta, tx)
}, func() {
meta = &Meta{}
})
if err != nil {
return nil, err

@ -492,7 +492,7 @@ func TestMigrationDryRun(t *testing.T) {
}
return nil
})
}, func() {})
if err != nil {
t.Fatalf("unable to apply after func: %v", err)
}

@ -203,6 +203,8 @@ func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, erro
chanSummaries = append(chanSummaries, chanSummary)
return nil
})
}, func() {
chanSummaries = nil
}); err != nil {
return nil, err
}

@ -190,6 +190,8 @@ func (c *ChannelGraph) SourceNode() (*LightningNode, error) {
source = node
return nil
}, func() {
source = nil
})
if err != nil {
return nil, err

@ -282,6 +282,8 @@ func (d *DB) FetchAllInvoices(pendingOnly bool) ([]Invoice, error) {
return nil
})
}, func() {
invoices = nil
})
if err != nil {
return nil, err

@ -126,6 +126,8 @@ func (db *DB) fetchAllPayments() ([]*outgoingPayment, error) {
payments = append(payments, payment)
return nil
})
}, func() {
payments = nil
})
if err != nil {
return nil, err
@ -144,6 +146,8 @@ func (db *DB) fetchPaymentStatus(paymentHash [32]byte) (PaymentStatus, error) {
var err error
paymentStatus, err = fetchPaymentStatusTx(tx, paymentHash)
return err
}, func() {
paymentStatus = StatusUnknown
})
if err != nil {
return StatusUnknown, err
@ -424,6 +428,8 @@ func (db *DB) fetchPaymentsMigration9() ([]*Payment, error) {
return nil
})
})
}, func() {
payments = nil
})
if err != nil {
return nil, err

@ -418,6 +418,8 @@ func TestMigrateOptionalChannelCloseSummaryFields(t *testing.T) {
"serialization")
}
return nil
}, func() {
dbSummary = nil
})
if err != nil {
t.Fatalf("unable to view DB: %v", err)
@ -521,6 +523,8 @@ func TestMigrateGossipMessageStoreKeys(t *testing.T) {
}
return nil
}, func() {
rawMsg = nil
})
if err != nil {
t.Fatal(err)

@ -303,6 +303,8 @@ func (db *DB) FetchPayments() ([]*Payment, error) {
return nil
})
})
}, func() {
payments = nil
})
if err != nil {
return nil, err

@ -158,6 +158,8 @@ func (db *DB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) {
linkNode = node
return nil
}, func() {
linkNode = nil
})
return linkNode, err
@ -199,6 +201,8 @@ func (db *DB) FetchAllLinkNodes() ([]*LinkNode, error) {
linkNodes = nodes
return nil
}, func() {
linkNodes = nil
})
if err != nil {
return nil, err

@ -550,6 +550,8 @@ func (p *PaymentControl) FetchPayment(paymentHash lntypes.Hash) (
payment, err = fetchPayment(bucket)
return err
}, func() {
payment = nil
})
if err != nil {
return nil, err
@ -716,6 +718,8 @@ func (p *PaymentControl) FetchInFlightPayments() ([]*InFlightPayment, error) {
inFlights = append(inFlights, inFlight)
return nil
})
}, func() {
inFlights = nil
})
if err != nil {
return nil, err

@ -502,7 +502,7 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) {
indexCount++
return nil
})
})
}, func() { indexCount = 0 })
require.NoError(t, err)
require.Equal(t, 1, indexCount)
@ -989,7 +989,8 @@ func fetchPaymentIndexEntry(_ *testing.T, p *PaymentControl,
var err error
hash, err = deserializePaymentIndex(r)
return err
}, func() {
hash = lntypes.Hash{}
}); err != nil {
return nil, err
}

@ -269,6 +269,8 @@ func (db *DB) FetchPayments() ([]*MPPayment, error) {
payments = append(payments, duplicatePayments...)
return nil
})
}, func() {
payments = nil
})
if err != nil {
return nil, err
@ -572,6 +574,8 @@ func (db *DB) QueryPayments(query PaymentsQuery) (PaymentsResponse, error) {
}
return nil
}, func() {
resp = PaymentsResponse{}
}); err != nil {
return resp, err
}

@ -113,6 +113,8 @@ func (d *DB) ReadFlapCount(pubkey route.Vertex) (*FlapCount, error) {
}
return ReadElements(r, &flapCount.Count)
}, func() {
flapCount = FlapCount{}
}); err != nil {
return nil, err
}

@ -250,6 +250,8 @@ func (d DB) FetchChannelReports(chainHash chainhash.Hash,
return nil
})
}, func() {
reports = nil
}); err != nil {
return nil, err
}

@ -42,13 +42,14 @@ type WaitingProofStore struct {
// NewWaitingProofStore creates new instance of proofs storage.
func NewWaitingProofStore(db *DB) (*WaitingProofStore, error) {
s := &WaitingProofStore{
db: db,
cache: make(map[WaitingProofKey]struct{}),
db: db,
}
if err := s.ForAll(func(proof *WaitingProof) error {
s.cache[proof.Key()] = struct{}{}
return nil
}, func() {
s.cache = make(map[WaitingProofKey]struct{})
}); err != nil && err != ErrWaitingProofNotFound {
return nil, err
}
@ -122,7 +123,9 @@ func (s *WaitingProofStore) Remove(key WaitingProofKey) error {
// ForAll iterates thought all waiting proofs and passing the waiting proof
// in the given callback.
func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error) error {
func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error,
reset func()) error {
return kvdb.View(s.db, func(tx kvdb.RTx) error {
bucket := tx.ReadBucket(waitingProofsBucketKey)
if bucket == nil {
@ -144,12 +147,12 @@ func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error) error {
return cb(proof)
})
})
}, reset)
}
// Get returns the object which corresponds to the given index.
func (s *WaitingProofStore) Get(key WaitingProofKey) (*WaitingProof, error) {
proof := &WaitingProof{}
var proof *WaitingProof
s.mu.RLock()
defer s.mu.RUnlock()
@ -172,6 +175,8 @@ func (s *WaitingProofStore) Get(key WaitingProofKey) (*WaitingProof, error) {
r := bytes.NewReader(v)
return proof.Decode(r)
}, func() {
proof = &WaitingProof{}
})
return proof, err

@ -53,7 +53,7 @@ func TestWaitingProofStore(t *testing.T) {
if err := store.ForAll(func(proof *WaitingProof) error {
return errors.New("storage should be empty")
}); err != nil && err != ErrWaitingProofNotFound {
}, func() {}); err != nil && err != ErrWaitingProofNotFound {
t.Fatal(err)
}
}

@ -174,6 +174,8 @@ func (w *WitnessCache) lookupWitness(wType WitnessType, witnessKey []byte) ([]by
copy(witness[:], dbWitness)
return nil
}, func() {
witness = nil
})
if err != nil {
return nil, err

@ -430,6 +430,8 @@ func (b *boltArbitratorLog) CurrentState() (ArbitratorState, error) {
s = ArbitratorState(stateBytes[0])
return nil
}, func() {
s = 0
})
if err != nil && err != errScopeBucketNoExist {
return s, err
@ -521,6 +523,8 @@ func (b *boltArbitratorLog) FetchUnresolvedContracts() ([]ContractResolver, erro
contracts = append(contracts, res)
return nil
})
}, func() {
contracts = nil
})
if err != nil && err != errScopeBucketNoExist && err != errNoContracts {
return nil, err
@ -685,7 +689,7 @@ func (b *boltArbitratorLog) LogContractResolutions(c *ContractResolutions) error
//
// NOTE: Part of the ContractResolver interface.
func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, error) {
c := &ContractResolutions{}
var c *ContractResolutions
err := kvdb.View(b.db, func(tx kvdb.RTx) error {
scopeBucket := tx.ReadBucket(b.scopeKey[:])
if scopeBucket == nil {
@ -769,6 +773,8 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er
}
return nil
}, func() {
c = &ContractResolutions{}
})
if err != nil {
return nil, err
@ -783,7 +789,7 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er
//
// NOTE: Part of the ContractResolver interface.
func (b *boltArbitratorLog) FetchChainActions() (ChainActionMap, error) {
actionsMap := make(ChainActionMap)
var actionsMap ChainActionMap
err := kvdb.View(b.db, func(tx kvdb.RTx) error {
scopeBucket := tx.ReadBucket(b.scopeKey[:])
@ -813,6 +819,8 @@ func (b *boltArbitratorLog) FetchChainActions() (ChainActionMap, error) {
return nil
})
}, func() {
actionsMap = make(ChainActionMap)
})
if err != nil {
return nil, err
@ -866,6 +874,8 @@ func (b *boltArbitratorLog) FetchConfirmedCommitSet() (*CommitSet, error) {
c = commitSet
return nil
}, func() {
c = nil
})
if err != nil {
return nil, err

@ -1117,6 +1117,9 @@ func TestSignatureAnnouncementLocalFirst(t *testing.T) {
number++
return nil
},
func() {
number = 0
},
); err != nil {
t.Fatalf("unable to retrieve objects from store: %v", err)
}
@ -1150,6 +1153,9 @@ func TestSignatureAnnouncementLocalFirst(t *testing.T) {
number++
return nil
},
func() {
number = 0
},
); err != nil && err != channeldb.ErrWaitingProofNotFound {
t.Fatalf("unable to retrieve objects from store: %v", err)
}
@ -1219,6 +1225,9 @@ func TestOrphanSignatureAnnouncement(t *testing.T) {
number++
return nil
},
func() {
number = 0
},
); err != nil {
t.Fatalf("unable to retrieve objects from store: %v", err)
}
@ -1355,6 +1364,9 @@ func TestOrphanSignatureAnnouncement(t *testing.T) {
number++
return nil
},
func() {
number = 0
},
); err != nil {
t.Fatalf("unable to retrieve objects from store: %v", err)
}
@ -1466,6 +1478,9 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) {
number++
return nil
},
func() {
number = 0
},
); err != nil {
t.Fatalf("unable to retrieve objects from store: %v", err)
}
@ -1570,6 +1585,9 @@ out:
number++
return nil
},
func() {
number = 0
},
); err != nil && err != channeldb.ErrWaitingProofNotFound {
t.Fatalf("unable to retrieve objects from store: %v", err)
}
@ -1754,6 +1772,9 @@ func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) {
number++
return nil
},
func() {
number = 0
},
); err != nil && err != channeldb.ErrWaitingProofNotFound {
t.Fatalf("unable to retrieve objects from store: %v", err)
}
@ -2583,6 +2604,9 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) {
number++
return nil
},
func() {
number = 0
},
); err != nil {
t.Fatalf("unable to retrieve objects from store: %v", err)
}
@ -2612,6 +2636,9 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) {
number++
return nil
},
func() {
number = 0
},
); err != nil && err != channeldb.ErrWaitingProofNotFound {
t.Fatalf("unable to retrieve objects from store: %v", err)
}

@ -199,7 +199,7 @@ func readMessage(msgBytes []byte) (lnwire.Message, error) {
// Messages returns the total set of messages that exist within the store for
// all peers.
func (s *MessageStore) Messages() (map[[33]byte][]lnwire.Message, error) {
msgs := make(map[[33]byte][]lnwire.Message)
var msgs map[[33]byte][]lnwire.Message
err := kvdb.View(s.db, func(tx kvdb.RTx) error {
messageStore := tx.ReadBucket(messageStoreBucket)
if messageStore == nil {
@ -224,6 +224,8 @@ func (s *MessageStore) Messages() (map[[33]byte][]lnwire.Message, error) {
msgs[pubKey] = append(msgs[pubKey], msg)
return nil
})
}, func() {
msgs = make(map[[33]byte][]lnwire.Message)
})
if err != nil {
return nil, err
@ -262,6 +264,8 @@ func (s *MessageStore) MessagesForPeer(
}
return nil
}, func() {
msgs = nil
})
if err != nil {
return nil, err
@ -272,7 +276,7 @@ func (s *MessageStore) MessagesForPeer(
// Peers returns the public key of all peers with messages within the store.
func (s *MessageStore) Peers() (map[[33]byte]struct{}, error) {
peers := make(map[[33]byte]struct{})
var peers map[[33]byte]struct{}
err := kvdb.View(s.db, func(tx kvdb.RTx) error {
messageStore := tx.ReadBucket(messageStoreBucket)
if messageStore == nil {
@ -285,6 +289,8 @@ func (s *MessageStore) Peers() (map[[33]byte]struct{}, error) {
peers[pubKey] = struct{}{}
return nil
})
}, func() {
peers = make(map[[33]byte]struct{})
})
if err != nil {
return nil, err

@ -3538,7 +3538,7 @@ func (f *fundingManager) getChannelOpeningState(chanPoint *wire.OutPoint) (
state = channelOpeningState(byteOrder.Uint16(value[:2]))
shortChanID = lnwire.NewShortChanIDFromInt(byteOrder.Uint64(value[2:]))
return nil
})
}, func() {})
if err != nil {
return 0, nil, err
}

@ -280,6 +280,8 @@ func (d *DecayedLog) Get(hash *sphinx.HashPrefix) (uint32, error) {
value = uint32(binary.BigEndian.Uint32(valueBytes))
return nil
}, func() {
value = 0
})
if err != nil {
return value, err

@ -197,6 +197,8 @@ func (store *networkResultStore) subscribeResult(paymentID uint64) (
default:
return nil
}
}, func() {
result = nil
})
if err != nil {
return nil, err
@ -230,6 +232,8 @@ func (store *networkResultStore) getResult(pid uint64) (
var err error
result, err = fetchResult(tx, pid)
return err
}, func() {
result = nil
})
if err != nil {
return nil, err

@ -1833,6 +1833,8 @@ func (s *Switch) loadChannelFwdPkgs(source lnwire.ShortChannelID) ([]*channeldb.
tx, source,
)
return err
}, func() {
fwdPkgs = nil
}); err != nil {
return nil, err
}

@ -150,6 +150,8 @@ func (r *RootKeyStorage) Get(_ context.Context, id []byte) ([]byte, error) {
rootKey = make([]byte, len(decKey))
copy(rootKey[:], decKey)
return nil
}, func() {
rootKey = nil
})
if err != nil {
return nil, err
@ -257,6 +259,8 @@ func (r *RootKeyStorage) ListMacaroonIDs(_ context.Context) ([][]byte, error) {
}
return tx.ReadBucket(rootKeyBucketName).ForEach(appendRootKey)
}, func() {
rootKeySlice = nil
})
if err != nil {
return nil, err

@ -129,7 +129,7 @@ type NurseryStore interface {
// the caller to process each key-value pair. The key will be a prefixed
// outpoint, and the value will be the serialized bytes for an output,
// whose type should be inferred from the key's prefix.
ForChanOutputs(*wire.OutPoint, func([]byte, []byte) error) error
ForChanOutputs(*wire.OutPoint, func([]byte, []byte) error, func()) error
// ListChannels returns all channels the nursery is currently tracking.
ListChannels() ([]wire.OutPoint, error)
@ -582,6 +582,9 @@ func (ns *nurseryStore) FetchClass(
})
}, func() {
kids = nil
babies = nil
}); err != nil {
return nil, nil, err
}
@ -655,6 +658,8 @@ func (ns *nurseryStore) FetchPreschools() ([]kidOutput, error) {
}
return nil
}, func() {
kids = nil
}); err != nil {
return nil, err
}
@ -693,6 +698,8 @@ func (ns *nurseryStore) HeightsBelowOrEqual(height uint32) ([]uint32, error) {
}
return nil
}, func() {
activeHeights = nil
})
if err != nil {
return nil, err
@ -709,11 +716,11 @@ func (ns *nurseryStore) HeightsBelowOrEqual(height uint32) ([]uint32, error) {
// NOTE: The callback should not modify the provided byte slices and is
// preferably non-blocking.
func (ns *nurseryStore) ForChanOutputs(chanPoint *wire.OutPoint,
callback func([]byte, []byte) error) error {
callback func([]byte, []byte) error, reset func()) error {
return kvdb.View(ns.db, func(tx kvdb.RTx) error {
return ns.forChanOutputs(tx, chanPoint, callback)
})
}, reset)
}
// ListChannels returns all channels the nursery is currently tracking.
@ -743,6 +750,8 @@ func (ns *nurseryStore) ListChannels() ([]wire.OutPoint, error) {
return nil
})
}, func() {
activeChannels = nil
}); err != nil {
return nil, err
}
@ -765,7 +774,7 @@ func (ns *nurseryStore) IsMatureChannel(chanPoint *wire.OutPoint) (bool, error)
return nil
})
})
}, func() {})
if err != nil && err != ErrImmatureChannel {
return false, err
}

@ -370,6 +370,8 @@ func assertNumChanOutputs(t *testing.T, ns NurseryStore,
err := ns.ForChanOutputs(chanPoint, func([]byte, []byte) error {
count++
return nil
}, func() {
count = 0
})
if count == 0 && err == ErrContractNotFound {

@ -103,6 +103,8 @@ func (b *missionControlStore) fetchAll() ([]*paymentResult, error) {
return nil
})
}, func() {
results = nil
})
if err != nil {
return nil, err

@ -219,6 +219,8 @@ func (s *sweeperStore) GetLastPublishedTx() (*wire.MsgTx, error) {
}
return nil
}, func() {
sweepTx = nil
})
if err != nil {
return nil, err
@ -241,6 +243,8 @@ func (s *sweeperStore) IsOurTx(hash chainhash.Hash) (bool, error) {
ours = txHashesBucket.Get(hash[:]) != nil
return nil
}, func() {
ours = false
})
if err != nil {
return false, err
@ -269,6 +273,8 @@ func (s *sweeperStore) ListSweeps() ([]chainhash.Hash, error) {
return nil
})
}, func() {
sweepTxns = nil
}); err != nil {
return nil, err
}

@ -477,7 +477,7 @@ func (u *utxoNursery) NurseryReport(
utxnLog.Debugf("NurseryReport: building nursery report for channel %v",
chanPoint)
report := &contractMaturityReport{}
var report *contractMaturityReport
if err := u.cfg.Store.ForChanOutputs(chanPoint, func(k, v []byte) error {
switch {
@ -576,6 +576,8 @@ func (u *utxoNursery) NurseryReport(
}
return nil
}, func() {
report = &contractMaturityReport{}
}); err != nil {
return nil, err
}

@ -931,9 +931,9 @@ func (i *nurseryStoreInterceptor) HeightsBelowOrEqual(height uint32) (
}
func (i *nurseryStoreInterceptor) ForChanOutputs(chanPoint *wire.OutPoint,
callback func([]byte, []byte) error) error {
callback func([]byte, []byte) error, reset func()) error {
return i.ns.ForChanOutputs(chanPoint, callback)
return i.ns.ForChanOutputs(chanPoint, callback, reset)
}
func (i *nurseryStoreInterceptor) ListChannels() ([]wire.OutPoint, error) {

@ -192,6 +192,8 @@ func (c *ClientDB) Version() (uint32, error) {
var err error
version, err = getDBVersion(tx)
return err
}, func() {
version = 0
})
if err != nil {
return 0, err
@ -392,6 +394,8 @@ func (c *ClientDB) LoadTowerByID(towerID TowerID) (*Tower, error) {
var err error
tower, err = getTower(towers, towerID.Bytes())
return err
}, func() {
tower = nil
})
if err != nil {
return nil, err
@ -421,6 +425,8 @@ func (c *ClientDB) LoadTower(pubKey *btcec.PublicKey) (*Tower, error) {
var err error
tower, err = getTower(towers, towerIDBytes)
return err
}, func() {
tower = nil
})
if err != nil {
return nil, err
@ -446,6 +452,8 @@ func (c *ClientDB) ListTowers() ([]*Tower, error) {
towers = append(towers, tower)
return nil
})
}, func() {
towers = nil
})
if err != nil {
return nil, err
@ -566,6 +574,8 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (map[SessionID]*ClientSession
var err error
clientSessions, err = listClientSessions(sessions, id)
return err
}, func() {
clientSessions = nil
})
if err != nil {
return nil, err
@ -611,7 +621,7 @@ func listClientSessions(sessions kvdb.RBucket,
// FetchChanSummaries loads a mapping from all registered channels to their
// channel summaries.
func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
summaries := make(map[lnwire.ChannelID]ClientChanSummary)
var summaries map[lnwire.ChannelID]ClientChanSummary
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
chanSummaries := tx.ReadBucket(cChanSummaryBkt)
if chanSummaries == nil {
@ -632,6 +642,8 @@ func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
return nil
})
}, func() {
summaries = make(map[lnwire.ChannelID]ClientChanSummary)
})
if err != nil {
return nil, err

@ -80,6 +80,8 @@ func createDBIfNotExist(dbPath, name string) (kvdb.Backend, bool, error) {
err = kvdb.View(bdb, func(tx kvdb.RTx) error {
metadataExists = tx.ReadBucket(metadataBkt) != nil
return nil
}, func() {
metadataExists = false
})
if err != nil {
return nil, false, err

@ -133,6 +133,8 @@ func (t *TowerDB) Version() (uint32, error) {
var err error
version, err = getDBVersion(tx)
return err
}, func() {
version = 0
})
if err != nil {
return 0, err
@ -159,6 +161,8 @@ func (t *TowerDB) GetSessionInfo(id *SessionID) (*SessionInfo, error) {
var err error
session, err = getSession(sessions, id[:])
return err
}, func() {
session = nil
})
if err != nil {
return nil, err
@ -460,6 +464,8 @@ func (t *TowerDB) QueryMatches(breachHints []blob.BreachHint) ([]Match, error) {
}
return nil
}, func() {
matches = nil
})
if err != nil {
return nil, err
@ -494,6 +500,8 @@ func (t *TowerDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) {
epoch = getLookoutEpoch(lookoutTip)
return nil
}, func() {
epoch = nil
})
if err != nil {
return nil, err