Merge pull request #4705 from bhandras/kvdb_view_reset

multi: add reset closure to `kvdb.View/Update` to be able to clean up external state before retries
This commit is contained in:
András Bánki-Horváth 2020-11-06 12:47:35 +01:00 committed by GitHub
commit 0c3c6e6155
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
63 changed files with 564 additions and 260 deletions

@ -152,10 +152,12 @@ func (b *breachArbiter) start() error {
brarLog.Tracef("Starting breach arbiter") brarLog.Tracef("Starting breach arbiter")
// Load all retributions currently persisted in the retribution store. // Load all retributions currently persisted in the retribution store.
breachRetInfos := make(map[wire.OutPoint]retributionInfo) var breachRetInfos map[wire.OutPoint]retributionInfo
if err := b.cfg.Store.ForAll(func(ret *retributionInfo) error { if err := b.cfg.Store.ForAll(func(ret *retributionInfo) error {
breachRetInfos[ret.chanPoint] = *ret breachRetInfos[ret.chanPoint] = *ret
return nil return nil
}, func() {
breachRetInfos = make(map[wire.OutPoint]retributionInfo)
}); err != nil { }); err != nil {
return err return err
} }
@ -1223,7 +1225,7 @@ type RetributionStore interface {
// ForAll iterates over the existing on-disk contents and applies a // ForAll iterates over the existing on-disk contents and applies a
// chosen, read-only callback to each. This method should ensure that it // chosen, read-only callback to each. This method should ensure that it
// immediately propagate any errors generated by the callback. // immediately propagate any errors generated by the callback.
ForAll(cb func(*retributionInfo) error) error ForAll(cb func(*retributionInfo) error, reset func()) error
} }
// retributionStore handles persistence of retribution states to disk and is // retributionStore handles persistence of retribution states to disk and is
@ -1263,7 +1265,7 @@ func (rs *retributionStore) Add(ret *retributionInfo) error {
} }
return retBucket.Put(outBuf.Bytes(), retBuf.Bytes()) return retBucket.Put(outBuf.Bytes(), retBuf.Bytes())
}) }, func() {})
} }
// Finalize writes a signed justice transaction to the retribution store. This // Finalize writes a signed justice transaction to the retribution store. This
@ -1288,7 +1290,7 @@ func (rs *retributionStore) Finalize(chanPoint *wire.OutPoint,
} }
return justiceBkt.Put(chanBuf.Bytes(), txBuf.Bytes()) return justiceBkt.Put(chanBuf.Bytes(), txBuf.Bytes())
}) }, func() {})
} }
// GetFinalizedTxn loads the finalized justice transaction for the provided // GetFinalizedTxn loads the finalized justice transaction for the provided
@ -1312,6 +1314,8 @@ func (rs *retributionStore) GetFinalizedTxn(
finalTxBytes = justiceBkt.Get(chanBuf.Bytes()) finalTxBytes = justiceBkt.Get(chanBuf.Bytes())
return nil return nil
}, func() {
finalTxBytes = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
@ -1349,6 +1353,8 @@ func (rs *retributionStore) IsBreached(chanPoint *wire.OutPoint) (bool, error) {
} }
return nil return nil
}, func() {
found = false
}) })
return found, err return found, err
@ -1390,12 +1396,14 @@ func (rs *retributionStore) Remove(chanPoint *wire.OutPoint) error {
} }
return justiceBkt.Delete(chanBytes) return justiceBkt.Delete(chanBytes)
}) }, func() {})
} }
// ForAll iterates through all stored retributions and executes the passed // ForAll iterates through all stored retributions and executes the passed
// callback function on each retribution. // callback function on each retribution.
func (rs *retributionStore) ForAll(cb func(*retributionInfo) error) error { func (rs *retributionStore) ForAll(cb func(*retributionInfo) error,
reset func()) error {
return kvdb.View(rs.db, func(tx kvdb.RTx) error { return kvdb.View(rs.db, func(tx kvdb.RTx) error {
// If the bucket does not exist, then there are no pending // If the bucket does not exist, then there are no pending
// retributions. // retributions.
@ -1416,7 +1424,7 @@ func (rs *retributionStore) ForAll(cb func(*retributionInfo) error) error {
return cb(ret) return cb(ret)
}) })
}) }, reset)
} }
// Encode serializes the retribution into the passed byte stream. // Encode serializes the retribution into the passed byte stream.

@ -431,11 +431,13 @@ func (frs *failingRetributionStore) Remove(key *wire.OutPoint) error {
return frs.rs.Remove(key) return frs.rs.Remove(key)
} }
func (frs *failingRetributionStore) ForAll(cb func(*retributionInfo) error) error { func (frs *failingRetributionStore) ForAll(cb func(*retributionInfo) error,
reset func()) error {
frs.mu.Lock() frs.mu.Lock()
defer frs.mu.Unlock() defer frs.mu.Unlock()
return frs.rs.ForAll(cb) return frs.rs.ForAll(cb, reset)
} }
// Parse the pubkeys in the breached outputs. // Parse the pubkeys in the breached outputs.
@ -592,10 +594,13 @@ func (rs *mockRetributionStore) Remove(key *wire.OutPoint) error {
return nil return nil
} }
func (rs *mockRetributionStore) ForAll(cb func(*retributionInfo) error) error { func (rs *mockRetributionStore) ForAll(cb func(*retributionInfo) error,
reset func()) error {
rs.mu.Lock() rs.mu.Lock()
defer rs.mu.Unlock() defer rs.mu.Unlock()
reset()
for _, retInfo := range rs.state { for _, retInfo := range rs.state {
if err := cb(copyRetInfo(retInfo)); err != nil { if err := cb(copyRetInfo(retInfo)); err != nil {
return err return err
@ -717,6 +722,8 @@ func countRetributions(t *testing.T, rs RetributionStore) int {
err := rs.ForAll(func(_ *retributionInfo) error { err := rs.ForAll(func(_ *retributionInfo) error {
count++ count++
return nil return nil
}, func() {
count = 0
}) })
if err != nil { if err != nil {
t.Fatalf("unable to list retributions in db: %v", err) t.Fatalf("unable to list retributions in db: %v", err)
@ -919,7 +926,7 @@ restartCheck:
// Construct a set of all channel points presented by the store. Entries // Construct a set of all channel points presented by the store. Entries
// are only be added to the set if their corresponding retribution // are only be added to the set if their corresponding retribution
// information matches the test vector. // information matches the test vector.
var foundSet = make(map[wire.OutPoint]struct{}) var foundSet map[wire.OutPoint]struct{}
// Iterate through the stored retributions, checking to see if we have // Iterate through the stored retributions, checking to see if we have
// an equivalent retribution in the test vector. This will return an // an equivalent retribution in the test vector. This will return an
@ -948,6 +955,8 @@ restartCheck:
} }
return nil return nil
}, func() {
foundSet = make(map[wire.OutPoint]struct{})
}); err != nil { }); err != nil {
t.Fatalf("failed to iterate over persistent retributions: %v", t.Fatalf("failed to iterate over persistent retributions: %v",
err) err)

@ -179,6 +179,8 @@ func (c *HeightHintCache) QuerySpendHint(spendRequest SpendRequest) (uint32, err
} }
return channeldb.ReadElement(bytes.NewReader(spendHint), &hint) return channeldb.ReadElement(bytes.NewReader(spendHint), &hint)
}, func() {
hint = 0
}) })
if err != nil { if err != nil {
return 0, err return 0, err
@ -278,6 +280,8 @@ func (c *HeightHintCache) QueryConfirmHint(confRequest ConfRequest) (uint32, err
} }
return channeldb.ReadElement(bytes.NewReader(confirmHint), &hint) return channeldb.ReadElement(bytes.NewReader(confirmHint), &hint)
}, func() {
hint = 0
}) })
if err != nil { if err != nil {
return 0, err return 0, err

@ -756,7 +756,7 @@ func (c *OpenChannel) RefreshShortChanID() error {
} }
return nil return nil
}) }, func() {})
if err != nil { if err != nil {
return err return err
} }
@ -893,7 +893,7 @@ func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error {
channel.ShortChannelID = openLoc channel.ShortChannelID = openLoc
return putOpenChannel(chanBucket.(kvdb.RwBucket), channel) return putOpenChannel(chanBucket.(kvdb.RwBucket), channel)
}); err != nil { }, func() {}); err != nil {
return err return err
} }
@ -950,6 +950,8 @@ func (c *OpenChannel) DataLossCommitPoint() (*btcec.PublicKey, error) {
} }
return nil return nil
}, func() {
commitPoint = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -1168,6 +1170,8 @@ func (c *OpenChannel) getClosingTx(key []byte) (*wire.MsgTx, error) {
} }
r := bytes.NewReader(bs) r := bytes.NewReader(bs)
return ReadElement(r, &closeTx) return ReadElement(r, &closeTx)
}, func() {
closeTx = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -1215,7 +1219,7 @@ func (c *OpenChannel) putChanStatus(status ChannelStatus,
} }
return nil return nil
}); err != nil { }, func() {}); err != nil {
return err return err
} }
@ -1244,7 +1248,7 @@ func (c *OpenChannel) clearChanStatus(status ChannelStatus) error {
channel.chanStatus = status channel.chanStatus = status
return putOpenChannel(chanBucket, channel) return putOpenChannel(chanBucket, channel)
}); err != nil { }, func() {}); err != nil {
return err return err
} }
@ -1352,7 +1356,7 @@ func (c *OpenChannel) SyncPending(addr net.Addr, pendingHeight uint32) error {
return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { return kvdb.Update(c.Db, func(tx kvdb.RwTx) error {
return syncNewChannel(tx, c, []net.Addr{addr}) return syncNewChannel(tx, c, []net.Addr{addr})
}) }, func() {})
} }
// syncNewChannel will write the passed channel to disk, and also create a // syncNewChannel will write the passed channel to disk, and also create a
@ -1486,7 +1490,7 @@ func (c *OpenChannel) UpdateCommitment(newCommitment *ChannelCommitment,
} }
return nil return nil
}) }, func() {})
if err != nil { if err != nil {
return err return err
} }
@ -2026,7 +2030,7 @@ func (c *OpenChannel) AppendRemoteCommitChain(diff *CommitDiff) error {
return err return err
} }
return chanBucket.Put(commitDiffKey, b.Bytes()) return chanBucket.Put(commitDiffKey, b.Bytes())
}) }, func() {})
} }
// RemoteCommitChainTip returns the "tip" of the current remote commitment // RemoteCommitChainTip returns the "tip" of the current remote commitment
@ -2062,6 +2066,8 @@ func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) {
cd = dcd cd = dcd
return nil return nil
}, func() {
cd = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -2094,6 +2100,8 @@ func (c *OpenChannel) UnsignedAckedUpdates() ([]LogUpdate, error) {
r := bytes.NewReader(updateBytes) r := bytes.NewReader(updateBytes)
updates, err = deserializeLogUpdates(r) updates, err = deserializeLogUpdates(r)
return err return err
}, func() {
updates = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -2127,6 +2135,8 @@ func (c *OpenChannel) RemoteUnsignedLocalUpdates() ([]LogUpdate, error) {
r := bytes.NewReader(updateBytes) r := bytes.NewReader(updateBytes)
updates, err = deserializeLogUpdates(r) updates, err = deserializeLogUpdates(r)
return err return err
}, func() {
updates = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -2157,7 +2167,7 @@ func (c *OpenChannel) InsertNextRevocation(revKey *btcec.PublicKey) error {
} }
return putChanRevocationState(chanBucket, c) return putChanRevocationState(chanBucket, c)
}) }, func() {})
if err != nil { if err != nil {
return err return err
} }
@ -2317,6 +2327,8 @@ func (c *OpenChannel) AdvanceCommitChainTail(fwdPkg *FwdPkg,
newRemoteCommit = &newCommit.Commitment newRemoteCommit = &newCommit.Commitment
return nil return nil
}, func() {
newRemoteCommit = nil
}) })
if err != nil { if err != nil {
return err return err
@ -2365,6 +2377,8 @@ func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) {
var err error var err error
fwdPkgs, err = c.Packager.LoadFwdPkgs(tx) fwdPkgs, err = c.Packager.LoadFwdPkgs(tx)
return err return err
}, func() {
fwdPkgs = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
@ -2381,7 +2395,7 @@ func (c *OpenChannel) AckAddHtlcs(addRefs ...AddRef) error {
return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { return kvdb.Update(c.Db, func(tx kvdb.RwTx) error {
return c.Packager.AckAddHtlcs(tx, addRefs...) return c.Packager.AckAddHtlcs(tx, addRefs...)
}) }, func() {})
} }
// AckSettleFails updates the SettleFailFilter containing any of the provided // AckSettleFails updates the SettleFailFilter containing any of the provided
@ -2394,7 +2408,7 @@ func (c *OpenChannel) AckSettleFails(settleFailRefs ...SettleFailRef) error {
return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { return kvdb.Update(c.Db, func(tx kvdb.RwTx) error {
return c.Packager.AckSettleFails(tx, settleFailRefs...) return c.Packager.AckSettleFails(tx, settleFailRefs...)
}) }, func() {})
} }
// SetFwdFilter atomically sets the forwarding filter for the forwarding package // SetFwdFilter atomically sets the forwarding filter for the forwarding package
@ -2405,7 +2419,7 @@ func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error {
return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { return kvdb.Update(c.Db, func(tx kvdb.RwTx) error {
return c.Packager.SetFwdFilter(tx, height, fwdFilter) return c.Packager.SetFwdFilter(tx, height, fwdFilter)
}) }, func() {})
} }
// RemoveFwdPkgs atomically removes forwarding packages specified by the remote // RemoveFwdPkgs atomically removes forwarding packages specified by the remote
@ -2426,7 +2440,7 @@ func (c *OpenChannel) RemoveFwdPkgs(heights ...uint64) error {
} }
return nil return nil
}) }, func() {})
} }
// RevocationLogTail returns the "tail", or the end of the current revocation // RevocationLogTail returns the "tail", or the end of the current revocation
@ -2475,7 +2489,7 @@ func (c *OpenChannel) RevocationLogTail() (*ChannelCommitment, error) {
} }
return nil return nil
}); err != nil { }, func() {}); err != nil {
return nil, err return nil, err
} }
@ -2509,6 +2523,8 @@ func (c *OpenChannel) CommitmentHeight() (uint64, error) {
height = commit.CommitHeight height = commit.CommitHeight
return nil return nil
}, func() {
height = 0
}) })
if err != nil { if err != nil {
return 0, err return 0, err
@ -2547,7 +2563,7 @@ func (c *OpenChannel) FindPreviousState(updateNum uint64) (*ChannelCommitment, e
commit = c commit = c
return nil return nil
}) }, func() {})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -2785,7 +2801,7 @@ func (c *OpenChannel) CloseChannel(summary *ChannelCloseSummary,
return putChannelCloseSummary( return putChannelCloseSummary(
tx, chanPointBuf.Bytes(), summary, chanState, tx, chanPointBuf.Bytes(), summary, chanState,
) )
}) }, func() {})
} }
// ChannelSnapshot is a frozen snapshot of the current channel state. A // ChannelSnapshot is a frozen snapshot of the current channel state. A
@ -2870,7 +2886,7 @@ func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitmen
} }
return fetchChanCommitments(chanBucket, c) return fetchChanCommitments(chanBucket, c)
}) }, func() {})
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -2892,7 +2908,7 @@ func (c *OpenChannel) RemoteRevocationStore() (shachain.Store, error) {
} }
return fetchChanRevocationState(chanBucket, c) return fetchChanRevocationState(chanBucket, c)
}) }, func() {})
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -1447,7 +1447,7 @@ func TestBalanceAtHeight(t *testing.T) {
commit.RemoteBalance = remote commit.RemoteBalance = remote
return appendChannelLogEntry(logBucket, &commit) return appendChannelLogEntry(logBucket, &commit)
}) }, func() {})
return err return err
} }

