From 96ff63d219d106bab7a86c6d2ac57dcd5bf00ad9 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 31 Aug 2017 01:15:39 -0700 Subject: [PATCH] server: fixes race condition during unexpected peer disconnect --- server.go | 192 +++++++++++++++++++++++++++++++++--------------------- 1 file changed, 117 insertions(+), 75 deletions(-) diff --git a/server.go b/server.go index 2c906ad0..4e8ec31d 100644 --- a/server.go +++ b/server.go @@ -30,6 +30,16 @@ import ( "github.com/lightningnetwork/lnd/htlcswitch" ) +var ( + // ErrPeerNotFound signals that the server has no connection to the + // given peer. + ErrPeerNotFound = errors.New("server: peer not found") + + // ErrServerShuttingDown indicates that the server is in the process of + // gracefully exiting. + ErrServerShuttingDown = errors.New("server: shutting down") +) + // server is the main server of the Lightning Network Daemon. The server houses // global state pertaining to the wallet, database, and the rpcserver. // Additionally, the server is also used as a central messaging bus to interact @@ -62,6 +72,12 @@ type server struct { persistentPeers map[string]struct{} persistentConnReqs map[string][]*connmgr.ConnReq + // ignorePeerTermination tracks peers for which the server has initiated + // a disconnect. Adding a peer to this map causes the peer termination + // watcher to short circuit in the event that peers are purposefully + // disconnected. + ignorePeerTermination map[*peer]struct{} + cc *chainControl fundingMgr *fundingManager @@ -133,8 +149,9 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl, sphinx.NewRouter(privKey, activeNetParams.Params)), lightningID: sha256.Sum256(serializedPubKey), - persistentPeers: make(map[string]struct{}), - persistentConnReqs: make(map[string][]*connmgr.ConnReq), + persistentPeers: make(map[string]struct{}), + persistentConnReqs: make(map[string][]*connmgr.ConnReq), + ignorePeerTermination: make(map[*peer]struct{}), peersByID: make(map[int32]*peer), peersByPub: make(map[string]*peer), @@ -424,8 +441,7 @@ func (s *server) Stop() error { // Disconnect from each active peers to ensure that // peerTerminationWatchers signal completion to each peer. - peers := s.Peers() - for _, peer := range peers { + for _, peer := range s.Peers() { s.DisconnectPeer(peer.addr.IdentityKey) } @@ -826,13 +842,9 @@ func (s *server) broadcastMessages( // throughout this process to ensure we deliver messages to exact set // of peers present at the time of invocation. var wg sync.WaitGroup - for _, sPeer := range s.peersByPub { - if skip != nil && - sPeer.addr.IdentityKey.IsEqual(skip) { - - srvrLog.Debugf("Skipping %v in broadcast", - skip.SerializeCompressed()) - + for pubStr, sPeer := range s.peersByPub { + if skip != nil && sPeer.addr.IdentityKey.IsEqual(skip) { + srvrLog.Debugf("Skipping %v in broadcast", pubStr) continue } @@ -905,12 +917,11 @@ func (s *server) sendToPeer(target *btcec.PublicKey, // caller if the peer is unknown. Access to peersByPub is synchronized // here to ensure we consider the exact set of peers present at the // time of invocation. - targetPeer, ok := s.peersByPub[string(targetPubBytes)] - if !ok { + targetPeer, err := s.findPeerByPubStr(string(targetPubBytes)) + if err != nil { srvrLog.Errorf("unable to send message to %x, "+ "peer not found", targetPubBytes) - - return errors.New("peer not found") + return err } s.sendPeerMessages(targetPeer, msgs, nil) @@ -958,9 +969,9 @@ func (s *server) FindPeer(peerKey *btcec.PublicKey) (*peer, error) { s.mu.Lock() defer s.mu.Unlock() - serializedIDKey := string(peerKey.SerializeCompressed()) + pubStr := string(peerKey.SerializeCompressed()) - return s.findPeer(serializedIDKey) + return s.findPeerByPubStr(pubStr) } // FindPeerByPubStr will return the peer that corresponds to the passed peerID, @@ -968,34 +979,34 @@ func (s *server) FindPeer(peerKey *btcec.PublicKey) (*peer, error) { // public key. // // NOTE: This function is safe for concurrent access. -func (s *server) FindPeerByPubStr(peerID string) (*peer, error) { +func (s *server) FindPeerByPubStr(pubStr string) (*peer, error) { s.mu.Lock() defer s.mu.Unlock() - return s.findPeer(peerID) + return s.findPeerByPubStr(pubStr) } -// findPeer is an internal method that retrieves the specified peer from the -// server's internal state. -func (s *server) findPeer(peerID string) (*peer, error) { - peer := s.peersByPub[peerID] - if peer == nil { - return nil, errors.New("Peer not found. Pubkey: " + peerID) +// findPeerByPubStr is an internal method that retrieves the specified peer from +// the server's internal state using. +func (s *server) findPeerByPubStr(pubStr string) (*peer, error) { + peer, ok := s.peersByPub[pubStr] + if !ok { + return nil, ErrPeerNotFound } return peer, nil } -// peerTerminationWatcher waits until a peer has been disconnected, and then -// cleans up all resources allocated to the peer, notifies relevant sub-systems -// of its demise, and finally handles re-connecting to the peer if it's -// persistent. +// peerTerminationWatcher waits until a peer has been disconnected unexpectedly, +// and then cleans up all resources allocated to the peer, notifies relevant +// sub-systems of its demise, and finally handles re-connecting to the peer if +// it's persistent. If the server intentionally disconnects a peer, it should +// have a corresponding entry in the ignorePeerTermination map which will cause +// the cleanup routine to exit early. // -// NOTE: This MUST be launched as a goroutine AND the _peer's_ WaitGroup should -// be incremented before spawning this method, as it will signal to the peer's -// WaitGroup upon completion. +// NOTE: This MUST be launched as a goroutine. func (s *server) peerTerminationWatcher(p *peer) { - defer p.wg.Done() + defer s.wg.Done() p.WaitForDisconnect() @@ -1025,16 +1036,20 @@ func (s *server) peerTerminationWatcher(p *peer) { } } - // Send the peer to be garbage collected by the server. - s.removePeer(p) + s.mu.Lock() + defer s.mu.Unlock() - // If this peer had an active persistent connection request, then we - // can remove this as we manually decide below if we should attempt to - // re-connect. - if p.connReq != nil { - s.connMgr.Remove(p.connReq.ID()) + // If the server has already removed this peer, we can short circuit the + // peer termination watcher and skip cleanup. + if _, ok := s.ignorePeerTermination[p]; ok { + delete(s.ignorePeerTermination, p) + return } + // First, cleanup any remaining state the server has regarding the peer + // in question. + s.removePeer(p) + // Next, check to see if this is a persistent peer or not. pubStr := string(p.addr.IdentityKey.SerializeCompressed()) _, ok := s.persistentPeers[pubStr] @@ -1148,7 +1163,14 @@ func (s *server) InboundPeerConnected(conn net.Conn) { // Check to see if we should drop our connection, if not, then we'll // close out this connection with the remote peer. This // prevents us from having duplicate connections, or none. - if connectedPeer, ok := s.peersByPub[pubStr]; ok { + connectedPeer, err := s.findPeerByPubStr(pubStr) + switch err { + case ErrPeerNotFound: + // We were unable to locate an existing connection with the + // target peer, proceed to connect. + break + + case nil: // If the connection we've already established should be kept, // then we'll close out this connection s.t there's only a // single connection between us. @@ -1165,9 +1187,12 @@ func (s *server) InboundPeerConnected(conn net.Conn) { // peer to the peer garbage collection goroutine. srvrLog.Debugf("Disconnecting stale connection to %v", connectedPeer) - connectedPeer.Disconnect(errors.New("remove stale connection")) + // Remove the current peer from the server's internal state and + // signal that the peer termination watcher does not need to + // execute for this peer. s.removePeer(connectedPeer) + s.ignorePeerTermination[connectedPeer] = struct{}{} } // Next, check to see if we have any outstanding persistent connection @@ -1231,7 +1256,15 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn) // If we already have an inbound connection from this peer, then we'll // check to see _which_ of our connections should be dropped. - if connectedPeer, ok := s.peersByPub[pubStr]; ok { + connectedPeer, err := s.findPeerByPubStr(pubStr) + switch err { + case ErrPeerNotFound: + // We were unable to locate an existing connection with the + // target peer, proceed to connect. + break + + case nil: + // We already have a connection open with the target peer. // If our (this) connection should be dropped, then we'll do // so, in order to ensure we don't have any duplicate // connections. @@ -1239,11 +1272,9 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn) srvrLog.Warnf("Established outbound connection to "+ "peer %x, but already connected, dropping conn", nodePub.SerializeCompressed()) - if connReq != nil { s.connMgr.Remove(connReq.ID()) } - conn.Close() return } @@ -1253,9 +1284,12 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn) // server for garbage collection. srvrLog.Debugf("Disconnecting stale connection to %v", connectedPeer) - connectedPeer.Disconnect(errors.New("remove stale connection")) + // Remove the current peer from the server's internal state and + // signal that the peer termination watcher does not need to + // execute for this peer. s.removePeer(connectedPeer) + s.ignorePeerTermination[connectedPeer] = struct{}{} } s.peerConnected(conn, connReq, true) @@ -1269,8 +1303,8 @@ func (s *server) addPeer(p *peer) { } // Ignore new peers if we're shutting down. - if atomic.LoadInt32(&s.shutdown) != 0 { - p.Disconnect(errors.New("server is shutting down")) + if s.Stopped() { + p.Disconnect(ErrServerShuttingDown) return } @@ -1290,12 +1324,13 @@ func (s *server) addPeer(p *peer) { s.outboundPeers[pubStr] = p } - // Launch a goroutine to watch for the termination of this peer so we - // can ensure all resources are properly cleaned up and if need be - // connections are re-established. The go routine is tracked by the - // _peer's_ WaitGroup so that a call to Disconnect will block until the - // `peerTerminationWatcher` has exited. - p.wg.Add(1) + // Launch a goroutine to watch for the unexpected termination of this + // peer, which will ensure all resources are properly cleaned up, and + // re-establish persistent connections when necessary. The peer + // termination watcher will be short circuited if the peer is ever added + // to the ignorePeerTermination map, indicating that the server has + // already handled the removal of this peer. + s.wg.Add(1) go s.peerTerminationWatcher(p) // Once the peer has been added to our indexes, send a message to the @@ -1321,10 +1356,15 @@ func (s *server) removePeer(p *peer) { // As the peer is now finished, ensure that the TCP connection is // closed and all of its related goroutines have exited. - p.Disconnect(errors.New("remove peer")) + p.Disconnect(fmt.Errorf("server: disconnecting peer %v", p)) + + // If this peer had an active persistent connection request, remove it. + if p.connReq != nil { + s.connMgr.Remove(p.connReq.ID()) + } // Ignore deleting peers if we're shutting down. - if atomic.LoadInt32(&s.shutdown) != 0 { + if s.Stopped() { return } @@ -1376,10 +1416,14 @@ func (s *server) ConnectToPeer(addr *lnwire.NetAddress, perm bool) error { s.mu.Lock() // Ensure we're not already connected to this peer. - peer, ok := s.peersByPub[targetPub] - if ok { - s.mu.Unlock() + peer, err := s.findPeerByPubStr(targetPub) + switch err { + case ErrPeerNotFound: + // Peer was not found, continue to pursue connection with peer. + break + case nil: + s.mu.Unlock() return fmt.Errorf("already connected to peer: %v", peer) } @@ -1388,7 +1432,6 @@ func (s *server) ConnectToPeer(addr *lnwire.NetAddress, perm bool) error { // connection. if _, ok := s.persistentConnReqs[targetPub]; ok { s.mu.Unlock() - return fmt.Errorf("connection attempt to %v is pending", addr) } @@ -1443,27 +1486,26 @@ func (s *server) DisconnectPeer(pubKey *btcec.PublicKey) error { // Check that were actually connected to this peer. If not, then we'll // exit in an error as we can't disconnect from a peer that we're not - // currently connected to. - peer, ok := s.peersByPub[pubStr] - if !ok { - return fmt.Errorf("unable to find peer %x", pubBytes) + // currently connected to. This will also return an error if we already + // have a pending disconnect request for this peer, ensuring the + // operation only happens once. + peer, err := s.findPeerByPubStr(pubStr) + if err != nil { + return err } + srvrLog.Infof("Disconnecting from %v", peer) + // If this peer was formerly a persistent connection, then we'll remove // them from this map so we don't attempt to re-connect after we // disconnect. - if _, ok := s.persistentPeers[pubStr]; ok { - delete(s.persistentPeers, pubStr) - } + delete(s.persistentPeers, pubStr) - // Now that we know the peer is actually connected, we'll disconnect - // from the peer. The lock is held until after Disconnect to ensure - // that the peer's `peerTerminationWatcher` has fully exited before - // returning to the caller. - srvrLog.Infof("Disconnecting from %v", peer) - peer.Disconnect( - errors.New("received user command to disconnect the peer"), - ) + // Remove the current peer from the server's internal state and signal + // that the peer termination watcher does not need to execute for this + // peer. + s.removePeer(peer) + s.ignorePeerTermination[peer] = struct{}{} return nil }