Merge pull request #4705 from bhandras/kvdb_view_reset

multi: add reset closure to `kvdb.View/Update` to be able to clean up external state before retries
This commit is contained in:
András Bánki-Horváth 2020-11-06 12:47:35 +01:00 committed by GitHub
commit 0c3c6e6155
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
63 changed files with 564 additions and 260 deletions

@ -152,10 +152,12 @@ func (b *breachArbiter) start() error {
brarLog.Tracef("Starting breach arbiter")
// Load all retributions currently persisted in the retribution store.
breachRetInfos := make(map[wire.OutPoint]retributionInfo)
var breachRetInfos map[wire.OutPoint]retributionInfo
if err := b.cfg.Store.ForAll(func(ret *retributionInfo) error {
breachRetInfos[ret.chanPoint] = *ret
return nil
}, func() {
breachRetInfos = make(map[wire.OutPoint]retributionInfo)
}); err != nil {
return err
}
@ -1223,7 +1225,7 @@ type RetributionStore interface {
// ForAll iterates over the existing on-disk contents and applies a
// chosen, read-only callback to each. This method should ensure that it
// immediately propagate any errors generated by the callback.
ForAll(cb func(*retributionInfo) error) error
ForAll(cb func(*retributionInfo) error, reset func()) error
}
// retributionStore handles persistence of retribution states to disk and is
@ -1263,7 +1265,7 @@ func (rs *retributionStore) Add(ret *retributionInfo) error {
}
return retBucket.Put(outBuf.Bytes(), retBuf.Bytes())
})
}, func() {})
}
// Finalize writes a signed justice transaction to the retribution store. This
@ -1288,7 +1290,7 @@ func (rs *retributionStore) Finalize(chanPoint *wire.OutPoint,
}
return justiceBkt.Put(chanBuf.Bytes(), txBuf.Bytes())
})
}, func() {})
}
// GetFinalizedTxn loads the finalized justice transaction for the provided
@ -1312,6 +1314,8 @@ func (rs *retributionStore) GetFinalizedTxn(
finalTxBytes = justiceBkt.Get(chanBuf.Bytes())
return nil
}, func() {
finalTxBytes = nil
}); err != nil {
return nil, err
}
@ -1349,6 +1353,8 @@ func (rs *retributionStore) IsBreached(chanPoint *wire.OutPoint) (bool, error) {
}
return nil
}, func() {
found = false
})
return found, err
@ -1390,12 +1396,14 @@ func (rs *retributionStore) Remove(chanPoint *wire.OutPoint) error {
}
return justiceBkt.Delete(chanBytes)
})
}, func() {})
}
// ForAll iterates through all stored retributions and executes the passed
// callback function on each retribution.
func (rs *retributionStore) ForAll(cb func(*retributionInfo) error) error {
func (rs *retributionStore) ForAll(cb func(*retributionInfo) error,
reset func()) error {
return kvdb.View(rs.db, func(tx kvdb.RTx) error {
// If the bucket does not exist, then there are no pending
// retributions.
@ -1416,7 +1424,7 @@ func (rs *retributionStore) ForAll(cb func(*retributionInfo) error) error {
return cb(ret)
})
})
}, reset)
}
// Encode serializes the retribution into the passed byte stream.

@ -431,11 +431,13 @@ func (frs *failingRetributionStore) Remove(key *wire.OutPoint) error {
return frs.rs.Remove(key)
}
func (frs *failingRetributionStore) ForAll(cb func(*retributionInfo) error) error {
func (frs *failingRetributionStore) ForAll(cb func(*retributionInfo) error,
reset func()) error {
frs.mu.Lock()
defer frs.mu.Unlock()
return frs.rs.ForAll(cb)
return frs.rs.ForAll(cb, reset)
}
// Parse the pubkeys in the breached outputs.
@ -592,10 +594,13 @@ func (rs *mockRetributionStore) Remove(key *wire.OutPoint) error {
return nil
}
func (rs *mockRetributionStore) ForAll(cb func(*retributionInfo) error) error {
func (rs *mockRetributionStore) ForAll(cb func(*retributionInfo) error,
reset func()) error {
rs.mu.Lock()
defer rs.mu.Unlock()
reset()
for _, retInfo := range rs.state {
if err := cb(copyRetInfo(retInfo)); err != nil {
return err
@ -717,6 +722,8 @@ func countRetributions(t *testing.T, rs RetributionStore) int {
err := rs.ForAll(func(_ *retributionInfo) error {
count++
return nil
}, func() {
count = 0
})
if err != nil {
t.Fatalf("unable to list retributions in db: %v", err)
@ -919,7 +926,7 @@ restartCheck:
// Construct a set of all channel points presented by the store. Entries
// are only be added to the set if their corresponding retribution
// information matches the test vector.
var foundSet = make(map[wire.OutPoint]struct{})
var foundSet map[wire.OutPoint]struct{}
// Iterate through the stored retributions, checking to see if we have
// an equivalent retribution in the test vector. This will return an
@ -948,6 +955,8 @@ restartCheck:
}
return nil
}, func() {
foundSet = make(map[wire.OutPoint]struct{})
}); err != nil {
t.Fatalf("failed to iterate over persistent retributions: %v",
err)

@ -179,6 +179,8 @@ func (c *HeightHintCache) QuerySpendHint(spendRequest SpendRequest) (uint32, err
}
return channeldb.ReadElement(bytes.NewReader(spendHint), &hint)
}, func() {
hint = 0
})
if err != nil {
return 0, err
@ -278,6 +280,8 @@ func (c *HeightHintCache) QueryConfirmHint(confRequest ConfRequest) (uint32, err
}
return channeldb.ReadElement(bytes.NewReader(confirmHint), &hint)
}, func() {
hint = 0
})
if err != nil {
return 0, err

@ -756,7 +756,7 @@ func (c *OpenChannel) RefreshShortChanID() error {
}
return nil
})
}, func() {})
if err != nil {
return err
}
@ -893,7 +893,7 @@ func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error {
channel.ShortChannelID = openLoc
return putOpenChannel(chanBucket.(kvdb.RwBucket), channel)
}); err != nil {
}, func() {}); err != nil {
return err
}
@ -950,6 +950,8 @@ func (c *OpenChannel) DataLossCommitPoint() (*btcec.PublicKey, error) {
}
return nil
}, func() {
commitPoint = nil
})
if err != nil {
return nil, err
@ -1168,6 +1170,8 @@ func (c *OpenChannel) getClosingTx(key []byte) (*wire.MsgTx, error) {
}
r := bytes.NewReader(bs)
return ReadElement(r, &closeTx)
}, func() {
closeTx = nil
})
if err != nil {
return nil, err
@ -1215,7 +1219,7 @@ func (c *OpenChannel) putChanStatus(status ChannelStatus,
}
return nil
}); err != nil {
}, func() {}); err != nil {
return err
}
@ -1244,7 +1248,7 @@ func (c *OpenChannel) clearChanStatus(status ChannelStatus) error {
channel.chanStatus = status
return putOpenChannel(chanBucket, channel)
}); err != nil {
}, func() {}); err != nil {
return err
}
@ -1352,7 +1356,7 @@ func (c *OpenChannel) SyncPending(addr net.Addr, pendingHeight uint32) error {
return kvdb.Update(c.Db, func(tx kvdb.RwTx) error {
return syncNewChannel(tx, c, []net.Addr{addr})
})
}, func() {})
}
// syncNewChannel will write the passed channel to disk, and also create a
@ -1486,7 +1490,7 @@ func (c *OpenChannel) UpdateCommitment(newCommitment *ChannelCommitment,
}
return nil
})
}, func() {})
if err != nil {
return err
}
@ -2026,7 +2030,7 @@ func (c *OpenChannel) AppendRemoteCommitChain(diff *CommitDiff) error {
return err
}
return chanBucket.Put(commitDiffKey, b.Bytes())
})
}, func() {})
}
// RemoteCommitChainTip returns the "tip" of the current remote commitment
@ -2062,6 +2066,8 @@ func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) {
cd = dcd
return nil
}, func() {
cd = nil
})
if err != nil {
return nil, err
@ -2094,6 +2100,8 @@ func (c *OpenChannel) UnsignedAckedUpdates() ([]LogUpdate, error) {
r := bytes.NewReader(updateBytes)
updates, err = deserializeLogUpdates(r)
return err
}, func() {
updates = nil
})
if err != nil {
return nil, err
@ -2127,6 +2135,8 @@ func (c *OpenChannel) RemoteUnsignedLocalUpdates() ([]LogUpdate, error) {
r := bytes.NewReader(updateBytes)
updates, err = deserializeLogUpdates(r)
return err
}, func() {
updates = nil
})
if err != nil {
return nil, err
@ -2157,7 +2167,7 @@ func (c *OpenChannel) InsertNextRevocation(revKey *btcec.PublicKey) error {
}
return putChanRevocationState(chanBucket, c)
})
}, func() {})
if err != nil {
return err
}
@ -2317,6 +2327,8 @@ func (c *OpenChannel) AdvanceCommitChainTail(fwdPkg *FwdPkg,
newRemoteCommit = &newCommit.Commitment
return nil
}, func() {
newRemoteCommit = nil
})
if err != nil {
return err
@ -2365,6 +2377,8 @@ func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) {
var err error
fwdPkgs, err = c.Packager.LoadFwdPkgs(tx)
return err
}, func() {
fwdPkgs = nil
}); err != nil {
return nil, err
}
@ -2381,7 +2395,7 @@ func (c *OpenChannel) AckAddHtlcs(addRefs ...AddRef) error {
return kvdb.Update(c.Db, func(tx kvdb.RwTx) error {
return c.Packager.AckAddHtlcs(tx, addRefs...)
})
}, func() {})
}
// AckSettleFails updates the SettleFailFilter containing any of the provided
@ -2394,7 +2408,7 @@ func (c *OpenChannel) AckSettleFails(settleFailRefs ...SettleFailRef) error {
return kvdb.Update(c.Db, func(tx kvdb.RwTx) error {
return c.Packager.AckSettleFails(tx, settleFailRefs...)
})
}, func() {})
}
// SetFwdFilter atomically sets the forwarding filter for the forwarding package
@ -2405,7 +2419,7 @@ func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error {
return kvdb.Update(c.Db, func(tx kvdb.RwTx) error {
return c.Packager.SetFwdFilter(tx, height, fwdFilter)
})
}, func() {})
}
// RemoveFwdPkgs atomically removes forwarding packages specified by the remote
@ -2426,7 +2440,7 @@ func (c *OpenChannel) RemoveFwdPkgs(heights ...uint64) error {
}
return nil
})
}, func() {})
}
// RevocationLogTail returns the "tail", or the end of the current revocation
@ -2475,7 +2489,7 @@ func (c *OpenChannel) RevocationLogTail() (*ChannelCommitment, error) {
}
return nil
}); err != nil {
}, func() {}); err != nil {
return nil, err
}
@ -2509,6 +2523,8 @@ func (c *OpenChannel) CommitmentHeight() (uint64, error) {
height = commit.CommitHeight
return nil
}, func() {
height = 0
})
if err != nil {
return 0, err
@ -2547,7 +2563,7 @@ func (c *OpenChannel) FindPreviousState(updateNum uint64) (*ChannelCommitment, e
commit = c
return nil
})
}, func() {})
if err != nil {
return nil, err
}
@ -2785,7 +2801,7 @@ func (c *OpenChannel) CloseChannel(summary *ChannelCloseSummary,
return putChannelCloseSummary(
tx, chanPointBuf.Bytes(), summary, chanState,
)
})
}, func() {})
}
// ChannelSnapshot is a frozen snapshot of the current channel state. A
@ -2870,7 +2886,7 @@ func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitmen
}
return fetchChanCommitments(chanBucket, c)
})
}, func() {})
if err != nil {
return nil, nil, err
}
@ -2892,7 +2908,7 @@ func (c *OpenChannel) RemoteRevocationStore() (shachain.Store, error) {
}
return fetchChanRevocationState(chanBucket, c)
})
}, func() {})
if err != nil {
return nil, err
}