@ -191,22 +191,31 @@ type DB struct {
// Update is a wrapper around walletdb.Update which calls into the extended // Update is a wrapper around walletdb.Update which calls into the extended
// backend when available. This call is needed to be able to cast DB to // backend when available. This call is needed to be able to cast DB to
// ExtendedBackend. // ExtendedBackend. The passed reset function is called before the start of the
func (db *DB) Update(f func(tx walletdb.ReadWriteTx) error) error { // transaction and can be used to reset intermediate state. As callers may
// expect retries of the f closure (depending on the database backend used), the
// reset function will be called before each retry respectively.
func (db *DB) Update(f func(tx walletdb.ReadWriteTx) error, reset func()) error {
if v, ok := db.Backend.(kvdb.ExtendedBackend); ok { if v, ok := db.Backend.(kvdb.ExtendedBackend); ok {
return v.Update(f) return v.Update(f, reset)
} }
reset()
return walletdb.Update(db, f) return walletdb.Update(db, f)
} }
// View is a wrapper around walletdb.View which calls into the extended // View is a wrapper around walletdb.View which calls into the extended
// backend when available. This call is needed to be able to cast DB to // backend when available. This call is needed to be able to cast DB to
// ExtendedBackend. // ExtendedBackend. The passed reset function is called before the start of the
func (db *DB) View(f func(tx walletdb.ReadTx) error) error { // transaction and can be used to reset intermediate state. As callers may
// expect retries of the f closure (depending on the database backend used), the
// reset function will be called before each retry respectively.
func (db *DB) View(f func(tx walletdb.ReadTx) error, reset func()) error {
if v, ok := db.Backend.(kvdb.ExtendedBackend); ok { if v, ok := db.Backend.(kvdb.ExtendedBackend); ok {
return v.View(f) return v.View(f, reset)
} }
reset()
return walletdb.View(db, f) return walletdb.View(db, f)
} }
@ -306,7 +315,7 @@ func (d *DB) Wipe() error {
} }
} }
return nil return nil
}) }, func() {})
} }
// createChannelDB creates and initializes a fresh version of channeldb. In // createChannelDB creates and initializes a fresh version of channeldb. In
@ -360,7 +369,7 @@ func initChannelDB(db kvdb.Backend) error {
meta.DbVersionNumber = getLatestDBVersion(dbVersions) meta.DbVersionNumber = getLatestDBVersion(dbVersions)
return putMeta(meta, tx) return putMeta(meta, tx)
}) }, func() {})
if err != nil { if err != nil {
return fmt.Errorf("unable to create new channeldb: %v", err) return fmt.Errorf("unable to create new channeldb: %v", err)
} }
@ -389,6 +398,8 @@ func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error)
var err error var err error
channels, err = d.fetchOpenChannels(tx, nodeID) channels, err = d.fetchOpenChannels(tx, nodeID)
return err return err
}, func() {
channels = nil
}) })
return channels, err return channels, err
@ -574,7 +585,7 @@ func (d *DB) FetchChannel(chanPoint wire.OutPoint) (*OpenChannel, error) {
}) })
} }
err := kvdb.View(d, chanScan) err := kvdb.View(d, chanScan, func() {})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -741,6 +752,8 @@ func fetchChannels(d *DB, filters ...fetchChannelsFilter) ([]*OpenChannel, error
}) })
}) })
}, func() {
channels = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -781,6 +794,8 @@ func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, erro
chanSummaries = append(chanSummaries, chanSummary) chanSummaries = append(chanSummaries, chanSummary)
return nil return nil
}) })
}, func() {
chanSummaries = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
@ -817,6 +832,8 @@ func (d *DB) FetchClosedChannel(chanID *wire.OutPoint) (*ChannelCloseSummary, er
chanSummary, err = deserializeCloseChannelSummary(summaryReader) chanSummary, err = deserializeCloseChannelSummary(summaryReader)
return err return err
}, func() {
chanSummary = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
@ -865,6 +882,8 @@ func (d *DB) FetchClosedChannelForID(cid lnwire.ChannelID) (
return nil return nil
} }
return ErrClosedChannelNotFound return ErrClosedChannelNotFound
}, func() {
chanSummary = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
@ -925,7 +944,7 @@ func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error {
// garbage collect it to ensure we don't establish persistent // garbage collect it to ensure we don't establish persistent
// connections to peers without open channels. // connections to peers without open channels.
return d.pruneLinkNode(tx, chanSummary.RemotePub) return d.pruneLinkNode(tx, chanSummary.RemotePub)
}) }, func() {})
} }
// pruneLinkNode determines whether we should garbage collect a link node from // pruneLinkNode determines whether we should garbage collect a link node from
@ -965,7 +984,7 @@ func (d *DB) PruneLinkNodes() error {
} }
return nil return nil
}) }, func() {})
} }
// ChannelShell is a shell of a channel that is meant to be used for channel // ChannelShell is a shell of a channel that is meant to be used for channel
@ -1012,7 +1031,7 @@ func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error {
} }
return nil return nil
}) }, func() {})
if err != nil { if err != nil {
return err return err
} }
@ -1052,6 +1071,8 @@ func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) {
} }
return nil return nil
}, func() {
linkNode = nil
}) })
if dbErr != nil { if dbErr != nil {
return nil, dbErr return nil, dbErr
@ -1194,7 +1215,7 @@ func (d *DB) syncVersions(versions []version) error {
} }
return nil return nil
}) }, func() {})
} }
// ChannelGraph returns a new instance of the directed channel graph. // ChannelGraph returns a new instance of the directed channel graph.
@ -1261,6 +1282,8 @@ func (db *DB) FetchHistoricalChannel(outPoint *wire.OutPoint) (*OpenChannel, err
channel, err = fetchOpenChannel(chanBucket, outPoint) channel, err = fetchOpenChannel(chanBucket, outPoint)
return err return err
}, func() {
channel = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -228,9 +228,7 @@ type ForwardingLogTimeSlice struct {
// //
// TODO(roasbeef): rename? // TODO(roasbeef): rename?
func (f *ForwardingLog) Query(q ForwardingEventQuery) (ForwardingLogTimeSlice, error) { func (f *ForwardingLog) Query(q ForwardingEventQuery) (ForwardingLogTimeSlice, error) {
resp := ForwardingLogTimeSlice{ var resp ForwardingLogTimeSlice
ForwardingEventQuery: q,
}
// If the user provided an index offset, then we'll not know how many // If the user provided an index offset, then we'll not know how many
// records we need to skip. We'll also keep track of the record offset // records we need to skip. We'll also keep track of the record offset
@ -297,6 +295,10 @@ func (f *ForwardingLog) Query(q ForwardingEventQuery) (ForwardingLogTimeSlice, e
} }
return nil return nil
}, func() {
resp = ForwardingLogTimeSlice{
ForwardingEventQuery: q,
}
}) })
if err != nil && err != ErrNoForwardingEvents { if err != nil && err != ErrNoForwardingEvents {
return ForwardingLogTimeSlice{}, err return ForwardingLogTimeSlice{}, err

@ -209,7 +209,7 @@ func TestPackagerEmptyFwdPkg(t *testing.T) {
if err := kvdb.Update(db, func(tx kvdb.RwTx) error { if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.AddFwdPkg(tx, fwdPkg) return packager.AddFwdPkg(tx, fwdPkg)
}); err != nil { }, func() {}); err != nil {
t.Fatalf("unable to add fwd pkg: %v", err) t.Fatalf("unable to add fwd pkg: %v", err)
} }
@ -228,7 +228,7 @@ func TestPackagerEmptyFwdPkg(t *testing.T) {
// fwd filter. // fwd filter.
if err := kvdb.Update(db, func(tx kvdb.RwTx) error { if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
}); err != nil { }, func() {}); err != nil {
t.Fatalf("unable to set fwdfiter: %v", err) t.Fatalf("unable to set fwdfiter: %v", err)
} }
@ -246,7 +246,7 @@ func TestPackagerEmptyFwdPkg(t *testing.T) {
// Lastly, remove the completed forwarding package from disk. // Lastly, remove the completed forwarding package from disk.
if err := kvdb.Update(db, func(tx kvdb.RwTx) error { if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.RemovePkg(tx, fwdPkg.Height) return packager.RemovePkg(tx, fwdPkg.Height)
}); err != nil { }, func() {}); err != nil {
t.Fatalf("unable to remove fwdpkg: %v", err) t.Fatalf("unable to remove fwdpkg: %v", err)
} }
@ -281,7 +281,7 @@ func TestPackagerOnlyAdds(t *testing.T) {
if err := kvdb.Update(db, func(tx kvdb.RwTx) error { if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.AddFwdPkg(tx, fwdPkg) return packager.AddFwdPkg(tx, fwdPkg)
}); err != nil { }, func() {}); err != nil {
t.Fatalf("unable to add fwd pkg: %v", err) t.Fatalf("unable to add fwd pkg: %v", err)
} }
@ -302,7 +302,7 @@ func TestPackagerOnlyAdds(t *testing.T) {
// was failed locally. // was failed locally.
if err := kvdb.Update(db, func(tx kvdb.RwTx) error { if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
}); err != nil { }, func() {}); err != nil {
t.Fatalf("unable to set fwdfiter: %v", err) t.Fatalf("unable to set fwdfiter: %v", err)
} }
@ -326,7 +326,7 @@ func TestPackagerOnlyAdds(t *testing.T) {
if err := kvdb.Update(db, func(tx kvdb.RwTx) error { if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.AckAddHtlcs(tx, addRef) return packager.AckAddHtlcs(tx, addRef)
}); err != nil { }, func() {}); err != nil {
t.Fatalf("unable to ack add htlc: %v", err) t.Fatalf("unable to ack add htlc: %v", err)
} }
} }
@ -345,7 +345,7 @@ func TestPackagerOnlyAdds(t *testing.T) {
// Lastly, remove the completed forwarding package from disk. // Lastly, remove the completed forwarding package from disk.
if err := kvdb.Update(db, func(tx kvdb.RwTx) error { if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.RemovePkg(tx, fwdPkg.Height) return packager.RemovePkg(tx, fwdPkg.Height)
}); err != nil { }, func() {}); err != nil {
t.Fatalf("unable to remove fwdpkg: %v", err) t.Fatalf("unable to remove fwdpkg: %v", err)
} }
@ -383,7 +383,7 @@ func TestPackagerOnlySettleFails(t *testing.T) {
if err := kvdb.Update(db, func(tx kvdb.RwTx) error { if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.AddFwdPkg(tx, fwdPkg) return packager.AddFwdPkg(tx, fwdPkg)
}); err != nil { }, func() {}); err != nil {
t.Fatalf("unable to add fwd pkg: %v", err) t.Fatalf("unable to add fwd pkg: %v", err)
} }
@ -404,7 +404,7 @@ func TestPackagerOnlySettleFails(t *testing.T) {
// was failed locally. // was failed locally.
if err := kvdb.Update(db, func(tx kvdb.RwTx) error { if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
}); err != nil { }, func() {}); err != nil {
t.Fatalf("unable to set fwdfiter: %v", err) t.Fatalf("unable to set fwdfiter: %v", err)
} }
@ -430,7 +430,7 @@ func TestPackagerOnlySettleFails(t *testing.T) {
if err := kvdb.Update(db, func(tx kvdb.RwTx) error { if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.AckSettleFails(tx, failSettleRef) return packager.AckSettleFails(tx, failSettleRef)
}); err != nil { }, func() {}); err != nil {
t.Fatalf("unable to ack add htlc: %v", err) t.Fatalf("unable to ack add htlc: %v", err)
} }
} }
@ -450,7 +450,7 @@ func TestPackagerOnlySettleFails(t *testing.T) {
// Lastly, remove the completed forwarding package from disk. // Lastly, remove the completed forwarding package from disk.
if err := kvdb.Update(db, func(tx kvdb.RwTx) error { if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.RemovePkg(tx, fwdPkg.Height) return packager.RemovePkg(tx, fwdPkg.Height)
}); err != nil { }, func() {}); err != nil {
t.Fatalf("unable to remove fwdpkg: %v", err) t.Fatalf("unable to remove fwdpkg: %v", err)
} }
@ -488,7 +488,7 @@ func TestPackagerAddsThenSettleFails(t *testing.T) {
if err := kvdb.Update(db, func(tx kvdb.RwTx) error { if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.AddFwdPkg(tx, fwdPkg) return packager.AddFwdPkg(tx, fwdPkg)
}); err != nil { }, func() {}); err != nil {
t.Fatalf("unable to add fwd pkg: %v", err) t.Fatalf("unable to add fwd pkg: %v", err)
} }
@ -509,7 +509,7 @@ func TestPackagerAddsThenSettleFails(t *testing.T) {
// was failed locally. // was failed locally.
if err := kvdb.Update(db, func(tx kvdb.RwTx) error { if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
}); err != nil { }, func() {}); err != nil {
t.Fatalf("unable to set fwdfiter: %v", err) t.Fatalf("unable to set fwdfiter: %v", err)
} }
@ -534,7 +534,7 @@ func TestPackagerAddsThenSettleFails(t *testing.T) {
if err := kvdb.Update(db, func(tx kvdb.RwTx) error { if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.AckAddHtlcs(tx, addRef) return packager.AckAddHtlcs(tx, addRef)
}); err != nil { }, func() {}); err != nil {
t.Fatalf("unable to ack add htlc: %v", err) t.Fatalf("unable to ack add htlc: %v", err)
} }
} }
@ -561,7 +561,7 @@ func TestPackagerAddsThenSettleFails(t *testing.T) {
if err := kvdb.Update(db, func(tx kvdb.RwTx) error { if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.AckSettleFails(tx, failSettleRef) return packager.AckSettleFails(tx, failSettleRef)
}); err != nil { }, func() {}); err != nil {
t.Fatalf("unable to remove settle/fail htlc: %v", err) t.Fatalf("unable to remove settle/fail htlc: %v", err)
} }
} }
@ -581,7 +581,7 @@ func TestPackagerAddsThenSettleFails(t *testing.T) {
// Lastly, remove the completed forwarding package from disk. // Lastly, remove the completed forwarding package from disk.
if err := kvdb.Update(db, func(tx kvdb.RwTx) error { if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.RemovePkg(tx, fwdPkg.Height) return packager.RemovePkg(tx, fwdPkg.Height)
}); err != nil { }, func() {}); err != nil {
t.Fatalf("unable to remove fwdpkg: %v", err) t.Fatalf("unable to remove fwdpkg: %v", err)
} }
@ -621,7 +621,7 @@ func TestPackagerSettleFailsThenAdds(t *testing.T) {
if err := kvdb.Update(db, func(tx kvdb.RwTx) error { if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.AddFwdPkg(tx, fwdPkg) return packager.AddFwdPkg(tx, fwdPkg)
}); err != nil { }, func() {}); err != nil {
t.Fatalf("unable to add fwd pkg: %v", err) t.Fatalf("unable to add fwd pkg: %v", err)
} }
@ -642,7 +642,7 @@ func TestPackagerSettleFailsThenAdds(t *testing.T) {
// was failed locally. // was failed locally.
if err := kvdb.Update(db, func(tx kvdb.RwTx) error { if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
}); err != nil { }, func() {}); err != nil {
t.Fatalf("unable to set fwdfiter: %v", err) t.Fatalf("unable to set fwdfiter: %v", err)
} }
@ -671,7 +671,7 @@ func TestPackagerSettleFailsThenAdds(t *testing.T) {
if err := kvdb.Update(db, func(tx kvdb.RwTx) error { if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.AckSettleFails(tx, failSettleRef) return packager.AckSettleFails(tx, failSettleRef)
}); err != nil { }, func() {}); err != nil {
t.Fatalf("unable to remove settle/fail htlc: %v", err) t.Fatalf("unable to remove settle/fail htlc: %v", err)
} }
} }
@ -698,7 +698,7 @@ func TestPackagerSettleFailsThenAdds(t *testing.T) {
if err := kvdb.Update(db, func(tx kvdb.RwTx) error { if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.AckAddHtlcs(tx, addRef) return packager.AckAddHtlcs(tx, addRef)
}); err != nil { }, func() {}); err != nil {
t.Fatalf("unable to ack add htlc: %v", err) t.Fatalf("unable to ack add htlc: %v", err)
} }
} }
@ -718,7 +718,7 @@ func TestPackagerSettleFailsThenAdds(t *testing.T) {
// Lastly, remove the completed forwarding package from disk. // Lastly, remove the completed forwarding package from disk.
if err := kvdb.Update(db, func(tx kvdb.RwTx) error { if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.RemovePkg(tx, fwdPkg.Height) return packager.RemovePkg(tx, fwdPkg.Height)
}); err != nil { }, func() {}); err != nil {
t.Fatalf("unable to remove fwdpkg: %v", err) t.Fatalf("unable to remove fwdpkg: %v", err)
} }
@ -786,6 +786,8 @@ func loadFwdPkgs(t *testing.T, db kvdb.Backend,
var err error var err error
fwdPkgs, err = packager.LoadFwdPkgs(tx) fwdPkgs, err = packager.LoadFwdPkgs(tx)
return err return err
}, func() {
fwdPkgs = nil
}); err != nil { }); err != nil {
t.Fatalf("unable to load fwd pkgs: %v", err) t.Fatalf("unable to load fwd pkgs: %v", err)
} }

