channeldb: modify ForEachNode/ForEachChannel to accept a db txn
This commit modifies the ForEachNode on the ChannelGraph and ForEachChannel on the LightningNode struct to accept a database transaction as its first argument. With this change, it’ll now be possible to implement graph traversals that typically required a nested loop with all the vertex loaded into memory using the callback API instead: c.ForEachNode(nil, func(tx, node) { node.ForEachChannel(tx, func(…) { }) })
This commit is contained in:
parent
c8bf521c75
commit
b96b180b0b
@ -217,11 +217,16 @@ func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, *ChannelEdgePoli
|
||||
// executing the passed callback with each node encountered. If the callback
|
||||
// returns an error, then the transaction is aborted and the iteration stops
|
||||
// early.
|
||||
func (c *ChannelGraph) ForEachNode(cb func(*LightningNode) error) error {
|
||||
// TODO(roasbeef): need to also pass in a transaction? or reverse order
|
||||
// to get all in memory THEN execute callback?
|
||||
|
||||
return c.db.View(func(tx *bolt.Tx) error {
|
||||
//
|
||||
// If the caller wishes to re-use an existing boltdb transaction, then it
|
||||
// should be passed as the first argument. Otherwise the first argument should
|
||||
// be nil and a fresh transaction will be created to execute the graph
|
||||
// traversal
|
||||
//
|
||||
// TODO(roasbeef): add iterator interface to allow for memory efficient graph
|
||||
// traversal when graph gets mega
|
||||
func (c *ChannelGraph) ForEachNode(tx *bolt.Tx, cb func(*bolt.Tx, *LightningNode) error) error {
|
||||
traversal := func(tx *bolt.Tx) error {
|
||||
// First grab the nodes bucket which stores the mapping from
|
||||
// pubKey to node information.
|
||||
nodes := tx.Bucket(nodeBucket)
|
||||
@ -246,9 +251,19 @@ func (c *ChannelGraph) ForEachNode(cb func(*LightningNode) error) error {
|
||||
|
||||
// Execute the callback, the transaction will abort if
|
||||
// this returns an error.
|
||||
return cb(node)
|
||||
})
|
||||
return cb(tx, node)
|
||||
})
|
||||
}
|
||||
|
||||
// If no transaction was provided, then we'll create a new transaction
|
||||
// to execute the transaction within.
|
||||
if tx == nil {
|
||||
return c.db.View(traversal)
|
||||
}
|
||||
|
||||
// Otherwise, we re-use the existing transaction to execute the graph
|
||||
// traversal.
|
||||
return traversal(tx)
|
||||
}
|
||||
|
||||
// SourceNode returns the source node of the graph. The source node is treated
|
||||
@ -962,13 +977,15 @@ func (c *ChannelGraph) HasLightningNode(pub *btcec.PublicKey) (time.Time, bool,
|
||||
// ForEachChannel iterates through all the outgoing channel edges from this
|
||||
// node, executing the passed callback with each edge as its sole argument. If
|
||||
// the callback returns an error, then the iteration is halted with the error
|
||||
// propagated back up to the caller. If the caller wishes to re-use an existing
|
||||
// boltdb transaction, then it should be passed as the first argument.
|
||||
// Otherwise the first argument should be nil and a fresh transaction will be
|
||||
// created to execute the graph traversal.
|
||||
func (l *LightningNode) ForEachChannel(tx *bolt.Tx, cb func(*ChannelEdgeInfo, *ChannelEdgePolicy) error) error {
|
||||
// TODO(roasbeef): remove the option to pass in a transaction after
|
||||
// all?
|
||||
// propagated back up to the caller.
|
||||
//
|
||||
// If the caller wishes to re-use an existing boltdb transaction, then it
|
||||
// should be passed as the first argument. Otherwise the first argument should
|
||||
// be nil and a fresh transaction will be created to execute the graph
|
||||
// traversal.
|
||||
func (l *LightningNode) ForEachChannel(tx *bolt.Tx,
|
||||
cb func(*bolt.Tx, *ChannelEdgeInfo, *ChannelEdgePolicy) error) error {
|
||||
|
||||
nodePub := l.PubKey.SerializeCompressed()
|
||||
|
||||
traversal := func(tx *bolt.Tx) error {
|
||||
@ -1021,7 +1038,7 @@ func (l *LightningNode) ForEachChannel(tx *bolt.Tx, cb func(*ChannelEdgeInfo, *C
|
||||
}
|
||||
|
||||
// Finally, we execute the callback.
|
||||
if err := cb(edgeInfo, edgePolicy); err != nil {
|
||||
if err := cb(tx, edgeInfo, edgePolicy); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/boltdb/bolt"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/roasbeef/btcd/btcec"
|
||||
@ -533,7 +534,7 @@ func TestGraphTraversal(t *testing.T) {
|
||||
|
||||
// Iterate over each node as returned by the graph, if all nodes are
|
||||
// reached, then the map created above should be empty.
|
||||
err = graph.ForEachNode(func(node *LightningNode) error {
|
||||
err = graph.ForEachNode(nil, func(_ *bolt.Tx, node *LightningNode) error {
|
||||
delete(nodeIndex, node.Alias)
|
||||
return nil
|
||||
})
|
||||
@ -630,7 +631,9 @@ func TestGraphTraversal(t *testing.T) {
|
||||
// Finally, we want to test the ability to iterate over all the
|
||||
// outgoing channels for a particular node.
|
||||
numNodeChans := 0
|
||||
err = firstNode.ForEachChannel(nil, func(_ *ChannelEdgeInfo, c *ChannelEdgePolicy) error {
|
||||
err = firstNode.ForEachChannel(nil, func(_ *bolt.Tx, _ *ChannelEdgeInfo,
|
||||
c *ChannelEdgePolicy) error {
|
||||
|
||||
// Each each should indicate that it's outgoing (pointed
|
||||
// towards the second node).
|
||||
if !c.Node.PubKey.IsEqual(secondNode.PubKey) {
|
||||
|
Loading…
Reference in New Issue
Block a user