@ -1447,7 +1447,7 @@ func TestBalanceAtHeight(t *testing.T) {
commit.RemoteBalance = remote
return appendChannelLogEntry(logBucket, &commit)
})
}, func() {})
return err
}

@ -191,22 +191,31 @@ type DB struct {
// Update is a wrapper around walletdb.Update which calls into the extended
// backend when available. This call is needed to be able to cast DB to
// ExtendedBackend.
func (db *DB) Update(f func(tx walletdb.ReadWriteTx) error) error {
// ExtendedBackend. The passed reset function is called before the start of the
// transaction and can be used to reset intermediate state. As callers may
// expect retries of the f closure (depending on the database backend used), the
// reset function will be called before each retry respectively.
func (db *DB) Update(f func(tx walletdb.ReadWriteTx) error, reset func()) error {
if v, ok := db.Backend.(kvdb.ExtendedBackend); ok {
return v.Update(f)
return v.Update(f, reset)
}
reset()
return walletdb.Update(db, f)
}
// View is a wrapper around walletdb.View which calls into the extended
// backend when available. This call is needed to be able to cast DB to
// ExtendedBackend.
func (db *DB) View(f func(tx walletdb.ReadTx) error) error {
// ExtendedBackend. The passed reset function is called before the start of the
// transaction and can be used to reset intermediate state. As callers may
// expect retries of the f closure (depending on the database backend used), the
// reset function will be called before each retry respectively.
func (db *DB) View(f func(tx walletdb.ReadTx) error, reset func()) error {
if v, ok := db.Backend.(kvdb.ExtendedBackend); ok {
return v.View(f)
return v.View(f, reset)
}
reset()
return walletdb.View(db, f)
}
@ -306,7 +315,7 @@ func (d *DB) Wipe() error {
}
}
return nil
})
}, func() {})
}
// createChannelDB creates and initializes a fresh version of channeldb. In
@ -360,7 +369,7 @@ func initChannelDB(db kvdb.Backend) error {
meta.DbVersionNumber = getLatestDBVersion(dbVersions)
return putMeta(meta, tx)
})
}, func() {})
if err != nil {
return fmt.Errorf("unable to create new channeldb: %v", err)
}
@ -389,6 +398,8 @@ func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error)
var err error
channels, err = d.fetchOpenChannels(tx, nodeID)
return err
}, func() {
channels = nil
})
return channels, err
@ -574,7 +585,7 @@ func (d *DB) FetchChannel(chanPoint wire.OutPoint) (*OpenChannel, error) {
})
}
err := kvdb.View(d, chanScan)
err := kvdb.View(d, chanScan, func() {})
if err != nil {
return nil, err
}
@ -741,6 +752,8 @@ func fetchChannels(d *DB, filters ...fetchChannelsFilter) ([]*OpenChannel, error
})
})
}, func() {
channels = nil
})
if err != nil {
return nil, err
@ -781,6 +794,8 @@ func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, erro
chanSummaries = append(chanSummaries, chanSummary)
return nil
})
}, func() {
chanSummaries = nil
}); err != nil {
return nil, err
}
@ -817,6 +832,8 @@ func (d *DB) FetchClosedChannel(chanID *wire.OutPoint) (*ChannelCloseSummary, er
chanSummary, err = deserializeCloseChannelSummary(summaryReader)
return err
}, func() {
chanSummary = nil
}); err != nil {
return nil, err
}
@ -865,6 +882,8 @@ func (d *DB) FetchClosedChannelForID(cid lnwire.ChannelID) (
return nil
}
return ErrClosedChannelNotFound
}, func() {
chanSummary = nil
}); err != nil {
return nil, err
}
@ -925,7 +944,7 @@ func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error {
// garbage collect it to ensure we don't establish persistent
// connections to peers without open channels.
return d.pruneLinkNode(tx, chanSummary.RemotePub)
})
}, func() {})
}
// pruneLinkNode determines whether we should garbage collect a link node from
@ -965,7 +984,7 @@ func (d *DB) PruneLinkNodes() error {
}
return nil
})
}, func() {})
}
// ChannelShell is a shell of a channel that is meant to be used for channel
@ -1012,7 +1031,7 @@ func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error {
}
return nil
})
}, func() {})
if err != nil {
return err
}
@ -1052,6 +1071,8 @@ func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) {
}
return nil
}, func() {
linkNode = nil
})
if dbErr != nil {
return nil, dbErr
@ -1194,7 +1215,7 @@ func (d *DB) syncVersions(versions []version) error {
}
return nil
})
}, func() {})
}
// ChannelGraph returns a new instance of the directed channel graph.
@ -1261,6 +1282,8 @@ func (db *DB) FetchHistoricalChannel(outPoint *wire.OutPoint) (*OpenChannel, err
channel, err = fetchOpenChannel(chanBucket, outPoint)
return err
}, func() {
channel = nil
})
if err != nil {
return nil, err

@ -228,9 +228,7 @@ type ForwardingLogTimeSlice struct {
//
// TODO(roasbeef): rename?
func (f *ForwardingLog) Query(q ForwardingEventQuery) (ForwardingLogTimeSlice, error) {
resp := ForwardingLogTimeSlice{
ForwardingEventQuery: q,
}
var resp ForwardingLogTimeSlice
// If the user provided an index offset, then we'll not know how many
// records we need to skip. We'll also keep track of the record offset
@ -297,6 +295,10 @@ func (f *ForwardingLog) Query(q ForwardingEventQuery) (ForwardingLogTimeSlice, e
}
return nil
}, func() {
resp = ForwardingLogTimeSlice{
ForwardingEventQuery: q,
}
})
if err != nil && err != ErrNoForwardingEvents {
return ForwardingLogTimeSlice{}, err

@ -209,7 +209,7 @@ func TestPackagerEmptyFwdPkg(t *testing.T) {
if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.AddFwdPkg(tx, fwdPkg)
}); err != nil {
}, func() {}); err != nil {
t.Fatalf("unable to add fwd pkg: %v", err)
}
@ -228,7 +228,7 @@ func TestPackagerEmptyFwdPkg(t *testing.T) {
// fwd filter.
if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
}); err != nil {
}, func() {}); err != nil {
t.Fatalf("unable to set fwdfiter: %v", err)
}
@ -246,7 +246,7 @@ func TestPackagerEmptyFwdPkg(t *testing.T) {
// Lastly, remove the completed forwarding package from disk.
if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.RemovePkg(tx, fwdPkg.Height)
}); err != nil {
}, func() {}); err != nil {
t.Fatalf("unable to remove fwdpkg: %v", err)
}
@ -281,7 +281,7 @@ func TestPackagerOnlyAdds(t *testing.T) {
if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.AddFwdPkg(tx, fwdPkg)
}); err != nil {
}, func() {}); err != nil {
t.Fatalf("unable to add fwd pkg: %v", err)
}
@ -302,7 +302,7 @@ func TestPackagerOnlyAdds(t *testing.T) {
// was failed locally.
if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
}); err != nil {
}, func() {}); err != nil {
t.Fatalf("unable to set fwdfiter: %v", err)
}
@ -326,7 +326,7 @@ func TestPackagerOnlyAdds(t *testing.T) {
if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.AckAddHtlcs(tx, addRef)
}); err != nil {
}, func() {}); err != nil {
t.Fatalf("unable to ack add htlc: %v", err)
}
}
@ -345,7 +345,7 @@ func TestPackagerOnlyAdds(t *testing.T) {
// Lastly, remove the completed forwarding package from disk.
if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.RemovePkg(tx, fwdPkg.Height)
}); err != nil {
}, func() {}); err != nil {
t.Fatalf("unable to remove fwdpkg: %v", err)
}
@ -383,7 +383,7 @@ func TestPackagerOnlySettleFails(t *testing.T) {
if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.AddFwdPkg(tx, fwdPkg)
}); err != nil {
}, func() {}); err != nil {
t.Fatalf("unable to add fwd pkg: %v", err)
}
@ -404,7 +404,7 @@ func TestPackagerOnlySettleFails(t *testing.T) {
// was failed locally.
if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
}); err != nil {
}, func() {}); err != nil {
t.Fatalf("unable to set fwdfiter: %v", err)
}
@ -430,7 +430,7 @@ func TestPackagerOnlySettleFails(t *testing.T) {
if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.AckSettleFails(tx, failSettleRef)
}); err != nil {
}, func() {}); err != nil {
t.Fatalf("unable to ack add htlc: %v", err)
}
}
@ -450,7 +450,7 @@ func TestPackagerOnlySettleFails(t *testing.T) {
// Lastly, remove the completed forwarding package from disk.
if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.RemovePkg(tx, fwdPkg.Height)
}); err != nil {
}, func() {}); err != nil {
t.Fatalf("unable to remove fwdpkg: %v", err)
}
@ -488,7 +488,7 @@ func TestPackagerAddsThenSettleFails(t *testing.T) {
if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.AddFwdPkg(tx, fwdPkg)
}); err != nil {
}, func() {}); err != nil {
t.Fatalf("unable to add fwd pkg: %v", err)
}
@ -509,7 +509,7 @@ func TestPackagerAddsThenSettleFails(t *testing.T) {
// was failed locally.
if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
}); err != nil {
}, func() {}); err != nil {
t.Fatalf("unable to set fwdfiter: %v", err)
}
@ -534,7 +534,7 @@ func TestPackagerAddsThenSettleFails(t *testing.T) {
if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.AckAddHtlcs(tx, addRef)
}); err != nil {
}, func() {}); err != nil {
t.Fatalf("unable to ack add htlc: %v", err)
}
}
@ -561,7 +561,7 @@ func TestPackagerAddsThenSettleFails(t *testing.T) {
if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.AckSettleFails(tx, failSettleRef)
}); err != nil {
}, func() {}); err != nil {
t.Fatalf("unable to remove settle/fail htlc: %v", err)
}
}
@ -581,7 +581,7 @@ func TestPackagerAddsThenSettleFails(t *testing.T) {
// Lastly, remove the completed forwarding package from disk.
if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.RemovePkg(tx, fwdPkg.Height)
}); err != nil {
}, func() {}); err != nil {
t.Fatalf("unable to remove fwdpkg: %v", err)
}
@ -621,7 +621,7 @@ func TestPackagerSettleFailsThenAdds(t *testing.T) {
if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.AddFwdPkg(tx, fwdPkg)
}); err != nil {
}, func() {}); err != nil {
t.Fatalf("unable to add fwd pkg: %v", err)
}
@ -642,7 +642,7 @@ func TestPackagerSettleFailsThenAdds(t *testing.T) {
// was failed locally.
if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
}); err != nil {
}, func() {}); err != nil {
t.Fatalf("unable to set fwdfiter: %v", err)
}
@ -671,7 +671,7 @@ func TestPackagerSettleFailsThenAdds(t *testing.T) {
if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.AckSettleFails(tx, failSettleRef)
}); err != nil {
}, func() {}); err != nil {
t.Fatalf("unable to remove settle/fail htlc: %v", err)
}
}
@ -698,7 +698,7 @@ func TestPackagerSettleFailsThenAdds(t *testing.T) {
if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.AckAddHtlcs(tx, addRef)
}); err != nil {
}, func() {}); err != nil {
t.Fatalf("unable to ack add htlc: %v", err)
}
}
@ -718,7 +718,7 @@ func TestPackagerSettleFailsThenAdds(t *testing.T) {
// Lastly, remove the completed forwarding package from disk.
if err := kvdb.Update(db, func(tx kvdb.RwTx) error {
return packager.RemovePkg(tx, fwdPkg.Height)
}); err != nil {
}, func() {}); err != nil {
t.Fatalf("unable to remove fwdpkg: %v", err)
}
@ -786,6 +786,8 @@ func loadFwdPkgs(t *testing.T, db kvdb.Backend,
var err error
fwdPkgs, err = packager.LoadFwdPkgs(tx)
return err
}, func() {
fwdPkgs = nil
}); err != nil {
t.Fatalf("unable to load fwd pkgs: %v", err)
}