@ -249,7 +249,7 @@ func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, *ChannelEdgePoli
// be aborted. // be aborted.
return cb(&edgeInfo, edge1, edge2) return cb(&edgeInfo, edge1, edge2)
}) })
}) }, func() {})
} }
// ForEachNodeChannel iterates through all channels of a given node, executing the // ForEachNodeChannel iterates through all channels of a given node, executing the
@ -279,7 +279,7 @@ func (c *ChannelGraph) ForEachNodeChannel(tx kvdb.RTx, nodePub []byte,
// have their disabled bit on. // have their disabled bit on.
func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) { func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) {
var disabledChanIDs []uint64 var disabledChanIDs []uint64
chanEdgeFound := make(map[uint64]struct{}) var chanEdgeFound map[uint64]struct{}
err := kvdb.View(c.db, func(tx kvdb.RTx) error { err := kvdb.View(c.db, func(tx kvdb.RTx) error {
edges := tx.ReadBucket(edgeBucket) edges := tx.ReadBucket(edgeBucket)
@ -308,6 +308,9 @@ func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) {
chanEdgeFound[chanID] = struct{}{} chanEdgeFound[chanID] = struct{}{}
return nil return nil
}) })
}, func() {
disabledChanIDs = nil
chanEdgeFound = make(map[uint64]struct{})
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -353,7 +356,7 @@ func (c *ChannelGraph) ForEachNode(cb func(kvdb.RTx, *LightningNode) error) erro
}) })
} }
return kvdb.View(c.db, traversal) return kvdb.View(c.db, traversal, func() {})
} }
// SourceNode returns the source node of the graph. The source node is treated // SourceNode returns the source node of the graph. The source node is treated
@ -377,6 +380,8 @@ func (c *ChannelGraph) SourceNode() (*LightningNode, error) {
source = node source = node
return nil return nil
}, func() {
source = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -429,7 +434,7 @@ func (c *ChannelGraph) SetSourceNode(node *LightningNode) error {
// Finally, we commit the information of the lightning node // Finally, we commit the information of the lightning node
// itself. // itself.
return addLightningNode(tx, node) return addLightningNode(tx, node)
}) }, func() {})
} }
// AddLightningNode adds a vertex/node to the graph database. If the node is not // AddLightningNode adds a vertex/node to the graph database. If the node is not
@ -443,7 +448,7 @@ func (c *ChannelGraph) SetSourceNode(node *LightningNode) error {
func (c *ChannelGraph) AddLightningNode(node *LightningNode) error { func (c *ChannelGraph) AddLightningNode(node *LightningNode) error {
return kvdb.Update(c.db, func(tx kvdb.RwTx) error { return kvdb.Update(c.db, func(tx kvdb.RwTx) error {
return addLightningNode(tx, node) return addLightningNode(tx, node)
}) }, func() {})
} }
func addLightningNode(tx kvdb.RwTx, node *LightningNode) error { func addLightningNode(tx kvdb.RwTx, node *LightningNode) error {
@ -493,6 +498,8 @@ func (c *ChannelGraph) LookupAlias(pub *btcec.PublicKey) (string, error) {
// package... // package...
alias = string(a) alias = string(a)
return nil return nil
}, func() {
alias = ""
}) })
if err != nil { if err != nil {
return "", err return "", err
@ -512,7 +519,7 @@ func (c *ChannelGraph) DeleteLightningNode(nodePub route.Vertex) error {
} }
return c.deleteLightningNode(nodes, nodePub[:]) return c.deleteLightningNode(nodes, nodePub[:])
}) }, func() {})
} }
// deleteLightningNode uses an existing database transaction to remove a // deleteLightningNode uses an existing database transaction to remove a
@ -572,7 +579,7 @@ func (c *ChannelGraph) AddChannelEdge(edge *ChannelEdgeInfo) error {
err := kvdb.Update(c.db, func(tx kvdb.RwTx) error { err := kvdb.Update(c.db, func(tx kvdb.RwTx) error {
return c.addChannelEdge(tx, edge) return c.addChannelEdge(tx, edge)
}) }, func() {})
if err != nil { if err != nil {
return err return err
} }
@ -774,7 +781,7 @@ func (c *ChannelGraph) HasChannelEdge(
} }
return nil return nil
}); err != nil { }, func() {}); err != nil {
return time.Time{}, time.Time{}, exists, isZombie, err return time.Time{}, time.Time{}, exists, isZombie, err
} }
@ -813,7 +820,7 @@ func (c *ChannelGraph) UpdateChannelEdge(edge *ChannelEdgeInfo) error {
} }
return putChanEdgeInfo(edgeIndex, edge, chanKey) return putChanEdgeInfo(edgeIndex, edge, chanKey)
}) }, func() {})
} }
const ( const (
@ -936,6 +943,8 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint,
// prune any nodes that have had a channel closed within the // prune any nodes that have had a channel closed within the
// latest block. // latest block.
return c.pruneGraphNodes(nodes, edgeIndex) return c.pruneGraphNodes(nodes, edgeIndex)
}, func() {
chansClosed = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -969,7 +978,7 @@ func (c *ChannelGraph) PruneGraphNodes() error {
} }
return c.pruneGraphNodes(nodes, edgeIndex) return c.pruneGraphNodes(nodes, edgeIndex)
}) }, func() {})
} }
// pruneGraphNodes attempts to remove any nodes from the graph who have had a // pruneGraphNodes attempts to remove any nodes from the graph who have had a
@ -1188,6 +1197,8 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*ChannelEdgeInf
} }
return nil return nil
}, func() {
removedChans = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
@ -1235,7 +1246,7 @@ func (c *ChannelGraph) PruneTip() (*chainhash.Hash, uint32, error) {
tipHeight = byteOrder.Uint32(k[:]) tipHeight = byteOrder.Uint32(k[:])
return nil return nil
}) }, func() {})
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
@ -1290,7 +1301,7 @@ func (c *ChannelGraph) DeleteChannelEdges(chanIDs ...uint64) error {
} }
return nil return nil
}) }, func() {})
if err != nil { if err != nil {
return err return err
} }
@ -1312,6 +1323,8 @@ func (c *ChannelGraph) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
var err error var err error
chanID, err = getChanID(tx, chanPoint) chanID, err = getChanID(tx, chanPoint)
return err return err
}, func() {
chanID = 0
}); err != nil { }); err != nil {
return 0, err return 0, err
} }
@ -1379,6 +1392,8 @@ func (c *ChannelGraph) HighestChanID() (uint64, error) {
// to the caller. // to the caller.
cid = byteOrder.Uint64(lastChanID) cid = byteOrder.Uint64(lastChanID)
return nil return nil
}, func() {
cid = 0
}) })
if err != nil && err != ErrGraphNoEdgesFound { if err != nil && err != ErrGraphNoEdgesFound {
return 0, err return 0, err
@ -1409,8 +1424,8 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]Cha
// To ensure we don't return duplicate ChannelEdges, we'll use an // To ensure we don't return duplicate ChannelEdges, we'll use an
// additional map to keep track of the edges already seen to prevent // additional map to keep track of the edges already seen to prevent
// re-adding it. // re-adding it.
edgesSeen := make(map[uint64]struct{}) var edgesSeen map[uint64]struct{}
edgesToCache := make(map[uint64]ChannelEdge) var edgesToCache map[uint64]ChannelEdge
var edgesInHorizon []ChannelEdge var edgesInHorizon []ChannelEdge
c.cacheMu.Lock() c.cacheMu.Lock()
@ -1507,6 +1522,10 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]Cha
} }
return nil return nil
}, func() {
edgesSeen = make(map[uint64]struct{})
edgesToCache = make(map[uint64]ChannelEdge)
edgesInHorizon = nil
}) })
switch { switch {
case err == ErrGraphNoEdgesFound: case err == ErrGraphNoEdgesFound:
@ -1577,6 +1596,8 @@ func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, endTime time.Time) ([]Lig
} }
return nil return nil
}, func() {
nodesInHorizon = nil
}) })
switch { switch {
case err == ErrGraphNoEdgesFound: case err == ErrGraphNoEdgesFound:
@ -1637,6 +1658,8 @@ func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) {
} }
return nil return nil
}, func() {
newChanIDs = nil
}) })
switch { switch {
// If we don't know of any edges yet, then we'll return the entire set // If we don't know of any edges yet, then we'll return the entire set
@ -1701,7 +1724,10 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, endHeight uint32) ([]uint
} }
return nil return nil
}, func() {
chanIDs = nil
}) })
switch { switch {
// If we don't know of any channels yet, then there's nothing to // If we don't know of any channels yet, then there's nothing to
// filter, so we'll return an empty slice. // filter, so we'll return an empty slice.
@ -1775,6 +1801,8 @@ func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
}) })
} }
return nil return nil
}, func() {
chanEdges = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -1912,6 +1940,8 @@ func (c *ChannelGraph) UpdateEdgePolicy(edge *ChannelEdgePolicy) error {
var err error var err error
isUpdate1, err = updateEdgePolicy(tx, edge) isUpdate1, err = updateEdgePolicy(tx, edge)
return err return err
}, func() {
isUpdate1 = false
}) })
if err != nil { if err != nil {
return err return err
@ -2209,7 +2239,7 @@ func (c *ChannelGraph) FetchLightningNode(tx kvdb.RTx, nodePub route.Vertex) (
var err error var err error
if tx == nil { if tx == nil {
err = kvdb.View(c.db, fetchNode) err = kvdb.View(c.db, fetchNode, func() {})
} else { } else {
err = fetchNode(tx) err = fetchNode(tx)
} }
@ -2259,6 +2289,9 @@ func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, erro
exists = true exists = true
updateTime = node.LastUpdate updateTime = node.LastUpdate
return nil return nil
}, func() {
updateTime = time.Time{}
exists = false
}) })
if err != nil { if err != nil {
return time.Time{}, exists, err return time.Time{}, exists, err
@ -2346,7 +2379,7 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db *DB,
// If no transaction was provided, then we'll create a new transaction // If no transaction was provided, then we'll create a new transaction
// to execute the transaction within. // to execute the transaction within.
if tx == nil { if tx == nil {
return kvdb.View(db, traversal) return kvdb.View(db, traversal, func() {})
} }
// Otherwise, we re-use the existing transaction to execute the graph // Otherwise, we re-use the existing transaction to execute the graph
@ -2596,7 +2629,7 @@ func (c *ChannelEdgeInfo) FetchOtherNode(tx kvdb.RTx, thisNodeKey []byte) (*Ligh
// otherwise we can use the existing db transaction. // otherwise we can use the existing db transaction.
var err error var err error
if tx == nil { if tx == nil {
err = kvdb.View(c.db, fetchNodeFunc) err = kvdb.View(c.db, fetchNodeFunc, func() { targetNode = nil })
} else { } else {
err = fetchNodeFunc(tx) err = fetchNodeFunc(tx)
} }
@ -2929,6 +2962,10 @@ func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint,
policy1 = e1 policy1 = e1
policy2 = e2 policy2 = e2
return nil return nil
}, func() {
edgeInfo = nil
policy1 = nil
policy2 = nil
}) })
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
@ -3030,6 +3067,10 @@ func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64,
policy1 = e1 policy1 = e1
policy2 = e2 policy2 = e2
return nil return nil
}, func() {
edgeInfo = nil
policy1 = nil
policy2 = nil
}) })
if err == ErrZombieEdge { if err == ErrZombieEdge {
return edgeInfo, nil, nil, err return edgeInfo, nil, nil, err
@ -3062,6 +3103,8 @@ func (c *ChannelGraph) IsPublicNode(pubKey [33]byte) (bool, error) {
nodeIsPublic, err = node.isPublic(tx, ourPubKey) nodeIsPublic, err = node.isPublic(tx, ourPubKey)
return err return err
}, func() {
nodeIsPublic = false
}) })
if err != nil { if err != nil {
return false, err return false, err
@ -3183,6 +3226,8 @@ func (c *ChannelGraph) ChannelView() ([]EdgePoint, error) {
return nil return nil
}) })
}, func() {
edgePoints = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
@ -3229,7 +3274,7 @@ func (c *ChannelGraph) MarkEdgeLive(chanID uint64) error {
var k [8]byte var k [8]byte
byteOrder.PutUint64(k[:], chanID) byteOrder.PutUint64(k[:], chanID)
return zombieIndex.Delete(k[:]) return zombieIndex.Delete(k[:])
}) }, func() {})
if err != nil { if err != nil {
return err return err
} }
@ -3261,6 +3306,10 @@ func (c *ChannelGraph) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte) {
isZombie, pubKey1, pubKey2 = isZombieEdge(zombieIndex, chanID) isZombie, pubKey1, pubKey2 = isZombieEdge(zombieIndex, chanID)
return nil return nil
}, func() {
isZombie = false
pubKey1 = [33]byte{}
pubKey2 = [33]byte{}
}) })
if err != nil { if err != nil {
return false, [33]byte{}, [33]byte{} return false, [33]byte{}, [33]byte{}
@ -3307,6 +3356,8 @@ func (c *ChannelGraph) NumZombies() (uint64, error) {
numZombies++ numZombies++
return nil return nil
}) })
}, func() {
numZombies = 0
}) })
if err != nil { if err != nil {
return 0, err return 0, err

@ -2272,7 +2272,7 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) {
return nil return nil
}) })
}) }, func() {})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -2858,7 +2858,7 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) {
} }
return nil return nil
}) }, func() {})
if err != nil { if err != nil {
t.Fatalf("error reading db: %v", err) t.Fatalf("error reading db: %v", err)
} }
@ -2894,7 +2894,7 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) {
} }
return edges.Put(edgeKey[:], stripped) return edges.Put(edgeKey[:], stripped)
}) }, func() {})
if err != nil { if err != nil {
t.Fatalf("error writing db: %v", err) t.Fatalf("error writing db: %v", err)
} }

