From e1e805d1b81228d1fa1bafb3b2a8682bc8ab3b7a Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 15 Mar 2019 02:33:33 -0700 Subject: [PATCH] watchtower/wtserver/server: fix race condition on Stop --- watchtower/wtserver/server.go | 81 ++++++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 24 deletions(-) diff --git a/watchtower/wtserver/server.go b/watchtower/wtserver/server.go index dd39af98..b9d22275 100644 --- a/watchtower/wtserver/server.go +++ b/watchtower/wtserver/server.go @@ -6,7 +6,6 @@ import ( "fmt" "net" "sync" - "sync/atomic" "time" "github.com/btcsuite/btcd/btcec" @@ -65,8 +64,8 @@ type Config struct { // is to accept incoming connections, and dispatch processing of the client // message streams. type Server struct { - started int32 // atomic - shutdown int32 // atomic + started sync.Once + stopped sync.Once cfg *Config @@ -75,6 +74,8 @@ type Server struct { clientMtx sync.RWMutex clients map[wtdb.SessionID]Peer + newPeers chan Peer + localInit *wtwire.Init wg sync.WaitGroup @@ -93,6 +94,7 @@ func New(cfg *Config) (*Server, error) { s := &Server{ cfg: cfg, clients: make(map[wtdb.SessionID]Peer), + newPeers: make(chan Peer), localInit: localInit, quit: make(chan struct{}), } @@ -113,36 +115,31 @@ func New(cfg *Config) (*Server, error) { // Start begins listening on the server's listeners. func (s *Server) Start() error { - // Already running? - if !atomic.CompareAndSwapInt32(&s.started, 0, 1) { - return nil - } + s.started.Do(func() { + log.Infof("Starting watchtower server") - log.Infof("Starting watchtower server") + s.wg.Add(1) + go s.peerHandler() - s.connMgr.Start() - - log.Infof("Watchtower server started successfully") + s.connMgr.Start() + log.Infof("Watchtower server started successfully") + }) return nil } // Stop shutdowns down the server's listeners and any active requests. func (s *Server) Stop() error { - // Bail if we're already shutting down. - if !atomic.CompareAndSwapInt32(&s.shutdown, 0, 1) { - return nil - } + s.stopped.Do(func() { + log.Infof("Stopping watchtower server") - log.Infof("Stopping watchtower server") + s.connMgr.Stop() - s.connMgr.Stop() - - close(s.quit) - s.wg.Wait() - - log.Infof("Watchtower server stopped successfully") + close(s.quit) + s.wg.Wait() + log.Infof("Watchtower server stopped successfully") + }) return nil } @@ -167,8 +164,29 @@ func (s *Server) inboundPeerConnected(c net.Conn) { // by the client. This method serves also as a public endpoint for locally // registering new clients with the server. func (s *Server) InboundPeerConnected(peer Peer) { - s.wg.Add(1) - go s.handleClient(peer) + select { + case s.newPeers <- peer: + case <-s.quit: + } +} + +// peerHandler processes newly accepted peers and spawns a client handler for +// each. The peerHandler is used to ensure that waitgrouped client handlers are +// spawned from a waitgrouped goroutine. +func (s *Server) peerHandler() { + defer s.wg.Done() + defer s.removeAllPeers() + + for { + select { + case peer := <-s.newPeers: + s.wg.Add(1) + go s.handleClient(peer) + + case <-s.quit: + return + } + } } // handleClient processes a series watchtower messages sent by a client. The @@ -625,6 +643,21 @@ func (s *Server) removePeer(id *wtdb.SessionID, addr net.Addr) { } } +// removeAllPeers iterates through the server's current set of peers and closes +// all open connections. +func (s *Server) removeAllPeers() { + s.clientMtx.Lock() + defer s.clientMtx.Unlock() + + for id, peer := range s.clients { + log.Infof("Releasing incoming peer %s@%s", id, + peer.RemoteAddr()) + + delete(s.clients, id) + peer.Close() + } +} + // logMessage writes information about a message exchanged with a remote peer, // using directional prepositions to signal whether the message was sent or // received.