From 959618d596c3420b7e21f9995cf86c6c8f666019 Mon Sep 17 00:00:00 2001 From: Wilmer Paulino Date: Thu, 14 Jun 2018 19:47:00 -0700 Subject: [PATCH] channeldb: refactor methods to allow using existing db transaction --- channeldb/db.go | 96 ++++++++++++++++++++++++++-------------------- channeldb/graph.go | 89 +++++++++++++++++++++++++----------------- channeldb/nodes.go | 54 +++++++++++++++++--------- 3 files changed, 144 insertions(+), 95 deletions(-) diff --git a/channeldb/db.go b/channeldb/db.go index fa844a2d..6764b6de 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -233,56 +233,70 @@ func fileExists(path string) bool { return true } -// FetchOpenChannels returns all stored currently active/open channels -// associated with the target nodeID. In the case that no active channels are -// known to have been created with this node, then a zero-length slice is -// returned. +// FetchOpenChannels starts a new database transaction and returns all stored +// currently active/open channels associated with the target nodeID. In the case +// that no active channels are known to have been created with this node, then a +// zero-length slice is returned. func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error) { var channels []*OpenChannel err := d.View(func(tx *bolt.Tx) error { - // Get the bucket dedicated to storing the metadata for open - // channels. - openChanBucket := tx.Bucket(openChannelBucket) - if openChanBucket == nil { + var err error + channels, err = d.fetchOpenChannels(tx, nodeID) + return err + }) + + return channels, err +} + +// fetchOpenChannels uses and existing database transaction and returns all +// stored currently active/open channels associated with the target nodeID. In +// the case that no active channels are known to have been created with this +// node, then a zero-length slice is returned. +func (d *DB) fetchOpenChannels(tx *bolt.Tx, + nodeID *btcec.PublicKey) ([]*OpenChannel, error) { + + // Get the bucket dedicated to storing the metadata for open channels. + openChanBucket := tx.Bucket(openChannelBucket) + if openChanBucket == nil { + return nil, nil + } + + // Within this top level bucket, fetch the bucket dedicated to storing + // open channel data specific to the remote node. + pub := nodeID.SerializeCompressed() + nodeChanBucket := openChanBucket.Bucket(pub) + if nodeChanBucket == nil { + return nil, nil + } + + // Next, we'll need to go down an additional layer in order to retrieve + // the channels for each chain the node knows of. + var channels []*OpenChannel + err := nodeChanBucket.ForEach(func(chainHash, v []byte) error { + // If there's a value, it's not a bucket so ignore it. + if v != nil { return nil } - // Within this top level bucket, fetch the bucket dedicated to - // storing open channel data specific to the remote node. - pub := nodeID.SerializeCompressed() - nodeChanBucket := openChanBucket.Bucket(pub) - if nodeChanBucket == nil { - return nil + // If we've found a valid chainhash bucket, then we'll retrieve + // that so we can extract all the channels. + chainBucket := nodeChanBucket.Bucket(chainHash) + if chainBucket == nil { + return fmt.Errorf("unable to read bucket for chain=%x", + chainHash[:]) } - // Next, we'll need to go down an additional layer in order to - // retrieve the channels for each chain the node knows of. - return nodeChanBucket.ForEach(func(chainHash, v []byte) error { - // If there's a value, it's not a bucket so ignore it. - if v != nil { - return nil - } + // Finally, we both of the necessary buckets retrieved, fetch + // all the active channels related to this node. + nodeChannels, err := d.fetchNodeChannels(chainBucket) + if err != nil { + return fmt.Errorf("unable to read channel for "+ + "chain_hash=%x, node_key=%x: %v", + chainHash[:], pub, err) + } - // If we've found a valid chainhash bucket, then we'll - // retrieve that so we can extract all the channels. - chainBucket := nodeChanBucket.Bucket(chainHash) - if chainBucket == nil { - return fmt.Errorf("unable to read bucket for "+ - "chain=%x", chainHash[:]) - } - - // Finally, we both of the necessary buckets retrieved, - // fetch all the active channels related to this node. - nodeChannels, err := d.fetchNodeChannels(chainBucket) - if err != nil { - return fmt.Errorf("unable to read channel for "+ - "chain_hash=%x, node_key=%x: %v", - chainHash[:], pub, err) - } - - channels = nodeChannels - return nil - }) + channels = nodeChannels + return nil }) return channels, err diff --git a/channeldb/graph.go b/channeldb/graph.go index 886ceee4..682d1256 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -259,27 +259,12 @@ func (c *ChannelGraph) ForEachNode(tx *bolt.Tx, cb func(*bolt.Tx, *LightningNode func (c *ChannelGraph) SourceNode() (*LightningNode, error) { var source *LightningNode err := c.db.View(func(tx *bolt.Tx) error { - // First grab the nodes bucket which stores the mapping from - // pubKey to node information. - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - - selfPub := nodes.Get(sourceKey) - if selfPub == nil { - return ErrSourceNodeNotSet - } - - // With the pubKey of the source node retrieved, we're able to - // fetch the full node information. - node, err := fetchLightningNode(nodes, selfPub) + node, err := c.sourceNode(tx) if err != nil { return err } + source = node - source = &node - source.db = c.db return nil }) if err != nil { @@ -289,6 +274,34 @@ func (c *ChannelGraph) SourceNode() (*LightningNode, error) { return source, nil } +// sourceNode uses an existing database transaction and returns the source node +// of the graph. The source node is treated as the center node within a +// star-graph. This method may be used to kick off a path finding algorithm in +// order to explore the reachability of another node based off the source node. +func (c *ChannelGraph) sourceNode(tx *bolt.Tx) (*LightningNode, error) { + // First grab the nodes bucket which stores the mapping from + // pubKey to node information. + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return nil, ErrGraphNotFound + } + + selfPub := nodes.Get(sourceKey) + if selfPub == nil { + return nil, ErrSourceNodeNotSet + } + + // With the pubKey of the source node retrieved, we're able to + // fetch the full node information. + node, err := fetchLightningNode(nodes, selfPub) + if err != nil { + return nil, err + } + node.db = c.db + + return &node, nil +} + // SetSourceNode sets the source node within the graph database. The source // node is to be used as the center of a star-graph within path finding // algorithms. @@ -384,30 +397,36 @@ func (c *ChannelGraph) LookupAlias(pub *btcec.PublicKey) (string, error) { return alias, nil } -// DeleteLightningNode removes a vertex/node from the database according to the -// node's public key. +// DeleteLightningNode starts a new database transaction to remove a vertex/node +// from the database according to the node's public key. func (c *ChannelGraph) DeleteLightningNode(nodePub *btcec.PublicKey) error { - pub := nodePub.SerializeCompressed() - // TODO(roasbeef): ensure dangling edges are removed... return c.db.Update(func(tx *bolt.Tx) error { - nodes, err := tx.CreateBucketIfNotExists(nodeBucket) - if err != nil { - return err - } - - aliases, err := tx.CreateBucketIfNotExists(aliasIndexBucket) - if err != nil { - return err - } - - if err := aliases.Delete(pub); err != nil { - return err - } - return nodes.Delete(pub) + return c.deleteLightningNode(tx, nodePub.SerializeCompressed()) }) } +// deleteLightningNode uses an existing database transaction to remove a +// vertex/node from the database according to the node's public key. +func (c *ChannelGraph) deleteLightningNode(tx *bolt.Tx, + compressedPubKey []byte) error { + + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNodesNotFound + } + + aliases := nodes.Bucket(aliasIndexBucket) + if aliases == nil { + return ErrGraphNodesNotFound + } + + if err := aliases.Delete(compressedPubKey); err != nil { + return err + } + return nodes.Delete(compressedPubKey) +} + // AddChannelEdge adds a new (undirected, blank) edge to the graph database. An // undirected edge from the two target nodes are created. The information // stored denotes the static attributes of the channel, such as the channelID, diff --git a/channeldb/nodes.go b/channeldb/nodes.go index a845744d..59c67d41 100644 --- a/channeldb/nodes.go +++ b/channeldb/nodes.go @@ -6,9 +6,9 @@ import ( "net" "time" - "github.com/coreos/bbolt" "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/wire" + "github.com/coreos/bbolt" ) var ( @@ -183,32 +183,48 @@ func (db *DB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) { return node, nil } -// FetchAllLinkNodes attempts to fetch all active LinkNodes from the database. -// If there haven't been any channels explicitly linked to LinkNodes written to -// the database, then this function will return an empty slice. +// FetchAllLinkNodes starts a new database transaction to fetch all nodes with +// whom we have active channels with. func (db *DB) FetchAllLinkNodes() ([]*LinkNode, error) { var linkNodes []*LinkNode - err := db.View(func(tx *bolt.Tx) error { - nodeMetaBucket := tx.Bucket(nodeInfoBucket) - if nodeMetaBucket == nil { - return ErrLinkNodesNotFound + nodes, err := db.fetchAllLinkNodes(tx) + if err != nil { + return err } - return nodeMetaBucket.ForEach(func(k, v []byte) error { - if v == nil { - return nil - } + linkNodes = nodes + return nil + }) + if err != nil { + return nil, err + } - nodeReader := bytes.NewReader(v) - linkNode, err := deserializeLinkNode(nodeReader) - if err != nil { - return err - } + return linkNodes, nil +} - linkNodes = append(linkNodes, linkNode) +// fetchAllLinkNodes uses an existing database transaction to fetch all nodes +// with whom we have active channels with. +func (db *DB) fetchAllLinkNodes(tx *bolt.Tx) ([]*LinkNode, error) { + nodeMetaBucket := tx.Bucket(nodeInfoBucket) + if nodeMetaBucket == nil { + return nil, ErrLinkNodesNotFound + } + + var linkNodes []*LinkNode + err := nodeMetaBucket.ForEach(func(k, v []byte) error { + if v == nil { return nil - }) + } + + nodeReader := bytes.NewReader(v) + linkNode, err := deserializeLinkNode(nodeReader) + if err != nil { + return err + } + + linkNodes = append(linkNodes, linkNode) + return nil }) if err != nil { return nil, err