@ -562,6 +562,8 @@ func (d *DB) AddInvoice(newInvoice *Invoice, paymentHash lntypes.Hash) (
invoiceAddIndex = newIndex invoiceAddIndex = newIndex
return nil return nil
}, func() {
invoiceAddIndex = 0
}) })
if err != nil { if err != nil {
return 0, err return 0, err
@ -624,6 +626,8 @@ func (d *DB) InvoicesAddedSince(sinceAddIndex uint64) ([]Invoice, error) {
} }
return nil return nil
}, func() {
newInvoices = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -669,7 +673,7 @@ func (d *DB) LookupInvoice(ref InvoiceRef) (Invoice, error) {
invoice = i invoice = i
return nil return nil
}) }, func() {})
if err != nil { if err != nil {
return invoice, err return invoice, err
} }
@ -731,13 +735,6 @@ func (d *DB) ScanInvoices(
scanFunc func(lntypes.Hash, *Invoice) error, reset func()) error { scanFunc func(lntypes.Hash, *Invoice) error, reset func()) error {
return kvdb.View(d, func(tx kvdb.RTx) error { return kvdb.View(d, func(tx kvdb.RTx) error {
// Reset partial results. As transaction commit success is not
// guaranteed when using etcd, we need to be prepared to redo
// the whole view transaction. In order to be able to do that
// we need a way to reset existing results. This is also done
// upon first run for initialization.
reset()
invoices := tx.ReadBucket(invoiceBucket) invoices := tx.ReadBucket(invoiceBucket)
if invoices == nil { if invoices == nil {
return ErrNoInvoicesCreated return ErrNoInvoicesCreated
@ -773,7 +770,7 @@ func (d *DB) ScanInvoices(
return scanFunc(paymentHash, &invoice) return scanFunc(paymentHash, &invoice)
}) })
}) }, reset)
} }
// InvoiceQuery represents a query to the invoice database. The query allows a // InvoiceQuery represents a query to the invoice database. The query allows a
@ -825,9 +822,7 @@ type InvoiceSlice struct {
// QueryInvoices allows a caller to query the invoice database for invoices // QueryInvoices allows a caller to query the invoice database for invoices
// within the specified add index range. // within the specified add index range.
func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) { func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) {
resp := InvoiceSlice{ var resp InvoiceSlice
InvoiceQuery: q,
}
err := kvdb.View(d, func(tx kvdb.RTx) error { err := kvdb.View(d, func(tx kvdb.RTx) error {
// If the bucket wasn't found, then there aren't any invoices // If the bucket wasn't found, then there aren't any invoices
@ -892,6 +887,10 @@ func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) {
} }
return nil return nil
}, func() {
resp = InvoiceSlice{
InvoiceQuery: q,
}
}) })
if err != nil && err != ErrNoInvoicesCreated { if err != nil && err != ErrNoInvoicesCreated {
return resp, err return resp, err
@ -953,6 +952,8 @@ func (d *DB) UpdateInvoice(ref InvoiceRef,
) )
return err return err
}, func() {
updatedInvoice = nil
}) })
return updatedInvoice, err return updatedInvoice, err
@ -1011,6 +1012,8 @@ func (d *DB) InvoicesSettledSince(sinceSettleIndex uint64) ([]Invoice, error) {
} }
return nil return nil
}, func() {
settledInvoices = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -1867,7 +1870,7 @@ func (d *DB) DeleteInvoice(invoicesToDelete []InvoiceDeleteRef) error {
} }
return nil return nil
}) }, func() {})
return err return err
} }

@ -218,11 +218,15 @@ func (db *db) getSTMOptions() []STMOptionFunc {
} }
// View opens a database read transaction and executes the function f with the // View opens a database read transaction and executes the function f with the
// transaction passed as a parameter. After f exits, the transaction is rolled // transaction passed as a parameter. After f exits, the transaction is rolled
// back. If f errors, its error is returned, not a rollback error (if any // back. If f errors, its error is returned, not a rollback error (if any
// occur). // occur). The passed reset function is called before the start of the
func (db *db) View(f func(tx walletdb.ReadTx) error) error { // transaction and can be used to reset intermediate state. As callers may
// expect retries of the f closure (depending on the database backend used), the
// reset function will be called before each retry respectively.
func (db *db) View(f func(tx walletdb.ReadTx) error, reset func()) error {
apply := func(stm STM) error { apply := func(stm STM) error {
reset()
return f(newReadWriteTx(stm, db.config.Prefix)) return f(newReadWriteTx(stm, db.config.Prefix))
} }
@ -230,13 +234,15 @@ func (db *db) View(f func(tx walletdb.ReadTx) error) error {
} }
// Update opens a database read/write transaction and executes the function f // Update opens a database read/write transaction and executes the function f
// with the transaction passed as a parameter. After f exits, if f did not // with the transaction passed as a parameter. After f exits, if f did not
// error, the transaction is committed. Otherwise, if f did error, the // error, the transaction is committed. Otherwise, if f did error, the
// transaction is rolled back. If the rollback fails, the original error // transaction is rolled back. If the rollback fails, the original error
// returned by f is still returned. If the commit fails, the commit error is // returned by f is still returned. If the commit fails, the commit error is
// returned. // returned. As callers may expect retries of the f closure, the reset function
func (db *db) Update(f func(tx walletdb.ReadWriteTx) error) error { // will be called before each retry respectively.
func (db *db) Update(f func(tx walletdb.ReadWriteTx) error, reset func()) error {
apply := func(stm STM) error { apply := func(stm STM) error {
reset()
return f(newReadWriteTx(stm, db.config.Prefix)) return f(newReadWriteTx(stm, db.config.Prefix))
} }
@ -300,5 +306,5 @@ func (db *db) Close() error {
// //
// Batch is only useful when there are multiple goroutines calling it. // Batch is only useful when there are multiple goroutines calling it.
func (db *db) Batch(apply func(tx walletdb.ReadWriteTx) error) error { func (db *db) Batch(apply func(tx walletdb.ReadWriteTx) error) error {
return db.Update(apply) return db.Update(apply, func() {})
} }

@ -28,7 +28,7 @@ func TestCopy(t *testing.T) {
require.NoError(t, apple.Put([]byte("key"), []byte("val"))) require.NoError(t, apple.Put([]byte("key"), []byte("val")))
return nil return nil
}) }, func() {})
// Expect non-zero copy. // Expect non-zero copy.
var buf bytes.Buffer var buf bytes.Buffer
@ -66,7 +66,7 @@ func TestAbortContext(t *testing.T) {
require.Error(t, err, "context canceled") require.Error(t, err, "context canceled")
return nil return nil
}) }, func() {})
require.Error(t, err, "context canceled") require.Error(t, err, "context canceled")

@ -290,6 +290,9 @@ func (b *readWriteBucket) Put(key, value []byte) error {
// Delete deletes the key/value pointed to by the passed key. // Delete deletes the key/value pointed to by the passed key.
// Returns ErrKeyRequred if the passed key is empty. // Returns ErrKeyRequred if the passed key is empty.
func (b *readWriteBucket) Delete(key []byte) error { func (b *readWriteBucket) Delete(key []byte) error {
if key == nil {
return nil
}
if len(key) == 0 { if len(key) == 0 {
return walletdb.ErrKeyRequired return walletdb.ErrKeyRequired
} }

@ -79,7 +79,7 @@ func TestBucketCreation(t *testing.T) {
require.NotNil(t, apple.NestedReadWriteBucket([]byte("banana"))) require.NotNil(t, apple.NestedReadWriteBucket([]byte("banana")))
require.NotNil(t, apple.NestedReadBucket([]byte("banana"))) require.NotNil(t, apple.NestedReadBucket([]byte("banana")))
return nil return nil
}) }, func() {})
require.Nil(t, err) require.Nil(t, err)
@ -189,7 +189,7 @@ func TestBucketDeletion(t *testing.T) {
// "aple/banana" exists // "aple/banana" exists
require.NotNil(t, apple.NestedReadWriteBucket([]byte("banana"))) require.NotNil(t, apple.NestedReadWriteBucket([]byte("banana")))
return nil return nil
}) }, func() {})
require.Nil(t, err) require.Nil(t, err)
@ -261,7 +261,7 @@ func TestBucketForEach(t *testing.T) {
require.Equal(t, expected, got) require.Equal(t, expected, got)
return nil return nil
}) }, func() {})
require.Nil(t, err) require.Nil(t, err)
@ -354,7 +354,7 @@ func TestBucketForEachWithError(t *testing.T) {
require.Equal(t, expected, got) require.Equal(t, expected, got)
require.Error(t, err) require.Error(t, err)
return nil return nil
}) }, func() {})
require.Nil(t, err) require.Nil(t, err)
@ -399,7 +399,7 @@ func TestBucketSequence(t *testing.T) {
} }
return nil return nil
}) }, func() {})
require.Nil(t, err) require.Nil(t, err)
} }
@ -431,7 +431,7 @@ func TestKeyClash(t *testing.T) {
require.NotNil(t, banana) require.NotNil(t, banana)
return nil return nil
}) }, func() {})
require.Nil(t, err) require.Nil(t, err)
@ -457,7 +457,7 @@ func TestKeyClash(t *testing.T) {
require.Error(t, walletdb.ErrIncompatibleValue, b) require.Error(t, walletdb.ErrIncompatibleValue, b)
return nil return nil
}) }, func() {})
require.Nil(t, err) require.Nil(t, err)
@ -494,7 +494,7 @@ func TestBucketCreateDelete(t *testing.T) {
require.NotNil(t, banana) require.NotNil(t, banana)
return nil return nil
}) }, func() {})
require.NoError(t, err) require.NoError(t, err)
err = db.Update(func(tx walletdb.ReadWriteTx) error { err = db.Update(func(tx walletdb.ReadWriteTx) error {
@ -503,7 +503,7 @@ func TestBucketCreateDelete(t *testing.T) {
require.NoError(t, apple.DeleteNestedBucket([]byte("banana"))) require.NoError(t, apple.DeleteNestedBucket([]byte("banana")))
return nil return nil
}) }, func() {})
require.NoError(t, err) require.NoError(t, err)
err = db.Update(func(tx walletdb.ReadWriteTx) error { err = db.Update(func(tx walletdb.ReadWriteTx) error {
@ -512,7 +512,7 @@ func TestBucketCreateDelete(t *testing.T) {
require.NoError(t, apple.Put([]byte("banana"), []byte("value"))) require.NoError(t, apple.Put([]byte("banana"), []byte("value")))
return nil return nil
}) }, func() {})
require.NoError(t, err) require.NoError(t, err)
expected := map[string]string{ expected := map[string]string{

@ -24,7 +24,7 @@ func TestReadCursorEmptyInterval(t *testing.T) {
require.NotNil(t, b) require.NotNil(t, b)
return nil return nil
}) }, func() {})
require.NoError(t, err) require.NoError(t, err)
err = db.View(func(tx walletdb.ReadTx) error { err = db.View(func(tx walletdb.ReadTx) error {
@ -49,7 +49,7 @@ func TestReadCursorEmptyInterval(t *testing.T) {
require.Nil(t, v) require.Nil(t, v)
return nil return nil
}) }, func() {})
require.NoError(t, err) require.NoError(t, err)
} }
@ -78,7 +78,7 @@ func TestReadCursorNonEmptyInterval(t *testing.T) {
require.NoError(t, b.Put([]byte(kv.key), []byte(kv.val))) require.NoError(t, b.Put([]byte(kv.key), []byte(kv.val)))
} }
return nil return nil
}) }, func() {})
require.NoError(t, err) require.NoError(t, err)
@ -125,7 +125,7 @@ func TestReadCursorNonEmptyInterval(t *testing.T) {
require.Nil(t, v) require.Nil(t, v)
return nil return nil
}) }, func() {})
require.NoError(t, err) require.NoError(t, err)
} }
@ -162,7 +162,7 @@ func TestReadWriteCursor(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
return nil return nil
})) }, func() {}))
err = db.Update(func(tx walletdb.ReadWriteTx) error { err = db.Update(func(tx walletdb.ReadWriteTx) error {
b := tx.ReadWriteBucket([]byte("apple")) b := tx.ReadWriteBucket([]byte("apple"))
@ -276,7 +276,7 @@ func TestReadWriteCursor(t *testing.T) {
require.Equal(t, reverseKVs(expected), kvs) require.Equal(t, reverseKVs(expected), kvs)
return nil return nil
}) }, func() {})
require.NoError(t, err) require.NoError(t, err)
@ -320,7 +320,7 @@ func TestReadWriteCursorWithBucketAndValue(t *testing.T) {
require.NotNil(t, b2) require.NotNil(t, b2)
return nil return nil
})) }, func() {}))
err = db.View(func(tx walletdb.ReadTx) error { err = db.View(func(tx walletdb.ReadTx) error {
b := tx.ReadBucket([]byte("apple")) b := tx.ReadBucket([]byte("apple"))
@ -354,7 +354,7 @@ func TestReadWriteCursorWithBucketAndValue(t *testing.T) {
require.Equal(t, []byte("val"), v) require.Equal(t, []byte("val"), v)
return nil return nil
}) }, func() {})
require.NoError(t, err) require.NoError(t, err)

@ -142,7 +142,7 @@ func TestChangeDuringUpdate(t *testing.T) {
count++ count++
return nil return nil
}) }, func() {})
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, count, 2) require.Equal(t, count, 2)

@ -10,23 +10,34 @@ import (
// error, the transaction is committed. Otherwise, if f did error, the // error, the transaction is committed. Otherwise, if f did error, the
// transaction is rolled back. If the rollback fails, the original error // transaction is rolled back. If the rollback fails, the original error
// returned by f is still returned. If the commit fails, the commit error is // returned by f is still returned. If the commit fails, the commit error is
// returned. // returned. As callers may expect retries of the f closure (depending on the
func Update(db Backend, f func(tx RwTx) error) error { // database backend used), the reset function will be called before each retry
// respectively.
func Update(db Backend, f func(tx RwTx) error, reset func()) error {
if extendedDB, ok := db.(ExtendedBackend); ok { if extendedDB, ok := db.(ExtendedBackend); ok {
return extendedDB.Update(f) return extendedDB.Update(f, reset)
} }
reset()
return walletdb.Update(db, f) return walletdb.Update(db, f)
} }
// View opens a database read transaction and executes the function f with the // View opens a database read transaction and executes the function f with the
// transaction passed as a parameter. After f exits, the transaction is rolled // transaction passed as a parameter. After f exits, the transaction is rolled
// back. If f errors, its error is returned, not a rollback error (if any // back. If f errors, its error is returned, not a rollback error (if any
// occur). // occur). The passed reset function is called before the start of the
func View(db Backend, f func(tx RTx) error) error { // transaction and can be used to reset intermediate state. As callers may
// expect retries of the f closure (depending on the database backend used), the
// reset function will be called before each retry respectively.
func View(db Backend, f func(tx RTx) error, reset func()) error {
if extendedDB, ok := db.(ExtendedBackend); ok { if extendedDB, ok := db.(ExtendedBackend); ok {
return extendedDB.View(f) return extendedDB.View(f, reset)
} }
// Since we know that walletdb simply calls into bbolt which never
// retries transactions, we'll call the reset function here before View.
reset()
return walletdb.View(db, f) return walletdb.View(db, f)
} }
@ -55,19 +66,25 @@ type ExtendedBackend interface {
// PrintStats returns all collected stats pretty printed into a string. // PrintStats returns all collected stats pretty printed into a string.
PrintStats() string PrintStats() string
// View opens a database read transaction and executes the function f with // View opens a database read transaction and executes the function f
// the transaction passed as a parameter. After f exits, the transaction is // with the transaction passed as a parameter. After f exits, the
// rolled back. If f errors, its error is returned, not a rollback error // transaction is rolled back. If f errors, its error is returned, not a
// (if any occur). // rollback error (if any occur). The passed reset function is called
View(f func(tx walletdb.ReadTx) error) error // before the start of the transaction and can be used to reset
// intermediate state. As callers may expect retries of the f closure
// (depending on the database backend used), the reset function will be
//called before each retry respectively.
View(f func(tx walletdb.ReadTx) error, reset func()) error
// Update opens a database read/write transaction and executes the function // Update opens a database read/write transaction and executes the
// f with the transaction passed as a parameter. After f exits, if f did not // function f with the transaction passed as a parameter. After f exits,
// error, the transaction is committed. Otherwise, if f did error, the // if f did not error, the transaction is committed. Otherwise, if f did
// transaction is rolled back. If the rollback fails, the original error // error, the transaction is rolled back. If the rollback fails, the
// returned by f is still returned. If the commit fails, the commit error is // original error returned by f is still returned. If the commit fails,
// returned. // the commit error is returned. As callers may expect retries of the f
Update(f func(tx walletdb.ReadWriteTx) error) error // closure (depending on the database backend used), the reset function
// will be called before each retry respectively.
Update(f func(tx walletdb.ReadWriteTx) error, reset func()) error
} }
// Open opens an existing database for the specified type. The arguments are // Open opens an existing database for the specified type. The arguments are