@ -249,7 +249,7 @@ func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, *ChannelEdgePoli
// be aborted.
return cb(&edgeInfo, edge1, edge2)
})
})
}, func() {})
}
// ForEachNodeChannel iterates through all channels of a given node, executing the
@ -279,7 +279,7 @@ func (c *ChannelGraph) ForEachNodeChannel(tx kvdb.RTx, nodePub []byte,
// have their disabled bit on.
func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) {
var disabledChanIDs []uint64
chanEdgeFound := make(map[uint64]struct{})
var chanEdgeFound map[uint64]struct{}
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
edges := tx.ReadBucket(edgeBucket)
@ -308,6 +308,9 @@ func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) {
chanEdgeFound[chanID] = struct{}{}
return nil
})
}, func() {
disabledChanIDs = nil
chanEdgeFound = make(map[uint64]struct{})
})
if err != nil {
return nil, err
@ -353,7 +356,7 @@ func (c *ChannelGraph) ForEachNode(cb func(kvdb.RTx, *LightningNode) error) erro
})
}
return kvdb.View(c.db, traversal)
return kvdb.View(c.db, traversal, func() {})
}
// SourceNode returns the source node of the graph. The source node is treated
@ -377,6 +380,8 @@ func (c *ChannelGraph) SourceNode() (*LightningNode, error) {
source = node
return nil
}, func() {
source = nil
})
if err != nil {
return nil, err
@ -429,7 +434,7 @@ func (c *ChannelGraph) SetSourceNode(node *LightningNode) error {
// Finally, we commit the information of the lightning node
// itself.
return addLightningNode(tx, node)
})
}, func() {})
}
// AddLightningNode adds a vertex/node to the graph database. If the node is not
@ -443,7 +448,7 @@ func (c *ChannelGraph) SetSourceNode(node *LightningNode) error {
func (c *ChannelGraph) AddLightningNode(node *LightningNode) error {
return kvdb.Update(c.db, func(tx kvdb.RwTx) error {
return addLightningNode(tx, node)
})
}, func() {})
}
func addLightningNode(tx kvdb.RwTx, node *LightningNode) error {
@ -493,6 +498,8 @@ func (c *ChannelGraph) LookupAlias(pub *btcec.PublicKey) (string, error) {
// package...
alias = string(a)
return nil
}, func() {
alias = ""
})
if err != nil {
return "", err
@ -512,7 +519,7 @@ func (c *ChannelGraph) DeleteLightningNode(nodePub route.Vertex) error {
}
return c.deleteLightningNode(nodes, nodePub[:])
})
}, func() {})
}
// deleteLightningNode uses an existing database transaction to remove a
@ -572,7 +579,7 @@ func (c *ChannelGraph) AddChannelEdge(edge *ChannelEdgeInfo) error {
err := kvdb.Update(c.db, func(tx kvdb.RwTx) error {
return c.addChannelEdge(tx, edge)
})
}, func() {})
if err != nil {
return err
}
@ -774,7 +781,7 @@ func (c *ChannelGraph) HasChannelEdge(
}
return nil
}); err != nil {
}, func() {}); err != nil {
return time.Time{}, time.Time{}, exists, isZombie, err
}
@ -813,7 +820,7 @@ func (c *ChannelGraph) UpdateChannelEdge(edge *ChannelEdgeInfo) error {
}
return putChanEdgeInfo(edgeIndex, edge, chanKey)
})
}, func() {})
}
const (
@ -936,6 +943,8 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint,
// prune any nodes that have had a channel closed within the
// latest block.
return c.pruneGraphNodes(nodes, edgeIndex)
}, func() {
chansClosed = nil
})
if err != nil {
return nil, err
@ -969,7 +978,7 @@ func (c *ChannelGraph) PruneGraphNodes() error {
}
return c.pruneGraphNodes(nodes, edgeIndex)
})
}, func() {})
}
// pruneGraphNodes attempts to remove any nodes from the graph who have had a
@ -1188,6 +1197,8 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*ChannelEdgeInf
}
return nil
}, func() {
removedChans = nil
}); err != nil {
return nil, err
}
@ -1235,7 +1246,7 @@ func (c *ChannelGraph) PruneTip() (*chainhash.Hash, uint32, error) {
tipHeight = byteOrder.Uint32(k[:])
return nil
})
}, func() {})
if err != nil {
return nil, 0, err
}
@ -1290,7 +1301,7 @@ func (c *ChannelGraph) DeleteChannelEdges(chanIDs ...uint64) error {
}
return nil
})
}, func() {})
if err != nil {
return err
}
@ -1312,6 +1323,8 @@ func (c *ChannelGraph) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
var err error
chanID, err = getChanID(tx, chanPoint)
return err
}, func() {
chanID = 0
}); err != nil {
return 0, err
}
@ -1379,6 +1392,8 @@ func (c *ChannelGraph) HighestChanID() (uint64, error) {
// to the caller.
cid = byteOrder.Uint64(lastChanID)
return nil
}, func() {
cid = 0
})
if err != nil && err != ErrGraphNoEdgesFound {
return 0, err
@ -1409,8 +1424,8 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]Cha
// To ensure we don't return duplicate ChannelEdges, we'll use an
// additional map to keep track of the edges already seen to prevent
// re-adding it.
edgesSeen := make(map[uint64]struct{})
edgesToCache := make(map[uint64]ChannelEdge)
var edgesSeen map[uint64]struct{}
var edgesToCache map[uint64]ChannelEdge
var edgesInHorizon []ChannelEdge
c.cacheMu.Lock()
@ -1507,6 +1522,10 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]Cha
}
return nil
}, func() {
edgesSeen = make(map[uint64]struct{})
edgesToCache = make(map[uint64]ChannelEdge)
edgesInHorizon = nil
})
switch {
case err == ErrGraphNoEdgesFound:
@ -1577,6 +1596,8 @@ func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, endTime time.Time) ([]Lig
}
return nil
}, func() {
nodesInHorizon = nil
})
switch {
case err == ErrGraphNoEdgesFound:
@ -1637,6 +1658,8 @@ func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) {
}
return nil
}, func() {
newChanIDs = nil
})
switch {
// If we don't know of any edges yet, then we'll return the entire set
@ -1701,7 +1724,10 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, endHeight uint32) ([]uint
}
return nil
}, func() {
chanIDs = nil
})
switch {
// If we don't know of any channels yet, then there's nothing to
// filter, so we'll return an empty slice.
@ -1775,6 +1801,8 @@ func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
})
}
return nil
}, func() {
chanEdges = nil
})
if err != nil {
return nil, err
@ -1912,6 +1940,8 @@ func (c *ChannelGraph) UpdateEdgePolicy(edge *ChannelEdgePolicy) error {
var err error
isUpdate1, err = updateEdgePolicy(tx, edge)
return err
}, func() {
isUpdate1 = false
})
if err != nil {
return err
@ -2209,7 +2239,7 @@ func (c *ChannelGraph) FetchLightningNode(tx kvdb.RTx, nodePub route.Vertex) (
var err error
if tx == nil {
err = kvdb.View(c.db, fetchNode)
err = kvdb.View(c.db, fetchNode, func() {})
} else {
err = fetchNode(tx)
}
@ -2259,6 +2289,9 @@ func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, erro
exists = true
updateTime = node.LastUpdate
return nil
}, func() {
updateTime = time.Time{}
exists = false
})
if err != nil {
return time.Time{}, exists, err
@ -2346,7 +2379,7 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db *DB,
// If no transaction was provided, then we'll create a new transaction
// to execute the transaction within.
if tx == nil {
return kvdb.View(db, traversal)
return kvdb.View(db, traversal, func() {})
}
// Otherwise, we re-use the existing transaction to execute the graph
@ -2596,7 +2629,7 @@ func (c *ChannelEdgeInfo) FetchOtherNode(tx kvdb.RTx, thisNodeKey []byte) (*Ligh
// otherwise we can use the existing db transaction.
var err error
if tx == nil {
err = kvdb.View(c.db, fetchNodeFunc)
err = kvdb.View(c.db, fetchNodeFunc, func() { targetNode = nil })
} else {
err = fetchNodeFunc(tx)
}
@ -2929,6 +2962,10 @@ func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint,
policy1 = e1
policy2 = e2
return nil
}, func() {
edgeInfo = nil
policy1 = nil
policy2 = nil
})
if err != nil {
return nil, nil, nil, err
@ -3030,6 +3067,10 @@ func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64,
policy1 = e1
policy2 = e2
return nil
}, func() {
edgeInfo = nil
policy1 = nil
policy2 = nil
})
if err == ErrZombieEdge {
return edgeInfo, nil, nil, err
@ -3062,6 +3103,8 @@ func (c *ChannelGraph) IsPublicNode(pubKey [33]byte) (bool, error) {
nodeIsPublic, err = node.isPublic(tx, ourPubKey)
return err
}, func() {
nodeIsPublic = false
})
if err != nil {
return false, err
@ -3183,6 +3226,8 @@ func (c *ChannelGraph) ChannelView() ([]EdgePoint, error) {
return nil
})
}, func() {
edgePoints = nil
}); err != nil {
return nil, err
}
@ -3229,7 +3274,7 @@ func (c *ChannelGraph) MarkEdgeLive(chanID uint64) error {
var k [8]byte
byteOrder.PutUint64(k[:], chanID)
return zombieIndex.Delete(k[:])
})
}, func() {})
if err != nil {
return err
}
@ -3261,6 +3306,10 @@ func (c *ChannelGraph) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte) {
isZombie, pubKey1, pubKey2 = isZombieEdge(zombieIndex, chanID)
return nil
}, func() {
isZombie = false
pubKey1 = [33]byte{}
pubKey2 = [33]byte{}
})
if err != nil {
return false, [33]byte{}, [33]byte{}
@ -3307,6 +3356,8 @@ func (c *ChannelGraph) NumZombies() (uint64, error) {
numZombies++
return nil
})
}, func() {
numZombies = 0
})
if err != nil {
return 0, err

@ -2272,7 +2272,7 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) {
return nil
})
})
}, func() {})
if err != nil {
t.Fatal(err)
}
@ -2858,7 +2858,7 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) {
}
return nil
})
}, func() {})
if err != nil {
t.Fatalf("error reading db: %v", err)
}
@ -2894,7 +2894,7 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) {
}
return edges.Put(edgeKey[:], stripped)
})
}, func() {})
if err != nil {
t.Fatalf("error writing db: %v", err)
}

@ -562,6 +562,8 @@ func (d *DB) AddInvoice(newInvoice *Invoice, paymentHash lntypes.Hash) (
invoiceAddIndex = newIndex
return nil
}, func() {
invoiceAddIndex = 0
})
if err != nil {
return 0, err
@ -624,6 +626,8 @@ func (d *DB) InvoicesAddedSince(sinceAddIndex uint64) ([]Invoice, error) {
}
return nil
}, func() {
newInvoices = nil
})
if err != nil {
return nil, err
@ -669,7 +673,7 @@ func (d *DB) LookupInvoice(ref InvoiceRef) (Invoice, error) {
invoice = i
return nil
})
}, func() {})
if err != nil {
return invoice, err
}
@ -731,13 +735,6 @@ func (d *DB) ScanInvoices(
scanFunc func(lntypes.Hash, *Invoice) error, reset func()) error {
return kvdb.View(d, func(tx kvdb.RTx) error {
// Reset partial results. As transaction commit success is not
// guaranteed when using etcd, we need to be prepared to redo
// the whole view transaction. In order to be able to do that
// we need a way to reset existing results. This is also done
// upon first run for initialization.
reset()
invoices := tx.ReadBucket(invoiceBucket)
if invoices == nil {
return ErrNoInvoicesCreated
@ -773,7 +770,7 @@ func (d *DB) ScanInvoices(
return scanFunc(paymentHash, &invoice)
})
})
}, reset)
}
// InvoiceQuery represents a query to the invoice database. The query allows a
@ -825,9 +822,7 @@ type InvoiceSlice struct {
// QueryInvoices allows a caller to query the invoice database for invoices
// within the specified add index range.
func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) {
resp := InvoiceSlice{
InvoiceQuery: q,
}
var resp InvoiceSlice
err := kvdb.View(d, func(tx kvdb.RTx) error {
// If the bucket wasn't found, then there aren't any invoices
@ -892,6 +887,10 @@ func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) {
}
return nil
}, func() {
resp = InvoiceSlice{
InvoiceQuery: q,
}
})
if err != nil && err != ErrNoInvoicesCreated {
return resp, err
@ -953,6 +952,8 @@ func (d *DB) UpdateInvoice(ref InvoiceRef,
)
return err
}, func() {
updatedInvoice = nil
})
return updatedInvoice, err
@ -1011,6 +1012,8 @@ func (d *DB) InvoicesSettledSince(sinceSettleIndex uint64) ([]Invoice, error) {
}
return nil
}, func() {
settledInvoices = nil
})
if err != nil {
return nil, err
@ -1867,7 +1870,7 @@ func (d *DB) DeleteInvoice(invoicesToDelete []InvoiceDeleteRef) error {
}
return nil
})
}, func() {})
return err
}

