server: fixes race condition during unexpected peer disconnect
This commit is contained in:
parent
f20cb89982
commit
96ff63d219
192
server.go
192
server.go
|
@ -30,6 +30,16 @@ import (
|
||||||
"github.com/lightningnetwork/lnd/htlcswitch"
|
"github.com/lightningnetwork/lnd/htlcswitch"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrPeerNotFound signals that the server has no connection to the
|
||||||
|
// given peer.
|
||||||
|
ErrPeerNotFound = errors.New("server: peer not found")
|
||||||
|
|
||||||
|
// ErrServerShuttingDown indicates that the server is in the process of
|
||||||
|
// gracefully exiting.
|
||||||
|
ErrServerShuttingDown = errors.New("server: shutting down")
|
||||||
|
)
|
||||||
|
|
||||||
// server is the main server of the Lightning Network Daemon. The server houses
|
// server is the main server of the Lightning Network Daemon. The server houses
|
||||||
// global state pertaining to the wallet, database, and the rpcserver.
|
// global state pertaining to the wallet, database, and the rpcserver.
|
||||||
// Additionally, the server is also used as a central messaging bus to interact
|
// Additionally, the server is also used as a central messaging bus to interact
|
||||||
|
@ -62,6 +72,12 @@ type server struct {
|
||||||
persistentPeers map[string]struct{}
|
persistentPeers map[string]struct{}
|
||||||
persistentConnReqs map[string][]*connmgr.ConnReq
|
persistentConnReqs map[string][]*connmgr.ConnReq
|
||||||
|
|
||||||
|
// ignorePeerTermination tracks peers for which the server has initiated
|
||||||
|
// a disconnect. Adding a peer to this map causes the peer termination
|
||||||
|
// watcher to short circuit in the event that peers are purposefully
|
||||||
|
// disconnected.
|
||||||
|
ignorePeerTermination map[*peer]struct{}
|
||||||
|
|
||||||
cc *chainControl
|
cc *chainControl
|
||||||
|
|
||||||
fundingMgr *fundingManager
|
fundingMgr *fundingManager
|
||||||
|
@ -133,8 +149,9 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
|
||||||
sphinx.NewRouter(privKey, activeNetParams.Params)),
|
sphinx.NewRouter(privKey, activeNetParams.Params)),
|
||||||
lightningID: sha256.Sum256(serializedPubKey),
|
lightningID: sha256.Sum256(serializedPubKey),
|
||||||
|
|
||||||
persistentPeers: make(map[string]struct{}),
|
persistentPeers: make(map[string]struct{}),
|
||||||
persistentConnReqs: make(map[string][]*connmgr.ConnReq),
|
persistentConnReqs: make(map[string][]*connmgr.ConnReq),
|
||||||
|
ignorePeerTermination: make(map[*peer]struct{}),
|
||||||
|
|
||||||
peersByID: make(map[int32]*peer),
|
peersByID: make(map[int32]*peer),
|
||||||
peersByPub: make(map[string]*peer),
|
peersByPub: make(map[string]*peer),
|
||||||
|
@ -424,8 +441,7 @@ func (s *server) Stop() error {
|
||||||
|
|
||||||
// Disconnect from each active peers to ensure that
|
// Disconnect from each active peers to ensure that
|
||||||
// peerTerminationWatchers signal completion to each peer.
|
// peerTerminationWatchers signal completion to each peer.
|
||||||
peers := s.Peers()
|
for _, peer := range s.Peers() {
|
||||||
for _, peer := range peers {
|
|
||||||
s.DisconnectPeer(peer.addr.IdentityKey)
|
s.DisconnectPeer(peer.addr.IdentityKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -826,13 +842,9 @@ func (s *server) broadcastMessages(
|
||||||
// throughout this process to ensure we deliver messages to exact set
|
// throughout this process to ensure we deliver messages to exact set
|
||||||
// of peers present at the time of invocation.
|
// of peers present at the time of invocation.
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for _, sPeer := range s.peersByPub {
|
for pubStr, sPeer := range s.peersByPub {
|
||||||
if skip != nil &&
|
if skip != nil && sPeer.addr.IdentityKey.IsEqual(skip) {
|
||||||
sPeer.addr.IdentityKey.IsEqual(skip) {
|
srvrLog.Debugf("Skipping %v in broadcast", pubStr)
|
||||||
|
|
||||||
srvrLog.Debugf("Skipping %v in broadcast",
|
|
||||||
skip.SerializeCompressed())
|
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -905,12 +917,11 @@ func (s *server) sendToPeer(target *btcec.PublicKey,
|
||||||
// caller if the peer is unknown. Access to peersByPub is synchronized
|
// caller if the peer is unknown. Access to peersByPub is synchronized
|
||||||
// here to ensure we consider the exact set of peers present at the
|
// here to ensure we consider the exact set of peers present at the
|
||||||
// time of invocation.
|
// time of invocation.
|
||||||
targetPeer, ok := s.peersByPub[string(targetPubBytes)]
|
targetPeer, err := s.findPeerByPubStr(string(targetPubBytes))
|
||||||
if !ok {
|
if err != nil {
|
||||||
srvrLog.Errorf("unable to send message to %x, "+
|
srvrLog.Errorf("unable to send message to %x, "+
|
||||||
"peer not found", targetPubBytes)
|
"peer not found", targetPubBytes)
|
||||||
|
return err
|
||||||
return errors.New("peer not found")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s.sendPeerMessages(targetPeer, msgs, nil)
|
s.sendPeerMessages(targetPeer, msgs, nil)
|
||||||
|
@ -958,9 +969,9 @@ func (s *server) FindPeer(peerKey *btcec.PublicKey) (*peer, error) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
serializedIDKey := string(peerKey.SerializeCompressed())
|
pubStr := string(peerKey.SerializeCompressed())
|
||||||
|
|
||||||
return s.findPeer(serializedIDKey)
|
return s.findPeerByPubStr(pubStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindPeerByPubStr will return the peer that corresponds to the passed peerID,
|
// FindPeerByPubStr will return the peer that corresponds to the passed peerID,
|
||||||
|
@ -968,34 +979,34 @@ func (s *server) FindPeer(peerKey *btcec.PublicKey) (*peer, error) {
|
||||||
// public key.
|
// public key.
|
||||||
//
|
//
|
||||||
// NOTE: This function is safe for concurrent access.
|
// NOTE: This function is safe for concurrent access.
|
||||||
func (s *server) FindPeerByPubStr(peerID string) (*peer, error) {
|
func (s *server) FindPeerByPubStr(pubStr string) (*peer, error) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
return s.findPeer(peerID)
|
return s.findPeerByPubStr(pubStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// findPeer is an internal method that retrieves the specified peer from the
|
// findPeerByPubStr is an internal method that retrieves the specified peer from
|
||||||
// server's internal state.
|
// the server's internal state using.
|
||||||
func (s *server) findPeer(peerID string) (*peer, error) {
|
func (s *server) findPeerByPubStr(pubStr string) (*peer, error) {
|
||||||
peer := s.peersByPub[peerID]
|
peer, ok := s.peersByPub[pubStr]
|
||||||
if peer == nil {
|
if !ok {
|
||||||
return nil, errors.New("Peer not found. Pubkey: " + peerID)
|
return nil, ErrPeerNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
return peer, nil
|
return peer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// peerTerminationWatcher waits until a peer has been disconnected, and then
|
// peerTerminationWatcher waits until a peer has been disconnected unexpectedly,
|
||||||
// cleans up all resources allocated to the peer, notifies relevant sub-systems
|
// and then cleans up all resources allocated to the peer, notifies relevant
|
||||||
// of its demise, and finally handles re-connecting to the peer if it's
|
// sub-systems of its demise, and finally handles re-connecting to the peer if
|
||||||
// persistent.
|
// it's persistent. If the server intentionally disconnects a peer, it should
|
||||||
|
// have a corresponding entry in the ignorePeerTermination map which will cause
|
||||||
|
// the cleanup routine to exit early.
|
||||||
//
|
//
|
||||||
// NOTE: This MUST be launched as a goroutine AND the _peer's_ WaitGroup should
|
// NOTE: This MUST be launched as a goroutine.
|
||||||
// be incremented before spawning this method, as it will signal to the peer's
|
|
||||||
// WaitGroup upon completion.
|
|
||||||
func (s *server) peerTerminationWatcher(p *peer) {
|
func (s *server) peerTerminationWatcher(p *peer) {
|
||||||
defer p.wg.Done()
|
defer s.wg.Done()
|
||||||
|
|
||||||
p.WaitForDisconnect()
|
p.WaitForDisconnect()
|
||||||
|
|
||||||
|
@ -1025,16 +1036,20 @@ func (s *server) peerTerminationWatcher(p *peer) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send the peer to be garbage collected by the server.
|
s.mu.Lock()
|
||||||
s.removePeer(p)
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
// If this peer had an active persistent connection request, then we
|
// If the server has already removed this peer, we can short circuit the
|
||||||
// can remove this as we manually decide below if we should attempt to
|
// peer termination watcher and skip cleanup.
|
||||||
// re-connect.
|
if _, ok := s.ignorePeerTermination[p]; ok {
|
||||||
if p.connReq != nil {
|
delete(s.ignorePeerTermination, p)
|
||||||
s.connMgr.Remove(p.connReq.ID())
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// First, cleanup any remaining state the server has regarding the peer
|
||||||
|
// in question.
|
||||||
|
s.removePeer(p)
|
||||||
|
|
||||||
// Next, check to see if this is a persistent peer or not.
|
// Next, check to see if this is a persistent peer or not.
|
||||||
pubStr := string(p.addr.IdentityKey.SerializeCompressed())
|
pubStr := string(p.addr.IdentityKey.SerializeCompressed())
|
||||||
_, ok := s.persistentPeers[pubStr]
|
_, ok := s.persistentPeers[pubStr]
|
||||||
|
@ -1148,7 +1163,14 @@ func (s *server) InboundPeerConnected(conn net.Conn) {
|
||||||
// Check to see if we should drop our connection, if not, then we'll
|
// Check to see if we should drop our connection, if not, then we'll
|
||||||
// close out this connection with the remote peer. This
|
// close out this connection with the remote peer. This
|
||||||
// prevents us from having duplicate connections, or none.
|
// prevents us from having duplicate connections, or none.
|
||||||
if connectedPeer, ok := s.peersByPub[pubStr]; ok {
|
connectedPeer, err := s.findPeerByPubStr(pubStr)
|
||||||
|
switch err {
|
||||||
|
case ErrPeerNotFound:
|
||||||
|
// We were unable to locate an existing connection with the
|
||||||
|
// target peer, proceed to connect.
|
||||||
|
break
|
||||||
|
|
||||||
|
case nil:
|
||||||
// If the connection we've already established should be kept,
|
// If the connection we've already established should be kept,
|
||||||
// then we'll close out this connection s.t there's only a
|
// then we'll close out this connection s.t there's only a
|
||||||
// single connection between us.
|
// single connection between us.
|
||||||
|
@ -1165,9 +1187,12 @@ func (s *server) InboundPeerConnected(conn net.Conn) {
|
||||||
// peer to the peer garbage collection goroutine.
|
// peer to the peer garbage collection goroutine.
|
||||||
srvrLog.Debugf("Disconnecting stale connection to %v",
|
srvrLog.Debugf("Disconnecting stale connection to %v",
|
||||||
connectedPeer)
|
connectedPeer)
|
||||||
connectedPeer.Disconnect(errors.New("remove stale connection"))
|
|
||||||
|
|
||||||
|
// Remove the current peer from the server's internal state and
|
||||||
|
// signal that the peer termination watcher does not need to
|
||||||
|
// execute for this peer.
|
||||||
s.removePeer(connectedPeer)
|
s.removePeer(connectedPeer)
|
||||||
|
s.ignorePeerTermination[connectedPeer] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Next, check to see if we have any outstanding persistent connection
|
// Next, check to see if we have any outstanding persistent connection
|
||||||
|
@ -1231,7 +1256,15 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn)
|
||||||
|
|
||||||
// If we already have an inbound connection from this peer, then we'll
|
// If we already have an inbound connection from this peer, then we'll
|
||||||
// check to see _which_ of our connections should be dropped.
|
// check to see _which_ of our connections should be dropped.
|
||||||
if connectedPeer, ok := s.peersByPub[pubStr]; ok {
|
connectedPeer, err := s.findPeerByPubStr(pubStr)
|
||||||
|
switch err {
|
||||||
|
case ErrPeerNotFound:
|
||||||
|
// We were unable to locate an existing connection with the
|
||||||
|
// target peer, proceed to connect.
|
||||||
|
break
|
||||||
|
|
||||||
|
case nil:
|
||||||
|
// We already have a connection open with the target peer.
|
||||||
// If our (this) connection should be dropped, then we'll do
|
// If our (this) connection should be dropped, then we'll do
|
||||||
// so, in order to ensure we don't have any duplicate
|
// so, in order to ensure we don't have any duplicate
|
||||||
// connections.
|
// connections.
|
||||||
|
@ -1239,11 +1272,9 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn)
|
||||||
srvrLog.Warnf("Established outbound connection to "+
|
srvrLog.Warnf("Established outbound connection to "+
|
||||||
"peer %x, but already connected, dropping conn",
|
"peer %x, but already connected, dropping conn",
|
||||||
nodePub.SerializeCompressed())
|
nodePub.SerializeCompressed())
|
||||||
|
|
||||||
if connReq != nil {
|
if connReq != nil {
|
||||||
s.connMgr.Remove(connReq.ID())
|
s.connMgr.Remove(connReq.ID())
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -1253,9 +1284,12 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn)
|
||||||
// server for garbage collection.
|
// server for garbage collection.
|
||||||
srvrLog.Debugf("Disconnecting stale connection to %v",
|
srvrLog.Debugf("Disconnecting stale connection to %v",
|
||||||
connectedPeer)
|
connectedPeer)
|
||||||
connectedPeer.Disconnect(errors.New("remove stale connection"))
|
|
||||||
|
|
||||||
|
// Remove the current peer from the server's internal state and
|
||||||
|
// signal that the peer termination watcher does not need to
|
||||||
|
// execute for this peer.
|
||||||
s.removePeer(connectedPeer)
|
s.removePeer(connectedPeer)
|
||||||
|
s.ignorePeerTermination[connectedPeer] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.peerConnected(conn, connReq, true)
|
s.peerConnected(conn, connReq, true)
|
||||||
|
@ -1269,8 +1303,8 @@ func (s *server) addPeer(p *peer) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ignore new peers if we're shutting down.
|
// Ignore new peers if we're shutting down.
|
||||||
if atomic.LoadInt32(&s.shutdown) != 0 {
|
if s.Stopped() {
|
||||||
p.Disconnect(errors.New("server is shutting down"))
|
p.Disconnect(ErrServerShuttingDown)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1290,12 +1324,13 @@ func (s *server) addPeer(p *peer) {
|
||||||
s.outboundPeers[pubStr] = p
|
s.outboundPeers[pubStr] = p
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch a goroutine to watch for the termination of this peer so we
|
// Launch a goroutine to watch for the unexpected termination of this
|
||||||
// can ensure all resources are properly cleaned up and if need be
|
// peer, which will ensure all resources are properly cleaned up, and
|
||||||
// connections are re-established. The go routine is tracked by the
|
// re-establish persistent connections when necessary. The peer
|
||||||
// _peer's_ WaitGroup so that a call to Disconnect will block until the
|
// termination watcher will be short circuited if the peer is ever added
|
||||||
// `peerTerminationWatcher` has exited.
|
// to the ignorePeerTermination map, indicating that the server has
|
||||||
p.wg.Add(1)
|
// already handled the removal of this peer.
|
||||||
|
s.wg.Add(1)
|
||||||
go s.peerTerminationWatcher(p)
|
go s.peerTerminationWatcher(p)
|
||||||
|
|
||||||
// Once the peer has been added to our indexes, send a message to the
|
// Once the peer has been added to our indexes, send a message to the
|
||||||
|
@ -1321,10 +1356,15 @@ func (s *server) removePeer(p *peer) {
|
||||||
|
|
||||||
// As the peer is now finished, ensure that the TCP connection is
|
// As the peer is now finished, ensure that the TCP connection is
|
||||||
// closed and all of its related goroutines have exited.
|
// closed and all of its related goroutines have exited.
|
||||||
p.Disconnect(errors.New("remove peer"))
|
p.Disconnect(fmt.Errorf("server: disconnecting peer %v", p))
|
||||||
|
|
||||||
|
// If this peer had an active persistent connection request, remove it.
|
||||||
|
if p.connReq != nil {
|
||||||
|
s.connMgr.Remove(p.connReq.ID())
|
||||||
|
}
|
||||||
|
|
||||||
// Ignore deleting peers if we're shutting down.
|
// Ignore deleting peers if we're shutting down.
|
||||||
if atomic.LoadInt32(&s.shutdown) != 0 {
|
if s.Stopped() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1376,10 +1416,14 @@ func (s *server) ConnectToPeer(addr *lnwire.NetAddress, perm bool) error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
|
|
||||||
// Ensure we're not already connected to this peer.
|
// Ensure we're not already connected to this peer.
|
||||||
peer, ok := s.peersByPub[targetPub]
|
peer, err := s.findPeerByPubStr(targetPub)
|
||||||
if ok {
|
switch err {
|
||||||
s.mu.Unlock()
|
case ErrPeerNotFound:
|
||||||
|
// Peer was not found, continue to pursue connection with peer.
|
||||||
|
break
|
||||||
|
|
||||||
|
case nil:
|
||||||
|
s.mu.Unlock()
|
||||||
return fmt.Errorf("already connected to peer: %v", peer)
|
return fmt.Errorf("already connected to peer: %v", peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1388,7 +1432,6 @@ func (s *server) ConnectToPeer(addr *lnwire.NetAddress, perm bool) error {
|
||||||
// connection.
|
// connection.
|
||||||
if _, ok := s.persistentConnReqs[targetPub]; ok {
|
if _, ok := s.persistentConnReqs[targetPub]; ok {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
return fmt.Errorf("connection attempt to %v is pending", addr)
|
return fmt.Errorf("connection attempt to %v is pending", addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1443,27 +1486,26 @@ func (s *server) DisconnectPeer(pubKey *btcec.PublicKey) error {
|
||||||
|
|
||||||
// Check that were actually connected to this peer. If not, then we'll
|
// Check that were actually connected to this peer. If not, then we'll
|
||||||
// exit in an error as we can't disconnect from a peer that we're not
|
// exit in an error as we can't disconnect from a peer that we're not
|
||||||
// currently connected to.
|
// currently connected to. This will also return an error if we already
|
||||||
peer, ok := s.peersByPub[pubStr]
|
// have a pending disconnect request for this peer, ensuring the
|
||||||
if !ok {
|
// operation only happens once.
|
||||||
return fmt.Errorf("unable to find peer %x", pubBytes)
|
peer, err := s.findPeerByPubStr(pubStr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
srvrLog.Infof("Disconnecting from %v", peer)
|
||||||
|
|
||||||
// If this peer was formerly a persistent connection, then we'll remove
|
// If this peer was formerly a persistent connection, then we'll remove
|
||||||
// them from this map so we don't attempt to re-connect after we
|
// them from this map so we don't attempt to re-connect after we
|
||||||
// disconnect.
|
// disconnect.
|
||||||
if _, ok := s.persistentPeers[pubStr]; ok {
|
delete(s.persistentPeers, pubStr)
|
||||||
delete(s.persistentPeers, pubStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now that we know the peer is actually connected, we'll disconnect
|
// Remove the current peer from the server's internal state and signal
|
||||||
// from the peer. The lock is held until after Disconnect to ensure
|
// that the peer termination watcher does not need to execute for this
|
||||||
// that the peer's `peerTerminationWatcher` has fully exited before
|
// peer.
|
||||||
// returning to the caller.
|
s.removePeer(peer)
|
||||||
srvrLog.Infof("Disconnecting from %v", peer)
|
s.ignorePeerTermination[peer] = struct{}{}
|
||||||
peer.Disconnect(
|
|
||||||
errors.New("received user command to disconnect the peer"),
|
|
||||||
)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user