server: fixes race condition during unexpected peer disconnect
This commit is contained in:
parent
f20cb89982
commit
96ff63d219
186
server.go
186
server.go
@ -30,6 +30,16 @@ import (
|
||||
"github.com/lightningnetwork/lnd/htlcswitch"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrPeerNotFound signals that the server has no connection to the
|
||||
// given peer.
|
||||
ErrPeerNotFound = errors.New("server: peer not found")
|
||||
|
||||
// ErrServerShuttingDown indicates that the server is in the process of
|
||||
// gracefully exiting.
|
||||
ErrServerShuttingDown = errors.New("server: shutting down")
|
||||
)
|
||||
|
||||
// server is the main server of the Lightning Network Daemon. The server houses
|
||||
// global state pertaining to the wallet, database, and the rpcserver.
|
||||
// Additionally, the server is also used as a central messaging bus to interact
|
||||
@ -62,6 +72,12 @@ type server struct {
|
||||
persistentPeers map[string]struct{}
|
||||
persistentConnReqs map[string][]*connmgr.ConnReq
|
||||
|
||||
// ignorePeerTermination tracks peers for which the server has initiated
|
||||
// a disconnect. Adding a peer to this map causes the peer termination
|
||||
// watcher to short circuit in the event that peers are purposefully
|
||||
// disconnected.
|
||||
ignorePeerTermination map[*peer]struct{}
|
||||
|
||||
cc *chainControl
|
||||
|
||||
fundingMgr *fundingManager
|
||||
@ -135,6 +151,7 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
|
||||
|
||||
persistentPeers: make(map[string]struct{}),
|
||||
persistentConnReqs: make(map[string][]*connmgr.ConnReq),
|
||||
ignorePeerTermination: make(map[*peer]struct{}),
|
||||
|
||||
peersByID: make(map[int32]*peer),
|
||||
peersByPub: make(map[string]*peer),
|
||||
@ -424,8 +441,7 @@ func (s *server) Stop() error {
|
||||
|
||||
// Disconnect from each active peers to ensure that
|
||||
// peerTerminationWatchers signal completion to each peer.
|
||||
peers := s.Peers()
|
||||
for _, peer := range peers {
|
||||
for _, peer := range s.Peers() {
|
||||
s.DisconnectPeer(peer.addr.IdentityKey)
|
||||
}
|
||||
|
||||
@ -826,13 +842,9 @@ func (s *server) broadcastMessages(
|
||||
// throughout this process to ensure we deliver messages to exact set
|
||||
// of peers present at the time of invocation.
|
||||
var wg sync.WaitGroup
|
||||
for _, sPeer := range s.peersByPub {
|
||||
if skip != nil &&
|
||||
sPeer.addr.IdentityKey.IsEqual(skip) {
|
||||
|
||||
srvrLog.Debugf("Skipping %v in broadcast",
|
||||
skip.SerializeCompressed())
|
||||
|
||||
for pubStr, sPeer := range s.peersByPub {
|
||||
if skip != nil && sPeer.addr.IdentityKey.IsEqual(skip) {
|
||||
srvrLog.Debugf("Skipping %v in broadcast", pubStr)
|
||||
continue
|
||||
}
|
||||
|
||||
@ -905,12 +917,11 @@ func (s *server) sendToPeer(target *btcec.PublicKey,
|
||||
// caller if the peer is unknown. Access to peersByPub is synchronized
|
||||
// here to ensure we consider the exact set of peers present at the
|
||||
// time of invocation.
|
||||
targetPeer, ok := s.peersByPub[string(targetPubBytes)]
|
||||
if !ok {
|
||||
targetPeer, err := s.findPeerByPubStr(string(targetPubBytes))
|
||||
if err != nil {
|
||||
srvrLog.Errorf("unable to send message to %x, "+
|
||||
"peer not found", targetPubBytes)
|
||||
|
||||
return errors.New("peer not found")
|
||||
return err
|
||||
}
|
||||
|
||||
s.sendPeerMessages(targetPeer, msgs, nil)
|
||||
@ -958,9 +969,9 @@ func (s *server) FindPeer(peerKey *btcec.PublicKey) (*peer, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
serializedIDKey := string(peerKey.SerializeCompressed())
|
||||
pubStr := string(peerKey.SerializeCompressed())
|
||||
|
||||
return s.findPeer(serializedIDKey)
|
||||
return s.findPeerByPubStr(pubStr)
|
||||
}
|
||||
|
||||
// FindPeerByPubStr will return the peer that corresponds to the passed peerID,
|
||||
@ -968,34 +979,34 @@ func (s *server) FindPeer(peerKey *btcec.PublicKey) (*peer, error) {
|
||||
// public key.
|
||||
//
|
||||
// NOTE: This function is safe for concurrent access.
|
||||
func (s *server) FindPeerByPubStr(peerID string) (*peer, error) {
|
||||
func (s *server) FindPeerByPubStr(pubStr string) (*peer, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
return s.findPeer(peerID)
|
||||
return s.findPeerByPubStr(pubStr)
|
||||
}
|
||||
|
||||
// findPeer is an internal method that retrieves the specified peer from the
|
||||
// server's internal state.
|
||||
func (s *server) findPeer(peerID string) (*peer, error) {
|
||||
peer := s.peersByPub[peerID]
|
||||
if peer == nil {
|
||||
return nil, errors.New("Peer not found. Pubkey: " + peerID)
|
||||
// findPeerByPubStr is an internal method that retrieves the specified peer from
|
||||
// the server's internal state using.
|
||||
func (s *server) findPeerByPubStr(pubStr string) (*peer, error) {
|
||||
peer, ok := s.peersByPub[pubStr]
|
||||
if !ok {
|
||||
return nil, ErrPeerNotFound
|
||||
}
|
||||
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
// peerTerminationWatcher waits until a peer has been disconnected, and then
|
||||
// cleans up all resources allocated to the peer, notifies relevant sub-systems
|
||||
// of its demise, and finally handles re-connecting to the peer if it's
|
||||
// persistent.
|
||||
// peerTerminationWatcher waits until a peer has been disconnected unexpectedly,
|
||||
// and then cleans up all resources allocated to the peer, notifies relevant
|
||||
// sub-systems of its demise, and finally handles re-connecting to the peer if
|
||||
// it's persistent. If the server intentionally disconnects a peer, it should
|
||||
// have a corresponding entry in the ignorePeerTermination map which will cause
|
||||
// the cleanup routine to exit early.
|
||||
//
|
||||
// NOTE: This MUST be launched as a goroutine AND the _peer's_ WaitGroup should
|
||||
// be incremented before spawning this method, as it will signal to the peer's
|
||||
// WaitGroup upon completion.
|
||||
// NOTE: This MUST be launched as a goroutine.
|
||||
func (s *server) peerTerminationWatcher(p *peer) {
|
||||
defer p.wg.Done()
|
||||
defer s.wg.Done()
|
||||
|
||||
p.WaitForDisconnect()
|
||||
|
||||
@ -1025,16 +1036,20 @@ func (s *server) peerTerminationWatcher(p *peer) {
|
||||
}
|
||||
}
|
||||
|
||||
// Send the peer to be garbage collected by the server.
|
||||
s.removePeer(p)
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// If this peer had an active persistent connection request, then we
|
||||
// can remove this as we manually decide below if we should attempt to
|
||||
// re-connect.
|
||||
if p.connReq != nil {
|
||||
s.connMgr.Remove(p.connReq.ID())
|
||||
// If the server has already removed this peer, we can short circuit the
|
||||
// peer termination watcher and skip cleanup.
|
||||
if _, ok := s.ignorePeerTermination[p]; ok {
|
||||
delete(s.ignorePeerTermination, p)
|
||||
return
|
||||
}
|
||||
|
||||
// First, cleanup any remaining state the server has regarding the peer
|
||||
// in question.
|
||||
s.removePeer(p)
|
||||
|
||||
// Next, check to see if this is a persistent peer or not.
|
||||
pubStr := string(p.addr.IdentityKey.SerializeCompressed())
|
||||
_, ok := s.persistentPeers[pubStr]
|
||||
@ -1148,7 +1163,14 @@ func (s *server) InboundPeerConnected(conn net.Conn) {
|
||||
// Check to see if we should drop our connection, if not, then we'll
|
||||
// close out this connection with the remote peer. This
|
||||
// prevents us from having duplicate connections, or none.
|
||||
if connectedPeer, ok := s.peersByPub[pubStr]; ok {
|
||||
connectedPeer, err := s.findPeerByPubStr(pubStr)
|
||||
switch err {
|
||||
case ErrPeerNotFound:
|
||||
// We were unable to locate an existing connection with the
|
||||
// target peer, proceed to connect.
|
||||
break
|
||||
|
||||
case nil:
|
||||
// If the connection we've already established should be kept,
|
||||
// then we'll close out this connection s.t there's only a
|
||||
// single connection between us.
|
||||
@ -1165,9 +1187,12 @@ func (s *server) InboundPeerConnected(conn net.Conn) {
|
||||
// peer to the peer garbage collection goroutine.
|
||||
srvrLog.Debugf("Disconnecting stale connection to %v",
|
||||
connectedPeer)
|
||||
connectedPeer.Disconnect(errors.New("remove stale connection"))
|
||||
|
||||
// Remove the current peer from the server's internal state and
|
||||
// signal that the peer termination watcher does not need to
|
||||
// execute for this peer.
|
||||
s.removePeer(connectedPeer)
|
||||
s.ignorePeerTermination[connectedPeer] = struct{}{}
|
||||
}
|
||||
|
||||
// Next, check to see if we have any outstanding persistent connection
|
||||
@ -1231,7 +1256,15 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn)
|
||||
|
||||
// If we already have an inbound connection from this peer, then we'll
|
||||
// check to see _which_ of our connections should be dropped.
|
||||
if connectedPeer, ok := s.peersByPub[pubStr]; ok {
|
||||
connectedPeer, err := s.findPeerByPubStr(pubStr)
|
||||
switch err {
|
||||
case ErrPeerNotFound:
|
||||
// We were unable to locate an existing connection with the
|
||||
// target peer, proceed to connect.
|
||||
break
|
||||
|
||||
case nil:
|
||||
// We already have a connection open with the target peer.
|
||||
// If our (this) connection should be dropped, then we'll do
|
||||
// so, in order to ensure we don't have any duplicate
|
||||
// connections.
|
||||
@ -1239,11 +1272,9 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn)
|
||||
srvrLog.Warnf("Established outbound connection to "+
|
||||
"peer %x, but already connected, dropping conn",
|
||||
nodePub.SerializeCompressed())
|
||||
|
||||
if connReq != nil {
|
||||
s.connMgr.Remove(connReq.ID())
|
||||
}
|
||||
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
@ -1253,9 +1284,12 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn)
|
||||
// server for garbage collection.
|
||||
srvrLog.Debugf("Disconnecting stale connection to %v",
|
||||
connectedPeer)
|
||||
connectedPeer.Disconnect(errors.New("remove stale connection"))
|
||||
|
||||
// Remove the current peer from the server's internal state and
|
||||
// signal that the peer termination watcher does not need to
|
||||
// execute for this peer.
|
||||
s.removePeer(connectedPeer)
|
||||
s.ignorePeerTermination[connectedPeer] = struct{}{}
|
||||
}
|
||||
|
||||
s.peerConnected(conn, connReq, true)
|
||||
@ -1269,8 +1303,8 @@ func (s *server) addPeer(p *peer) {
|
||||
}
|
||||
|
||||
// Ignore new peers if we're shutting down.
|
||||
if atomic.LoadInt32(&s.shutdown) != 0 {
|
||||
p.Disconnect(errors.New("server is shutting down"))
|
||||
if s.Stopped() {
|
||||
p.Disconnect(ErrServerShuttingDown)
|
||||
return
|
||||
}
|
||||
|
||||
@ -1290,12 +1324,13 @@ func (s *server) addPeer(p *peer) {
|
||||
s.outboundPeers[pubStr] = p
|
||||
}
|
||||
|
||||
// Launch a goroutine to watch for the termination of this peer so we
|
||||
// can ensure all resources are properly cleaned up and if need be
|
||||
// connections are re-established. The go routine is tracked by the
|
||||
// _peer's_ WaitGroup so that a call to Disconnect will block until the
|
||||
// `peerTerminationWatcher` has exited.
|
||||
p.wg.Add(1)
|
||||
// Launch a goroutine to watch for the unexpected termination of this
|
||||
// peer, which will ensure all resources are properly cleaned up, and
|
||||
// re-establish persistent connections when necessary. The peer
|
||||
// termination watcher will be short circuited if the peer is ever added
|
||||
// to the ignorePeerTermination map, indicating that the server has
|
||||
// already handled the removal of this peer.
|
||||
s.wg.Add(1)
|
||||
go s.peerTerminationWatcher(p)
|
||||
|
||||
// Once the peer has been added to our indexes, send a message to the
|
||||
@ -1321,10 +1356,15 @@ func (s *server) removePeer(p *peer) {
|
||||
|
||||
// As the peer is now finished, ensure that the TCP connection is
|
||||
// closed and all of its related goroutines have exited.
|
||||
p.Disconnect(errors.New("remove peer"))
|
||||
p.Disconnect(fmt.Errorf("server: disconnecting peer %v", p))
|
||||
|
||||
// If this peer had an active persistent connection request, remove it.
|
||||
if p.connReq != nil {
|
||||
s.connMgr.Remove(p.connReq.ID())
|
||||
}
|
||||
|
||||
// Ignore deleting peers if we're shutting down.
|
||||
if atomic.LoadInt32(&s.shutdown) != 0 {
|
||||
if s.Stopped() {
|
||||
return
|
||||
}
|
||||
|
||||
@ -1376,10 +1416,14 @@ func (s *server) ConnectToPeer(addr *lnwire.NetAddress, perm bool) error {
|
||||
s.mu.Lock()
|
||||
|
||||
// Ensure we're not already connected to this peer.
|
||||
peer, ok := s.peersByPub[targetPub]
|
||||
if ok {
|
||||
s.mu.Unlock()
|
||||
peer, err := s.findPeerByPubStr(targetPub)
|
||||
switch err {
|
||||
case ErrPeerNotFound:
|
||||
// Peer was not found, continue to pursue connection with peer.
|
||||
break
|
||||
|
||||
case nil:
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("already connected to peer: %v", peer)
|
||||
}
|
||||
|
||||
@ -1388,7 +1432,6 @@ func (s *server) ConnectToPeer(addr *lnwire.NetAddress, perm bool) error {
|
||||
// connection.
|
||||
if _, ok := s.persistentConnReqs[targetPub]; ok {
|
||||
s.mu.Unlock()
|
||||
|
||||
return fmt.Errorf("connection attempt to %v is pending", addr)
|
||||
}
|
||||
|
||||
@ -1443,27 +1486,26 @@ func (s *server) DisconnectPeer(pubKey *btcec.PublicKey) error {
|
||||
|
||||
// Check that were actually connected to this peer. If not, then we'll
|
||||
// exit in an error as we can't disconnect from a peer that we're not
|
||||
// currently connected to.
|
||||
peer, ok := s.peersByPub[pubStr]
|
||||
if !ok {
|
||||
return fmt.Errorf("unable to find peer %x", pubBytes)
|
||||
// currently connected to. This will also return an error if we already
|
||||
// have a pending disconnect request for this peer, ensuring the
|
||||
// operation only happens once.
|
||||
peer, err := s.findPeerByPubStr(pubStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
srvrLog.Infof("Disconnecting from %v", peer)
|
||||
|
||||
// If this peer was formerly a persistent connection, then we'll remove
|
||||
// them from this map so we don't attempt to re-connect after we
|
||||
// disconnect.
|
||||
if _, ok := s.persistentPeers[pubStr]; ok {
|
||||
delete(s.persistentPeers, pubStr)
|
||||
}
|
||||
|
||||
// Now that we know the peer is actually connected, we'll disconnect
|
||||
// from the peer. The lock is held until after Disconnect to ensure
|
||||
// that the peer's `peerTerminationWatcher` has fully exited before
|
||||
// returning to the caller.
|
||||
srvrLog.Infof("Disconnecting from %v", peer)
|
||||
peer.Disconnect(
|
||||
errors.New("received user command to disconnect the peer"),
|
||||
)
|
||||
// Remove the current peer from the server's internal state and signal
|
||||
// that the peer termination watcher does not need to execute for this
|
||||
// peer.
|
||||
s.removePeer(peer)
|
||||
s.ignorePeerTermination[peer] = struct{}{}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user