@ -23,10 +23,12 @@ type Meta struct {
// FetchMeta fetches the meta data from boltdb and returns filled meta // FetchMeta fetches the meta data from boltdb and returns filled meta
// structure. // structure.
func (d *DB) FetchMeta(tx kvdb.RTx) (*Meta, error) { func (d *DB) FetchMeta(tx kvdb.RTx) (*Meta, error) {
meta := &Meta{} var meta *Meta
err := kvdb.View(d, func(tx kvdb.RTx) error { err := kvdb.View(d, func(tx kvdb.RTx) error {
return fetchMeta(meta, tx) return fetchMeta(meta, tx)
}, func() {
meta = &Meta{}
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -58,7 +60,7 @@ func fetchMeta(meta *Meta, tx kvdb.RTx) error {
func (d *DB) PutMeta(meta *Meta) error { func (d *DB) PutMeta(meta *Meta) error {
return kvdb.Update(d, func(tx kvdb.RwTx) error { return kvdb.Update(d, func(tx kvdb.RwTx) error {
return putMeta(meta, tx) return putMeta(meta, tx)
}) }, func() {})
} }
// putMeta is an internal helper function used in order to allow callers to // putMeta is an internal helper function used in order to allow callers to

@ -209,7 +209,7 @@ func TestMigrationWithPanic(t *testing.T) {
} }
return bucket.Put(keyPrefix, beforeMigration) return bucket.Put(keyPrefix, beforeMigration)
}) }, func() {})
if err != nil { if err != nil {
t.Fatalf("unable to insert: %v", err) t.Fatalf("unable to insert: %v", err)
} }
@ -251,7 +251,7 @@ func TestMigrationWithPanic(t *testing.T) {
} }
return nil return nil
}) }, func() {})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -283,7 +283,7 @@ func TestMigrationWithFatal(t *testing.T) {
} }
return bucket.Put(keyPrefix, beforeMigration) return bucket.Put(keyPrefix, beforeMigration)
}) }, func() {})
if err != nil { if err != nil {
t.Fatalf("unable to insert pre migration key: %v", err) t.Fatalf("unable to insert pre migration key: %v", err)
} }
@ -326,7 +326,7 @@ func TestMigrationWithFatal(t *testing.T) {
} }
return nil return nil
}) }, func() {})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -359,7 +359,7 @@ func TestMigrationWithoutErrors(t *testing.T) {
} }
return bucket.Put(keyPrefix, beforeMigration) return bucket.Put(keyPrefix, beforeMigration)
}) }, func() {})
if err != nil { if err != nil {
t.Fatalf("unable to update db pre migration: %v", err) t.Fatalf("unable to update db pre migration: %v", err)
} }
@ -401,7 +401,7 @@ func TestMigrationWithoutErrors(t *testing.T) {
} }
return nil return nil
}) }, func() {})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -448,7 +448,7 @@ func TestMigrationReversion(t *testing.T) {
} }
return putMeta(newMeta, tx) return putMeta(newMeta, tx)
}) }, func() {})
// Close the database. Even if we succeeded, our next step is to reopen. // Close the database. Even if we succeeded, our next step is to reopen.
cdb.Close() cdb.Close()
@ -492,7 +492,7 @@ func TestMigrationDryRun(t *testing.T) {
} }
return nil return nil
}) }, func() {})
if err != nil { if err != nil {
t.Fatalf("unable to apply after func: %v", err) t.Fatalf("unable to apply after func: %v", err)
} }

@ -152,7 +152,7 @@ func createChannelDB(dbPath string) error {
DbVersionNumber: 0, DbVersionNumber: 0,
} }
return putMeta(meta, tx) return putMeta(meta, tx)
}) }, func() {})
if err != nil { if err != nil {
return fmt.Errorf("unable to create new channeldb") return fmt.Errorf("unable to create new channeldb")
} }
@ -203,6 +203,8 @@ func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, erro
chanSummaries = append(chanSummaries, chanSummary) chanSummaries = append(chanSummaries, chanSummary)
return nil return nil
}) })
}, func() {
chanSummaries = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }

@ -190,6 +190,8 @@ func (c *ChannelGraph) SourceNode() (*LightningNode, error) {
source = node source = node
return nil return nil
}, func() {
source = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -242,7 +244,7 @@ func (c *ChannelGraph) SetSourceNode(node *LightningNode) error {
// Finally, we commit the information of the lightning node // Finally, we commit the information of the lightning node
// itself. // itself.
return addLightningNode(tx, node) return addLightningNode(tx, node)
}) }, func() {})
} }
func addLightningNode(tx kvdb.RwTx, node *LightningNode) error { func addLightningNode(tx kvdb.RwTx, node *LightningNode) error {

@ -282,6 +282,8 @@ func (d *DB) FetchAllInvoices(pendingOnly bool) ([]Invoice, error) {
return nil return nil
}) })
}, func() {
invoices = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -51,7 +51,7 @@ func applyMigration(t *testing.T, beforeMigration, afterMigration func(d *DB),
// Apply migration. // Apply migration.
err = kvdb.Update(cdb, func(tx kvdb.RwTx) error { err = kvdb.Update(cdb, func(tx kvdb.RwTx) error {
return migrationFunc(tx) return migrationFunc(tx)
}) }, func() {})
if err != nil { if err != nil {
log.Error(err) log.Error(err)
} }

@ -95,7 +95,7 @@ func (db *DB) addPayment(payment *outgoingPayment) error {
binary.BigEndian.PutUint64(paymentIDBytes, paymentID) binary.BigEndian.PutUint64(paymentIDBytes, paymentID)
return payments.Put(paymentIDBytes, paymentBytes) return payments.Put(paymentIDBytes, paymentBytes)
}) }, func() {})
} }
// fetchAllPayments returns all outgoing payments in DB. // fetchAllPayments returns all outgoing payments in DB.
@ -126,6 +126,8 @@ func (db *DB) fetchAllPayments() ([]*outgoingPayment, error) {
payments = append(payments, payment) payments = append(payments, payment)
return nil return nil
}) })
}, func() {
payments = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -144,6 +146,8 @@ func (db *DB) fetchPaymentStatus(paymentHash [32]byte) (PaymentStatus, error) {
var err error var err error
paymentStatus, err = fetchPaymentStatusTx(tx, paymentHash) paymentStatus, err = fetchPaymentStatusTx(tx, paymentHash)
return err return err
}, func() {
paymentStatus = StatusUnknown
}) })
if err != nil { if err != nil {
return StatusUnknown, err return StatusUnknown, err
@ -424,6 +428,8 @@ func (db *DB) fetchPaymentsMigration9() ([]*Payment, error) {
return nil return nil
}) })
}) })
}, func() {
payments = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -55,7 +55,7 @@ func beforeMigrationFuncV11(t *testing.T, d *DB, invoices []Invoice) {
} }
return nil return nil
}) }, func() {})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

@ -125,7 +125,7 @@ func TestPaymentStatusesMigration(t *testing.T) {
} }
return circuits.Put(inFlightKey, inFlightCircuit) return circuits.Put(inFlightKey, inFlightCircuit)
}) }, func() {})
if err != nil { if err != nil {
t.Fatalf("unable to add circuit map entry: %v", err) t.Fatalf("unable to add circuit map entry: %v", err)
} }
@ -385,7 +385,7 @@ func TestMigrateOptionalChannelCloseSummaryFields(t *testing.T) {
return err return err
} }
return closedChanBucket.Put(chanID, old) return closedChanBucket.Put(chanID, old)
}) }, func() {})
if err != nil { if err != nil {
t.Fatalf("unable to add old serialization: %v", t.Fatalf("unable to add old serialization: %v",
err) err)
@ -418,6 +418,8 @@ func TestMigrateOptionalChannelCloseSummaryFields(t *testing.T) {
"serialization") "serialization")
} }
return nil return nil
}, func() {
dbSummary = nil
}) })
if err != nil { if err != nil {
t.Fatalf("unable to view DB: %v", err) t.Fatalf("unable to view DB: %v", err)
@ -491,7 +493,7 @@ func TestMigrateGossipMessageStoreKeys(t *testing.T) {
} }
return messageStore.Put(oldMsgKey[:], b.Bytes()) return messageStore.Put(oldMsgKey[:], b.Bytes())
}) }, func() {})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -521,6 +523,8 @@ func TestMigrateGossipMessageStoreKeys(t *testing.T) {
} }
return nil return nil
}, func() {
rawMsg = nil
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -679,7 +683,7 @@ func TestOutgoingPaymentsMigration(t *testing.T) {
} }
return nil return nil
}) }, func() {})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -855,6 +859,8 @@ func TestPaymentRouteSerialization(t *testing.T) {
} }
return nil return nil
}, func() {
oldPayments = nil
}) })
if err != nil { if err != nil {
t.Fatalf("unable to create test payments: %v", err) t.Fatalf("unable to create test payments: %v", err)

@ -303,6 +303,8 @@ func (db *DB) FetchPayments() ([]*Payment, error) {
return nil return nil
}) })
}) })
}, func() {
payments = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -47,7 +47,7 @@ func ApplyMigration(t *testing.T,
// beforeMigration usually used for populating the database // beforeMigration usually used for populating the database
// with test data. // with test data.
err = kvdb.Update(cdb, beforeMigration) err = kvdb.Update(cdb, beforeMigration, func() {})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -65,14 +65,14 @@ func ApplyMigration(t *testing.T,
// afterMigration usually used for checking the database state and // afterMigration usually used for checking the database state and
// throwing the error if something went wrong. // throwing the error if something went wrong.
err = kvdb.Update(cdb, afterMigration) err = kvdb.Update(cdb, afterMigration, func() {})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
}() }()
// Apply migration. // Apply migration.
err = kvdb.Update(cdb, migrationFunc) err = kvdb.Update(cdb, migrationFunc, func() {})
if err != nil { if err != nil {
t.Logf("migration error: %v", err) t.Logf("migration error: %v", err)
} }

@ -108,7 +108,7 @@ func (l *LinkNode) Sync() error {
} }
return putLinkNode(nodeMetaBucket, l) return putLinkNode(nodeMetaBucket, l)
}) }, func() {})
} }
// putLinkNode serializes then writes the encoded version of the passed link // putLinkNode serializes then writes the encoded version of the passed link
@ -132,7 +132,7 @@ func putLinkNode(nodeMetaBucket kvdb.RwBucket, l *LinkNode) error {
func (db *DB) DeleteLinkNode(identity *btcec.PublicKey) error { func (db *DB) DeleteLinkNode(identity *btcec.PublicKey) error {
return kvdb.Update(db, func(tx kvdb.RwTx) error { return kvdb.Update(db, func(tx kvdb.RwTx) error {
return db.deleteLinkNode(tx, identity) return db.deleteLinkNode(tx, identity)
}) }, func() {})
} }
func (db *DB) deleteLinkNode(tx kvdb.RwTx, identity *btcec.PublicKey) error { func (db *DB) deleteLinkNode(tx kvdb.RwTx, identity *btcec.PublicKey) error {
@ -158,6 +158,8 @@ func (db *DB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) {
linkNode = node linkNode = node
return nil return nil
}, func() {
linkNode = nil
}) })
return linkNode, err return linkNode, err
@ -199,6 +201,8 @@ func (db *DB) FetchAllLinkNodes() ([]*LinkNode, error) {
linkNodes = nodes linkNodes = nodes
return nil return nil
}, func() {
linkNodes = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -550,6 +550,8 @@ func (p *PaymentControl) FetchPayment(paymentHash lntypes.Hash) (
payment, err = fetchPayment(bucket) payment, err = fetchPayment(bucket)
return err return err
}, func() {
payment = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -716,6 +718,8 @@ func (p *PaymentControl) FetchInFlightPayments() ([]*InFlightPayment, error) {
inFlights = append(inFlights, inFlight) inFlights = append(inFlights, inFlight)
return nil return nil
}) })
}, func() {
inFlights = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -502,7 +502,7 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) {
indexCount++ indexCount++
return nil return nil
}) })
}) }, func() { indexCount = 0 })
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, indexCount) require.Equal(t, 1, indexCount)
@ -989,7 +989,8 @@ func fetchPaymentIndexEntry(_ *testing.T, p *PaymentControl,
var err error var err error
hash, err = deserializePaymentIndex(r) hash, err = deserializePaymentIndex(r)
return err return err
}, func() {
hash = lntypes.Hash{}
}); err != nil { }); err != nil {
return nil, err return nil, err
} }

@ -269,6 +269,8 @@ func (db *DB) FetchPayments() ([]*MPPayment, error) {
payments = append(payments, duplicatePayments...) payments = append(payments, duplicatePayments...)
return nil return nil
}) })
}, func() {
payments = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -572,6 +574,8 @@ func (db *DB) QueryPayments(query PaymentsQuery) (PaymentsResponse, error) {
} }
return nil return nil
}, func() {
resp = PaymentsResponse{}
}); err != nil { }); err != nil {
return resp, err return resp, err
} }
@ -745,7 +749,7 @@ func (db *DB) DeletePayments() error {
} }
return nil return nil
}) }, func() {})
} }
// fetchSequenceNumbers fetches all the sequence numbers associated with a // fetchSequenceNumbers fetches all the sequence numbers associated with a

@ -183,7 +183,7 @@ func deletePayment(t *testing.T, db *DB, paymentHash lntypes.Hash, seqNr uint64)
// Delete the index that references this payment. // Delete the index that references this payment.
indexes := tx.ReadWriteBucket(paymentsIndexBucket) indexes := tx.ReadWriteBucket(paymentsIndexBucket)
return indexes.Delete(key) return indexes.Delete(key)
}) }, func() {})
if err != nil { if err != nil {
t.Fatalf("could not delete "+ t.Fatalf("could not delete "+
@ -622,7 +622,7 @@ func TestFetchPaymentWithSequenceNumber(t *testing.T) {
tx, test.paymentHash, seqNrBytes[:], tx, test.paymentHash, seqNrBytes[:],
) )
return err return err
}) }, func() {})
require.Equal(t, test.expectedErr, err) require.Equal(t, test.expectedErr, err)
}) })
} }
@ -666,7 +666,7 @@ func appendDuplicatePayment(t *testing.T, db *DB, paymentHash lntypes.Hash,
require.NoError(t, err) require.NoError(t, err)
return nil return nil
}) }, func() {})
if err != nil { if err != nil {
t.Fatalf("could not create payment: %v", err) t.Fatalf("could not create payment: %v", err)
} }

@ -6,7 +6,7 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/btcsuite/btcwallet/walletdb" "github.com/lightningnetwork/lnd/channeldb/kvdb"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
) )
@ -50,7 +50,7 @@ type FlapCount struct {
// bucket for the peer's pubkey if necessary. Note that this function overwrites // bucket for the peer's pubkey if necessary. Note that this function overwrites
// the current value. // the current value.
func (d *DB) WriteFlapCounts(flapCounts map[route.Vertex]*FlapCount) error { func (d *DB) WriteFlapCounts(flapCounts map[route.Vertex]*FlapCount) error {
return d.Update(func(tx walletdb.ReadWriteTx) error { return kvdb.Update(d, func(tx kvdb.RwTx) error {
// Run through our set of flap counts and record them for // Run through our set of flap counts and record them for
// each peer, creating a bucket for the peer pubkey if required. // each peer, creating a bucket for the peer pubkey if required.
for peer, flapCount := range flapCounts { for peer, flapCount := range flapCounts {
@ -80,7 +80,7 @@ func (d *DB) WriteFlapCounts(flapCounts map[route.Vertex]*FlapCount) error {
} }
return nil return nil
}) }, func() {})
} }
// ReadFlapCount attempts to read the flap count for a peer, failing if the // ReadFlapCount attempts to read the flap count for a peer, failing if the
@ -88,7 +88,7 @@ func (d *DB) WriteFlapCounts(flapCounts map[route.Vertex]*FlapCount) error {
func (d *DB) ReadFlapCount(pubkey route.Vertex) (*FlapCount, error) { func (d *DB) ReadFlapCount(pubkey route.Vertex) (*FlapCount, error) {
var flapCount FlapCount var flapCount FlapCount
if err := d.View(func(tx walletdb.ReadTx) error { if err := kvdb.View(d, func(tx kvdb.RTx) error {
peers := tx.ReadBucket(peersBucket) peers := tx.ReadBucket(peersBucket)
peerBucket := peers.NestedReadBucket(pubkey[:]) peerBucket := peers.NestedReadBucket(pubkey[:])
@ -113,6 +113,8 @@ func (d *DB) ReadFlapCount(pubkey route.Vertex) (*FlapCount, error) {
} }
return ReadElements(r, &flapCount.Count) return ReadElements(r, &flapCount.Count)
}, func() {
flapCount = FlapCount{}
}); err != nil { }); err != nil {
return nil, err return nil, err
} }

@ -130,7 +130,7 @@ func (d *DB) PutResolverReport(tx kvdb.RwTx, chainHash chainhash.Hash,
// If the transaction is nil, we'll create a new one. // If the transaction is nil, we'll create a new one.
if tx == nil { if tx == nil {
return kvdb.Update(d, putReportFunc) return kvdb.Update(d, putReportFunc, func() {})
} }
// Otherwise, we can write the report to disk using the existing // Otherwise, we can write the report to disk using the existing
@ -250,6 +250,8 @@ func (d DB) FetchChannelReports(chainHash chainhash.Hash,
return nil return nil
}) })
}, func() {
reports = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }

@ -202,7 +202,7 @@ func TestFetchChannelWriteBucket(t *testing.T) {
defer cleanup() defer cleanup()
// Update our db to the starting state we expect. // Update our db to the starting state we expect.
err = kvdb.Update(db, test.setup) err = kvdb.Update(db, test.setup, func() {})
require.NoError(t, err) require.NoError(t, err)
// Try to get our report bucket. // Try to get our report bucket.
@ -211,7 +211,7 @@ func TestFetchChannelWriteBucket(t *testing.T) {
tx, testChainHash, &testChanPoint1, tx, testChainHash, &testChanPoint1,
) )
return err return err
}) }, func() {})
require.NoError(t, err) require.NoError(t, err)
}) })
} }

