diff --git a/peer.go b/peer.go index e141cd6c..9a6b792e 100644 --- a/peer.go +++ b/peer.go @@ -458,7 +458,7 @@ func (p *peer) loadActiveChannels(chans []*channeldb.OpenChannel) error { // disconnected if the local or remote side terminating the connection, or an // irrecoverable protocol error has been encountered. func (p *peer) WaitForDisconnect() { - <-p.quit + p.wg.Wait() } // Disconnect terminates the connection with the remote peer. Additionally, a @@ -475,8 +475,6 @@ func (p *peer) Disconnect(reason error) { p.conn.Close() close(p.quit) - - p.wg.Wait() } // String returns the string representation of this peer. diff --git a/server.go b/server.go index 945fd9a6..519f7ebe 100644 --- a/server.go +++ b/server.go @@ -93,6 +93,13 @@ type server struct { // disconnected. ignorePeerTermination map[*peer]struct{} + // scheduledPeerConnection maps a pubkey string to a callback that + // should be executed in the peerTerminationWatcher the prior peer with + // the same pubkey exits. This allows the server to wait until the + // prior peer has cleaned up successfully, before adding the new peer + // intended to replace it. + scheduledPeerConnection map[string]func() + cc *chainControl fundingMgr *fundingManager @@ -178,11 +185,12 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl, sphinx: htlcswitch.NewOnionProcessor(sphinxRouter), lightningID: sha256.Sum256(serializedPubKey[:]), - persistentPeers: make(map[string]struct{}), - persistentPeersBackoff: make(map[string]time.Duration), - persistentConnReqs: make(map[string][]*connmgr.ConnReq), - persistentRetryCancels: make(map[string]chan struct{}), - ignorePeerTermination: make(map[*peer]struct{}), + persistentPeers: make(map[string]struct{}), + persistentPeersBackoff: make(map[string]time.Duration), + persistentConnReqs: make(map[string][]*connmgr.ConnReq), + persistentRetryCancels: make(map[string]chan struct{}), + ignorePeerTermination: make(map[*peer]struct{}), + scheduledPeerConnection: make(map[string]func()), peersByPub: make(map[string]*peer), inboundPeers: make(map[string]*peer), @@ -1299,6 +1307,20 @@ func (s *server) peerTerminationWatcher(p *peer) { // peer termination watcher and skip cleanup. if _, ok := s.ignorePeerTermination[p]; ok { delete(s.ignorePeerTermination, p) + + pubKey := p.PubKey() + pubStr := string(pubKey[:]) + + // If a connection callback is present, we'll go ahead and + // execute it now that previous peer has fully disconnected. If + // the callback is not present, this likely implies the peer was + // purposefully disconnected via RPC, and that no reconnect + // should be attempted. + connCallback, ok := s.scheduledPeerConnection[pubStr] + if ok { + delete(s.scheduledPeerConnection, pubStr) + connCallback() + } return } @@ -1489,8 +1511,23 @@ func (s *server) InboundPeerConnected(conn net.Conn) { return } + // If we already have a valid connection that is scheduled to take + // precedence once the prior peer has finished disconnecting, we'll + // ignore this connection. + if _, ok := s.scheduledPeerConnection[pubStr]; ok { + srvrLog.Debugf("Ignoring connection, peer already scheduled") + conn.Close() + return + } + srvrLog.Infof("New inbound connection from %v", conn.RemoteAddr()) + // Cancel all pending connection requests, we either already have an + // outbound connection, or this incoming connection will become our + // primary connection. The incoming connection will not have an + // associated connection request, so we pass nil. + s.cancelConnReqs(pubStr, nil) + // Check to see if we already have a connection with this peer. If so, // we may need to drop our existing connection. This prevents us from // having duplicate connections to the same peer. We forgo adding a @@ -1501,6 +1538,7 @@ func (s *server) InboundPeerConnected(conn net.Conn) { case ErrPeerNotConnected: // We were unable to locate an existing connection with the // target peer, proceed to connect. + s.peerConnected(conn, nil, false) case nil: // We already have a connection with the incoming peer. If the @@ -1526,13 +1564,10 @@ func (s *server) InboundPeerConnected(conn net.Conn) { // execute for this peer. s.removePeer(connectedPeer) s.ignorePeerTermination[connectedPeer] = struct{}{} + s.scheduledPeerConnection[pubStr] = func() { + s.peerConnected(conn, nil, false) + } } - - // Lastly, cancel all pending requests. The incoming connection will not - // have an associated connection request. - s.cancelConnReqs(pubStr, nil) - - s.peerConnected(conn, nil, false) } // OutboundPeerConnected initializes a new peer in response to a new outbound @@ -1568,6 +1603,15 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn) return } + // If we already have a valid connection that is scheduled to take + // precedence once the prior peer has finished disconnecting, we'll + // ignore this connection. + if _, ok := s.scheduledPeerConnection[pubStr]; ok { + srvrLog.Debugf("Ignoring connection, peer already scheduled") + conn.Close() + return + } + srvrLog.Infof("Established connection to: %v", conn.RemoteAddr()) if connReq != nil { @@ -1591,6 +1635,7 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn) case ErrPeerNotConnected: // We were unable to locate an existing connection with the // target peer, proceed to connect. + s.peerConnected(conn, connReq, true) case nil: // We already have a connection open with the target peer. @@ -1620,9 +1665,10 @@ func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn) // execute for this peer. s.removePeer(connectedPeer) s.ignorePeerTermination[connectedPeer] = struct{}{} + s.scheduledPeerConnection[pubStr] = func() { + s.peerConnected(conn, connReq, true) + } } - - s.peerConnected(conn, connReq, true) } // UnassignedConnID is the default connection ID that a request can have before