@ -218,11 +218,15 @@ func (db *db) getSTMOptions() []STMOptionFunc {
}
// View opens a database read transaction and executes the function f with the
// transaction passed as a parameter. After f exits, the transaction is rolled
// back. If f errors, its error is returned, not a rollback error (if any
// occur).
func (db *db) View(f func(tx walletdb.ReadTx) error) error {
// transaction passed as a parameter. After f exits, the transaction is rolled
// back. If f errors, its error is returned, not a rollback error (if any
// occur). The passed reset function is called before the start of the
// transaction and can be used to reset intermediate state. As callers may
// expect retries of the f closure (depending on the database backend used), the
// reset function will be called before each retry respectively.
func (db *db) View(f func(tx walletdb.ReadTx) error, reset func()) error {
apply := func(stm STM) error {
reset()
return f(newReadWriteTx(stm, db.config.Prefix))
}
@ -230,13 +234,15 @@ func (db *db) View(f func(tx walletdb.ReadTx) error) error {
}
// Update opens a database read/write transaction and executes the function f
// with the transaction passed as a parameter. After f exits, if f did not
// error, the transaction is committed. Otherwise, if f did error, the
// transaction is rolled back. If the rollback fails, the original error
// returned by f is still returned. If the commit fails, the commit error is
// returned.
func (db *db) Update(f func(tx walletdb.ReadWriteTx) error) error {
// with the transaction passed as a parameter. After f exits, if f did not
// error, the transaction is committed. Otherwise, if f did error, the
// transaction is rolled back. If the rollback fails, the original error
// returned by f is still returned. If the commit fails, the commit error is
// returned. As callers may expect retries of the f closure, the reset function
// will be called before each retry respectively.
func (db *db) Update(f func(tx walletdb.ReadWriteTx) error, reset func()) error {
apply := func(stm STM) error {
reset()
return f(newReadWriteTx(stm, db.config.Prefix))
}
@ -300,5 +306,5 @@ func (db *db) Close() error {
//
// Batch is only useful when there are multiple goroutines calling it.
func (db *db) Batch(apply func(tx walletdb.ReadWriteTx) error) error {
return db.Update(apply)
return db.Update(apply, func() {})
}

@ -28,7 +28,7 @@ func TestCopy(t *testing.T) {
require.NoError(t, apple.Put([]byte("key"), []byte("val")))
return nil
})
}, func() {})
// Expect non-zero copy.
var buf bytes.Buffer
@ -66,7 +66,7 @@ func TestAbortContext(t *testing.T) {
require.Error(t, err, "context canceled")
return nil
})
}, func() {})
require.Error(t, err, "context canceled")

@ -290,6 +290,9 @@ func (b *readWriteBucket) Put(key, value []byte) error {
// Delete deletes the key/value pointed to by the passed key.
// Returns ErrKeyRequred if the passed key is empty.
func (b *readWriteBucket) Delete(key []byte) error {
if key == nil {
return nil
}
if len(key) == 0 {
return walletdb.ErrKeyRequired
}

@ -79,7 +79,7 @@ func TestBucketCreation(t *testing.T) {
require.NotNil(t, apple.NestedReadWriteBucket([]byte("banana")))
require.NotNil(t, apple.NestedReadBucket([]byte("banana")))
return nil
})
}, func() {})
require.Nil(t, err)
@ -189,7 +189,7 @@ func TestBucketDeletion(t *testing.T) {
// "aple/banana" exists
require.NotNil(t, apple.NestedReadWriteBucket([]byte("banana")))
return nil
})
}, func() {})
require.Nil(t, err)
@ -261,7 +261,7 @@ func TestBucketForEach(t *testing.T) {
require.Equal(t, expected, got)
return nil
})
}, func() {})
require.Nil(t, err)
@ -354,7 +354,7 @@ func TestBucketForEachWithError(t *testing.T) {
require.Equal(t, expected, got)
require.Error(t, err)
return nil
})
}, func() {})
require.Nil(t, err)
@ -399,7 +399,7 @@ func TestBucketSequence(t *testing.T) {
}
return nil
})
}, func() {})
require.Nil(t, err)
}
@ -431,7 +431,7 @@ func TestKeyClash(t *testing.T) {
require.NotNil(t, banana)
return nil
})
}, func() {})
require.Nil(t, err)
@ -457,7 +457,7 @@ func TestKeyClash(t *testing.T) {
require.Error(t, walletdb.ErrIncompatibleValue, b)
return nil
})
}, func() {})
require.Nil(t, err)
@ -494,7 +494,7 @@ func TestBucketCreateDelete(t *testing.T) {
require.NotNil(t, banana)
return nil
})
}, func() {})
require.NoError(t, err)
err = db.Update(func(tx walletdb.ReadWriteTx) error {
@ -503,7 +503,7 @@ func TestBucketCreateDelete(t *testing.T) {
require.NoError(t, apple.DeleteNestedBucket([]byte("banana")))
return nil
})
}, func() {})
require.NoError(t, err)
err = db.Update(func(tx walletdb.ReadWriteTx) error {
@ -512,7 +512,7 @@ func TestBucketCreateDelete(t *testing.T) {
require.NoError(t, apple.Put([]byte("banana"), []byte("value")))
return nil
})
}, func() {})
require.NoError(t, err)
expected := map[string]string{

@ -24,7 +24,7 @@ func TestReadCursorEmptyInterval(t *testing.T) {
require.NotNil(t, b)
return nil
})
}, func() {})
require.NoError(t, err)
err = db.View(func(tx walletdb.ReadTx) error {
@ -49,7 +49,7 @@ func TestReadCursorEmptyInterval(t *testing.T) {
require.Nil(t, v)
return nil
})
}, func() {})
require.NoError(t, err)
}
@ -78,7 +78,7 @@ func TestReadCursorNonEmptyInterval(t *testing.T) {
require.NoError(t, b.Put([]byte(kv.key), []byte(kv.val)))
}
return nil
})
}, func() {})
require.NoError(t, err)
@ -125,7 +125,7 @@ func TestReadCursorNonEmptyInterval(t *testing.T) {
require.Nil(t, v)
return nil
})
}, func() {})
require.NoError(t, err)
}
@ -162,7 +162,7 @@ func TestReadWriteCursor(t *testing.T) {
require.NoError(t, err)
}
return nil
}))
}, func() {}))
err = db.Update(func(tx walletdb.ReadWriteTx) error {
b := tx.ReadWriteBucket([]byte("apple"))
@ -276,7 +276,7 @@ func TestReadWriteCursor(t *testing.T) {
require.Equal(t, reverseKVs(expected), kvs)
return nil
})
}, func() {})
require.NoError(t, err)
@ -320,7 +320,7 @@ func TestReadWriteCursorWithBucketAndValue(t *testing.T) {
require.NotNil(t, b2)
return nil
}))
}, func() {}))
err = db.View(func(tx walletdb.ReadTx) error {
b := tx.ReadBucket([]byte("apple"))
@ -354,7 +354,7 @@ func TestReadWriteCursorWithBucketAndValue(t *testing.T) {
require.Equal(t, []byte("val"), v)
return nil
})
}, func() {})
require.NoError(t, err)

@ -142,7 +142,7 @@ func TestChangeDuringUpdate(t *testing.T) {
count++
return nil
})
}, func() {})
require.Nil(t, err)
require.Equal(t, count, 2)

@ -10,23 +10,34 @@ import (
// error, the transaction is committed. Otherwise, if f did error, the
// transaction is rolled back. If the rollback fails, the original error
// returned by f is still returned. If the commit fails, the commit error is
// returned.
func Update(db Backend, f func(tx RwTx) error) error {
// returned. As callers may expect retries of the f closure (depending on the
// database backend used), the reset function will be called before each retry
// respectively.
func Update(db Backend, f func(tx RwTx) error, reset func()) error {
if extendedDB, ok := db.(ExtendedBackend); ok {
return extendedDB.Update(f)
return extendedDB.Update(f, reset)
}
reset()
return walletdb.Update(db, f)
}
// View opens a database read transaction and executes the function f with the
// transaction passed as a parameter. After f exits, the transaction is rolled
// back. If f errors, its error is returned, not a rollback error (if any
// occur).
func View(db Backend, f func(tx RTx) error) error {
// occur). The passed reset function is called before the start of the
// transaction and can be used to reset intermediate state. As callers may
// expect retries of the f closure (depending on the database backend used), the
// reset function will be called before each retry respectively.
func View(db Backend, f func(tx RTx) error, reset func()) error {
if extendedDB, ok := db.(ExtendedBackend); ok {
return extendedDB.View(f)
return extendedDB.View(f, reset)
}
// Since we know that walletdb simply calls into bbolt which never
// retries transactions, we'll call the reset function here before View.
reset()
return walletdb.View(db, f)
}
@ -55,19 +66,25 @@ type ExtendedBackend interface {
// PrintStats returns all collected stats pretty printed into a string.
PrintStats() string
// View opens a database read transaction and executes the function f with
// the transaction passed as a parameter. After f exits, the transaction is
// rolled back. If f errors, its error is returned, not a rollback error
// (if any occur).
View(f func(tx walletdb.ReadTx) error) error
// View opens a database read transaction and executes the function f
// with the transaction passed as a parameter. After f exits, the
// transaction is rolled back. If f errors, its error is returned, not a
// rollback error (if any occur). The passed reset function is called
// before the start of the transaction and can be used to reset
// intermediate state. As callers may expect retries of the f closure
// (depending on the database backend used), the reset function will be
//called before each retry respectively.
View(f func(tx walletdb.ReadTx) error, reset func()) error
// Update opens a database read/write transaction and executes the function
// f with the transaction passed as a parameter. After f exits, if f did not
// error, the transaction is committed. Otherwise, if f did error, the
// transaction is rolled back. If the rollback fails, the original error
// returned by f is still returned. If the commit fails, the commit error is
// returned.
Update(f func(tx walletdb.ReadWriteTx) error) error
// Update opens a database read/write transaction and executes the
// function f with the transaction passed as a parameter. After f exits,
// if f did not error, the transaction is committed. Otherwise, if f did
// error, the transaction is rolled back. If the rollback fails, the
// original error returned by f is still returned. If the commit fails,
// the commit error is returned. As callers may expect retries of the f
// closure (depending on the database backend used), the reset function
// will be called before each retry respectively.
Update(f func(tx walletdb.ReadWriteTx) error, reset func()) error
}
// Open opens an existing database for the specified type. The arguments are

@ -23,10 +23,12 @@ type Meta struct {
// FetchMeta fetches the meta data from boltdb and returns filled meta
// structure.
func (d *DB) FetchMeta(tx kvdb.RTx) (*Meta, error) {
meta := &Meta{}
var meta *Meta
err := kvdb.View(d, func(tx kvdb.RTx) error {
return fetchMeta(meta, tx)
}, func() {
meta = &Meta{}
})
if err != nil {
return nil, err
@ -58,7 +60,7 @@ func fetchMeta(meta *Meta, tx kvdb.RTx) error {
func (d *DB) PutMeta(meta *Meta) error {
return kvdb.Update(d, func(tx kvdb.RwTx) error {
return putMeta(meta, tx)
})
}, func() {})
}
// putMeta is an internal helper function used in order to allow callers to

@ -209,7 +209,7 @@ func TestMigrationWithPanic(t *testing.T) {
}
return bucket.Put(keyPrefix, beforeMigration)
})
}, func() {})
if err != nil {
t.Fatalf("unable to insert: %v", err)
}
@ -251,7 +251,7 @@ func TestMigrationWithPanic(t *testing.T) {
}
return nil
})
}, func() {})
if err != nil {
t.Fatal(err)
}
@ -283,7 +283,7 @@ func TestMigrationWithFatal(t *testing.T) {
}
return bucket.Put(keyPrefix, beforeMigration)
})
}, func() {})
if err != nil {
t.Fatalf("unable to insert pre migration key: %v", err)
}
@ -326,7 +326,7 @@ func TestMigrationWithFatal(t *testing.T) {
}
return nil
})
}, func() {})
if err != nil {
t.Fatal(err)
}
@ -359,7 +359,7 @@ func TestMigrationWithoutErrors(t *testing.T) {
}
return bucket.Put(keyPrefix, beforeMigration)
})
}, func() {})
if err != nil {
t.Fatalf("unable to update db pre migration: %v", err)
}
@ -401,7 +401,7 @@ func TestMigrationWithoutErrors(t *testing.T) {
}
return nil
})
}, func() {})
if err != nil {
t.Fatal(err)
}
@ -448,7 +448,7 @@ func TestMigrationReversion(t *testing.T) {
}
return putMeta(newMeta, tx)
})
}, func() {})
// Close the database. Even if we succeeded, our next step is to reopen.
cdb.Close()
@ -492,7 +492,7 @@ func TestMigrationDryRun(t *testing.T) {
}
return nil
})
}, func() {})
if err != nil {
t.Fatalf("unable to apply after func: %v", err)
}

