diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index d9499a4a..541e8869 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -196,6 +196,18 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) ( return chanInfo, edges[0], edges[1], nil } +func (r *mockGraphSource) FetchLightningNode( + nodePub routing.Vertex) (*channeldb.LightningNode, error) { + + for _, node := range r.nodes { + if bytes.Equal(nodePub[:], node.PubKeyBytes[:]) { + return node, nil + } + } + + return nil, channeldb.ErrGraphNodeNotFound +} + // IsStaleNode returns true if the graph source has a node announcement for the // target node with a more recent timestamp. func (r *mockGraphSource) IsStaleNode(nodePub routing.Vertex, timestamp time.Time) bool { diff --git a/routing/router.go b/routing/router.go index a3c91388..31a269e4 100644 --- a/routing/router.go +++ b/routing/router.go @@ -100,6 +100,11 @@ type ChannelGraphSource interface { GetChannelByID(chanID lnwire.ShortChannelID) (*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, *channeldb.ChannelEdgePolicy, error) + // FetchLightningNode attempts to look up a target node by its identity + // public key. channeldb.ErrGraphNodeNotFound is returned if the node + // doesn't exist within the graph. + FetchLightningNode(Vertex) (*channeldb.LightningNode, error) + // ForEachNode is used to iterate over every node in the known graph. ForEachNode(func(node *channeldb.LightningNode) error) error @@ -2163,6 +2168,19 @@ func (r *ChannelRouter) GetChannelByID(chanID lnwire.ShortChannelID) ( return r.cfg.Graph.FetchChannelEdgesByID(chanID.ToUint64()) } +// FetchLightningNode attempts to look up a target node by its identity public +// key. channeldb.ErrGraphNodeNotFound is returned if the node doesn't exist +// within the graph. +// +// NOTE: This method is part of the ChannelGraphSource interface. +func (r *ChannelRouter) FetchLightningNode(node Vertex) (*channeldb.LightningNode, error) { + pubKey, err := btcec.ParsePubKey(node[:], btcec.S256()) + if err != nil { + return nil, fmt.Errorf("unable to parse raw public key: %v", err) + } + return r.cfg.Graph.FetchLightningNode(pubKey) +} + // ForEachNode is used to iterate over every node in router topology. // // NOTE: This method is part of the ChannelGraphSource interface.