server: fixes race condition during unexpected peer disconnect

This commit is contained in:
Conner Fromknecht 2017-08-31 01:15:39 -07:00 committed by Olaoluwa Osuntokun
parent f20cb89982
commit 96ff63d219

186
server.go

@ -30,6 +30,16 @@ import (
"github.com/lightningnetwork/lnd/htlcswitch"
)
var (
// ErrPeerNotFound signals that the server has no connection to the
// given peer.
ErrPeerNotFound = errors.New("server: peer not found")
// ErrServerShuttingDown indicates that the server is in the process of
// gracefully exiting.
ErrServerShuttingDown = errors.New("server: shutting down")
)
// server is the main server of the Lightning Network Daemon. The server houses
// global state pertaining to the wallet, database, and the rpcserver.
// Additionally, the server is also used as a central messaging bus to interact
@ -62,6 +72,12 @@ type server struct {
persistentPeers map[string]struct{}
persistentConnReqs map[string][]*connmgr.ConnReq
// ignorePeerTermination tracks peers for which the server has initiated
// a disconnect. Adding a peer to this map causes the peer termination
// watcher to short circuit in the event that peers are purposefully
// disconnected.
ignorePeerTermination map[*peer]struct{}
cc *chainControl
fundingMgr *fundingManager
@ -135,6 +151,7 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
persistentPeers: make(map[string]struct{}),
persistentConnReqs: make(map[string][]*connmgr.ConnReq),
ignorePeerTermination: make(map[*peer]struct{}),
peersByID: make(map[int32]*peer),
peersByPub: make(map[string]*peer),
@ -424,8 +441,7 @@ func (s *server) Stop() error {
// Disconnect from each active peers to ensure that
// peerTerminationWatchers signal completion to each peer.
peers := s.Peers()
for _, peer := range peers {
for _, peer := range s.Peers() {
s.DisconnectPeer(peer.addr.IdentityKey)
}
@ -826,13 +842,9 @@ func (s *server) broadcastMessages(
// throughout this process to ensure we deliver messages to exact set
// of peers present at the time of invocation.
var wg sync.WaitGroup
for _, sPeer := range s.peersByPub {
if skip != nil &&
sPeer.addr.IdentityKey.IsEqual(skip) {
srvrLog.Debugf("Skipping %v in broadcast",
skip.SerializeCompressed())
for pubStr, sPeer := range s.peersByPub {
if skip != nil && sPeer.addr.IdentityKey.IsEqual(skip) {
srvrLog.Debugf("Skipping %v in broadcast", pubStr)
continue
}
@ -905,12 +917,11 @@ func (s *server) sendToPeer(target *btcec.PublicKey,
// caller if the peer is unknown. Access to peersByPub is synchronized
// here to ensure we consider the exact set of peers present at the
// time of invocation.
targetPeer, ok := s.peersByPub[string(targetPubBytes)]
if !ok {
targetPeer, err := s.findPeerByPubStr(string(targetPubBytes))
if err != nil {
srvrLog.Errorf("unable to send message to %x, "+
"peer not found", targetPubBytes)
return errors.New("peer not found")
return err
}
s.sendPeerMessages(targetPeer, msgs, nil)
@ -958,9 +969,9 @@ func (s *server) FindPeer(peerKey *btcec.PublicKey) (*peer, error) {
s.mu.Lock()
defer s.mu.Unlock()
serializedIDKey := string(peerKey.SerializeCompressed())
pubStr := string(peerKey.SerializeCompressed())
return s.findPeer(serializedIDKey)
return s.findPeerByPubStr(pubStr)
}
// FindPeerByPubStr will return the peer that corresponds to the passed peerID,
@ -968,34 +979,34 @@ func (s *server) FindPeer(peerKey *btcec.PublicKey) (*peer, error) {
// public key.
//
// NOTE: This function is safe for concurrent access.
func (s *server) FindPeerByPubStr(peerID string) (*peer, error) {
func (s *server) FindPeerByPubStr(pubStr string) (*peer, error) {
s.mu.Lock()
defer s.mu.Unlock()
return s.findPeer(peerID)
return s.findPeerByPubStr(pubStr)
}
// findPeer is an internal method that retrieves the specified peer from the
// server's internal state.
func (s *server) findPeer(peerID string) (*peer, error) {
peer := s.peersByPub[peerID]
if peer == nil {
return nil, errors.New("Peer not found. Pubkey: " + peerID)
// findPeerByPubStr is an internal method that retrieves the specified peer from
// the server's internal state using.
func (s *server) findPeerByPubStr(pubStr string) (*peer, error) {
peer, ok := s.peersByPub[pubStr]
if !ok {
return nil, ErrPeerNotFound
}
return peer, nil
}
// peerTerminationWatcher waits until a peer has been disconnected, and then
// cleans up all resources allocated to the peer, notifies relevant sub-systems
// of its demise, and finally handles re-connecting to the peer if it's
// persistent.
// peerTerminationWatcher waits until a peer has been disconnected unexpectedly,
// and then cleans up all resources allocated to the peer, notifies relevant
// sub-systems of its demise, and finally handles re-connecting to the peer if
// it's persistent. If the server intentionally disconnects a peer, it should
// have a corresponding entry in the ignorePeerTermination map which will cause
// the cleanup routine to exit early.
//
// NOTE: This MUST be launched as a goroutine AND the _peer's_ WaitGroup should
// be incremented before spawning this method, as it will signal to the peer's
// WaitGroup upon completion.
// NOTE: This MUST be launched as a goroutine.
func (s *server) peerTerminationWatcher(p *peer) {
defer p.wg.Done()
defer s.wg.Done()
p.WaitForDisconnect()
@ -1025,16 +1036,20 @@ func (s *server) peerTerminationWatcher(p *peer) {
}
}
// Send the peer to be garbage collected by the server.
s.removePeer(p)
s.mu.Lock()
defer s.mu.Unlock()
// If this peer had an active persistent connection request, then we
// can remove this as we manually decide below if we should attempt to
// re-connect.
if p.connReq != nil {
s.connMgr.Remove(p.connReq.ID())
// If the server has already removed this peer, we can short circuit the
// peer termination watcher and skip cleanup.
if _, ok := s.ignorePeerTermination[p]; ok {
delete(s.ignorePeerTermination, p)
return
}
// First, cleanup any remaining state the server has regarding the peer
// in question.
s.removePeer(p)
// Next, check to see if this is a persistent peer or not.
pubStr := string(p.addr.IdentityKey.SerializeCompressed())
_, ok := s.persistentPeers[pubStr]
@ -1148,7 +1163,14 @@ func (s *server) InboundPeerConnected(conn net.Conn) {
// Check to see if we should drop our connection, if not, then we'll
// close out this connection with the remote peer. This
// prevents us from having duplicate connections, or none.
if connectedPeer, ok := s.peersByPub[pubStr]; ok {
connectedPeer, err := s.findPeerByPubStr(pubStr)
switch err {
case ErrPeerNotFound:
// We were unable to locate an existing connection with the
// target peer, proceed to connect.
break
case nil:
// If the connection we've already established should be kept,
// then we'll close out this connection s.t there's only a
// single connection between us.
@ -1165,9 +1187,12 @@ func (s *server) InboundPeerConnected(conn net.Conn) {
// peer to the peer garbage collection goroutine.
srvrLog.Debugf("Disconnecting stale connection to %v",
connectedPeer)
connectedPeer.Disconnect(errors.New("remove stale connection"))
// Remove the current peer from the server's internal state and
// signal that the peer termination watcher does not need to
// execute for this peer.
s.removePeer(connectedPeer)
s.ignorePeerTermination[connectedPeer] = struct{}{}
}
// Next, check to see if we have any outstanding persistent connection
@ -1231,7 +1256,15 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn)
// If we already have an inbound connection from this peer, then we'll
// check to see _which_ of our connections should be dropped.
if connectedPeer, ok := s.peersByPub[pubStr]; ok {
connectedPeer, err := s.findPeerByPubStr(pubStr)
switch err {
case ErrPeerNotFound:
// We were unable to locate an existing connection with the
// target peer, proceed to connect.
break
case nil:
// We already have a connection open with the target peer.
// If our (this) connection should be dropped, then we'll do
// so, in order to ensure we don't have any duplicate
// connections.
@ -1239,11 +1272,9 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn)
srvrLog.Warnf("Established outbound connection to "+
"peer %x, but already connected, dropping conn",
nodePub.SerializeCompressed())
if connReq != nil {
s.connMgr.Remove(connReq.ID())
}
conn.Close()
return
}
@ -1253,9 +1284,12 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn)
// server for garbage collection.
srvrLog.Debugf("Disconnecting stale connection to %v",
connectedPeer)
connectedPeer.Disconnect(errors.New("remove stale connection"))
// Remove the current peer from the server's internal state and
// signal that the peer termination watcher does not need to
// execute for this peer.
s.removePeer(connectedPeer)
s.ignorePeerTermination[connectedPeer] = struct{}{}
}
s.peerConnected(conn, connReq, true)
@ -1269,8 +1303,8 @@ func (s *server) addPeer(p *peer) {
}
// Ignore new peers if we're shutting down.
if atomic.LoadInt32(&s.shutdown) != 0 {
p.Disconnect(errors.New("server is shutting down"))
if s.Stopped() {
p.Disconnect(ErrServerShuttingDown)
return
}
@ -1290,12 +1324,13 @@ func (s *server) addPeer(p *peer) {
s.outboundPeers[pubStr] = p
}
// Launch a goroutine to watch for the termination of this peer so we
// can ensure all resources are properly cleaned up and if need be
// connections are re-established. The go routine is tracked by the
// _peer's_ WaitGroup so that a call to Disconnect will block until the
// `peerTerminationWatcher` has exited.
p.wg.Add(1)
// Launch a goroutine to watch for the unexpected termination of this
// peer, which will ensure all resources are properly cleaned up, and
// re-establish persistent connections when necessary. The peer
// termination watcher will be short circuited if the peer is ever added
// to the ignorePeerTermination map, indicating that the server has
// already handled the removal of this peer.
s.wg.Add(1)
go s.peerTerminationWatcher(p)
// Once the peer has been added to our indexes, send a message to the
@ -1321,10 +1356,15 @@ func (s *server) removePeer(p *peer) {
// As the peer is now finished, ensure that the TCP connection is
// closed and all of its related goroutines have exited.
p.Disconnect(errors.New("remove peer"))
p.Disconnect(fmt.Errorf("server: disconnecting peer %v", p))
// If this peer had an active persistent connection request, remove it.
if p.connReq != nil {
s.connMgr.Remove(p.connReq.ID())
}
// Ignore deleting peers if we're shutting down.
if atomic.LoadInt32(&s.shutdown) != 0 {
if s.Stopped() {
return
}
@ -1376,10 +1416,14 @@ func (s *server) ConnectToPeer(addr *lnwire.NetAddress, perm bool) error {
s.mu.Lock()
// Ensure we're not already connected to this peer.
peer, ok := s.peersByPub[targetPub]
if ok {
s.mu.Unlock()
peer, err := s.findPeerByPubStr(targetPub)
switch err {
case ErrPeerNotFound:
// Peer was not found, continue to pursue connection with peer.
break
case nil:
s.mu.Unlock()
return fmt.Errorf("already connected to peer: %v", peer)
}
@ -1388,7 +1432,6 @@ func (s *server) ConnectToPeer(addr *lnwire.NetAddress, perm bool) error {
// connection.
if _, ok := s.persistentConnReqs[targetPub]; ok {
s.mu.Unlock()
return fmt.Errorf("connection attempt to %v is pending", addr)
}
@ -1443,27 +1486,26 @@ func (s *server) DisconnectPeer(pubKey *btcec.PublicKey) error {
// Check that were actually connected to this peer. If not, then we'll
// exit in an error as we can't disconnect from a peer that we're not
// currently connected to.
peer, ok := s.peersByPub[pubStr]
if !ok {
return fmt.Errorf("unable to find peer %x", pubBytes)
// currently connected to. This will also return an error if we already
// have a pending disconnect request for this peer, ensuring the
// operation only happens once.
peer, err := s.findPeerByPubStr(pubStr)
if err != nil {
return err
}
srvrLog.Infof("Disconnecting from %v", peer)
// If this peer was formerly a persistent connection, then we'll remove
// them from this map so we don't attempt to re-connect after we
// disconnect.
if _, ok := s.persistentPeers[pubStr]; ok {
delete(s.persistentPeers, pubStr)
}
// Now that we know the peer is actually connected, we'll disconnect
// from the peer. The lock is held until after Disconnect to ensure
// that the peer's `peerTerminationWatcher` has fully exited before
// returning to the caller.
srvrLog.Infof("Disconnecting from %v", peer)
peer.Disconnect(
errors.New("received user command to disconnect the peer"),
)
// Remove the current peer from the server's internal state and signal
// that the peer termination watcher does not need to execute for this
// peer.
s.removePeer(peer)
s.ignorePeerTermination[peer] = struct{}{}
return nil
}