@ -152,7 +152,7 @@ func createChannelDB(dbPath string) error {
DbVersionNumber: 0,
}
return putMeta(meta, tx)
})
}, func() {})
if err != nil {
return fmt.Errorf("unable to create new channeldb")
}
@ -203,6 +203,8 @@ func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, erro
chanSummaries = append(chanSummaries, chanSummary)
return nil
})
}, func() {
chanSummaries = nil
}); err != nil {
return nil, err
}

@ -190,6 +190,8 @@ func (c *ChannelGraph) SourceNode() (*LightningNode, error) {
source = node
return nil
}, func() {
source = nil
})
if err != nil {
return nil, err
@ -242,7 +244,7 @@ func (c *ChannelGraph) SetSourceNode(node *LightningNode) error {
// Finally, we commit the information of the lightning node
// itself.
return addLightningNode(tx, node)
})
}, func() {})
}
func addLightningNode(tx kvdb.RwTx, node *LightningNode) error {

@ -282,6 +282,8 @@ func (d *DB) FetchAllInvoices(pendingOnly bool) ([]Invoice, error) {
return nil
})
}, func() {
invoices = nil
})
if err != nil {
return nil, err

@ -51,7 +51,7 @@ func applyMigration(t *testing.T, beforeMigration, afterMigration func(d *DB),
// Apply migration.
err = kvdb.Update(cdb, func(tx kvdb.RwTx) error {
return migrationFunc(tx)
})
}, func() {})
if err != nil {
log.Error(err)
}

@ -95,7 +95,7 @@ func (db *DB) addPayment(payment *outgoingPayment) error {
binary.BigEndian.PutUint64(paymentIDBytes, paymentID)
return payments.Put(paymentIDBytes, paymentBytes)
})
}, func() {})
}
// fetchAllPayments returns all outgoing payments in DB.
@ -126,6 +126,8 @@ func (db *DB) fetchAllPayments() ([]*outgoingPayment, error) {
payments = append(payments, payment)
return nil
})
}, func() {
payments = nil
})
if err != nil {
return nil, err
@ -144,6 +146,8 @@ func (db *DB) fetchPaymentStatus(paymentHash [32]byte) (PaymentStatus, error) {
var err error
paymentStatus, err = fetchPaymentStatusTx(tx, paymentHash)
return err
}, func() {
paymentStatus = StatusUnknown
})
if err != nil {
return StatusUnknown, err
@ -424,6 +428,8 @@ func (db *DB) fetchPaymentsMigration9() ([]*Payment, error) {
return nil
})
})
}, func() {
payments = nil
})
if err != nil {
return nil, err

@ -55,7 +55,7 @@ func beforeMigrationFuncV11(t *testing.T, d *DB, invoices []Invoice) {
}
return nil
})
}, func() {})
if err != nil {
t.Fatal(err)
}

@ -125,7 +125,7 @@ func TestPaymentStatusesMigration(t *testing.T) {
}
return circuits.Put(inFlightKey, inFlightCircuit)
})
}, func() {})
if err != nil {
t.Fatalf("unable to add circuit map entry: %v", err)
}
@ -385,7 +385,7 @@ func TestMigrateOptionalChannelCloseSummaryFields(t *testing.T) {
return err
}
return closedChanBucket.Put(chanID, old)
})
}, func() {})
if err != nil {
t.Fatalf("unable to add old serialization: %v",
err)
@ -418,6 +418,8 @@ func TestMigrateOptionalChannelCloseSummaryFields(t *testing.T) {
"serialization")
}
return nil
}, func() {
dbSummary = nil
})
if err != nil {
t.Fatalf("unable to view DB: %v", err)
@ -491,7 +493,7 @@ func TestMigrateGossipMessageStoreKeys(t *testing.T) {
}
return messageStore.Put(oldMsgKey[:], b.Bytes())
})
}, func() {})
if err != nil {
t.Fatal(err)
}
@ -521,6 +523,8 @@ func TestMigrateGossipMessageStoreKeys(t *testing.T) {
}
return nil
}, func() {
rawMsg = nil
})
if err != nil {
t.Fatal(err)
@ -679,7 +683,7 @@ func TestOutgoingPaymentsMigration(t *testing.T) {
}
return nil
})
}, func() {})
if err != nil {
t.Fatal(err)
}
@ -855,6 +859,8 @@ func TestPaymentRouteSerialization(t *testing.T) {
}
return nil
}, func() {
oldPayments = nil
})
if err != nil {
t.Fatalf("unable to create test payments: %v", err)

@ -303,6 +303,8 @@ func (db *DB) FetchPayments() ([]*Payment, error) {
return nil
})
})
}, func() {
payments = nil
})
if err != nil {
return nil, err

@ -47,7 +47,7 @@ func ApplyMigration(t *testing.T,
// beforeMigration usually used for populating the database
// with test data.
err = kvdb.Update(cdb, beforeMigration)
err = kvdb.Update(cdb, beforeMigration, func() {})
if err != nil {
t.Fatal(err)
}
@ -65,14 +65,14 @@ func ApplyMigration(t *testing.T,
// afterMigration usually used for checking the database state and
// throwing the error if something went wrong.
err = kvdb.Update(cdb, afterMigration)
err = kvdb.Update(cdb, afterMigration, func() {})
if err != nil {
t.Fatal(err)
}
}()
// Apply migration.
err = kvdb.Update(cdb, migrationFunc)
err = kvdb.Update(cdb, migrationFunc, func() {})
if err != nil {
t.Logf("migration error: %v", err)
}

@ -108,7 +108,7 @@ func (l *LinkNode) Sync() error {
}
return putLinkNode(nodeMetaBucket, l)
})
}, func() {})
}
// putLinkNode serializes then writes the encoded version of the passed link
@ -132,7 +132,7 @@ func putLinkNode(nodeMetaBucket kvdb.RwBucket, l *LinkNode) error {
func (db *DB) DeleteLinkNode(identity *btcec.PublicKey) error {
return kvdb.Update(db, func(tx kvdb.RwTx) error {
return db.deleteLinkNode(tx, identity)
})
}, func() {})
}
func (db *DB) deleteLinkNode(tx kvdb.RwTx, identity *btcec.PublicKey) error {
@ -158,6 +158,8 @@ func (db *DB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) {
linkNode = node
return nil
}, func() {
linkNode = nil
})
return linkNode, err
@ -199,6 +201,8 @@ func (db *DB) FetchAllLinkNodes() ([]*LinkNode, error) {
linkNodes = nodes
return nil
}, func() {
linkNodes = nil
})
if err != nil {
return nil, err

@ -550,6 +550,8 @@ func (p *PaymentControl) FetchPayment(paymentHash lntypes.Hash) (
payment, err = fetchPayment(bucket)
return err
}, func() {
payment = nil
})
if err != nil {
return nil, err
@ -716,6 +718,8 @@ func (p *PaymentControl) FetchInFlightPayments() ([]*InFlightPayment, error) {
inFlights = append(inFlights, inFlight)
return nil
})
}, func() {
inFlights = nil
})
if err != nil {
return nil, err

@ -502,7 +502,7 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) {
indexCount++
return nil
})
})
}, func() { indexCount = 0 })
require.NoError(t, err)
require.Equal(t, 1, indexCount)
@ -989,7 +989,8 @@ func fetchPaymentIndexEntry(_ *testing.T, p *PaymentControl,
var err error
hash, err = deserializePaymentIndex(r)
return err
}, func() {
hash = lntypes.Hash{}
}); err != nil {
return nil, err
}

@ -269,6 +269,8 @@ func (db *DB) FetchPayments() ([]*MPPayment, error) {
payments = append(payments, duplicatePayments...)
return nil
})
}, func() {
payments = nil
})
if err != nil {
return nil, err
@ -572,6 +574,8 @@ func (db *DB) QueryPayments(query PaymentsQuery) (PaymentsResponse, error) {
}
return nil
}, func() {
resp = PaymentsResponse{}
}); err != nil {
return resp, err
}
@ -745,7 +749,7 @@ func (db *DB) DeletePayments() error {
}
return nil
})
}, func() {})
}
// fetchSequenceNumbers fetches all the sequence numbers associated with a

@ -183,7 +183,7 @@ func deletePayment(t *testing.T, db *DB, paymentHash lntypes.Hash, seqNr uint64)
// Delete the index that references this payment.
indexes := tx.ReadWriteBucket(paymentsIndexBucket)
return indexes.Delete(key)
})
}, func() {})
if err != nil {
t.Fatalf("could not delete "+
@ -622,7 +622,7 @@ func TestFetchPaymentWithSequenceNumber(t *testing.T) {
tx, test.paymentHash, seqNrBytes[:],
)
return err
})
}, func() {})
require.Equal(t, test.expectedErr, err)
})
}
@ -666,7 +666,7 @@ func appendDuplicatePayment(t *testing.T, db *DB, paymentHash lntypes.Hash,
require.NoError(t, err)
return nil
})
}, func() {})
if err != nil {
t.Fatalf("could not create payment: %v", err)
}

