diff --git a/channeldb/db.go b/channeldb/db.go index d71656a7..c8ce7fb8 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "fmt" + "net" "os" "path/filepath" "sync" @@ -772,6 +773,59 @@ func (d *DB) PruneLinkNodes() error { }) } +// AddrsForNode consults the graph and channel database for all addresses known +// to the passed node public key. +func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) { + var ( + linkNode *LinkNode + graphNode LightningNode + ) + + dbErr := d.View(func(tx *bbolt.Tx) error { + var err error + + linkNode, err = fetchLinkNode(tx, nodePub) + if err != nil { + return err + } + + // We'll also query the graph for this peer to see if they have + // any addresses that we don't currently have stored within the + // link node database. + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + compressedPubKey := nodePub.SerializeCompressed() + graphNode, err = fetchLightningNode(nodes, compressedPubKey) + if err != nil { + return err + } + + return nil + }) + if dbErr != nil { + return nil, dbErr + } + + // Now that we have both sources of addrs for this node, we'll use a + // map to de-duplicate any addresses between the two sources, and + // produce a final list of the combined addrs. + addrs := make(map[string]net.Addr) + for _, addr := range linkNode.Addresses { + addrs[addr.String()] = addr + } + for _, addr := range graphNode.Addresses { + addrs[addr.String()] = addr + } + dedupedAddrs := make([]net.Addr, 0, len(addrs)) + for _, addr := range addrs { + dedupedAddrs = append(dedupedAddrs, addr) + } + + return dedupedAddrs, nil +} + // syncVersions function is used for safe db version synchronization. It // applies migration functions to the current database and recovers the // previous state of db if at least one error/panic appeared during migration. diff --git a/channeldb/db_test.go b/channeldb/db_test.go index 794b1fcf..dbd255e5 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -2,10 +2,12 @@ package channeldb import ( "io/ioutil" + "net" "os" "path/filepath" "testing" + "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/lightningnetwork/lnd/lnwire" ) @@ -147,3 +149,64 @@ func TestFetchClosedChannelForID(t *testing.T) { t.Fatalf("expected ErrClosedChannelNotFound, instead got: %v", err) } } + +// TestAddrsForNode tests the we're able to properly obtain all the addresses +// for a target node. +func TestAddrsForNode(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + graph := cdb.ChannelGraph() + + // We'll make a test vertex to insert into the database, as the source + // node, but this node will only have half the number of addresses it + // usually does. + testNode, err := createTestVertex(cdb) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + testNode.Addresses = []net.Addr{testAddr} + if err := graph.SetSourceNode(testNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } + + // Next, we'll make a link node with the same pubkey, but with an + // additional address. + nodePub, err := testNode.PubKey() + if err != nil { + t.Fatalf("unable to recv node pub: %v", err) + } + linkNode := cdb.NewLinkNode( + wire.MainNet, nodePub, anotherAddr, + ) + if err := linkNode.Sync(); err != nil { + t.Fatalf("unable to sync link node: %v", err) + } + + // Now that we've created a link node, as well as a vertex for the + // node, we'll query for all its addresses. + nodeAddrs, err := cdb.AddrsForNode(nodePub) + if err != nil { + t.Fatalf("unable to obtain node addrs: %v", err) + } + + expectedAddrs := make(map[string]struct{}) + expectedAddrs[testAddr.String()] = struct{}{} + expectedAddrs[anotherAddr.String()] = struct{}{} + + // Finally, ensure that all the expected addresses are found. + if len(nodeAddrs) != len(expectedAddrs) { + t.Fatalf("expected %v addrs, got %v", + len(expectedAddrs), len(nodeAddrs)) + } + for _, addr := range nodeAddrs { + if _, ok := expectedAddrs[addr.String()]; !ok { + t.Fatalf("unexpected addr: %v", addr) + } + } +} diff --git a/channeldb/nodes.go b/channeldb/nodes.go index 43729d33..95f6f7a2 100644 --- a/channeldb/nodes.go +++ b/channeldb/nodes.go @@ -62,13 +62,13 @@ type LinkNode struct { // NewLinkNode creates a new LinkNode from the provided parameters, which is // backed by an instance of channeldb. func (db *DB) NewLinkNode(bitNet wire.BitcoinNet, pub *btcec.PublicKey, - addr net.Addr) *LinkNode { + addrs ...net.Addr) *LinkNode { return &LinkNode{ Network: bitNet, IdentityPub: pub, LastSeen: time.Now(), - Addresses: []net.Addr{addr}, + Addresses: addrs, db: db, } } @@ -149,40 +149,44 @@ func (db *DB) deleteLinkNode(tx *bbolt.Tx, identity *btcec.PublicKey) error { // identity public key. If a particular LinkNode for the passed identity public // key cannot be found, then ErrNodeNotFound if returned. func (db *DB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) { - var ( - node *LinkNode - err error - ) - - err = db.View(func(tx *bbolt.Tx) error { - // First fetch the bucket for storing node metadata, bailing - // out early if it hasn't been created yet. - nodeMetaBucket := tx.Bucket(nodeInfoBucket) - if nodeMetaBucket == nil { - return ErrLinkNodesNotFound + var linkNode *LinkNode + err := db.View(func(tx *bbolt.Tx) error { + node, err := fetchLinkNode(tx, identity) + if err != nil { + return err } - // If a link node for that particular public key cannot be - // located, then exit early with an ErrNodeNotFound. - pubKey := identity.SerializeCompressed() - nodeBytes := nodeMetaBucket.Get(pubKey) - if nodeBytes == nil { - return ErrNodeNotFound - } - - // Finally, decode an allocate a fresh LinkNode object to be - // returned to the caller. - nodeReader := bytes.NewReader(nodeBytes) - node, err = deserializeLinkNode(nodeReader) - return err + linkNode = node + return nil }) - if err != nil { - return nil, err + + return linkNode, err +} + +func fetchLinkNode(tx *bbolt.Tx, targetPub *btcec.PublicKey) (*LinkNode, error) { + // First fetch the bucket for storing node metadata, bailing out early + // if it hasn't been created yet. + nodeMetaBucket := tx.Bucket(nodeInfoBucket) + if nodeMetaBucket == nil { + return nil, ErrLinkNodesNotFound } - return node, nil + // If a link node for that particular public key cannot be located, + // then exit early with an ErrNodeNotFound. + pubKey := targetPub.SerializeCompressed() + nodeBytes := nodeMetaBucket.Get(pubKey) + if nodeBytes == nil { + return nil, ErrNodeNotFound + } + + // Finally, decode and allocate a fresh LinkNode object to be returned + // to the caller. + nodeReader := bytes.NewReader(nodeBytes) + return deserializeLinkNode(nodeReader) } +// TODO(roasbeef): update link node addrs in server upon connection + // FetchAllLinkNodes starts a new database transaction to fetch all nodes with // whom we have active channels with. func (db *DB) FetchAllLinkNodes() ([]*LinkNode, error) {