channeldb: modify ForEachNode/ForEachChannel to accept a db txn

This commit modifies the ForEachNode on the ChannelGraph and
ForEachChannel on the LightningNode struct to accept a database
transaction as its first argument. With this change, it’ll now be
possible to implement graph traversals that typically required a nested
loop with all the vertex loaded into memory using the callback API
instead:
c.ForEachNode(nil, func(tx, node) {
    node.ForEachChannel(tx, func(…) {
    })
})
This commit is contained in:
Olaoluwa Osuntokun 2017-04-14 13:14:02 -07:00
parent c8bf521c75
commit b96b180b0b
No known key found for this signature in database
GPG Key ID: 9CC5B105D03521A2
2 changed files with 37 additions and 17 deletions

@ -217,11 +217,16 @@ func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, *ChannelEdgePoli
// executing the passed callback with each node encountered. If the callback
// returns an error, then the transaction is aborted and the iteration stops
// early.
func (c *ChannelGraph) ForEachNode(cb func(*LightningNode) error) error {
// TODO(roasbeef): need to also pass in a transaction? or reverse order
// to get all in memory THEN execute callback?
return c.db.View(func(tx *bolt.Tx) error {
//
// If the caller wishes to re-use an existing boltdb transaction, then it
// should be passed as the first argument. Otherwise the first argument should
// be nil and a fresh transaction will be created to execute the graph
// traversal
//
// TODO(roasbeef): add iterator interface to allow for memory efficient graph
// traversal when graph gets mega
func (c *ChannelGraph) ForEachNode(tx *bolt.Tx, cb func(*bolt.Tx, *LightningNode) error) error {
traversal := func(tx *bolt.Tx) error {
// First grab the nodes bucket which stores the mapping from
// pubKey to node information.
nodes := tx.Bucket(nodeBucket)
@ -246,9 +251,19 @@ func (c *ChannelGraph) ForEachNode(cb func(*LightningNode) error) error {
// Execute the callback, the transaction will abort if
// this returns an error.
return cb(node)
})
return cb(tx, node)
})
}
// If no transaction was provided, then we'll create a new transaction
// to execute the transaction within.
if tx == nil {
return c.db.View(traversal)
}
// Otherwise, we re-use the existing transaction to execute the graph
// traversal.
return traversal(tx)
}
// SourceNode returns the source node of the graph. The source node is treated
@ -962,13 +977,15 @@ func (c *ChannelGraph) HasLightningNode(pub *btcec.PublicKey) (time.Time, bool,
// ForEachChannel iterates through all the outgoing channel edges from this
// node, executing the passed callback with each edge as its sole argument. If
// the callback returns an error, then the iteration is halted with the error
// propagated back up to the caller. If the caller wishes to re-use an existing
// boltdb transaction, then it should be passed as the first argument.
// Otherwise the first argument should be nil and a fresh transaction will be
// created to execute the graph traversal.
func (l *LightningNode) ForEachChannel(tx *bolt.Tx, cb func(*ChannelEdgeInfo, *ChannelEdgePolicy) error) error {
// TODO(roasbeef): remove the option to pass in a transaction after
// all?
// propagated back up to the caller.
//
// If the caller wishes to re-use an existing boltdb transaction, then it
// should be passed as the first argument. Otherwise the first argument should
// be nil and a fresh transaction will be created to execute the graph
// traversal.
func (l *LightningNode) ForEachChannel(tx *bolt.Tx,
cb func(*bolt.Tx, *ChannelEdgeInfo, *ChannelEdgePolicy) error) error {
nodePub := l.PubKey.SerializeCompressed()
traversal := func(tx *bolt.Tx) error {
@ -1021,7 +1038,7 @@ func (l *LightningNode) ForEachChannel(tx *bolt.Tx, cb func(*ChannelEdgeInfo, *C
}
// Finally, we execute the callback.
if err := cb(edgeInfo, edgePolicy); err != nil {
if err := cb(tx, edgeInfo, edgePolicy); err != nil {
return err
}
}

@ -13,6 +13,7 @@ import (
"testing"
"time"
"github.com/boltdb/bolt"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/roasbeef/btcd/btcec"
@ -533,7 +534,7 @@ func TestGraphTraversal(t *testing.T) {
// Iterate over each node as returned by the graph, if all nodes are
// reached, then the map created above should be empty.
err = graph.ForEachNode(func(node *LightningNode) error {
err = graph.ForEachNode(nil, func(_ *bolt.Tx, node *LightningNode) error {
delete(nodeIndex, node.Alias)
return nil
})
@ -630,7 +631,9 @@ func TestGraphTraversal(t *testing.T) {
// Finally, we want to test the ability to iterate over all the
// outgoing channels for a particular node.
numNodeChans := 0
err = firstNode.ForEachChannel(nil, func(_ *ChannelEdgeInfo, c *ChannelEdgePolicy) error {
err = firstNode.ForEachChannel(nil, func(_ *bolt.Tx, _ *ChannelEdgeInfo,
c *ChannelEdgePolicy) error {
// Each each should indicate that it's outgoing (pointed
// towards the second node).
if !c.Node.PubKey.IsEqual(secondNode.PubKey) {