@ -6,7 +6,7 @@ import (
"fmt"
"time"
"github.com/btcsuite/btcwallet/walletdb"
"github.com/lightningnetwork/lnd/channeldb/kvdb"
"github.com/lightningnetwork/lnd/routing/route"
)
@ -50,7 +50,7 @@ type FlapCount struct {
// bucket for the peer's pubkey if necessary. Note that this function overwrites
// the current value.
func (d *DB) WriteFlapCounts(flapCounts map[route.Vertex]*FlapCount) error {
return d.Update(func(tx walletdb.ReadWriteTx) error {
return kvdb.Update(d, func(tx kvdb.RwTx) error {
// Run through our set of flap counts and record them for
// each peer, creating a bucket for the peer pubkey if required.
for peer, flapCount := range flapCounts {
@ -80,7 +80,7 @@ func (d *DB) WriteFlapCounts(flapCounts map[route.Vertex]*FlapCount) error {
}
return nil
})
}, func() {})
}
// ReadFlapCount attempts to read the flap count for a peer, failing if the
@ -88,7 +88,7 @@ func (d *DB) WriteFlapCounts(flapCounts map[route.Vertex]*FlapCount) error {
func (d *DB) ReadFlapCount(pubkey route.Vertex) (*FlapCount, error) {
var flapCount FlapCount
if err := d.View(func(tx walletdb.ReadTx) error {
if err := kvdb.View(d, func(tx kvdb.RTx) error {
peers := tx.ReadBucket(peersBucket)
peerBucket := peers.NestedReadBucket(pubkey[:])
@ -113,6 +113,8 @@ func (d *DB) ReadFlapCount(pubkey route.Vertex) (*FlapCount, error) {
}
return ReadElements(r, &flapCount.Count)
}, func() {
flapCount = FlapCount{}
}); err != nil {
return nil, err
}

@ -130,7 +130,7 @@ func (d *DB) PutResolverReport(tx kvdb.RwTx, chainHash chainhash.Hash,
// If the transaction is nil, we'll create a new one.
if tx == nil {
return kvdb.Update(d, putReportFunc)
return kvdb.Update(d, putReportFunc, func() {})
}
// Otherwise, we can write the report to disk using the existing
@ -250,6 +250,8 @@ func (d DB) FetchChannelReports(chainHash chainhash.Hash,
return nil
})
}, func() {
reports = nil
}); err != nil {
return nil, err
}

@ -202,7 +202,7 @@ func TestFetchChannelWriteBucket(t *testing.T) {
defer cleanup()
// Update our db to the starting state we expect.
err = kvdb.Update(db, test.setup)
err = kvdb.Update(db, test.setup, func() {})
require.NoError(t, err)
// Try to get our report bucket.
@ -211,7 +211,7 @@ func TestFetchChannelWriteBucket(t *testing.T) {
tx, testChainHash, &testChanPoint1,
)
return err
})
}, func() {})
require.NoError(t, err)
})
}

@ -42,13 +42,14 @@ type WaitingProofStore struct {
// NewWaitingProofStore creates new instance of proofs storage.
func NewWaitingProofStore(db *DB) (*WaitingProofStore, error) {
s := &WaitingProofStore{
db: db,
cache: make(map[WaitingProofKey]struct{}),
db: db,
}
if err := s.ForAll(func(proof *WaitingProof) error {
s.cache[proof.Key()] = struct{}{}
return nil
}, func() {
s.cache = make(map[WaitingProofKey]struct{})
}); err != nil && err != ErrWaitingProofNotFound {
return nil, err
}
@ -79,7 +80,7 @@ func (s *WaitingProofStore) Add(proof *WaitingProof) error {
key := proof.Key()
return bucket.Put(key[:], b.Bytes())
})
}, func() {})
if err != nil {
return err
}
@ -108,7 +109,7 @@ func (s *WaitingProofStore) Remove(key WaitingProofKey) error {
}
return bucket.Delete(key[:])
})
}, func() {})
if err != nil {
return err
}
@ -122,7 +123,9 @@ func (s *WaitingProofStore) Remove(key WaitingProofKey) error {
// ForAll iterates thought all waiting proofs and passing the waiting proof
// in the given callback.
func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error) error {
func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error,
reset func()) error {
return kvdb.View(s.db, func(tx kvdb.RTx) error {
bucket := tx.ReadBucket(waitingProofsBucketKey)
if bucket == nil {
@ -144,12 +147,12 @@ func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error) error {
return cb(proof)
})
})
}, reset)
}
// Get returns the object which corresponds to the given index.
func (s *WaitingProofStore) Get(key WaitingProofKey) (*WaitingProof, error) {
proof := &WaitingProof{}
var proof *WaitingProof
s.mu.RLock()
defer s.mu.RUnlock()
@ -172,6 +175,8 @@ func (s *WaitingProofStore) Get(key WaitingProofKey) (*WaitingProof, error) {
r := bytes.NewReader(v)
return proof.Decode(r)
}, func() {
proof = &WaitingProof{}
})
return proof, err

@ -53,7 +53,7 @@ func TestWaitingProofStore(t *testing.T) {
if err := store.ForAll(func(proof *WaitingProof) error {
return errors.New("storage should be empty")
}); err != nil && err != ErrWaitingProofNotFound {
}, func() {}); err != nil && err != ErrWaitingProofNotFound {
t.Fatal(err)
}
}

@ -174,6 +174,8 @@ func (w *WitnessCache) lookupWitness(wType WitnessType, witnessKey []byte) ([]by
copy(witness[:], dbWitness)
return nil
}, func() {
witness = nil
})
if err != nil {
return nil, err

@ -430,6 +430,8 @@ func (b *boltArbitratorLog) CurrentState() (ArbitratorState, error) {
s = ArbitratorState(stateBytes[0])
return nil
}, func() {
s = 0
})
if err != nil && err != errScopeBucketNoExist {
return s, err
@ -521,6 +523,8 @@ func (b *boltArbitratorLog) FetchUnresolvedContracts() ([]ContractResolver, erro
contracts = append(contracts, res)
return nil
})
}, func() {
contracts = nil
})
if err != nil && err != errScopeBucketNoExist && err != errNoContracts {
return nil, err
@ -685,7 +689,7 @@ func (b *boltArbitratorLog) LogContractResolutions(c *ContractResolutions) error
//
// NOTE: Part of the ContractResolver interface.
func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, error) {
c := &ContractResolutions{}
var c *ContractResolutions
err := kvdb.View(b.db, func(tx kvdb.RTx) error {
scopeBucket := tx.ReadBucket(b.scopeKey[:])
if scopeBucket == nil {
@ -769,6 +773,8 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er
}
return nil
}, func() {
c = &ContractResolutions{}
})
if err != nil {
return nil, err
@ -783,7 +789,7 @@ func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, er
//
// NOTE: Part of the ContractResolver interface.
func (b *boltArbitratorLog) FetchChainActions() (ChainActionMap, error) {
actionsMap := make(ChainActionMap)
var actionsMap ChainActionMap
err := kvdb.View(b.db, func(tx kvdb.RTx) error {
scopeBucket := tx.ReadBucket(b.scopeKey[:])
@ -813,6 +819,8 @@ func (b *boltArbitratorLog) FetchChainActions() (ChainActionMap, error) {
return nil
})
}, func() {
actionsMap = make(ChainActionMap)
})
if err != nil {
return nil, err
@ -866,6 +874,8 @@ func (b *boltArbitratorLog) FetchConfirmedCommitSet() (*CommitSet, error) {
c = commitSet
return nil
}, func() {
c = nil
})
if err != nil {
return nil, err
@ -914,7 +924,7 @@ func (b *boltArbitratorLog) WipeHistory() error {
// Finally, we'll delete the enclosing bucket itself.
return tx.DeleteTopLevelBucket(b.scopeKey[:])
})
}, func() {})
}
// checkpointContract is a private method that will be fed into
@ -941,7 +951,7 @@ func (b *boltArbitratorLog) checkpointContract(c ContractResolver,
}
return nil
})
}, func() {})
}
func encodeIncomingResolution(w io.Writer, i *lnwallet.IncomingHtlcResolution) error {

@ -1117,6 +1117,9 @@ func TestSignatureAnnouncementLocalFirst(t *testing.T) {
number++
return nil
},
func() {
number = 0
},
); err != nil {
t.Fatalf("unable to retrieve objects from store: %v", err)
}
@ -1150,6 +1153,9 @@ func TestSignatureAnnouncementLocalFirst(t *testing.T) {
number++
return nil
},
func() {
number = 0
},
); err != nil && err != channeldb.ErrWaitingProofNotFound {
t.Fatalf("unable to retrieve objects from store: %v", err)
}
@ -1219,6 +1225,9 @@ func TestOrphanSignatureAnnouncement(t *testing.T) {
number++
return nil
},
func() {
number = 0
},
); err != nil {
t.Fatalf("unable to retrieve objects from store: %v", err)
}
@ -1355,6 +1364,9 @@ func TestOrphanSignatureAnnouncement(t *testing.T) {
number++
return nil
},
func() {
number = 0
},
); err != nil {
t.Fatalf("unable to retrieve objects from store: %v", err)
}
@ -1466,6 +1478,9 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) {
number++
return nil
},
func() {
number = 0
},
); err != nil {
t.Fatalf("unable to retrieve objects from store: %v", err)
}
@ -1570,6 +1585,9 @@ out:
number++
return nil
},
func() {
number = 0
},
); err != nil && err != channeldb.ErrWaitingProofNotFound {
t.Fatalf("unable to retrieve objects from store: %v", err)
}
@ -1754,6 +1772,9 @@ func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) {
number++
return nil
},
func() {
number = 0
},
); err != nil && err != channeldb.ErrWaitingProofNotFound {
t.Fatalf("unable to retrieve objects from store: %v", err)
}
@ -2583,6 +2604,9 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) {
number++
return nil
},
func() {
number = 0
},
); err != nil {
t.Fatalf("unable to retrieve objects from store: %v", err)
}
@ -2612,6 +2636,9 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) {
number++
return nil
},
func() {
number = 0
},
); err != nil && err != channeldb.ErrWaitingProofNotFound {
t.Fatalf("unable to retrieve objects from store: %v", err)
}

@ -199,7 +199,7 @@ func readMessage(msgBytes []byte) (lnwire.Message, error) {
// Messages returns the total set of messages that exist within the store for
// all peers.
func (s *MessageStore) Messages() (map[[33]byte][]lnwire.Message, error) {
msgs := make(map[[33]byte][]lnwire.Message)
var msgs map[[33]byte][]lnwire.Message
err := kvdb.View(s.db, func(tx kvdb.RTx) error {
messageStore := tx.ReadBucket(messageStoreBucket)
if messageStore == nil {
@ -224,6 +224,8 @@ func (s *MessageStore) Messages() (map[[33]byte][]lnwire.Message, error) {
msgs[pubKey] = append(msgs[pubKey], msg)
return nil
})
}, func() {
msgs = make(map[[33]byte][]lnwire.Message)
})
if err != nil {
return nil, err
@ -262,6 +264,8 @@ func (s *MessageStore) MessagesForPeer(
}
return nil
}, func() {
msgs = nil
})
if err != nil {
return nil, err
@ -272,7 +276,7 @@ func (s *MessageStore) MessagesForPeer(
// Peers returns the public key of all peers with messages within the store.
func (s *MessageStore) Peers() (map[[33]byte]struct{}, error) {
peers := make(map[[33]byte]struct{})
var peers map[[33]byte]struct{}
err := kvdb.View(s.db, func(tx kvdb.RTx) error {
messageStore := tx.ReadBucket(messageStoreBucket)
if messageStore == nil {
@ -285,6 +289,8 @@ func (s *MessageStore) Peers() (map[[33]byte]struct{}, error) {
peers[pubKey] = struct{}{}
return nil
})
}, func() {
peers = make(map[[33]byte]struct{})
})
if err != nil {
return nil, err

@ -239,7 +239,7 @@ func TestMessageStoreUnsupportedMessage(t *testing.T) {
err = kvdb.Update(msgStore.db, func(tx kvdb.RwTx) error {
messageStore := tx.ReadWriteBucket(messageStoreBucket)
return messageStore.Put(msgKey, rawMsg.Bytes())
})
}, func() {})
if err != nil {
t.Fatalf("unable to add unsupported message to store: %v", err)
}

@ -3505,7 +3505,7 @@ func (f *fundingManager) saveChannelOpeningState(chanPoint *wire.OutPoint,
byteOrder.PutUint64(scratch[2:], shortChanID.ToUint64())
return bucket.Put(outpointBytes.Bytes(), scratch)
})
}, func() {})
}
// getChannelOpeningState fetches the channelOpeningState for the provided
@ -3538,7 +3538,7 @@ func (f *fundingManager) getChannelOpeningState(chanPoint *wire.OutPoint) (
state = channelOpeningState(byteOrder.Uint16(value[:2]))
shortChanID = lnwire.NewShortChanIDFromInt(byteOrder.Uint64(value[2:]))
return nil
})
}, func() {})
if err != nil {
return 0, nil, err
}
@ -3560,5 +3560,5 @@ func (f *fundingManager) deleteChannelOpeningState(chanPoint *wire.OutPoint) err
}
return bucket.Delete(outpointBytes.Bytes())
})
}, func() {})
}

