diff --git a/autopilot/manager.go b/autopilot/manager.go index 17ce145c..f078b2a6 100644 --- a/autopilot/manager.go +++ b/autopilot/manager.go @@ -270,7 +270,9 @@ func (m *Manager) StopAgent() error { } // QueryHeuristics queries the available autopilot heuristics for node scores. -func (m *Manager) QueryHeuristics(nodes []NodeID) (HeuristicScores, error) { +func (m *Manager) QueryHeuristics(nodes []NodeID, localState bool) ( + HeuristicScores, error) { + m.Lock() defer m.Unlock() @@ -280,7 +282,7 @@ func (m *Manager) QueryHeuristics(nodes []NodeID) (HeuristicScores, error) { } log.Debugf("Querying heuristics for %d nodes", len(n)) - return m.queryHeuristics(n) + return m.queryHeuristics(n, localState) } // HeuristicScores is an alias for a map that maps heuristic names to a map of @@ -291,25 +293,32 @@ type HeuristicScores map[string]map[NodeID]float64 // the agent's current active heuristic. // // NOTE: Must be called with the manager's lock. -func (m *Manager) queryHeuristics(nodes map[NodeID]struct{}) ( +func (m *Manager) queryHeuristics(nodes map[NodeID]struct{}, localState bool) ( HeuristicScores, error) { - // Fetch the current set of channels. - totalChans, err := m.cfg.ChannelState() - if err != nil { - return nil, err - } + // If we want to take the local state into action when querying the + // heuristics, we fetch it. If not we'll just pass an emply slice to + // the heuristic. + var totalChans []Channel + var err error + if localState { + // Fetch the current set of channels. + totalChans, err = m.cfg.ChannelState() + if err != nil { + return nil, err + } - // If the agent is active, we can merge the channel state with the - // channels pending open. - if m.pilot != nil { - m.pilot.chanStateMtx.Lock() - m.pilot.pendingMtx.Lock() - totalChans = mergeChanState( - m.pilot.pendingOpens, m.pilot.chanState, - ) - m.pilot.pendingMtx.Unlock() - m.pilot.chanStateMtx.Unlock() + // If the agent is active, we can merge the channel state with + // the channels pending open. + if m.pilot != nil { + m.pilot.chanStateMtx.Lock() + m.pilot.pendingMtx.Lock() + totalChans = mergeChanState( + m.pilot.pendingOpens, m.pilot.chanState, + ) + m.pilot.pendingMtx.Unlock() + m.pilot.chanStateMtx.Unlock() + } } // As channel size we'll use the maximum size. diff --git a/lnrpc/autopilotrpc/autopilot_server.go b/lnrpc/autopilotrpc/autopilot_server.go index 5f31384b..4a938072 100644 --- a/lnrpc/autopilotrpc/autopilot_server.go +++ b/lnrpc/autopilotrpc/autopilot_server.go @@ -180,7 +180,9 @@ func (s *Server) QueryScores(ctx context.Context, in *QueryScoresRequest) ( } // Query the heuristics. - heuristicScores, err := s.manager.QueryHeuristics(nodes) + heuristicScores, err := s.manager.QueryHeuristics( + nodes, !in.IgnoreLocalState, + ) if err != nil { return nil, err }