server: fixes race condition during unexpected peer disconnect

This commit is contained in:
Conner Fromknecht 2017-08-31 01:15:39 -07:00 committed by Olaoluwa Osuntokun
parent f20cb89982
commit 96ff63d219

192
server.go
View File

@ -30,6 +30,16 @@ import (
"github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/htlcswitch"
) )
var (
// ErrPeerNotFound signals that the server has no connection to the
// given peer.
ErrPeerNotFound = errors.New("server: peer not found")
// ErrServerShuttingDown indicates that the server is in the process of
// gracefully exiting.
ErrServerShuttingDown = errors.New("server: shutting down")
)
// server is the main server of the Lightning Network Daemon. The server houses // server is the main server of the Lightning Network Daemon. The server houses
// global state pertaining to the wallet, database, and the rpcserver. // global state pertaining to the wallet, database, and the rpcserver.
// Additionally, the server is also used as a central messaging bus to interact // Additionally, the server is also used as a central messaging bus to interact
@ -62,6 +72,12 @@ type server struct {
persistentPeers map[string]struct{} persistentPeers map[string]struct{}
persistentConnReqs map[string][]*connmgr.ConnReq persistentConnReqs map[string][]*connmgr.ConnReq
// ignorePeerTermination tracks peers for which the server has initiated
// a disconnect. Adding a peer to this map causes the peer termination
// watcher to short circuit in the event that peers are purposefully
// disconnected.
ignorePeerTermination map[*peer]struct{}
cc *chainControl cc *chainControl
fundingMgr *fundingManager fundingMgr *fundingManager
@ -133,8 +149,9 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
sphinx.NewRouter(privKey, activeNetParams.Params)), sphinx.NewRouter(privKey, activeNetParams.Params)),
lightningID: sha256.Sum256(serializedPubKey), lightningID: sha256.Sum256(serializedPubKey),
persistentPeers: make(map[string]struct{}), persistentPeers: make(map[string]struct{}),
persistentConnReqs: make(map[string][]*connmgr.ConnReq), persistentConnReqs: make(map[string][]*connmgr.ConnReq),
ignorePeerTermination: make(map[*peer]struct{}),
peersByID: make(map[int32]*peer), peersByID: make(map[int32]*peer),
peersByPub: make(map[string]*peer), peersByPub: make(map[string]*peer),
@ -424,8 +441,7 @@ func (s *server) Stop() error {
// Disconnect from each active peers to ensure that // Disconnect from each active peers to ensure that
// peerTerminationWatchers signal completion to each peer. // peerTerminationWatchers signal completion to each peer.
peers := s.Peers() for _, peer := range s.Peers() {
for _, peer := range peers {
s.DisconnectPeer(peer.addr.IdentityKey) s.DisconnectPeer(peer.addr.IdentityKey)
} }
@ -826,13 +842,9 @@ func (s *server) broadcastMessages(
// throughout this process to ensure we deliver messages to exact set // throughout this process to ensure we deliver messages to exact set
// of peers present at the time of invocation. // of peers present at the time of invocation.
var wg sync.WaitGroup var wg sync.WaitGroup
for _, sPeer := range s.peersByPub { for pubStr, sPeer := range s.peersByPub {
if skip != nil && if skip != nil && sPeer.addr.IdentityKey.IsEqual(skip) {
sPeer.addr.IdentityKey.IsEqual(skip) { srvrLog.Debugf("Skipping %v in broadcast", pubStr)
srvrLog.Debugf("Skipping %v in broadcast",
skip.SerializeCompressed())
continue continue
} }
@ -905,12 +917,11 @@ func (s *server) sendToPeer(target *btcec.PublicKey,
// caller if the peer is unknown. Access to peersByPub is synchronized // caller if the peer is unknown. Access to peersByPub is synchronized
// here to ensure we consider the exact set of peers present at the // here to ensure we consider the exact set of peers present at the
// time of invocation. // time of invocation.
targetPeer, ok := s.peersByPub[string(targetPubBytes)] targetPeer, err := s.findPeerByPubStr(string(targetPubBytes))
if !ok { if err != nil {
srvrLog.Errorf("unable to send message to %x, "+ srvrLog.Errorf("unable to send message to %x, "+
"peer not found", targetPubBytes) "peer not found", targetPubBytes)
return err
return errors.New("peer not found")
} }
s.sendPeerMessages(targetPeer, msgs, nil) s.sendPeerMessages(targetPeer, msgs, nil)
@ -958,9 +969,9 @@ func (s *server) FindPeer(peerKey *btcec.PublicKey) (*peer, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
serializedIDKey := string(peerKey.SerializeCompressed()) pubStr := string(peerKey.SerializeCompressed())
return s.findPeer(serializedIDKey) return s.findPeerByPubStr(pubStr)
} }
// FindPeerByPubStr will return the peer that corresponds to the passed peerID, // FindPeerByPubStr will return the peer that corresponds to the passed peerID,
@ -968,34 +979,34 @@ func (s *server) FindPeer(peerKey *btcec.PublicKey) (*peer, error) {
// public key. // public key.
// //
// NOTE: This function is safe for concurrent access. // NOTE: This function is safe for concurrent access.
func (s *server) FindPeerByPubStr(peerID string) (*peer, error) { func (s *server) FindPeerByPubStr(pubStr string) (*peer, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
return s.findPeer(peerID) return s.findPeerByPubStr(pubStr)
} }
// findPeer is an internal method that retrieves the specified peer from the // findPeerByPubStr is an internal method that retrieves the specified peer from
// server's internal state. // the server's internal state using.
func (s *server) findPeer(peerID string) (*peer, error) { func (s *server) findPeerByPubStr(pubStr string) (*peer, error) {
peer := s.peersByPub[peerID] peer, ok := s.peersByPub[pubStr]
if peer == nil { if !ok {
return nil, errors.New("Peer not found. Pubkey: " + peerID) return nil, ErrPeerNotFound
} }
return peer, nil return peer, nil
} }
// peerTerminationWatcher waits until a peer has been disconnected, and then // peerTerminationWatcher waits until a peer has been disconnected unexpectedly,
// cleans up all resources allocated to the peer, notifies relevant sub-systems // and then cleans up all resources allocated to the peer, notifies relevant
// of its demise, and finally handles re-connecting to the peer if it's // sub-systems of its demise, and finally handles re-connecting to the peer if
// persistent. // it's persistent. If the server intentionally disconnects a peer, it should
// have a corresponding entry in the ignorePeerTermination map which will cause
// the cleanup routine to exit early.
// //
// NOTE: This MUST be launched as a goroutine AND the _peer's_ WaitGroup should // NOTE: This MUST be launched as a goroutine.
// be incremented before spawning this method, as it will signal to the peer's
// WaitGroup upon completion.
func (s *server) peerTerminationWatcher(p *peer) { func (s *server) peerTerminationWatcher(p *peer) {
defer p.wg.Done() defer s.wg.Done()
p.WaitForDisconnect() p.WaitForDisconnect()
@ -1025,16 +1036,20 @@ func (s *server) peerTerminationWatcher(p *peer) {
} }
} }
// Send the peer to be garbage collected by the server. s.mu.Lock()
s.removePeer(p) defer s.mu.Unlock()
// If this peer had an active persistent connection request, then we // If the server has already removed this peer, we can short circuit the
// can remove this as we manually decide below if we should attempt to // peer termination watcher and skip cleanup.
// re-connect. if _, ok := s.ignorePeerTermination[p]; ok {
if p.connReq != nil { delete(s.ignorePeerTermination, p)
s.connMgr.Remove(p.connReq.ID()) return
} }
// First, cleanup any remaining state the server has regarding the peer
// in question.
s.removePeer(p)
// Next, check to see if this is a persistent peer or not. // Next, check to see if this is a persistent peer or not.
pubStr := string(p.addr.IdentityKey.SerializeCompressed()) pubStr := string(p.addr.IdentityKey.SerializeCompressed())
_, ok := s.persistentPeers[pubStr] _, ok := s.persistentPeers[pubStr]
@ -1148,7 +1163,14 @@ func (s *server) InboundPeerConnected(conn net.Conn) {
// Check to see if we should drop our connection, if not, then we'll // Check to see if we should drop our connection, if not, then we'll
// close out this connection with the remote peer. This // close out this connection with the remote peer. This
// prevents us from having duplicate connections, or none. // prevents us from having duplicate connections, or none.
if connectedPeer, ok := s.peersByPub[pubStr]; ok { connectedPeer, err := s.findPeerByPubStr(pubStr)
switch err {
case ErrPeerNotFound:
// We were unable to locate an existing connection with the
// target peer, proceed to connect.
break
case nil:
// If the connection we've already established should be kept, // If the connection we've already established should be kept,
// then we'll close out this connection s.t there's only a // then we'll close out this connection s.t there's only a
// single connection between us. // single connection between us.
@ -1165,9 +1187,12 @@ func (s *server) InboundPeerConnected(conn net.Conn) {
// peer to the peer garbage collection goroutine. // peer to the peer garbage collection goroutine.
srvrLog.Debugf("Disconnecting stale connection to %v", srvrLog.Debugf("Disconnecting stale connection to %v",
connectedPeer) connectedPeer)
connectedPeer.Disconnect(errors.New("remove stale connection"))
// Remove the current peer from the server's internal state and
// signal that the peer termination watcher does not need to
// execute for this peer.
s.removePeer(connectedPeer) s.removePeer(connectedPeer)
s.ignorePeerTermination[connectedPeer] = struct{}{}
} }
// Next, check to see if we have any outstanding persistent connection // Next, check to see if we have any outstanding persistent connection
@ -1231,7 +1256,15 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn)
// If we already have an inbound connection from this peer, then we'll // If we already have an inbound connection from this peer, then we'll
// check to see _which_ of our connections should be dropped. // check to see _which_ of our connections should be dropped.
if connectedPeer, ok := s.peersByPub[pubStr]; ok { connectedPeer, err := s.findPeerByPubStr(pubStr)
switch err {
case ErrPeerNotFound:
// We were unable to locate an existing connection with the
// target peer, proceed to connect.
break
case nil:
// We already have a connection open with the target peer.
// If our (this) connection should be dropped, then we'll do // If our (this) connection should be dropped, then we'll do
// so, in order to ensure we don't have any duplicate // so, in order to ensure we don't have any duplicate
// connections. // connections.
@ -1239,11 +1272,9 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn)
srvrLog.Warnf("Established outbound connection to "+ srvrLog.Warnf("Established outbound connection to "+
"peer %x, but already connected, dropping conn", "peer %x, but already connected, dropping conn",
nodePub.SerializeCompressed()) nodePub.SerializeCompressed())
if connReq != nil { if connReq != nil {
s.connMgr.Remove(connReq.ID()) s.connMgr.Remove(connReq.ID())
} }
conn.Close() conn.Close()
return return
} }
@ -1253,9 +1284,12 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn)
// server for garbage collection. // server for garbage collection.
srvrLog.Debugf("Disconnecting stale connection to %v", srvrLog.Debugf("Disconnecting stale connection to %v",
connectedPeer) connectedPeer)
connectedPeer.Disconnect(errors.New("remove stale connection"))
// Remove the current peer from the server's internal state and
// signal that the peer termination watcher does not need to
// execute for this peer.
s.removePeer(connectedPeer) s.removePeer(connectedPeer)
s.ignorePeerTermination[connectedPeer] = struct{}{}
} }
s.peerConnected(conn, connReq, true) s.peerConnected(conn, connReq, true)
@ -1269,8 +1303,8 @@ func (s *server) addPeer(p *peer) {
} }
// Ignore new peers if we're shutting down. // Ignore new peers if we're shutting down.
if atomic.LoadInt32(&s.shutdown) != 0 { if s.Stopped() {
p.Disconnect(errors.New("server is shutting down")) p.Disconnect(ErrServerShuttingDown)
return return
} }
@ -1290,12 +1324,13 @@ func (s *server) addPeer(p *peer) {
s.outboundPeers[pubStr] = p s.outboundPeers[pubStr] = p
} }
// Launch a goroutine to watch for the termination of this peer so we // Launch a goroutine to watch for the unexpected termination of this
// can ensure all resources are properly cleaned up and if need be // peer, which will ensure all resources are properly cleaned up, and
// connections are re-established. The go routine is tracked by the // re-establish persistent connections when necessary. The peer
// _peer's_ WaitGroup so that a call to Disconnect will block until the // termination watcher will be short circuited if the peer is ever added
// `peerTerminationWatcher` has exited. // to the ignorePeerTermination map, indicating that the server has
p.wg.Add(1) // already handled the removal of this peer.
s.wg.Add(1)
go s.peerTerminationWatcher(p) go s.peerTerminationWatcher(p)
// Once the peer has been added to our indexes, send a message to the // Once the peer has been added to our indexes, send a message to the
@ -1321,10 +1356,15 @@ func (s *server) removePeer(p *peer) {
// As the peer is now finished, ensure that the TCP connection is // As the peer is now finished, ensure that the TCP connection is
// closed and all of its related goroutines have exited. // closed and all of its related goroutines have exited.
p.Disconnect(errors.New("remove peer")) p.Disconnect(fmt.Errorf("server: disconnecting peer %v", p))
// If this peer had an active persistent connection request, remove it.
if p.connReq != nil {
s.connMgr.Remove(p.connReq.ID())
}
// Ignore deleting peers if we're shutting down. // Ignore deleting peers if we're shutting down.
if atomic.LoadInt32(&s.shutdown) != 0 { if s.Stopped() {
return return
} }
@ -1376,10 +1416,14 @@ func (s *server) ConnectToPeer(addr *lnwire.NetAddress, perm bool) error {
s.mu.Lock() s.mu.Lock()
// Ensure we're not already connected to this peer. // Ensure we're not already connected to this peer.
peer, ok := s.peersByPub[targetPub] peer, err := s.findPeerByPubStr(targetPub)
if ok { switch err {
s.mu.Unlock() case ErrPeerNotFound:
// Peer was not found, continue to pursue connection with peer.
break
case nil:
s.mu.Unlock()
return fmt.Errorf("already connected to peer: %v", peer) return fmt.Errorf("already connected to peer: %v", peer)
} }
@ -1388,7 +1432,6 @@ func (s *server) ConnectToPeer(addr *lnwire.NetAddress, perm bool) error {
// connection. // connection.
if _, ok := s.persistentConnReqs[targetPub]; ok { if _, ok := s.persistentConnReqs[targetPub]; ok {
s.mu.Unlock() s.mu.Unlock()
return fmt.Errorf("connection attempt to %v is pending", addr) return fmt.Errorf("connection attempt to %v is pending", addr)
} }
@ -1443,27 +1486,26 @@ func (s *server) DisconnectPeer(pubKey *btcec.PublicKey) error {
// Check that were actually connected to this peer. If not, then we'll // Check that were actually connected to this peer. If not, then we'll
// exit in an error as we can't disconnect from a peer that we're not // exit in an error as we can't disconnect from a peer that we're not
// currently connected to. // currently connected to. This will also return an error if we already
peer, ok := s.peersByPub[pubStr] // have a pending disconnect request for this peer, ensuring the
if !ok { // operation only happens once.
return fmt.Errorf("unable to find peer %x", pubBytes) peer, err := s.findPeerByPubStr(pubStr)
if err != nil {
return err
} }
srvrLog.Infof("Disconnecting from %v", peer)
// If this peer was formerly a persistent connection, then we'll remove // If this peer was formerly a persistent connection, then we'll remove
// them from this map so we don't attempt to re-connect after we // them from this map so we don't attempt to re-connect after we
// disconnect. // disconnect.
if _, ok := s.persistentPeers[pubStr]; ok { delete(s.persistentPeers, pubStr)
delete(s.persistentPeers, pubStr)
}
// Now that we know the peer is actually connected, we'll disconnect // Remove the current peer from the server's internal state and signal
// from the peer. The lock is held until after Disconnect to ensure // that the peer termination watcher does not need to execute for this
// that the peer's `peerTerminationWatcher` has fully exited before // peer.
// returning to the caller. s.removePeer(peer)
srvrLog.Infof("Disconnecting from %v", peer) s.ignorePeerTermination[peer] = struct{}{}
peer.Disconnect(
errors.New("received user command to disconnect the peer"),
)
return nil return nil
} }