@ -227,7 +227,7 @@ func (cm *circuitMap) initBuckets() error {
_, err = tx.CreateTopLevelBucket(circuitAddKey)
return err
})
}, func() {})
}
// restoreMemState loads the contents of the half circuit and full circuit
@ -240,8 +240,8 @@ func (cm *circuitMap) restoreMemState() error {
log.Infof("Restoring in-memory circuit state from disk")
var (
opened = make(map[CircuitKey]*PaymentCircuit)
pending = make(map[CircuitKey]*PaymentCircuit)
opened map[CircuitKey]*PaymentCircuit
pending map[CircuitKey]*PaymentCircuit
)
if err := kvdb.Update(cm.cfg.DB, func(tx kvdb.RwTx) error {
@ -331,6 +331,9 @@ func (cm *circuitMap) restoreMemState() error {
return nil
}, func() {
opened = make(map[CircuitKey]*PaymentCircuit)
pending = make(map[CircuitKey]*PaymentCircuit)
}); err != nil {
return err
}
@ -483,7 +486,7 @@ func (cm *circuitMap) TrimOpenCircuits(chanID lnwire.ShortChannelID,
}
return nil
})
}, func() {})
}
// LookupByHTLC looks up the payment circuit by the outgoing channel and HTLC
@ -730,7 +733,7 @@ func (cm *circuitMap) OpenCircuits(keystones ...Keystone) error {
}
return nil
})
}, func() {})
if err != nil {
return err

@ -131,7 +131,7 @@ func (d *DecayedLog) initBuckets() error {
}
return nil
})
}, func() {})
}
// Stop halts the garbage collector and closes boltdb.
@ -280,6 +280,8 @@ func (d *DecayedLog) Get(hash *sphinx.HashPrefix) (uint32, error) {
value = uint32(binary.BigEndian.Uint32(valueBytes))
return nil
}, func() {
value = 0
})
if err != nil {
return value, err

@ -197,6 +197,8 @@ func (store *networkResultStore) subscribeResult(paymentID uint64) (
default:
return nil
}
}, func() {
result = nil
})
if err != nil {
return nil, err
@ -230,6 +232,8 @@ func (store *networkResultStore) getResult(pid uint64) (
var err error
result, err = fetchResult(tx, pid)
return err
}, func() {
result = nil
})
if err != nil {
return nil, err
@ -301,5 +305,5 @@ func (store *networkResultStore) cleanStore(keep map[uint64]struct{}) error {
}
return nil
})
}, func() {})
}

@ -100,6 +100,8 @@ func (s *persistentSequencer) NextID() (uint64, error) {
nextIDBkt.SetSequence(nextHorizonID)
return nil
}, func() {
nextHorizonID = 0
}); err != nil {
return 0, err
}
@ -124,5 +126,5 @@ func (s *persistentSequencer) initDB() error {
return kvdb.Update(s.db, func(tx kvdb.RwTx) error {
_, err := tx.CreateTopLevelBucket(nextPaymentIDKey)
return err
})
}, func() {})
}

@ -1833,6 +1833,8 @@ func (s *Switch) loadChannelFwdPkgs(source lnwire.ShortChannelID) ([]*channeldb.
tx, source,
)
return err
}, func() {
fwdPkgs = nil
}); err != nil {
return nil, err
}

@ -47,6 +47,9 @@ type AddInvoiceConfig struct {
// channel graph.
ChanDB *channeldb.DB
// Graph holds a reference to the ChannelGraph database.
Graph *channeldb.ChannelGraph
// GenInvoiceFeatures returns a feature containing feature bits that
// should be advertised on freshly generated invoices.
GenInvoiceFeatures func() *lnwire.FeatureVector
@ -330,9 +333,8 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig,
// chanCanBeHopHint returns true if the target channel is eligible to be a hop
// hint.
func chanCanBeHopHint(channel *channeldb.OpenChannel,
graph *channeldb.ChannelGraph,
cfg *AddInvoiceConfig) (*channeldb.ChannelEdgePolicy, bool) {
func chanCanBeHopHint(channel *channeldb.OpenChannel, cfg *AddInvoiceConfig) (
*channeldb.ChannelEdgePolicy, bool) {
// Since we're only interested in our private channels, we'll skip
// public ones.
@ -359,7 +361,7 @@ func chanCanBeHopHint(channel *channeldb.OpenChannel,
// channels.
var remotePub [33]byte
copy(remotePub[:], channel.IdentityPub.SerializeCompressed())
isRemoteNodePublic, err := graph.IsPublicNode(remotePub)
isRemoteNodePublic, err := cfg.Graph.IsPublicNode(remotePub)
if err != nil {
log.Errorf("Unable to determine if node %x "+
"is advertised: %v", remotePub, err)
@ -375,7 +377,7 @@ func chanCanBeHopHint(channel *channeldb.OpenChannel,
// Fetch the policies for each end of the channel.
chanID := channel.ShortChanID().ToUint64()
info, p1, p2, err := graph.FetchChannelEdgesByID(chanID)
info, p1, p2, err := cfg.Graph.FetchChannelEdgesByID(chanID)
if err != nil {
log.Errorf("Unable to fetch the routing "+
"policies for the edges of the channel "+
@ -423,8 +425,6 @@ func selectHopHints(amtMSat lnwire.MilliSatoshi, cfg *AddInvoiceConfig,
openChannels []*channeldb.OpenChannel,
numMaxHophints int) []func(*zpay32.Invoice) {
graph := cfg.ChanDB.ChannelGraph()
// We'll add our hop hints in two passes, first we'll add all channels
// that are eligible to be hop hints, and also have a local balance
// above the payment amount.
@ -433,9 +433,7 @@ func selectHopHints(amtMSat lnwire.MilliSatoshi, cfg *AddInvoiceConfig,
hopHints := make([]func(*zpay32.Invoice), 0, numMaxHophints)
for _, channel := range openChannels {
// If this channel can't be a hop hint, then skip it.
edgePolicy, canBeHopHint := chanCanBeHopHint(
channel, graph, cfg,
)
edgePolicy, canBeHopHint := chanCanBeHopHint(channel, cfg)
if edgePolicy == nil || !canBeHopHint {
continue
}
@ -485,9 +483,7 @@ func selectHopHints(amtMSat lnwire.MilliSatoshi, cfg *AddInvoiceConfig,
// If the channel can't be a hop hint, then we'll skip it.
// Otherwise, we'll use the policy information to populate the
// hop hint.
remotePolicy, canBeHopHint := chanCanBeHopHint(
channel, graph, cfg,
)
remotePolicy, canBeHopHint := chanCanBeHopHint(channel, cfg)
if !canBeHopHint || remotePolicy == nil {
continue
}

@ -62,7 +62,7 @@ func NewRootKeyStorage(db kvdb.Backend) (*RootKeyStorage, error) {
err := kvdb.Update(db, func(tx kvdb.RwTx) error {
_, err := tx.CreateTopLevelBucket(rootKeyBucketName)
return err
})
}, func() {})
if err != nil {
return nil, err
}
@ -123,7 +123,7 @@ func (r *RootKeyStorage) CreateUnlock(password *[]byte) error {
r.encKey = encKey
return nil
})
}, func() {})
}
// Get implements the Get method for the bakery.RootKeyStorage interface.
@ -150,6 +150,8 @@ func (r *RootKeyStorage) Get(_ context.Context, id []byte) ([]byte, error) {
rootKey = make([]byte, len(decKey))
copy(rootKey[:], decKey)
return nil
}, func() {
rootKey = nil
})
if err != nil {
return nil, err
@ -209,6 +211,8 @@ func (r *RootKeyStorage) RootKey(ctx context.Context) ([]byte, []byte, error) {
return err
}
return ns.Put(id, encKey)
}, func() {
rootKey = nil
})
if err != nil {
return nil, nil, err
@ -257,6 +261,8 @@ func (r *RootKeyStorage) ListMacaroonIDs(_ context.Context) ([][]byte, error) {
}
return tx.ReadBucket(rootKeyBucketName).ForEach(appendRootKey)
}, func() {
rootKeySlice = nil
})
if err != nil {
return nil, err
@ -306,6 +312,8 @@ func (r *RootKeyStorage) DeleteMacaroonID(
rootKeyIDDeleted = rootKeyID
return nil
}, func() {
rootKeyIDDeleted = nil
})
if err != nil {
return nil, err

@ -129,7 +129,7 @@ type NurseryStore interface {
// the caller to process each key-value pair. The key will be a prefixed
// outpoint, and the value will be the serialized bytes for an output,
// whose type should be inferred from the key's prefix.
ForChanOutputs(*wire.OutPoint, func([]byte, []byte) error) error
ForChanOutputs(*wire.OutPoint, func([]byte, []byte) error, func()) error
// ListChannels returns all channels the nursery is currently tracking.
ListChannels() ([]wire.OutPoint, error)
@ -283,7 +283,7 @@ func (ns *nurseryStore) Incubate(kids []kidOutput, babies []babyOutput) error {
}
return nil
})
}, func() {})
}
// CribToKinder atomically moves a babyOutput in the crib bucket to the
@ -365,7 +365,7 @@ func (ns *nurseryStore) CribToKinder(bby *babyOutput) error {
// This informs the utxo nursery that it should attempt to spend
// this output when the blockchain reaches the maturity height.
return hghtChanBucketCsv.Put(pfxOutputKey, []byte{})
})
}, func() {})
}
// PreschoolToKinder atomically moves a kidOutput from the preschool bucket to
@ -463,7 +463,7 @@ func (ns *nurseryStore) PreschoolToKinder(kid *kidOutput,
// that this CSV delayed output will be ready to broadcast at
// the maturity height, after a brief period of incubation.
return hghtChanBucket.Put(pfxOutputKey, []byte{})
})
}, func() {})
}
// GraduateKinder atomically moves an output at the provided height into the
@ -525,7 +525,7 @@ func (ns *nurseryStore) GraduateKinder(height uint32, kid *kidOutput) error {
// using graduate-prefixed key.
return chanBucket.Put(pfxOutputKey,
gradBuffer.Bytes())
})
}, func() {})
}
// FetchClass returns a list of babyOutputs in the crib bucket whose CLTV
@ -582,6 +582,9 @@ func (ns *nurseryStore) FetchClass(
})
}, func() {
kids = nil
babies = nil
}); err != nil {
return nil, nil, err
}
@ -655,6 +658,8 @@ func (ns *nurseryStore) FetchPreschools() ([]kidOutput, error) {
}
return nil
}, func() {
kids = nil
}); err != nil {
return nil, err
}
@ -693,6 +698,8 @@ func (ns *nurseryStore) HeightsBelowOrEqual(height uint32) ([]uint32, error) {
}
return nil
}, func() {
activeHeights = nil
})
if err != nil {
return nil, err
@ -709,11 +716,11 @@ func (ns *nurseryStore) HeightsBelowOrEqual(height uint32) ([]uint32, error) {
// NOTE: The callback should not modify the provided byte slices and is
// preferably non-blocking.
func (ns *nurseryStore) ForChanOutputs(chanPoint *wire.OutPoint,
callback func([]byte, []byte) error) error {
callback func([]byte, []byte) error, reset func()) error {
return kvdb.View(ns.db, func(tx kvdb.RTx) error {
return ns.forChanOutputs(tx, chanPoint, callback)
})
}, reset)
}
// ListChannels returns all channels the nursery is currently tracking.
@ -743,6 +750,8 @@ func (ns *nurseryStore) ListChannels() ([]wire.OutPoint, error) {
return nil
})
}, func() {
activeChannels = nil
}); err != nil {
return nil, err
}
@ -765,7 +774,7 @@ func (ns *nurseryStore) IsMatureChannel(chanPoint *wire.OutPoint) (bool, error)
return nil
})
})
}, func() {})
if err != nil && err != ErrImmatureChannel {
return false, err
}
@ -835,7 +844,7 @@ func (ns *nurseryStore) RemoveChannel(chanPoint *wire.OutPoint) error {
}
return removeBucketIfExists(chanIndex, chanBytes)
})
}, func() {})
}
// Helper Methods

