diff --git a/autopilot/agent_test.go b/autopilot/agent_test.go index d1de9ccc..2c0b2280 100644 --- a/autopilot/agent_test.go +++ b/autopilot/agent_test.go @@ -1546,3 +1546,156 @@ func TestAgentSkipPendingConns(t *testing.T) { t.Fatalf("agent should have attempted connection") } } + +// TestAgentQuitWhenPendingConns tests that we are able to stop the autopilot +// agent even though there are pending connections to nodes. +func TestAgentQuitWhenPendingConns(t *testing.T) { + t.Parallel() + + // First, we'll create all the dependencies that we'll need in order to + // create the autopilot agent. + self, err := randKey() + if err != nil { + t.Fatalf("unable to generate key: %v", err) + } + + quit := make(chan struct{}) + defer close(quit) + + heuristic := &mockHeuristic{ + nodeScoresArgs: make(chan directiveArg), + nodeScoresResps: make(chan map[NodeID]*NodeScore), + quit: quit, + } + constraints := &mockConstraints{ + moreChansResps: make(chan moreChansResp), + quit: quit, + } + + chanController := &mockChanController{ + openChanSignals: make(chan openChanIntent), + } + memGraph, _, _ := newMemChanGraph() + + // The wallet will start with 6 BTC available. + const walletBalance = btcutil.SatoshiPerBitcoin * 6 + + connect := make(chan chan error) + + // With the dependencies we created, we can now create the initial + // agent itself. + testCfg := Config{ + Self: self, + Heuristic: heuristic, + ChanController: chanController, + WalletBalance: func() (btcutil.Amount, error) { + return walletBalance, nil + }, + ConnectToPeer: func(*btcec.PublicKey, []net.Addr) (bool, error) { + errChan := make(chan error) + + select { + case connect <- errChan: + case <-quit: + return false, errors.New("quit") + } + + select { + case err := <-errChan: + return false, err + case <-quit: + return false, errors.New("quit") + } + }, + DisconnectPeer: func(*btcec.PublicKey) error { + return nil + }, + Graph: memGraph, + Constraints: constraints, + } + initialChans := []Channel{} + agent, err := New(testCfg, initialChans) + if err != nil { + t.Fatalf("unable to create agent: %v", err) + } + + // To ensure the heuristic doesn't block on quitting the agent, we'll + // use the agent's quit chan to signal when it should also stop. + heuristic.quit = agent.quit + + // With the autopilot agent and all its dependencies we'll start the + // primary controller goroutine. + if err := agent.Start(); err != nil { + t.Fatalf("unable to start agent: %v", err) + } + defer agent.Stop() + + // We'll only return a single directive for a pre-chosen node. + nodeKey, err := memGraph.addRandNode() + if err != nil { + t.Fatalf("unable to generate key: %v", err) + } + nodeID := NewNodeID(nodeKey) + nodeDirective := &NodeScore{ + NodeID: nodeID, + Score: 0.5, + } + + // We'll send an initial "yes" response to advance the agent past its + // initial check. This will cause it to try to get directives from the + // graph. + select { + case constraints.moreChansResps <- moreChansResp{ + numMore: 1, + amt: walletBalance, + }: + case <-time.After(time.Second * 10): + t.Fatalf("heuristic wasn't queried in time") + } + + // Check the args. + select { + case req := <-heuristic.nodeScoresArgs: + if len(req.nodes) != 1 { + t.Fatalf("expected %v nodes, instead "+ + "had %v", 1, len(req.nodes)) + } + if _, ok := req.nodes[nodeID]; !ok { + t.Fatalf("node not included in arguments") + } + case <-time.After(time.Second * 10): + t.Fatalf("select wasn't queried in time") + } + + // Respond with a scored directive. + select { + case heuristic.nodeScoresResps <- map[NodeID]*NodeScore{ + NewNodeID(nodeKey): nodeDirective, + }: + case <-time.After(time.Second * 10): + t.Fatalf("heuristic wasn't queried in time") + } + + // The agent should attempt connection to the node. + select { + case <-connect: + case <-time.After(time.Second * 10): + t.Fatalf("agent did not attempt connection") + } + + // Make sure that we are able to stop the agent, even though there is a + // pending connection. + stopped := make(chan error) + go func() { + stopped <- agent.Stop() + }() + + select { + case err := <-stopped: + if err != nil { + t.Fatalf("error stopping agent: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatalf("unable to stop agent") + } +}