@ -42,13 +42,14 @@ type WaitingProofStore struct {
// NewWaitingProofStore creates new instance of proofs storage. // NewWaitingProofStore creates new instance of proofs storage.
func NewWaitingProofStore(db *DB) (*WaitingProofStore, error) { func NewWaitingProofStore(db *DB) (*WaitingProofStore, error) {
s := &WaitingProofStore{ s := &WaitingProofStore{
db: db, db: db,
cache: make(map[WaitingProofKey]struct{}),
} }
if err := s.ForAll(func(proof *WaitingProof) error { if err := s.ForAll(func(proof *WaitingProof) error {
s.cache[proof.Key()] = struct{}{} s.cache[proof.Key()] = struct{}{}
return nil return nil
}, func() {
s.cache = make(map[WaitingProofKey]struct{})
}); err != nil && err != ErrWaitingProofNotFound { }); err != nil && err != ErrWaitingProofNotFound {
return nil, err return nil, err
} }
@ -79,7 +80,7 @@ func (s *WaitingProofStore) Add(proof *WaitingProof) error {
key := proof.Key() key := proof.Key()
return bucket.Put(key[:], b.Bytes()) return bucket.Put(key[:], b.Bytes())
}) }, func() {})
if err != nil { if err != nil {
return err return err
} }
@ -108,7 +109,7 @@ func (s *WaitingProofStore) Remove(key WaitingProofKey) error {
} }
return bucket.Delete(key[:]) return bucket.Delete(key[:])
}) }, func() {})
if err != nil { if err != nil {
return err return err
} }
@ -122,7 +123,9 @@ func (s *WaitingProofStore) Remove(key WaitingProofKey) error {
// ForAll iterates thought all waiting proofs and passing the waiting proof // ForAll iterates thought all waiting proofs and passing the waiting proof
// in the given callback. // in the given callback.
func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error) error { func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error,
reset func()) error {
return kvdb.View(s.db, func(tx kvdb.RTx) error { return kvdb.View(s.db, func(tx kvdb.RTx) error {
bucket := tx.ReadBucket(waitingProofsBucketKey) bucket := tx.ReadBucket(waitingProofsBucketKey)
if bucket == nil { if bucket == nil {
@ -144,12 +147,12 @@ func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error) error {
return cb(proof) return cb(proof)
}) })
}) }, reset)
} }
// Get returns the object which corresponds to the given index. // Get returns the object which corresponds to the given index.
func (s *WaitingProofStore) Get(key WaitingProofKey) (*WaitingProof, error) { func (s *WaitingProofStore) Get(key WaitingProofKey) (*WaitingProof, error) {
proof := &WaitingProof{} var proof *WaitingProof
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
@ -172,6 +175,8 @@ func (s *WaitingProofStore) Get(key WaitingProofKey) (*WaitingProof, error) {
r := bytes.NewReader(v) r := bytes.NewReader(v)
return proof.Decode(r) return proof.Decode(r)
}, func() {
proof = &WaitingProof{}
}) })
return proof, err return proof, err

@ -53,7 +53,7 @@ func TestWaitingProofStore(t *testing.T) {
if err := store.ForAll(func(proof *WaitingProof) error { if err := store.ForAll(func(proof *WaitingProof) error {
return errors.New("storage should be empty") return errors.New("storage should be empty")
}); err != nil && err != ErrWaitingProofNotFound { }, func() {}); err != nil && err != ErrWaitingProofNotFound {
t.Fatal(err) t.Fatal(err)
} }
} }

@ -174,6 +174,8 @@ func (w *WitnessCache) lookupWitness(wType WitnessType, witnessKey []byte) ([]by
copy(witness[:], dbWitness) copy(witness[:], dbWitness)
return nil return nil
}, func() {
witness = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -430,6 +430,8 @@ func (b *boltArbitratorLog) CurrentState() (ArbitratorState, error) {
s = ArbitratorState(stateBytes[0]) s = ArbitratorState(stateBytes[0])
return nil return nil
}, func() {
s = 0
}) })
if err != nil && err != errScopeBucketNoExist { if err != nil && err != errScopeBucketNoExist {
return s, err return s, err
@ -521,6 +523,8 @@ func (b *boltArbitratorLog) FetchUnresolvedContracts() ([]ContractResolver, erro
contracts = append(contracts, res) contracts = append(contracts, res)
return nil return nil
}) })
}, func() {
contracts = nil
}) })
if err != nil && err != errScopeBucketNoExist && err != errNoContracts { if err != nil && err != errScopeBucketNoExist && err != errNoContracts {
return nil, err return nil, err
@ -685,7 +689,7 @@ func (b *boltArbitratorLog) LogContractResolutions(c *ContractResolutions) error
// //
// NOTE: Part of the ContractResolver interface. // NOTE: Part of the ContractResolver interface.
func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, error) { func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, error) {
c := &ContractResolutions{} var c *ContractResolutions
err := kvdb.View(b.db, func(tx kvdb.RTx) error { err := kvdb.View(b.db, func(tx kvdb.RTx) error {
scopeBucket := tx.ReadBucket(b.scopeKey[:]) scopeBucket := tx.ReadBucket(b.scopeKey[:])
if scopeBucket == nil { if scopeBucket == nil {
@ -769,6 +773,8 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er
} }
return nil return nil
}, func() {
c = &ContractResolutions{}
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -783,7 +789,7 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er
// //
// NOTE: Part of the ContractResolver interface. // NOTE: Part of the ContractResolver interface.
func (b *boltArbitratorLog) FetchChainActions() (ChainActionMap, error) { func (b *boltArbitratorLog) FetchChainActions() (ChainActionMap, error) {
actionsMap := make(ChainActionMap) var actionsMap ChainActionMap
err := kvdb.View(b.db, func(tx kvdb.RTx) error { err := kvdb.View(b.db, func(tx kvdb.RTx) error {
scopeBucket := tx.ReadBucket(b.scopeKey[:]) scopeBucket := tx.ReadBucket(b.scopeKey[:])
@ -813,6 +819,8 @@ func (b *boltArbitratorLog) FetchChainActions() (ChainActionMap, error) {
return nil return nil
}) })
}, func() {
actionsMap = make(ChainActionMap)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -866,6 +874,8 @@ func (b *boltArbitratorLog) FetchConfirmedCommitSet() (*CommitSet, error) {
c = commitSet c = commitSet
return nil return nil
}, func() {
c = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -914,7 +924,7 @@ func (b *boltArbitratorLog) WipeHistory() error {
// Finally, we'll delete the enclosing bucket itself. // Finally, we'll delete the enclosing bucket itself.
return tx.DeleteTopLevelBucket(b.scopeKey[:]) return tx.DeleteTopLevelBucket(b.scopeKey[:])
}) }, func() {})
} }
// checkpointContract is a private method that will be fed into // checkpointContract is a private method that will be fed into
@ -941,7 +951,7 @@ func (b *boltArbitratorLog) checkpointContract(c ContractResolver,
} }
return nil return nil
}) }, func() {})
} }
func encodeIncomingResolution(w io.Writer, i *lnwallet.IncomingHtlcResolution) error { func encodeIncomingResolution(w io.Writer, i *lnwallet.IncomingHtlcResolution) error {

@ -1117,6 +1117,9 @@ func TestSignatureAnnouncementLocalFirst(t *testing.T) {
number++ number++
return nil return nil
}, },
func() {
number = 0
},
); err != nil { ); err != nil {
t.Fatalf("unable to retrieve objects from store: %v", err) t.Fatalf("unable to retrieve objects from store: %v", err)
} }
@ -1150,6 +1153,9 @@ func TestSignatureAnnouncementLocalFirst(t *testing.T) {
number++ number++
return nil return nil
}, },
func() {
number = 0
},
); err != nil && err != channeldb.ErrWaitingProofNotFound { ); err != nil && err != channeldb.ErrWaitingProofNotFound {
t.Fatalf("unable to retrieve objects from store: %v", err) t.Fatalf("unable to retrieve objects from store: %v", err)
} }
@ -1219,6 +1225,9 @@ func TestOrphanSignatureAnnouncement(t *testing.T) {
number++ number++
return nil return nil
}, },
func() {
number = 0
},
); err != nil { ); err != nil {
t.Fatalf("unable to retrieve objects from store: %v", err) t.Fatalf("unable to retrieve objects from store: %v", err)
} }
@ -1355,6 +1364,9 @@ func TestOrphanSignatureAnnouncement(t *testing.T) {
number++ number++
return nil return nil
}, },
func() {
number = 0
},
); err != nil { ); err != nil {
t.Fatalf("unable to retrieve objects from store: %v", err) t.Fatalf("unable to retrieve objects from store: %v", err)
} }
@ -1466,6 +1478,9 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) {
number++ number++
return nil return nil
}, },
func() {
number = 0
},
); err != nil { ); err != nil {
t.Fatalf("unable to retrieve objects from store: %v", err) t.Fatalf("unable to retrieve objects from store: %v", err)
} }
@ -1570,6 +1585,9 @@ out:
number++ number++
return nil return nil
}, },
func() {
number = 0
},
); err != nil && err != channeldb.ErrWaitingProofNotFound { ); err != nil && err != channeldb.ErrWaitingProofNotFound {
t.Fatalf("unable to retrieve objects from store: %v", err) t.Fatalf("unable to retrieve objects from store: %v", err)
} }
@ -1754,6 +1772,9 @@ func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) {
number++ number++
return nil return nil
}, },
func() {
number = 0
},
); err != nil && err != channeldb.ErrWaitingProofNotFound { ); err != nil && err != channeldb.ErrWaitingProofNotFound {
t.Fatalf("unable to retrieve objects from store: %v", err) t.Fatalf("unable to retrieve objects from store: %v", err)
} }
@ -2583,6 +2604,9 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) {
number++ number++
return nil return nil
}, },
func() {
number = 0
},
); err != nil { ); err != nil {
t.Fatalf("unable to retrieve objects from store: %v", err) t.Fatalf("unable to retrieve objects from store: %v", err)
} }
@ -2612,6 +2636,9 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) {
number++ number++
return nil return nil
}, },
func() {
number = 0
},
); err != nil && err != channeldb.ErrWaitingProofNotFound { ); err != nil && err != channeldb.ErrWaitingProofNotFound {
t.Fatalf("unable to retrieve objects from store: %v", err) t.Fatalf("unable to retrieve objects from store: %v", err)
} }

@ -199,7 +199,7 @@ func readMessage(msgBytes []byte) (lnwire.Message, error) {
// Messages returns the total set of messages that exist within the store for // Messages returns the total set of messages that exist within the store for
// all peers. // all peers.
func (s *MessageStore) Messages() (map[[33]byte][]lnwire.Message, error) { func (s *MessageStore) Messages() (map[[33]byte][]lnwire.Message, error) {
msgs := make(map[[33]byte][]lnwire.Message) var msgs map[[33]byte][]lnwire.Message
err := kvdb.View(s.db, func(tx kvdb.RTx) error { err := kvdb.View(s.db, func(tx kvdb.RTx) error {
messageStore := tx.ReadBucket(messageStoreBucket) messageStore := tx.ReadBucket(messageStoreBucket)
if messageStore == nil { if messageStore == nil {
@ -224,6 +224,8 @@ func (s *MessageStore) Messages() (map[[33]byte][]lnwire.Message, error) {
msgs[pubKey] = append(msgs[pubKey], msg) msgs[pubKey] = append(msgs[pubKey], msg)
return nil return nil
}) })
}, func() {
msgs = make(map[[33]byte][]lnwire.Message)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -262,6 +264,8 @@ func (s *MessageStore) MessagesForPeer(
} }
return nil return nil
}, func() {
msgs = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -272,7 +276,7 @@ func (s *MessageStore) MessagesForPeer(
// Peers returns the public key of all peers with messages within the store. // Peers returns the public key of all peers with messages within the store.
func (s *MessageStore) Peers() (map[[33]byte]struct{}, error) { func (s *MessageStore) Peers() (map[[33]byte]struct{}, error) {
peers := make(map[[33]byte]struct{}) var peers map[[33]byte]struct{}
err := kvdb.View(s.db, func(tx kvdb.RTx) error { err := kvdb.View(s.db, func(tx kvdb.RTx) error {
messageStore := tx.ReadBucket(messageStoreBucket) messageStore := tx.ReadBucket(messageStoreBucket)
if messageStore == nil { if messageStore == nil {
@ -285,6 +289,8 @@ func (s *MessageStore) Peers() (map[[33]byte]struct{}, error) {
peers[pubKey] = struct{}{} peers[pubKey] = struct{}{}
return nil return nil
}) })
}, func() {
peers = make(map[[33]byte]struct{})
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -239,7 +239,7 @@ func TestMessageStoreUnsupportedMessage(t *testing.T) {
err = kvdb.Update(msgStore.db, func(tx kvdb.RwTx) error { err = kvdb.Update(msgStore.db, func(tx kvdb.RwTx) error {
messageStore := tx.ReadWriteBucket(messageStoreBucket) messageStore := tx.ReadWriteBucket(messageStoreBucket)
return messageStore.Put(msgKey, rawMsg.Bytes()) return messageStore.Put(msgKey, rawMsg.Bytes())
}) }, func() {})
if err != nil { if err != nil {
t.Fatalf("unable to add unsupported message to store: %v", err) t.Fatalf("unable to add unsupported message to store: %v", err)
} }

@ -3505,7 +3505,7 @@ func (f *fundingManager) saveChannelOpeningState(chanPoint *wire.OutPoint,
byteOrder.PutUint64(scratch[2:], shortChanID.ToUint64()) byteOrder.PutUint64(scratch[2:], shortChanID.ToUint64())
return bucket.Put(outpointBytes.Bytes(), scratch) return bucket.Put(outpointBytes.Bytes(), scratch)
}) }, func() {})
} }
// getChannelOpeningState fetches the channelOpeningState for the provided // getChannelOpeningState fetches the channelOpeningState for the provided
@ -3538,7 +3538,7 @@ func (f *fundingManager) getChannelOpeningState(chanPoint *wire.OutPoint) (
state = channelOpeningState(byteOrder.Uint16(value[:2])) state = channelOpeningState(byteOrder.Uint16(value[:2]))
shortChanID = lnwire.NewShortChanIDFromInt(byteOrder.Uint64(value[2:])) shortChanID = lnwire.NewShortChanIDFromInt(byteOrder.Uint64(value[2:]))
return nil return nil
}) }, func() {})
if err != nil { if err != nil {
return 0, nil, err return 0, nil, err
} }
@ -3560,5 +3560,5 @@ func (f *fundingManager) deleteChannelOpeningState(chanPoint *wire.OutPoint) err
} }
return bucket.Delete(outpointBytes.Bytes()) return bucket.Delete(outpointBytes.Bytes())
}) }, func() {})
} }