@ -370,6 +370,8 @@ func assertNumChanOutputs(t *testing.T, ns NurseryStore,
err := ns.ForChanOutputs(chanPoint, func([]byte, []byte) error {
count++
return nil
}, func() {
count = 0
})
if count == 0 && err == ErrContractNotFound {

@ -41,10 +41,7 @@ type missionControlStore struct {
}
func newMissionControlStore(db kvdb.Backend, maxRecords int) (*missionControlStore, error) {
store := &missionControlStore{
db: db,
maxRecords: maxRecords,
}
var store *missionControlStore
// Create buckets if not yet existing.
err := kvdb.Update(db, func(tx kvdb.RwTx) error {
@ -64,6 +61,11 @@ func newMissionControlStore(db kvdb.Backend, maxRecords int) (*missionControlSto
}
return nil
}, func() {
store = &missionControlStore{
db: db,
maxRecords: maxRecords,
}
})
if err != nil {
return nil, err
@ -81,7 +83,7 @@ func (b *missionControlStore) clear() error {
_, err := tx.CreateTopLevelBucket(resultsKey)
return err
})
}, func() {})
}
// fetchAll returns all results currently stored in the database.
@ -103,6 +105,8 @@ func (b *missionControlStore) fetchAll() ([]*paymentResult, error) {
return nil
})
}, func() {
results = nil
})
if err != nil {
return nil, err
@ -249,7 +253,7 @@ func (b *missionControlStore) AddResult(rp *paymentResult) error {
// Put into results bucket.
return bucket.Put(k, v)
})
}, func() {})
}
// getResultKey returns a byte slice representing a unique key for this payment

@ -4736,6 +4736,7 @@ func (r *rpcServer) AddInvoice(ctx context.Context,
NodeSigner: r.server.nodeSigner,
DefaultCLTVExpiry: defaultDelta,
ChanDB: r.server.remoteChanDB,
Graph: r.server.localChanDB.ChannelGraph(),
GenInvoiceFeatures: func() *lnwire.FeatureVector {
return r.server.featureMgr.Get(feature.SetInvoice)
},

@ -92,7 +92,7 @@ func NewSweeperStore(db kvdb.Backend, chainHash *chainhash.Hash) (
err = migrateTxHashes(tx, txHashesBucket, chainHash)
return err
})
}, func() {})
if err != nil {
return nil, err
}
@ -193,7 +193,7 @@ func (s *sweeperStore) NotifyPublishTx(sweepTx *wire.MsgTx) error {
hash := sweepTx.TxHash()
return txHashesBucket.Put(hash[:], []byte{})
})
}, func() {})
}
// GetLastPublishedTx returns the last tx that we called NotifyPublishTx
@ -219,6 +219,8 @@ func (s *sweeperStore) GetLastPublishedTx() (*wire.MsgTx, error) {
}
return nil
}, func() {
sweepTx = nil
})
if err != nil {
return nil, err
@ -241,6 +243,8 @@ func (s *sweeperStore) IsOurTx(hash chainhash.Hash) (bool, error) {
ours = txHashesBucket.Get(hash[:]) != nil
return nil
}, func() {
ours = false
})
if err != nil {
return false, err
@ -269,6 +273,8 @@ func (s *sweeperStore) ListSweeps() ([]chainhash.Hash, error) {
return nil
})
}, func() {
sweepTxns = nil
}); err != nil {
return nil, err
}

@ -477,7 +477,7 @@ func (u *utxoNursery) NurseryReport(
utxnLog.Debugf("NurseryReport: building nursery report for channel %v",
chanPoint)
report := &contractMaturityReport{}
var report *contractMaturityReport
if err := u.cfg.Store.ForChanOutputs(chanPoint, func(k, v []byte) error {
switch {
@ -576,6 +576,8 @@ func (u *utxoNursery) NurseryReport(
}
return nil
}, func() {
report = &contractMaturityReport{}
}); err != nil {
return nil, err
}

@ -931,9 +931,9 @@ func (i *nurseryStoreInterceptor) HeightsBelowOrEqual(height uint32) (
}
func (i *nurseryStoreInterceptor) ForChanOutputs(chanPoint *wire.OutPoint,
callback func([]byte, []byte) error) error {
callback func([]byte, []byte) error, reset func()) error {
return i.ns.ForChanOutputs(chanPoint, callback)
return i.ns.ForChanOutputs(chanPoint, callback, reset)
}
func (i *nurseryStoreInterceptor) ListChannels() ([]wire.OutPoint, error) {

@ -146,7 +146,7 @@ func OpenClientDB(dbPath string) (*ClientDB, error) {
// initialized. This allows us to assume their presence throughout all
// operations. If an known top-level bucket is expected to exist but is
// missing, this will trigger a ErrUninitializedDB error.
err = kvdb.Update(clientDB.db, initClientDBBuckets)
err = kvdb.Update(clientDB.db, initClientDBBuckets, func() {})
if err != nil {
bdb.Close()
return nil, err
@ -192,6 +192,8 @@ func (c *ClientDB) Version() (uint32, error) {
var err error
version, err = getDBVersion(tx)
return err
}, func() {
version = 0
})
if err != nil {
return 0, err
@ -291,6 +293,8 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) {
// Store the new or updated tower under its tower id.
return putTower(towers, tower)
}, func() {
tower = nil
})
if err != nil {
return nil, err
@ -377,7 +381,7 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
}
return nil
})
}, func() {})
}
// LoadTowerByID retrieves a tower by its tower ID.
@ -392,6 +396,8 @@ func (c *ClientDB) LoadTowerByID(towerID TowerID) (*Tower, error) {
var err error
tower, err = getTower(towers, towerID.Bytes())
return err
}, func() {
tower = nil
})
if err != nil {
return nil, err
@ -421,6 +427,8 @@ func (c *ClientDB) LoadTower(pubKey *btcec.PublicKey) (*Tower, error) {
var err error
tower, err = getTower(towers, towerIDBytes)
return err
}, func() {
tower = nil
})
if err != nil {
return nil, err
@ -446,6 +454,8 @@ func (c *ClientDB) ListTowers() ([]*Tower, error) {
towers = append(towers, tower)
return nil
})
}, func() {
towers = nil
})
if err != nil {
return nil, err
@ -498,6 +508,8 @@ func (c *ClientDB) NextSessionKeyIndex(towerID TowerID) (uint32, error) {
// Record the reserved session key index under this tower's id.
return keyIndex.Put(towerIDBytes, indexBuf[:])
}, func() {
index = 0
})
if err != nil {
return 0, err
@ -550,7 +562,7 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error {
// Finally, write the client session's body in the sessions
// bucket.
return putClientSessionBody(sessions, session)
})
}, func() {})
}
// ListClientSessions returns the set of all client sessions known to the db. An
@ -566,6 +578,8 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (map[SessionID]*ClientSession
var err error
clientSessions, err = listClientSessions(sessions, id)
return err
}, func() {
clientSessions = nil
})
if err != nil {
return nil, err
@ -611,7 +625,7 @@ func listClientSessions(sessions kvdb.RBucket,
// FetchChanSummaries loads a mapping from all registered channels to their
// channel summaries.
func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
summaries := make(map[lnwire.ChannelID]ClientChanSummary)
var summaries map[lnwire.ChannelID]ClientChanSummary
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
chanSummaries := tx.ReadBucket(cChanSummaryBkt)
if chanSummaries == nil {
@ -632,6 +646,8 @@ func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
return nil
})
}, func() {
summaries = make(map[lnwire.ChannelID]ClientChanSummary)
})
if err != nil {
return nil, err
@ -674,7 +690,7 @@ func (c *ClientDB) RegisterChannel(chanID lnwire.ChannelID,
}
return putChanSummary(chanSummaries, chanID, &summary)
})
}, func() {})
}
// MarkBackupIneligible records that the state identified by the (channel id,
@ -782,6 +798,8 @@ func (c *ClientDB) CommitUpdate(id *SessionID,
return nil
}, func() {
lastApplied = 0
})
if err != nil {
return 0, err
@ -887,7 +905,7 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
// Finally, insert the ack into the sessionAcks sub-bucket.
return sessionAcks.Put(seqNumBuf[:], b.Bytes())
})
}, func() {})
}
// getClientSessionBody loads the body of a ClientSession from the sessions

@ -80,6 +80,8 @@ func createDBIfNotExist(dbPath, name string) (kvdb.Backend, bool, error) {
err = kvdb.View(bdb, func(tx kvdb.RTx) error {
metadataExists = tx.ReadBucket(metadataBkt) != nil
return nil
}, func() {
metadataExists = false
})
if err != nil {
return nil, false, err

@ -88,7 +88,7 @@ func OpenTowerDB(dbPath string) (*TowerDB, error) {
// initialized. This allows us to assume their presence throughout all
// operations. If an known top-level bucket is expected to exist but is
// missing, this will trigger a ErrUninitializedDB error.
err = kvdb.Update(towerDB.db, initTowerDBBuckets)
err = kvdb.Update(towerDB.db, initTowerDBBuckets, func() {})
if err != nil {
bdb.Close()
return nil, err
@ -133,6 +133,8 @@ func (t *TowerDB) Version() (uint32, error) {
var err error
version, err = getDBVersion(tx)
return err
}, func() {
version = 0
})
if err != nil {
return 0, err
@ -159,6 +161,8 @@ func (t *TowerDB) GetSessionInfo(id *SessionID) (*SessionInfo, error) {
var err error
session, err = getSession(sessions, id[:])
return err
}, func() {
session = nil
})
if err != nil {
return nil, err
@ -210,7 +214,7 @@ func (t *TowerDB) InsertSessionInfo(session *SessionInfo) error {
// be deleted without needing to iterate over the entire
// database.
return touchSessionHintBkt(updateIndex, &session.ID)
})
}, func() {})
}
// InsertStateUpdate stores an update sent by the client after validating that
@ -292,6 +296,8 @@ func (t *TowerDB) InsertStateUpdate(update *SessionStateUpdate) (uint16, error)
// hint under its session id. This will allow us to delete the
// entries efficiently if the session is ever removed.
return putHintForSession(updateIndex, &update.ID, update.Hint)
}, func() {
lastApplied = 0
})
if err != nil {
return 0, err
@ -381,7 +387,7 @@ func (t *TowerDB) DeleteSession(target SessionID) error {
// Finally, remove this session from the update index, which
// also removes any of the indexed hints beneath it.
return removeSessionHintBkt(updateIndex, &target)
})
}, func() {})
}
// QueryMatches searches against all known state updates for any that match the
@ -460,6 +466,8 @@ func (t *TowerDB) QueryMatches(breachHints []blob.BreachHint) ([]Match, error) {
}
return nil
}, func() {
matches = nil
})
if err != nil {
return nil, err
@ -478,7 +486,7 @@ func (t *TowerDB) SetLookoutTip(epoch *chainntnfs.BlockEpoch) error {
}
return putLookoutEpoch(lookoutTip, epoch)
})
}, func() {})
}
// GetLookoutTip retrieves the current lookout tip block epoch from the tower
@ -494,6 +502,8 @@ func (t *TowerDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) {
epoch = getLookoutEpoch(lookoutTip)
return nil
}, func() {
epoch = nil
})
if err != nil {
return nil, err

@ -107,7 +107,7 @@ func initOrSyncVersions(db versionedDB, init bool, versions []version) error {
if init {
return kvdb.Update(db.bdb(), func(tx kvdb.RwTx) error {
return initDBVersion(tx, getLatestDBVersion(versions))
})
}, func() {})
}
// Otherwise, ensure that any migrations are applied to ensure the data
@ -159,5 +159,5 @@ func syncVersions(db versionedDB, versions []version) error {
}
return putDBVersion(tx, latestVersion)
})
}, func() {})
}