diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index b635b9b2..916d44f2 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -6,7 +6,6 @@ import ( "fmt" "github.com/btcsuite/btcutil" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing" @@ -27,6 +26,11 @@ type RouterBackend struct { // capacity of a channel to populate in responses. FetchChannelCapacity func(chanID uint64) (btcutil.Amount, error) + // FetchChannelEndpoints returns the pubkeys of both endpoints of the + // given channel id. + FetchChannelEndpoints func(chanID uint64) (route.Vertex, + route.Vertex, error) + // FindRoutes is a closure that abstracts away how we locate/query for // routes. FindRoute func(source, target route.Vertex, @@ -224,22 +228,21 @@ func (r *RouterBackend) MarshallRoute(route *route.Route) *lnrpc.Route { // not known. This function will query the channel graph with channel id to // retrieve both endpoints and determine the hop pubkey using the previous hop // pubkey. If the channel is unknown, an error is returned. -func UnmarshallHopByChannelLookup(graph *channeldb.ChannelGraph, hop *lnrpc.Hop, +func (r *RouterBackend) UnmarshallHopByChannelLookup(hop *lnrpc.Hop, prevPubKeyBytes [33]byte) (*route.Hop, error) { // Discard edge policies, because they may be nil. - edgeInfo, _, _, err := graph.FetchChannelEdgesByID(hop.ChanId) + node1, node2, err := r.FetchChannelEndpoints(hop.ChanId) if err != nil { - return nil, fmt.Errorf("unable to fetch channel edges by "+ - "channel ID %d: %v", hop.ChanId, err) + return nil, err } var pubKeyBytes [33]byte switch { - case prevPubKeyBytes == edgeInfo.NodeKey1Bytes: - pubKeyBytes = edgeInfo.NodeKey2Bytes - case prevPubKeyBytes == edgeInfo.NodeKey2Bytes: - pubKeyBytes = edgeInfo.NodeKey1Bytes + case prevPubKeyBytes == node1: + pubKeyBytes = node2 + case prevPubKeyBytes == node2: + pubKeyBytes = node1 default: return nil, fmt.Errorf("channel edge does not match expected node") } @@ -248,7 +251,7 @@ func UnmarshallHopByChannelLookup(graph *channeldb.ChannelGraph, hop *lnrpc.Hop, OutgoingTimeLock: hop.Expiry, AmtToForward: lnwire.MilliSatoshi(hop.AmtToForwardMsat), PubKeyBytes: pubKeyBytes, - ChannelID: edgeInfo.ChannelID, + ChannelID: hop.ChanId, }, nil } @@ -274,14 +277,14 @@ func UnmarshallKnownPubkeyHop(hop *lnrpc.Hop) (*route.Hop, error) { // UnmarshallHop unmarshalls an rpc hop that may or may not contain a node // pubkey. -func UnmarshallHop(graph *channeldb.ChannelGraph, hop *lnrpc.Hop, +func (r *RouterBackend) UnmarshallHop(hop *lnrpc.Hop, prevNodePubKey [33]byte) (*route.Hop, error) { if hop.PubKey == "" { // If no pub key is given of the hop, the local channel // graph needs to be queried to complete the information // necessary for routing. - return UnmarshallHopByChannelLookup(graph, hop, prevNodePubKey) + return r.UnmarshallHopByChannelLookup(hop, prevNodePubKey) } return UnmarshallKnownPubkeyHop(hop) @@ -289,21 +292,14 @@ func UnmarshallHop(graph *channeldb.ChannelGraph, hop *lnrpc.Hop, // UnmarshallRoute unmarshalls an rpc route. For hops that don't specify a // pubkey, the channel graph is queried. -func UnmarshallRoute(rpcroute *lnrpc.Route, - graph *channeldb.ChannelGraph) (*route.Route, error) { +func (r *RouterBackend) UnmarshallRoute(rpcroute *lnrpc.Route) ( + *route.Route, error) { - sourceNode, err := graph.SourceNode() - if err != nil { - return nil, fmt.Errorf("unable to fetch source node from graph "+ - "while unmarshaling route. %v", err) - } - - prevNodePubKey := sourceNode.PubKeyBytes + prevNodePubKey := r.SelfNode hops := make([]*route.Hop, len(rpcroute.Hops)) for i, hop := range rpcroute.Hops { - routeHop, err := UnmarshallHop(graph, - hop, prevNodePubKey) + routeHop, err := r.UnmarshallHop(hop, prevNodePubKey) if err != nil { return nil, err } @@ -316,7 +312,7 @@ func UnmarshallRoute(rpcroute *lnrpc.Route, route, err := route.NewRouteFromHops( lnwire.MilliSatoshi(rpcroute.TotalAmtMsat), rpcroute.TotalTimeLock, - sourceNode.PubKeyBytes, + r.SelfNode, hops, ) if err != nil { diff --git a/rpcserver.go b/rpcserver.go index 9b9fc1f1..858d6867 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -462,6 +462,21 @@ func newRPCServer(s *server, macService *macaroons.Service, } return info.Capacity, nil }, + FetchChannelEndpoints: func(chanID uint64) (route.Vertex, + route.Vertex, error) { + + info, _, _, err := graph.FetchChannelEdgesByID( + chanID, + ) + if err != nil { + return route.Vertex{}, route.Vertex{}, + fmt.Errorf("unable to fetch channel "+ + "edges by channel ID %d: %v", + chanID, err) + } + + return info.NodeKey1Bytes, info.NodeKey2Bytes, nil + }, FindRoute: s.chanRouter.FindRoute, } @@ -2837,9 +2852,7 @@ func (r *rpcServer) SendToRoute(stream lnrpc.Lightning_SendToRouteServer) error return nil, err } - graph := r.server.chanDB.ChannelGraph() - - return unmarshallSendToRouteRequest(req, graph) + return r.unmarshallSendToRouteRequest(req) }, send: func(r *lnrpc.SendResponse) error { // Calling stream.Send concurrently is not safe. @@ -2851,14 +2864,14 @@ func (r *rpcServer) SendToRoute(stream lnrpc.Lightning_SendToRouteServer) error } // unmarshallSendToRouteRequest unmarshalls an rpc sendtoroute request -func unmarshallSendToRouteRequest(req *lnrpc.SendToRouteRequest, - graph *channeldb.ChannelGraph) (*rpcPaymentRequest, error) { +func (r *rpcServer) unmarshallSendToRouteRequest( + req *lnrpc.SendToRouteRequest) (*rpcPaymentRequest, error) { if req.Route == nil { return nil, fmt.Errorf("unable to send, no route provided") } - route, err := routerrpc.UnmarshallRoute(req.Route, graph) + route, err := r.routerBackend.UnmarshallRoute(req.Route) if err != nil { return nil, err } @@ -3308,9 +3321,7 @@ func (r *rpcServer) SendToRouteSync(ctx context.Context, return nil, fmt.Errorf("unable to send, no routes provided") } - graph := r.server.chanDB.ChannelGraph() - - paymentRequest, err := unmarshallSendToRouteRequest(req, graph) + paymentRequest, err := r.unmarshallSendToRouteRequest(req) if err != nil { return nil, err }