@ -227,7 +227,7 @@ func (cm *circuitMap) initBuckets() error {
_, err = tx.CreateTopLevelBucket(circuitAddKey) _, err = tx.CreateTopLevelBucket(circuitAddKey)
return err return err
}) }, func() {})
} }
// restoreMemState loads the contents of the half circuit and full circuit // restoreMemState loads the contents of the half circuit and full circuit
@ -240,8 +240,8 @@ func (cm *circuitMap) restoreMemState() error {
log.Infof("Restoring in-memory circuit state from disk") log.Infof("Restoring in-memory circuit state from disk")
var ( var (
opened = make(map[CircuitKey]*PaymentCircuit) opened map[CircuitKey]*PaymentCircuit
pending = make(map[CircuitKey]*PaymentCircuit) pending map[CircuitKey]*PaymentCircuit
) )
if err := kvdb.Update(cm.cfg.DB, func(tx kvdb.RwTx) error { if err := kvdb.Update(cm.cfg.DB, func(tx kvdb.RwTx) error {
@ -331,6 +331,9 @@ func (cm *circuitMap) restoreMemState() error {
return nil return nil
}, func() {
opened = make(map[CircuitKey]*PaymentCircuit)
pending = make(map[CircuitKey]*PaymentCircuit)
}); err != nil { }); err != nil {
return err return err
} }
@ -483,7 +486,7 @@ func (cm *circuitMap) TrimOpenCircuits(chanID lnwire.ShortChannelID,
} }
return nil return nil
}) }, func() {})
} }
// LookupByHTLC looks up the payment circuit by the outgoing channel and HTLC // LookupByHTLC looks up the payment circuit by the outgoing channel and HTLC
@ -730,7 +733,7 @@ func (cm *circuitMap) OpenCircuits(keystones ...Keystone) error {
} }
return nil return nil
}) }, func() {})
if err != nil { if err != nil {
return err return err

@ -131,7 +131,7 @@ func (d *DecayedLog) initBuckets() error {
} }
return nil return nil
}) }, func() {})
} }
// Stop halts the garbage collector and closes boltdb. // Stop halts the garbage collector and closes boltdb.
@ -280,6 +280,8 @@ func (d *DecayedLog) Get(hash *sphinx.HashPrefix) (uint32, error) {
value = uint32(binary.BigEndian.Uint32(valueBytes)) value = uint32(binary.BigEndian.Uint32(valueBytes))
return nil return nil
}, func() {
value = 0
}) })
if err != nil { if err != nil {
return value, err return value, err

@ -197,6 +197,8 @@ func (store *networkResultStore) subscribeResult(paymentID uint64) (
default: default:
return nil return nil
} }
}, func() {
result = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -230,6 +232,8 @@ func (store *networkResultStore) getResult(pid uint64) (
var err error var err error
result, err = fetchResult(tx, pid) result, err = fetchResult(tx, pid)
return err return err
}, func() {
result = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -301,5 +305,5 @@ func (store *networkResultStore) cleanStore(keep map[uint64]struct{}) error {
} }
return nil return nil
}) }, func() {})
} }

@ -100,6 +100,8 @@ func (s *persistentSequencer) NextID() (uint64, error) {
nextIDBkt.SetSequence(nextHorizonID) nextIDBkt.SetSequence(nextHorizonID)
return nil return nil
}, func() {
nextHorizonID = 0
}); err != nil { }); err != nil {
return 0, err return 0, err
} }
@ -124,5 +126,5 @@ func (s *persistentSequencer) initDB() error {
return kvdb.Update(s.db, func(tx kvdb.RwTx) error { return kvdb.Update(s.db, func(tx kvdb.RwTx) error {
_, err := tx.CreateTopLevelBucket(nextPaymentIDKey) _, err := tx.CreateTopLevelBucket(nextPaymentIDKey)
return err return err
}) }, func() {})
} }

@ -1833,6 +1833,8 @@ func (s *Switch) loadChannelFwdPkgs(source lnwire.ShortChannelID) ([]*channeldb.
tx, source, tx, source,
) )
return err return err
}, func() {
fwdPkgs = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }

@ -47,6 +47,9 @@ type AddInvoiceConfig struct {
// channel graph. // channel graph.
ChanDB *channeldb.DB ChanDB *channeldb.DB
// Graph holds a reference to the ChannelGraph database.
Graph *channeldb.ChannelGraph
// GenInvoiceFeatures returns a feature containing feature bits that // GenInvoiceFeatures returns a feature containing feature bits that
// should be advertised on freshly generated invoices. // should be advertised on freshly generated invoices.
GenInvoiceFeatures func() *lnwire.FeatureVector GenInvoiceFeatures func() *lnwire.FeatureVector
@ -330,9 +333,8 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig,
// chanCanBeHopHint returns true if the target channel is eligible to be a hop // chanCanBeHopHint returns true if the target channel is eligible to be a hop
// hint. // hint.
func chanCanBeHopHint(channel *channeldb.OpenChannel, func chanCanBeHopHint(channel *channeldb.OpenChannel, cfg *AddInvoiceConfig) (
graph *channeldb.ChannelGraph, *channeldb.ChannelEdgePolicy, bool) {
cfg *AddInvoiceConfig) (*channeldb.ChannelEdgePolicy, bool) {
// Since we're only interested in our private channels, we'll skip // Since we're only interested in our private channels, we'll skip
// public ones. // public ones.
@ -359,7 +361,7 @@ func chanCanBeHopHint(channel *channeldb.OpenChannel,
// channels. // channels.
var remotePub [33]byte var remotePub [33]byte
copy(remotePub[:], channel.IdentityPub.SerializeCompressed()) copy(remotePub[:], channel.IdentityPub.SerializeCompressed())
isRemoteNodePublic, err := graph.IsPublicNode(remotePub) isRemoteNodePublic, err := cfg.Graph.IsPublicNode(remotePub)
if err != nil { if err != nil {
log.Errorf("Unable to determine if node %x "+ log.Errorf("Unable to determine if node %x "+
"is advertised: %v", remotePub, err) "is advertised: %v", remotePub, err)
@ -375,7 +377,7 @@ func chanCanBeHopHint(channel *channeldb.OpenChannel,
// Fetch the policies for each end of the channel. // Fetch the policies for each end of the channel.
chanID := channel.ShortChanID().ToUint64() chanID := channel.ShortChanID().ToUint64()
info, p1, p2, err := graph.FetchChannelEdgesByID(chanID) info, p1, p2, err := cfg.Graph.FetchChannelEdgesByID(chanID)
if err != nil { if err != nil {
log.Errorf("Unable to fetch the routing "+ log.Errorf("Unable to fetch the routing "+
"policies for the edges of the channel "+ "policies for the edges of the channel "+
@ -423,8 +425,6 @@ func selectHopHints(amtMSat lnwire.MilliSatoshi, cfg *AddInvoiceConfig,
openChannels []*channeldb.OpenChannel, openChannels []*channeldb.OpenChannel,
numMaxHophints int) []func(*zpay32.Invoice) { numMaxHophints int) []func(*zpay32.Invoice) {
graph := cfg.ChanDB.ChannelGraph()
// We'll add our hop hints in two passes, first we'll add all channels // We'll add our hop hints in two passes, first we'll add all channels
// that are eligible to be hop hints, and also have a local balance // that are eligible to be hop hints, and also have a local balance
// above the payment amount. // above the payment amount.
@ -433,9 +433,7 @@ func selectHopHints(amtMSat lnwire.MilliSatoshi, cfg *AddInvoiceConfig,
hopHints := make([]func(*zpay32.Invoice), 0, numMaxHophints) hopHints := make([]func(*zpay32.Invoice), 0, numMaxHophints)
for _, channel := range openChannels { for _, channel := range openChannels {
// If this channel can't be a hop hint, then skip it. // If this channel can't be a hop hint, then skip it.
edgePolicy, canBeHopHint := chanCanBeHopHint( edgePolicy, canBeHopHint := chanCanBeHopHint(channel, cfg)
channel, graph, cfg,
)
if edgePolicy == nil || !canBeHopHint { if edgePolicy == nil || !canBeHopHint {
continue continue
} }
@ -485,9 +483,7 @@ func selectHopHints(amtMSat lnwire.MilliSatoshi, cfg *AddInvoiceConfig,
// If the channel can't be a hop hint, then we'll skip it. // If the channel can't be a hop hint, then we'll skip it.
// Otherwise, we'll use the policy information to populate the // Otherwise, we'll use the policy information to populate the
// hop hint. // hop hint.
remotePolicy, canBeHopHint := chanCanBeHopHint( remotePolicy, canBeHopHint := chanCanBeHopHint(channel, cfg)
channel, graph, cfg,
)
if !canBeHopHint || remotePolicy == nil { if !canBeHopHint || remotePolicy == nil {
continue continue
} }

@ -62,7 +62,7 @@ func NewRootKeyStorage(db kvdb.Backend) (*RootKeyStorage, error) {
err := kvdb.Update(db, func(tx kvdb.RwTx) error { err := kvdb.Update(db, func(tx kvdb.RwTx) error {
_, err := tx.CreateTopLevelBucket(rootKeyBucketName) _, err := tx.CreateTopLevelBucket(rootKeyBucketName)
return err return err
}) }, func() {})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -123,7 +123,7 @@ func (r *RootKeyStorage) CreateUnlock(password *[]byte) error {
r.encKey = encKey r.encKey = encKey
return nil return nil
}) }, func() {})
} }
// Get implements the Get method for the bakery.RootKeyStorage interface. // Get implements the Get method for the bakery.RootKeyStorage interface.
@ -150,6 +150,8 @@ func (r *RootKeyStorage) Get(_ context.Context, id []byte) ([]byte, error) {
rootKey = make([]byte, len(decKey)) rootKey = make([]byte, len(decKey))
copy(rootKey[:], decKey) copy(rootKey[:], decKey)
return nil return nil
}, func() {
rootKey = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -209,6 +211,8 @@ func (r *RootKeyStorage) RootKey(ctx context.Context) ([]byte, []byte, error) {
return err return err
} }
return ns.Put(id, encKey) return ns.Put(id, encKey)
}, func() {
rootKey = nil
}) })
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -257,6 +261,8 @@ func (r *RootKeyStorage) ListMacaroonIDs(_ context.Context) ([][]byte, error) {
} }
return tx.ReadBucket(rootKeyBucketName).ForEach(appendRootKey) return tx.ReadBucket(rootKeyBucketName).ForEach(appendRootKey)
}, func() {
rootKeySlice = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -306,6 +312,8 @@ func (r *RootKeyStorage) DeleteMacaroonID(
rootKeyIDDeleted = rootKeyID rootKeyIDDeleted = rootKeyID
return nil return nil
}, func() {
rootKeyIDDeleted = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -129,7 +129,7 @@ type NurseryStore interface {
// the caller to process each key-value pair. The key will be a prefixed // the caller to process each key-value pair. The key will be a prefixed
// outpoint, and the value will be the serialized bytes for an output, // outpoint, and the value will be the serialized bytes for an output,
// whose type should be inferred from the key's prefix. // whose type should be inferred from the key's prefix.
ForChanOutputs(*wire.OutPoint, func([]byte, []byte) error) error ForChanOutputs(*wire.OutPoint, func([]byte, []byte) error, func()) error
// ListChannels returns all channels the nursery is currently tracking. // ListChannels returns all channels the nursery is currently tracking.
ListChannels() ([]wire.OutPoint, error) ListChannels() ([]wire.OutPoint, error)
@ -283,7 +283,7 @@ func (ns *nurseryStore) Incubate(kids []kidOutput, babies []babyOutput) error {
} }
return nil return nil
}) }, func() {})
} }
// CribToKinder atomically moves a babyOutput in the crib bucket to the // CribToKinder atomically moves a babyOutput in the crib bucket to the
@ -365,7 +365,7 @@ func (ns *nurseryStore) CribToKinder(bby *babyOutput) error {
// This informs the utxo nursery that it should attempt to spend // This informs the utxo nursery that it should attempt to spend
// this output when the blockchain reaches the maturity height. // this output when the blockchain reaches the maturity height.
return hghtChanBucketCsv.Put(pfxOutputKey, []byte{}) return hghtChanBucketCsv.Put(pfxOutputKey, []byte{})
}) }, func() {})
} }
// PreschoolToKinder atomically moves a kidOutput from the preschool bucket to // PreschoolToKinder atomically moves a kidOutput from the preschool bucket to
@ -463,7 +463,7 @@ func (ns *nurseryStore) PreschoolToKinder(kid *kidOutput,
// that this CSV delayed output will be ready to broadcast at // that this CSV delayed output will be ready to broadcast at
// the maturity height, after a brief period of incubation. // the maturity height, after a brief period of incubation.
return hghtChanBucket.Put(pfxOutputKey, []byte{}) return hghtChanBucket.Put(pfxOutputKey, []byte{})
}) }, func() {})
} }
// GraduateKinder atomically moves an output at the provided height into the // GraduateKinder atomically moves an output at the provided height into the
@ -525,7 +525,7 @@ func (ns *nurseryStore) GraduateKinder(height uint32, kid *kidOutput) error {
// using graduate-prefixed key. // using graduate-prefixed key.
return chanBucket.Put(pfxOutputKey, return chanBucket.Put(pfxOutputKey,
gradBuffer.Bytes()) gradBuffer.Bytes())
}) }, func() {})
} }
// FetchClass returns a list of babyOutputs in the crib bucket whose CLTV // FetchClass returns a list of babyOutputs in the crib bucket whose CLTV
@ -582,6 +582,9 @@ func (ns *nurseryStore) FetchClass(
}) })
}, func() {
kids = nil
babies = nil
}); err != nil { }); err != nil {
return nil, nil, err return nil, nil, err
} }
@ -655,6 +658,8 @@ func (ns *nurseryStore) FetchPreschools() ([]kidOutput, error) {
} }
return nil return nil
}, func() {
kids = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
@ -693,6 +698,8 @@ func (ns *nurseryStore) HeightsBelowOrEqual(height uint32) ([]uint32, error) {
} }
return nil return nil
}, func() {
activeHeights = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -709,11 +716,11 @@ func (ns *nurseryStore) HeightsBelowOrEqual(height uint32) ([]uint32, error) {
// NOTE: The callback should not modify the provided byte slices and is // NOTE: The callback should not modify the provided byte slices and is
// preferably non-blocking. // preferably non-blocking.
func (ns *nurseryStore) ForChanOutputs(chanPoint *wire.OutPoint, func (ns *nurseryStore) ForChanOutputs(chanPoint *wire.OutPoint,
callback func([]byte, []byte) error) error { callback func([]byte, []byte) error, reset func()) error {
return kvdb.View(ns.db, func(tx kvdb.RTx) error { return kvdb.View(ns.db, func(tx kvdb.RTx) error {
return ns.forChanOutputs(tx, chanPoint, callback) return ns.forChanOutputs(tx, chanPoint, callback)
}) }, reset)
} }
// ListChannels returns all channels the nursery is currently tracking. // ListChannels returns all channels the nursery is currently tracking.
@ -743,6 +750,8 @@ func (ns *nurseryStore) ListChannels() ([]wire.OutPoint, error) {
return nil return nil
}) })
}, func() {
activeChannels = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
@ -765,7 +774,7 @@ func (ns *nurseryStore) IsMatureChannel(chanPoint *wire.OutPoint) (bool, error)
return nil return nil
}) })
}) }, func() {})
if err != nil && err != ErrImmatureChannel { if err != nil && err != ErrImmatureChannel {
return false, err return false, err
} }
@ -835,7 +844,7 @@ func (ns *nurseryStore) RemoveChannel(chanPoint *wire.OutPoint) error {
} }
return removeBucketIfExists(chanIndex, chanBytes) return removeBucketIfExists(chanIndex, chanBytes)
}) }, func() {})
} }
// Helper Methods // Helper Methods

@ -370,6 +370,8 @@ func assertNumChanOutputs(t *testing.T, ns NurseryStore,
err := ns.ForChanOutputs(chanPoint, func([]byte, []byte) error { err := ns.ForChanOutputs(chanPoint, func([]byte, []byte) error {
count++ count++
return nil return nil
}, func() {
count = 0
}) })
if count == 0 && err == ErrContractNotFound { if count == 0 && err == ErrContractNotFound {

@ -41,10 +41,7 @@ type missionControlStore struct {
} }
func newMissionControlStore(db kvdb.Backend, maxRecords int) (*missionControlStore, error) { func newMissionControlStore(db kvdb.Backend, maxRecords int) (*missionControlStore, error) {
store := &missionControlStore{ var store *missionControlStore
db: db,
maxRecords: maxRecords,
}
// Create buckets if not yet existing. // Create buckets if not yet existing.
err := kvdb.Update(db, func(tx kvdb.RwTx) error { err := kvdb.Update(db, func(tx kvdb.RwTx) error {
@ -64,6 +61,11 @@ func newMissionControlStore(db kvdb.Backend, maxRecords int) (*missionControlSto
} }
return nil return nil
}, func() {
store = &missionControlStore{
db: db,
maxRecords: maxRecords,
}
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -81,7 +83,7 @@ func (b *missionControlStore) clear() error {
_, err := tx.CreateTopLevelBucket(resultsKey) _, err := tx.CreateTopLevelBucket(resultsKey)
return err return err
}) }, func() {})
} }
// fetchAll returns all results currently stored in the database. // fetchAll returns all results currently stored in the database.
@ -103,6 +105,8 @@ func (b *missionControlStore) fetchAll() ([]*paymentResult, error) {
return nil return nil
}) })
}, func() {
results = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -249,7 +253,7 @@ func (b *missionControlStore) AddResult(rp *paymentResult) error {
// Put into results bucket. // Put into results bucket.
return bucket.Put(k, v) return bucket.Put(k, v)
}) }, func() {})
} }
// getResultKey returns a byte slice representing a unique key for this payment // getResultKey returns a byte slice representing a unique key for this payment

@ -4736,6 +4736,7 @@ func (r *rpcServer) AddInvoice(ctx context.Context,
NodeSigner: r.server.nodeSigner, NodeSigner: r.server.nodeSigner,
DefaultCLTVExpiry: defaultDelta, DefaultCLTVExpiry: defaultDelta,
ChanDB: r.server.remoteChanDB, ChanDB: r.server.remoteChanDB,
Graph: r.server.localChanDB.ChannelGraph(),
GenInvoiceFeatures: func() *lnwire.FeatureVector { GenInvoiceFeatures: func() *lnwire.FeatureVector {
return r.server.featureMgr.Get(feature.SetInvoice) return r.server.featureMgr.Get(feature.SetInvoice)
}, },

@ -92,7 +92,7 @@ func NewSweeperStore(db kvdb.Backend, chainHash *chainhash.Hash) (
err = migrateTxHashes(tx, txHashesBucket, chainHash) err = migrateTxHashes(tx, txHashesBucket, chainHash)
return err return err
}) }, func() {})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -193,7 +193,7 @@ func (s *sweeperStore) NotifyPublishTx(sweepTx *wire.MsgTx) error {
hash := sweepTx.TxHash() hash := sweepTx.TxHash()
return txHashesBucket.Put(hash[:], []byte{}) return txHashesBucket.Put(hash[:], []byte{})
}) }, func() {})
} }
// GetLastPublishedTx returns the last tx that we called NotifyPublishTx // GetLastPublishedTx returns the last tx that we called NotifyPublishTx
@ -219,6 +219,8 @@ func (s *sweeperStore) GetLastPublishedTx() (*wire.MsgTx, error) {
} }
return nil return nil
}, func() {
sweepTx = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -241,6 +243,8 @@ func (s *sweeperStore) IsOurTx(hash chainhash.Hash) (bool, error) {
ours = txHashesBucket.Get(hash[:]) != nil ours = txHashesBucket.Get(hash[:]) != nil
return nil return nil
}, func() {
ours = false
}) })
if err != nil { if err != nil {
return false, err return false, err
@ -269,6 +273,8 @@ func (s *sweeperStore) ListSweeps() ([]chainhash.Hash, error) {
return nil return nil
}) })
}, func() {
sweepTxns = nil
}); err != nil { }); err != nil {
return nil, err return nil, err
} }

@ -477,7 +477,7 @@ func (u *utxoNursery) NurseryReport(
utxnLog.Debugf("NurseryReport: building nursery report for channel %v", utxnLog.Debugf("NurseryReport: building nursery report for channel %v",
chanPoint) chanPoint)
report := &contractMaturityReport{} var report *contractMaturityReport
if err := u.cfg.Store.ForChanOutputs(chanPoint, func(k, v []byte) error { if err := u.cfg.Store.ForChanOutputs(chanPoint, func(k, v []byte) error {
switch { switch {
@ -576,6 +576,8 @@ func (u *utxoNursery) NurseryReport(
} }
return nil return nil
}, func() {
report = &contractMaturityReport{}
}); err != nil { }); err != nil {
return nil, err return nil, err
} }

@ -931,9 +931,9 @@ func (i *nurseryStoreInterceptor) HeightsBelowOrEqual(height uint32) (
} }
func (i *nurseryStoreInterceptor) ForChanOutputs(chanPoint *wire.OutPoint, func (i *nurseryStoreInterceptor) ForChanOutputs(chanPoint *wire.OutPoint,
callback func([]byte, []byte) error) error { callback func([]byte, []byte) error, reset func()) error {
return i.ns.ForChanOutputs(chanPoint, callback) return i.ns.ForChanOutputs(chanPoint, callback, reset)
} }
func (i *nurseryStoreInterceptor) ListChannels() ([]wire.OutPoint, error) { func (i *nurseryStoreInterceptor) ListChannels() ([]wire.OutPoint, error) {

@ -146,7 +146,7 @@ func OpenClientDB(dbPath string) (*ClientDB, error) {
// initialized. This allows us to assume their presence throughout all // initialized. This allows us to assume their presence throughout all
// operations. If an known top-level bucket is expected to exist but is // operations. If an known top-level bucket is expected to exist but is
// missing, this will trigger a ErrUninitializedDB error. // missing, this will trigger a ErrUninitializedDB error.
err = kvdb.Update(clientDB.db, initClientDBBuckets) err = kvdb.Update(clientDB.db, initClientDBBuckets, func() {})
if err != nil { if err != nil {
bdb.Close() bdb.Close()
return nil, err return nil, err
@ -192,6 +192,8 @@ func (c *ClientDB) Version() (uint32, error) {
var err error var err error
version, err = getDBVersion(tx) version, err = getDBVersion(tx)
return err return err
}, func() {
version = 0
}) })
if err != nil { if err != nil {
return 0, err return 0, err
@ -291,6 +293,8 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
// Store the new or updated tower under its tower id. // Store the new or updated tower under its tower id.
return putTower(towers, tower) return putTower(towers, tower)
}, func() {
tower = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -377,7 +381,7 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
} }
return nil return nil
}) }, func() {})
} }
// LoadTowerByID retrieves a tower by its tower ID. // LoadTowerByID retrieves a tower by its tower ID.
@ -392,6 +396,8 @@ func (c *ClientDB) LoadTowerByID(towerID TowerID) (*Tower, error) {
var err error var err error
tower, err = getTower(towers, towerID.Bytes()) tower, err = getTower(towers, towerID.Bytes())
return err return err
}, func() {
tower = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -421,6 +427,8 @@ func (c *ClientDB) LoadTower(pubKey *btcec.PublicKey) (*Tower, error) {
var err error var err error
tower, err = getTower(towers, towerIDBytes) tower, err = getTower(towers, towerIDBytes)
return err return err
}, func() {
tower = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -446,6 +454,8 @@ func (c *ClientDB) ListTowers() ([]*Tower, error) {
towers = append(towers, tower) towers = append(towers, tower)
return nil return nil
}) })
}, func() {
towers = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -498,6 +508,8 @@ func (c *ClientDB) NextSessionKeyIndex(towerID TowerID) (uint32, error) {
// Record the reserved session key index under this tower's id. // Record the reserved session key index under this tower's id.
return keyIndex.Put(towerIDBytes, indexBuf[:]) return keyIndex.Put(towerIDBytes, indexBuf[:])
}, func() {
index = 0
}) })
if err != nil { if err != nil {
return 0, err return 0, err
@ -550,7 +562,7 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error {
// Finally, write the client session's body in the sessions // Finally, write the client session's body in the sessions
// bucket. // bucket.
return putClientSessionBody(sessions, session) return putClientSessionBody(sessions, session)
}) }, func() {})
} }
// ListClientSessions returns the set of all client sessions known to the db. An // ListClientSessions returns the set of all client sessions known to the db. An
@ -566,6 +578,8 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (map[SessionID]*ClientSession
var err error var err error
clientSessions, err = listClientSessions(sessions, id) clientSessions, err = listClientSessions(sessions, id)
return err return err
}, func() {
clientSessions = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -611,7 +625,7 @@ func listClientSessions(sessions kvdb.RBucket,
// FetchChanSummaries loads a mapping from all registered channels to their // FetchChanSummaries loads a mapping from all registered channels to their
// channel summaries. // channel summaries.
func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) { func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
summaries := make(map[lnwire.ChannelID]ClientChanSummary) var summaries map[lnwire.ChannelID]ClientChanSummary
err := kvdb.View(c.db, func(tx kvdb.RTx) error { err := kvdb.View(c.db, func(tx kvdb.RTx) error {
chanSummaries := tx.ReadBucket(cChanSummaryBkt) chanSummaries := tx.ReadBucket(cChanSummaryBkt)
if chanSummaries == nil { if chanSummaries == nil {
@ -632,6 +646,8 @@ func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
return nil return nil
}) })
}, func() {
summaries = make(map[lnwire.ChannelID]ClientChanSummary)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -674,7 +690,7 @@ func (c *ClientDB) RegisterChannel(chanID lnwire.ChannelID,
} }
return putChanSummary(chanSummaries, chanID, &summary) return putChanSummary(chanSummaries, chanID, &summary)
}) }, func() {})
} }
// MarkBackupIneligible records that the state identified by the (channel id, // MarkBackupIneligible records that the state identified by the (channel id,
@ -782,6 +798,8 @@ func (c *ClientDB) CommitUpdate(id *SessionID,
return nil return nil
}, func() {
lastApplied = 0
}) })
if err != nil { if err != nil {
return 0, err return 0, err
@ -887,7 +905,7 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
// Finally, insert the ack into the sessionAcks sub-bucket. // Finally, insert the ack into the sessionAcks sub-bucket.
return sessionAcks.Put(seqNumBuf[:], b.Bytes()) return sessionAcks.Put(seqNumBuf[:], b.Bytes())
}) }, func() {})
} }
// getClientSessionBody loads the body of a ClientSession from the sessions // getClientSessionBody loads the body of a ClientSession from the sessions

@ -80,6 +80,8 @@ func createDBIfNotExist(dbPath, name string) (kvdb.Backend, bool, error) {
err = kvdb.View(bdb, func(tx kvdb.RTx) error { err = kvdb.View(bdb, func(tx kvdb.RTx) error {
metadataExists = tx.ReadBucket(metadataBkt) != nil metadataExists = tx.ReadBucket(metadataBkt) != nil
return nil return nil
}, func() {
metadataExists = false
}) })
if err != nil { if err != nil {
return nil, false, err return nil, false, err

@ -88,7 +88,7 @@ func OpenTowerDB(dbPath string) (*TowerDB, error) {
// initialized. This allows us to assume their presence throughout all // initialized. This allows us to assume their presence throughout all
// operations. If an known top-level bucket is expected to exist but is // operations. If an known top-level bucket is expected to exist but is
// missing, this will trigger a ErrUninitializedDB error. // missing, this will trigger a ErrUninitializedDB error.
err = kvdb.Update(towerDB.db, initTowerDBBuckets) err = kvdb.Update(towerDB.db, initTowerDBBuckets, func() {})
if err != nil { if err != nil {
bdb.Close() bdb.Close()
return nil, err return nil, err
@ -133,6 +133,8 @@ func (t *TowerDB) Version() (uint32, error) {
var err error var err error
version, err = getDBVersion(tx) version, err = getDBVersion(tx)
return err return err
}, func() {
version = 0
}) })
if err != nil { if err != nil {
return 0, err return 0, err
@ -159,6 +161,8 @@ func (t *TowerDB) GetSessionInfo(id *SessionID) (*SessionInfo, error) {
var err error var err error
session, err = getSession(sessions, id[:]) session, err = getSession(sessions, id[:])
return err return err
}, func() {
session = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -210,7 +214,7 @@ func (t *TowerDB) InsertSessionInfo(session *SessionInfo) error {
// be deleted without needing to iterate over the entire // be deleted without needing to iterate over the entire
// database. // database.
return touchSessionHintBkt(updateIndex, &session.ID) return touchSessionHintBkt(updateIndex, &session.ID)
}) }, func() {})
} }
// InsertStateUpdate stores an update sent by the client after validating that // InsertStateUpdate stores an update sent by the client after validating that
@ -292,6 +296,8 @@ func (t *TowerDB) InsertStateUpdate(update *SessionStateUpdate) (uint16, error)
// hint under its session id. This will allow us to delete the // hint under its session id. This will allow us to delete the
// entries efficiently if the session is ever removed. // entries efficiently if the session is ever removed.
return putHintForSession(updateIndex, &update.ID, update.Hint) return putHintForSession(updateIndex, &update.ID, update.Hint)
}, func() {
lastApplied = 0
}) })
if err != nil { if err != nil {
return 0, err return 0, err
@ -381,7 +387,7 @@ func (t *TowerDB) DeleteSession(target SessionID) error {
// Finally, remove this session from the update index, which // Finally, remove this session from the update index, which
// also removes any of the indexed hints beneath it. // also removes any of the indexed hints beneath it.
return removeSessionHintBkt(updateIndex, &target) return removeSessionHintBkt(updateIndex, &target)
}) }, func() {})
} }
// QueryMatches searches against all known state updates for any that match the // QueryMatches searches against all known state updates for any that match the
@ -460,6 +466,8 @@ func (t *TowerDB) QueryMatches(breachHints []blob.BreachHint) ([]Match, error) {
} }
return nil return nil
}, func() {
matches = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -478,7 +486,7 @@ func (t *TowerDB) SetLookoutTip(epoch *chainntnfs.BlockEpoch) error {
} }
return putLookoutEpoch(lookoutTip, epoch) return putLookoutEpoch(lookoutTip, epoch)
}) }, func() {})
} }
// GetLookoutTip retrieves the current lookout tip block epoch from the tower // GetLookoutTip retrieves the current lookout tip block epoch from the tower
@ -494,6 +502,8 @@ func (t *TowerDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) {
epoch = getLookoutEpoch(lookoutTip) epoch = getLookoutEpoch(lookoutTip)
return nil return nil
}, func() {
epoch = nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err

@ -107,7 +107,7 @@ func initOrSyncVersions(db versionedDB, init bool, versions []version) error {
if init { if init {
return kvdb.Update(db.bdb(), func(tx kvdb.RwTx) error { return kvdb.Update(db.bdb(), func(tx kvdb.RwTx) error {
return initDBVersion(tx, getLatestDBVersion(versions)) return initDBVersion(tx, getLatestDBVersion(versions))
}) }, func() {})
} }
// Otherwise, ensure that any migrations are applied to ensure the data // Otherwise, ensure that any migrations are applied to ensure the data
@ -159,5 +159,5 @@ func syncVersions(db versionedDB, versions []version) error {
} }
return putDBVersion(tx, latestVersion) return putDBVersion(tx, latestVersion)
}) }, func() {})
} }