server: refactors w/ single mutex and sync disconnect

This commit alters the synchronization patterns used in the server
  such that the internal state is protected by a single mutex.  Overall,
  this simplifies the ability to reason about the behavior and
  manipulation of the internal state, which has resolved a few of flakes
  related to race conditions that were observed before hand.

  Invoking DisconnectPeer is now fully synchronous, and waits until
  the provided peer's peerTerminationWatcher has exited before
  returning.  Currently this is done by tracking the watcher using the
  peer's WaitGroup, and locking until the peer has shutdown.

  The server's API has also been refactored such that all public methods
  are safe for concurrent use. Therefore, other subsystems should be
  sure to make use of these endpoints to avoid corrupting the internal
  state.
This commit is contained in:
Conner Fromknecht 2017-08-08 16:51:41 -07:00 committed by Olaoluwa Osuntokun
parent 80a8cc0add
commit 91d6b0492e

712
server.go

@ -48,13 +48,15 @@ type server struct {
// long-term identity private key. // long-term identity private key.
lightningID [32]byte lightningID [32]byte
peersMtx sync.RWMutex mu sync.Mutex
peersByID map[int32]*peer peersByID map[int32]*peer
peersByPub map[string]*peer peersByPub map[string]*peer
persistentPeers map[string]struct{} inboundPeers map[string]*peer
inboundPeers map[string]*peer outboundPeers map[string]*peer
outboundPeers map[string]*peer
persistentPeers map[string]struct{}
persistentConnReqs map[string][]*connmgr.ConnReq
cc *chainControl cc *chainControl
@ -76,16 +78,6 @@ type server struct {
connMgr *connmgr.ConnManager connMgr *connmgr.ConnManager
pendingConnMtx sync.RWMutex
persistentConnReqs map[string][]*connmgr.ConnReq
broadcastRequests chan *broadcastReq
sendRequests chan *sendReq
newPeers chan *peer
donePeers chan *peer
queries chan interface{}
// globalFeatures feature vector which affects HTLCs and thus are also // globalFeatures feature vector which affects HTLCs and thus are also
// advertised to other nodes. // advertised to other nodes.
globalFeatures *lnwire.FeatureVector globalFeatures *lnwire.FeatureVector
@ -97,11 +89,9 @@ type server struct {
// currentNodeAnn is the node announcement that has been broadcast to // currentNodeAnn is the node announcement that has been broadcast to
// the network upon startup, if the attributes of the node (us) has // the network upon startup, if the attributes of the node (us) has
// changed since last start. // changed since last start.
annMtx sync.Mutex
currentNodeAnn *lnwire.NodeAnnouncement currentNodeAnn *lnwire.NodeAnnouncement
wg sync.WaitGroup wg sync.WaitGroup
quit chan struct{}
} }
// newServer creates a new instance of the server which is to listen using the // newServer creates a new instance of the server which is to listen using the
@ -145,17 +135,8 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
inboundPeers: make(map[string]*peer), inboundPeers: make(map[string]*peer),
outboundPeers: make(map[string]*peer), outboundPeers: make(map[string]*peer),
newPeers: make(chan *peer, 10),
donePeers: make(chan *peer, 10),
broadcastRequests: make(chan *broadcastReq),
sendRequests: make(chan *sendReq),
globalFeatures: globalFeatures, globalFeatures: globalFeatures,
localFeatures: localFeatures, localFeatures: localFeatures,
queries: make(chan interface{}),
quit: make(chan struct{}),
} }
// If the debug HTLC flag is on, then we invoice a "master debug" // If the debug HTLC flag is on, then we invoice a "master debug"
@ -172,17 +153,24 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
LocalChannelClose: func(pubKey []byte, LocalChannelClose: func(pubKey []byte,
request *htlcswitch.ChanClose) { request *htlcswitch.ChanClose) {
s.peersMtx.RLock() peer, err := s.FindPeerByPubStr(string(pubKey))
peer, ok := s.peersByPub[string(pubKey)] if err != nil {
s.peersMtx.RUnlock() srvrLog.Errorf("unable to close channel, peer"+
" with %v id can't be found: %v",
if !ok { pubKey, err,
srvrLog.Error("unable to close channel, peer"+ )
" with %v id can't be found", pubKey)
return return
} }
peer.localCloseChanReqs <- request select {
case peer.localCloseChanReqs <- request:
srvrLog.Infof("local close channel request "+
"delivered to peer: %v", string(pubKey))
case <-peer.quit:
srvrLog.Errorf("unable to deliver local close "+
"channel request to peer %v, err: %v",
string(pubKey), err)
}
}, },
UpdateTopology: func(msg *lnwire.ChannelUpdate) error { UpdateTopology: func(msg *lnwire.ChannelUpdate) error {
s.discoverSrv.ProcessRemoteAnnouncement(msg, nil) s.discoverSrv.ProcessRemoteAnnouncement(msg, nil)
@ -276,10 +264,10 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
} }
s.discoverSrv, err = discovery.New(discovery.Config{ s.discoverSrv, err = discovery.New(discovery.Config{
Broadcast: s.broadcastMessage, Broadcast: s.BroadcastMessage,
Notifier: s.cc.chainNotifier, Notifier: s.cc.chainNotifier,
Router: s.chanRouter, Router: s.chanRouter,
SendToPeer: s.sendToPeer, SendToPeer: s.SendToPeer,
TrickleDelay: time.Millisecond * 300, TrickleDelay: time.Millisecond * 300,
ProofMatureDelta: 0, ProofMatureDelta: 0,
DB: chanDB, DB: chanDB,
@ -299,12 +287,12 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
// incoming connections // incoming connections
cmgr, err := connmgr.New(&connmgr.Config{ cmgr, err := connmgr.New(&connmgr.Config{
Listeners: listeners, Listeners: listeners,
OnAccept: s.inboundPeerConnected, OnAccept: s.InboundPeerConnected,
RetryDuration: time.Second * 5, RetryDuration: time.Second * 5,
TargetOutbound: 100, TargetOutbound: 100,
GetNewAddress: nil, GetNewAddress: nil,
Dial: noiseDial(s.identityPriv), Dial: noiseDial(s.identityPriv),
OnConnection: s.outboundPeerConnected, OnConnection: s.OutboundPeerConnected,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -315,12 +303,14 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
} }
// Started returns true if the server has been started, and false otherwise. // Started returns true if the server has been started, and false otherwise.
// NOTE: This function is safe for concurrent access.
func (s *server) Started() bool { func (s *server) Started() bool {
return atomic.LoadInt32(&s.started) != 0 return atomic.LoadInt32(&s.started) != 0
} }
// Start starts the main daemon server, all requested listeners, and any helper // Start starts the main daemon server, all requested listeners, and any helper
// goroutines. // goroutines.
// NOTE: This function is safe for concurrent access.
func (s *server) Start() error { func (s *server) Start() error {
// Already running? // Already running?
if !atomic.CompareAndSwapInt32(&s.started, 0, 1) { if !atomic.CompareAndSwapInt32(&s.started, 0, 1) {
@ -352,9 +342,6 @@ func (s *server) Start() error {
return err return err
} }
s.wg.Add(1)
go s.queryHandler()
// With all the relevant sub-systems started, we'll now attempt to // With all the relevant sub-systems started, we'll now attempt to
// establish persistent connections to our direct channel collaborators // establish persistent connections to our direct channel collaborators
// within the network. // within the network.
@ -362,12 +349,15 @@ func (s *server) Start() error {
return err return err
} }
go s.connMgr.Start()
return nil return nil
} }
// Stop gracefully shutsdown the main daemon server. This function will signal // Stop gracefully shutsdown the main daemon server. This function will signal
// any active goroutines, or helper objects to exit, then blocks until they've // any active goroutines, or helper objects to exit, then blocks until they've
// all successfully exited. Additionally, any/all listeners are closed. // all successfully exited. Additionally, any/all listeners are closed.
// NOTE: This function is safe for concurrent access.
func (s *server) Stop() error { func (s *server) Stop() error {
// Bail if we're already shutting down. // Bail if we're already shutting down.
if !atomic.CompareAndSwapInt32(&s.shutdown, 0, 1) { if !atomic.CompareAndSwapInt32(&s.shutdown, 0, 1) {
@ -383,20 +373,40 @@ func (s *server) Stop() error {
s.discoverSrv.Stop() s.discoverSrv.Stop()
s.cc.wallet.Shutdown() s.cc.wallet.Shutdown()
s.cc.chainView.Stop() s.cc.chainView.Stop()
s.connMgr.Stop()
// Signal all the lingering goroutines to quit. // Disconnect from each active peers to ensure that
close(s.quit) // peerTerminationWatchers signal completion to each peer.
peers := s.Peers()
for _, peer := range peers {
s.DisconnectPeer(peer.addr.IdentityKey)
}
// Wait for all lingering goroutines to quit.
s.wg.Wait() s.wg.Wait()
return nil return nil
} }
// Stopped returns true if the server has been instructed to shutdown.
// NOTE: This function is safe for concurrent access.
func (s *server) Stopped() bool {
return atomic.LoadInt32(&s.shutdown) != 0
}
// WaitForShutdown blocks until all goroutines have been stopped.
func (s *server) WaitForShutdown() {
s.wg.Wait()
}
// genNodeAnnouncement generates and returns the current fully signed node // genNodeAnnouncement generates and returns the current fully signed node
// announcement. If refresh is true, then the time stamp of the announcement // announcement. If refresh is true, then the time stamp of the announcement
// will be updated in order to ensure it propagates through the network. // will be updated in order to ensure it propagates through the network.
func (s *server) genNodeAnnouncement(refresh bool) (lnwire.NodeAnnouncement, error) { func (s *server) genNodeAnnouncement(
s.annMtx.Lock() refresh bool) (lnwire.NodeAnnouncement, error) {
defer s.annMtx.Unlock()
s.mu.Lock()
defer s.mu.Unlock()
if !refresh { if !refresh {
return *s.currentNodeAnn, nil return *s.currentNodeAnn, nil
@ -460,8 +470,10 @@ func (s *server) establishPersistentConnections() error {
} }
// TODO(roasbeef): instead iterate over link nodes and query graph for // TODO(roasbeef): instead iterate over link nodes and query graph for
// each of the nodes. // each of the nodes.
err = sourceNode.ForEachChannel(nil, func(_ *bolt.Tx, err = sourceNode.ForEachChannel(nil, func(
_ *channeldb.ChannelEdgeInfo, policy *channeldb.ChannelEdgePolicy) error { _ *bolt.Tx,
_ *channeldb.ChannelEdgeInfo,
policy *channeldb.ChannelEdgePolicy) error {
pubStr := string(policy.Node.PubKey.SerializeCompressed()) pubStr := string(policy.Node.PubKey.SerializeCompressed())
@ -531,10 +543,8 @@ func (s *server) establishPersistentConnections() error {
Permanent: true, Permanent: true,
} }
s.pendingConnMtx.Lock() s.persistentConnReqs[pubStr] = append(
s.persistentConnReqs[pubStr] = append(s.persistentConnReqs[pubStr], s.persistentConnReqs[pubStr], connReq)
connReq)
s.pendingConnMtx.Unlock()
go s.connMgr.Connect(connReq) go s.connMgr.Connect(connReq)
} }
@ -543,54 +553,56 @@ func (s *server) establishPersistentConnections() error {
return nil return nil
} }
// WaitForShutdown blocks all goroutines have been stopped. // BroadcastMessage sends a request to the server to broadcast a set of
func (s *server) WaitForShutdown() {
s.wg.Wait()
}
// broadcastReq is a message sent to the server by a related subsystem when it
// wishes to broadcast one or more messages to all connected peers. Thi
type broadcastReq struct {
ignore *btcec.PublicKey
msgs []lnwire.Message
errChan chan error // MUST be buffered.
}
// broadcastMessage sends a request to the server to broadcast a set of
// messages to all peers other than the one specified by the `skip` parameter. // messages to all peers other than the one specified by the `skip` parameter.
func (s *server) broadcastMessage(skip *btcec.PublicKey, msgs ...lnwire.Message) error { // NOTE: This function is safe for concurrent access.
errChan := make(chan error, 1) func (s *server) BroadcastMessage(
skip *btcec.PublicKey,
msgs ...lnwire.Message) error {
msgsToSend := make([]lnwire.Message, 0, len(msgs)) msgsToSend := make([]lnwire.Message, 0, len(msgs))
msgsToSend = append(msgsToSend, msgs...) msgsToSend = append(msgsToSend, msgs...)
broadcastReq := &broadcastReq{
ignore: skip,
msgs: msgsToSend,
errChan: errChan,
}
select { s.mu.Lock()
case s.broadcastRequests <- broadcastReq: defer s.mu.Unlock()
case <-s.quit:
return errors.New("server shutting down")
}
select { return s.broadcastMessages(skip, msgsToSend)
case err := <-errChan:
return err
case <-s.quit:
return errors.New("server shutting down")
}
} }
// sendReq is message sent to the server by a related subsystem which it // broadcastMessages is an internal method that delivers messages to all active
// wishes to send a set of messages to a specified peer. // peers except the one specified by `skip`.
type sendReq struct { // NOTE: This method MUST be called while the server's mutex is locked.
target *btcec.PublicKey func (s *server) broadcastMessages(
msgs []lnwire.Message skip *btcec.PublicKey,
msgs []lnwire.Message) error {
errChan chan error srvrLog.Debugf("Broadcasting %v messages", len(msgs))
// Iterate over all known peers, dispatching a go routine to enqueue all
// messages to each of peers. We synchronize access to peersByPub
// throughout this process to ensure we deliver messages to exact set of
// peers present at the time of invocation.
var wg sync.WaitGroup
for _, sPeer := range s.peersByPub {
if skip != nil &&
sPeer.addr.IdentityKey.IsEqual(skip) {
srvrLog.Debugf("Skipping %v in broadcast",
skip.SerializeCompressed())
continue
}
// Dispatch a go routine to enqueue all messages to this peer.
wg.Add(1)
s.wg.Add(1)
go s.sendPeerMessages(sPeer, msgs, &wg)
}
// Wait for all messages to have been dispatched before returning to
// caller.
wg.Wait()
return nil
} }
type nodeAddresses struct { type nodeAddresses struct {
@ -598,47 +610,113 @@ type nodeAddresses struct {
addresses []*net.TCPAddr addresses []*net.TCPAddr
} }
// sendToPeer send a message to the server telling it to send the specific set // SendToPeer send a message to the server telling it to send the specific set
// of message to a particular peer. If the peer connect be found, then this // of message to a particular peer. If the peer connect be found, then this
// method will return a non-nil error. // method will return a non-nil error.
func (s *server) sendToPeer(target *btcec.PublicKey, msgs ...lnwire.Message) error { // NOTE: This function is safe for concurrent access.
errChan := make(chan error, 1) func (s *server) SendToPeer(
target *btcec.PublicKey,
msgs ...lnwire.Message) error {
msgsToSend := make([]lnwire.Message, 0, len(msgs)) msgsToSend := make([]lnwire.Message, 0, len(msgs))
msgsToSend = append(msgsToSend, msgs...) msgsToSend = append(msgsToSend, msgs...)
sMsg := &sendReq{
target: target, s.mu.Lock()
msgs: msgsToSend, defer s.mu.Unlock()
errChan: errChan,
return s.sendToPeer(target, msgsToSend)
}
// sendToPeer is an internal method that delivers messages to the specified
// `target` peer.
func (s *server) sendToPeer(
target *btcec.PublicKey,
msgs []lnwire.Message) error {
// Compute the target peer's identifier.
targetPubBytes := target.SerializeCompressed()
srvrLog.Infof("Attempting to send msgs %v to: %x",
len(msgs), targetPubBytes)
// Lookup intended target in peersByPub, returning an error to the
// caller if the peer is unknown. Access to peersByPub is synchronized
// here to ensure we consider the exact set of peers present at the time
// of invocation.
targetPeer, ok := s.peersByPub[string(targetPubBytes)]
if !ok {
srvrLog.Errorf("unable to send message to %x, "+
"peer not found", targetPubBytes)
return errors.New("peer not found")
} }
select { s.sendPeerMessages(targetPeer, msgs, nil)
case s.sendRequests <- sMsg:
case <-s.quit: return nil
return errors.New("server shutting down") }
// sendPeerMessages enqueues a list of messages into the outgoingQueue of the
// `targetPeer`. This method supports additional broadcast-level
// synchronization by using the additional `wg` to coordinate a particular
// broadcast.
//
// NOTE: This method must be invoked with a non-nil `wg` if it is spawned as a
// go routine--both `wg` and the server's WaitGroup should be incremented
// beforehand. If this method is not spawned as a go routine, the provided `wg`
// should be nil, and the server's WaitGroup should not be tracking this
// invocation.
func (s *server) sendPeerMessages(
targetPeer *peer,
msgs []lnwire.Message,
wg *sync.WaitGroup) {
// If a WaitGroup is provided, we assume that this method was spawned as
// a go routine, and that it is being tracked by both the server's
// WaitGroup, as well as the broadcast-level WaitGroup `wg`. In this
// event, we defer a call to Done on both WaitGroups to 1) ensure that
// server will be able to shutdown after its go routines exit, and 2) so
// the server can return to the caller of BroadcastMessage.
if wg != nil {
defer s.wg.Done()
defer wg.Done()
} }
select { for _, msg := range msgs {
case err := <-errChan: targetPeer.queueMsg(msg, nil)
return err
case <-s.quit:
return errors.New("server shutting down")
} }
} }
// findPeer will return the peer that corresponds to the passed in public key. // FindPeer will return the peer that corresponds to the passed in public key.
// This function is used by the funding manager, allowing it to update the // This function is used by the funding manager, allowing it to update the
// daemon's local representation of the remote peer. // daemon's local representation of the remote peer.
func (s *server) findPeer(peerKey *btcec.PublicKey) (*peer, error) { // NOTE: This function is safe for concurrent access.
func (s *server) FindPeer(peerKey *btcec.PublicKey) (*peer, error) {
s.mu.Lock()
defer s.mu.Unlock()
serializedIDKey := string(peerKey.SerializeCompressed()) serializedIDKey := string(peerKey.SerializeCompressed())
s.peersMtx.RLock() return s.findPeer(serializedIDKey)
peer := s.peersByPub[serializedIDKey] }
s.peersMtx.RUnlock()
// FindPeerByPubStr will return the peer that corresponds to the passed peerID,
// which should be a string representation of the peer's serialized, compressed
// public key.
// NOTE: This function is safe for concurrent access.
func (s *server) FindPeerByPubStr(peerID string) (*peer, error) {
s.mu.Lock()
defer s.mu.Unlock()
return s.findPeer(peerID)
}
// findPeer is an internal method that retrieves the specified peer from the
// server's internal state.
func (s *server) findPeer(peerID string) (*peer, error) {
peer := s.peersByPub[peerID]
if peer == nil { if peer == nil {
return nil, errors.New("Peer not found. Pubkey: " + return nil, errors.New("Peer not found. Pubkey: " + peerID)
string(peerKey.SerializeCompressed()))
} }
return peer, nil return peer, nil
@ -648,21 +726,20 @@ func (s *server) findPeer(peerKey *btcec.PublicKey) (*peer, error) {
// cleans up all resources allocated to the peer, notifies relevant sub-systems // cleans up all resources allocated to the peer, notifies relevant sub-systems
// of its demise, and finally handles re-connecting to the peer if it's // of its demise, and finally handles re-connecting to the peer if it's
// persistent. // persistent.
// // NOTE: This MUST be launched as a goroutine AND the _peer's_ WaitGroup should
// NOTE: This MUST be launched as a goroutine. // be incremented before spawning this method, as it will signal to the peer's
// WaitGroup upon completion.
func (s *server) peerTerminationWatcher(p *peer) { func (s *server) peerTerminationWatcher(p *peer) {
defer p.wg.Done()
p.WaitForDisconnect() p.WaitForDisconnect()
srvrLog.Debugf("Peer %v has been disconnected", p) srvrLog.Debugf("Peer %v has been disconnected", p)
// If the server is exiting then we can bail out early ourselves as all // If the server is exiting then we can bail out early ourselves as all
// the other sub-systems will already be shutting down. // the other sub-systems will already be shutting down.
select { if s.Stopped() {
case <-s.quit:
return return
default:
// If we aren't shutting down, then we'll fall through this
// this empty default case.
} }
// Tell the switch to remove all links associated with this peer. // Tell the switch to remove all links associated with this peer.
@ -684,7 +761,7 @@ func (s *server) peerTerminationWatcher(p *peer) {
} }
// Send the peer to be garbage collected by the server. // Send the peer to be garbage collected by the server.
p.server.donePeers <- p s.removePeer(p)
// If this peer had an active persistent connection request, then we // If this peer had an active persistent connection request, then we
// can remove this as we manually decide below if we should attempt to // can remove this as we manually decide below if we should attempt to
@ -695,9 +772,7 @@ func (s *server) peerTerminationWatcher(p *peer) {
// Next, check to see if this is a persistent peer or not. // Next, check to see if this is a persistent peer or not.
pubStr := string(p.addr.IdentityKey.SerializeCompressed()) pubStr := string(p.addr.IdentityKey.SerializeCompressed())
s.pendingConnMtx.RLock()
_, ok := s.persistentPeers[pubStr] _, ok := s.persistentPeers[pubStr]
s.pendingConnMtx.RUnlock()
if ok { if ok {
srvrLog.Debugf("Attempting to re-establish persistent "+ srvrLog.Debugf("Attempting to re-establish persistent "+
"connection to peer %v", p) "connection to peer %v", p)
@ -710,7 +785,6 @@ func (s *server) peerTerminationWatcher(p *peer) {
Permanent: true, Permanent: true,
} }
s.pendingConnMtx.Lock()
// We'll only need to re-launch a connection requests if one // We'll only need to re-launch a connection requests if one
// isn't already currently pending. // isn't already currently pending.
if _, ok := s.persistentConnReqs[pubStr]; ok { if _, ok := s.persistentConnReqs[pubStr]; ok {
@ -720,9 +794,8 @@ func (s *server) peerTerminationWatcher(p *peer) {
// Otherwise, we'll launch a new connection requests in order // Otherwise, we'll launch a new connection requests in order
// to attempt to maintain a persistent connection with this // to attempt to maintain a persistent connection with this
// peer. // peer.
s.persistentConnReqs[pubStr] = append(s.persistentConnReqs[pubStr], s.persistentConnReqs[pubStr] = append(
connReq) s.persistentConnReqs[pubStr], connReq)
s.pendingConnMtx.Unlock()
go s.connMgr.Connect(connReq) go s.connMgr.Connect(connReq)
} }
@ -731,7 +804,11 @@ func (s *server) peerTerminationWatcher(p *peer) {
// peerConnected is a function that handles initialization a newly connected // peerConnected is a function that handles initialization a newly connected
// peer by adding it to the server's global list of all active peers, and // peer by adding it to the server's global list of all active peers, and
// starting all the goroutines the peer needs to function properly. // starting all the goroutines the peer needs to function properly.
func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, inbound bool) { func (s *server) peerConnected(
conn net.Conn,
connReq *connmgr.ConnReq,
inbound bool) {
brontideConn := conn.(*brontide.Conn) brontideConn := conn.(*brontide.Conn)
peerAddr := &lnwire.NetAddress{ peerAddr := &lnwire.NetAddress{
IdentityKey: brontideConn.RemotePub(), IdentityKey: brontideConn.RemotePub(),
@ -757,7 +834,7 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, inbound
return return
} }
s.newPeers <- p s.addPeer(p)
} }
// shouldDropConnection determines if our local connection to a remote peer // shouldDropConnection determines if our local connection to a remote peer
@ -770,22 +847,31 @@ func shouldDropLocalConnection(local, remote *btcec.PublicKey) bool {
localPubBytes := local.SerializeCompressed() localPubBytes := local.SerializeCompressed()
remotePubPbytes := remote.SerializeCompressed() remotePubPbytes := remote.SerializeCompressed()
// The connection that comes from the node with a "smaller" pubkey should // The connection that comes from the node with a "smaller" pubkey
// be kept. Therefore, if our pubkey is "greater" than theirs, we should // should be kept. Therefore, if our pubkey is "greater" than theirs, we
// drop our established connection. // should drop our established connection.
return bytes.Compare(localPubBytes, remotePubPbytes) > 0 return bytes.Compare(localPubBytes, remotePubPbytes) > 0
} }
// inboundPeerConnected initializes a new peer in response to a new inbound // InboundPeerConnected initializes a new peer in response to a new inbound
// connection. // connection.
func (s *server) inboundPeerConnected(conn net.Conn) { //
s.peersMtx.Lock() // NOTE: This function is safe for concurrent access.
defer s.peersMtx.Unlock() func (s *server) InboundPeerConnected(conn net.Conn) {
// Exit early if we have already been instructed to shutdown, this
// prevents any delayed callbacks from accidentally registering peers.
if s.Stopped() {
return
}
nodePub := conn.(*brontide.Conn).RemotePub()
pubStr := string(nodePub.SerializeCompressed())
s.mu.Lock()
defer s.mu.Unlock()
// If we already have an inbound connection to this peer, then ignore // If we already have an inbound connection to this peer, then ignore
// this new connection. // this new connection.
nodePub := conn.(*brontide.Conn).RemotePub()
pubStr := string(nodePub.SerializeCompressed())
if _, ok := s.inboundPeers[pubStr]; ok { if _, ok := s.inboundPeers[pubStr]; ok {
srvrLog.Debugf("Ignoring duplicate inbound connection") srvrLog.Debugf("Ignoring duplicate inbound connection")
conn.Close() conn.Close()
@ -817,34 +903,40 @@ func (s *server) inboundPeerConnected(conn net.Conn) {
srvrLog.Debugf("Disconnecting stale connection to %v", srvrLog.Debugf("Disconnecting stale connection to %v",
connectedPeer) connectedPeer)
connectedPeer.Disconnect(errors.New("remove stale connection")) connectedPeer.Disconnect(errors.New("remove stale connection"))
s.donePeers <- connectedPeer
s.removePeer(connectedPeer)
} }
// Next, check to see if we have any outstanding persistent connection // Next, check to see if we have any outstanding persistent connection
// requests to this peer. If so, then we'll remove all of these // requests to this peer. If so, then we'll remove all of these
// connection requests, and also delete the entry from the map. // connection requests, and also delete the entry from the map.
s.pendingConnMtx.Lock()
if connReqs, ok := s.persistentConnReqs[pubStr]; ok { if connReqs, ok := s.persistentConnReqs[pubStr]; ok {
for _, connReq := range connReqs { for _, connReq := range connReqs {
s.connMgr.Remove(connReq.ID()) s.connMgr.Remove(connReq.ID())
} }
delete(s.persistentConnReqs, pubStr) delete(s.persistentConnReqs, pubStr)
} }
s.pendingConnMtx.Unlock()
go s.peerConnected(conn, nil, false) s.peerConnected(conn, nil, false)
} }
// outboundPeerConnected initializes a new peer in response to a new outbound // OutboundPeerConnected initializes a new peer in response to a new outbound
// connection. // connection.
func (s *server) outboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn) { // NOTE: This function is safe for concurrent access.
s.peersMtx.Lock() func (s *server) OutboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn) {
defer s.peersMtx.Unlock() // Exit early if we have already been instructed to shutdown, this
// prevents any delayed callbacks from accidentally registering peers.
if s.Stopped() {
return
}
localPub := s.identityPriv.PubKey() localPub := s.identityPriv.PubKey()
nodePub := conn.(*brontide.Conn).RemotePub() nodePub := conn.(*brontide.Conn).RemotePub()
pubStr := string(nodePub.SerializeCompressed()) pubStr := string(nodePub.SerializeCompressed())
s.mu.Lock()
defer s.mu.Unlock()
// If we already have an outbound connection to this peer, then ignore // If we already have an outbound connection to this peer, then ignore
// this new connection. // this new connection.
if _, ok := s.outboundPeers[pubStr]; ok { if _, ok := s.outboundPeers[pubStr]; ok {
@ -864,7 +956,6 @@ func (s *server) outboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn)
// As we've just established an outbound connection to this peer, we'll // As we've just established an outbound connection to this peer, we'll
// cancel all other persistent connection requests and eliminate the // cancel all other persistent connection requests and eliminate the
// entry for this peer from the map. // entry for this peer from the map.
s.pendingConnMtx.Lock()
if connReqs, ok := s.persistentConnReqs[pubStr]; ok { if connReqs, ok := s.persistentConnReqs[pubStr]; ok {
for _, pConnReq := range connReqs { for _, pConnReq := range connReqs {
if pConnReq.ID() != connReq.ID() { if pConnReq.ID() != connReq.ID() {
@ -873,7 +964,6 @@ func (s *server) outboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn)
} }
delete(s.persistentConnReqs, pubStr) delete(s.persistentConnReqs, pubStr)
} }
s.pendingConnMtx.Unlock()
// If we already have an inbound connection from this peer, then we'll // If we already have an inbound connection from this peer, then we'll
// check to see _which_ of our connections should be dropped. // check to see _which_ of our connections should be dropped.
@ -896,10 +986,11 @@ func (s *server) outboundPeerConnected(connReq *connmgr.ConnReq, conn net.Conn)
srvrLog.Debugf("Disconnecting stale connection to %v", srvrLog.Debugf("Disconnecting stale connection to %v",
connectedPeer) connectedPeer)
connectedPeer.Disconnect(errors.New("remove stale connection")) connectedPeer.Disconnect(errors.New("remove stale connection"))
s.donePeers <- connectedPeer
s.removePeer(connectedPeer)
} }
go s.peerConnected(conn, connReq, true) s.peerConnected(conn, connReq, true)
} }
// addPeer adds the passed peer to the server's global state of all active // addPeer adds the passed peer to the server's global state of all active
@ -919,7 +1010,6 @@ func (s *server) addPeer(p *peer) {
// according to its public key, or it's peer ID. // according to its public key, or it's peer ID.
// TODO(roasbeef): pipe all requests through to the // TODO(roasbeef): pipe all requests through to the
// queryHandler/peerManager // queryHandler/peerManager
s.peersMtx.Lock()
pubStr := string(p.addr.IdentityKey.SerializeCompressed()) pubStr := string(p.addr.IdentityKey.SerializeCompressed())
@ -932,11 +1022,12 @@ func (s *server) addPeer(p *peer) {
s.outboundPeers[pubStr] = p s.outboundPeers[pubStr] = p
} }
s.peersMtx.Unlock()
// Launch a goroutine to watch for the termination of this peer so we // Launch a goroutine to watch for the termination of this peer so we
// can ensure all resources are properly cleaned up and if need be // can ensure all resources are properly cleaned up and if need be
// connections are re-established. // connections are re-established. The go routine is tracked by the
// _peer's_ WaitGroup so that a call to Disconnect will block until the
// `peerTerminationWatcher` has exited.
p.wg.Add(1)
go s.peerTerminationWatcher(p) go s.peerTerminationWatcher(p)
// Once the peer has been added to our indexes, send a message to the // Once the peer has been added to our indexes, send a message to the
@ -948,15 +1039,12 @@ func (s *server) addPeer(p *peer) {
// removePeer removes the passed peer from the server's state of all active // removePeer removes the passed peer from the server's state of all active
// peers. // peers.
func (s *server) removePeer(p *peer) { func (s *server) removePeer(p *peer) {
s.peersMtx.Lock()
defer s.peersMtx.Unlock()
srvrLog.Debugf("removing peer %v", p)
if p == nil { if p == nil {
return return
} }
srvrLog.Debugf("removing peer %v", p)
// As the peer is now finished, ensure that the TCP connection is // As the peer is now finished, ensure that the TCP connection is
// closed and all of its related goroutines have exited. // closed and all of its related goroutines have exited.
p.Disconnect(errors.New("remove peer")) p.Disconnect(errors.New("remove peer"))
@ -978,24 +1066,6 @@ func (s *server) removePeer(p *peer) {
} }
} }
// connectPeerMsg is a message requesting the server to open a connection to a
// particular peer. This message also houses an error channel which will be
// used to report success/failure.
type connectPeerMsg struct {
addr *lnwire.NetAddress
persistent bool
err chan error
}
// disconnectPeerMsg is a message requesting the server to disconnect from an
// active peer.
type disconnectPeerMsg struct {
pubKey *btcec.PublicKey
err chan error
}
// openChanReq is a message sent to the server in order to request the // openChanReq is a message sent to the server in order to request the
// initiation of a channel funding workflow to the peer with either the // initiation of a channel funding workflow to the peer with either the
// specified relative peer ID, or a global lightning ID. // specified relative peer ID, or a global lightning ID.
@ -1016,211 +1086,125 @@ type openChanReq struct {
err chan error err chan error
} }
// queryHandler handles any requests to modify the server's internal state of // ConnectToPeer requests that the server connect to a Lightning Network peer
// all active peers, or query/mutate the server's global state. Additionally, // at the specified address. This function will *block* until either a
// any queries directed at peers will be handled by this goroutine. // connection is established, or the initial handshake process fails.
// // NOTE: This function is safe for concurrent access.
// NOTE: This MUST be run as a goroutine. func (s *server) ConnectToPeer(addr *lnwire.NetAddress, perm bool) error {
func (s *server) queryHandler() {
go s.connMgr.Start()
out: targetPub := string(addr.IdentityKey.SerializeCompressed())
for {
select {
// New peers.
case p := <-s.newPeers:
s.addPeer(p)
// Finished peers. // Acquire mutex, but use explicit unlocking instead of defer for better
case p := <-s.donePeers: // granularity. In certain conditions, this method requires making an
s.removePeer(p) // outbound connection to a remote peer, which requires the lock to be
// released, and subsequently reacquired.
case bMsg := <-s.broadcastRequests: s.mu.Lock()
ignore := bMsg.ignore
srvrLog.Debugf("Broadcasting %v messages", len(bMsg.msgs))
// Launch a new goroutine to handle the broadcast
// request, this allows us process this request
// asynchronously without blocking subsequent broadcast
// requests.
go func() {
s.peersMtx.RLock()
for _, sPeer := range s.peersByPub {
if ignore != nil &&
sPeer.addr.IdentityKey.IsEqual(ignore) {
srvrLog.Debugf("Skipping %v in broadcast",
ignore.SerializeCompressed())
continue
}
go func(p *peer) {
for _, msg := range bMsg.msgs {
p.queueMsg(msg, nil)
}
}(sPeer)
}
s.peersMtx.RUnlock()
bMsg.errChan <- nil
}()
case sMsg := <-s.sendRequests:
// TODO(roasbeef): use [33]byte everywhere instead
// * eliminate usage of mutexes, funnel all peer
// mutation to this goroutine
target := sMsg.target.SerializeCompressed()
srvrLog.Debugf("Attempting to send msgs %v to: %x",
len(sMsg.msgs), target)
// Launch a new goroutine to handle this send request,
// this allows us process this request asynchronously
// without blocking future send requests.
go func() {
s.peersMtx.RLock()
targetPeer, ok := s.peersByPub[string(target)]
if !ok {
s.peersMtx.RUnlock()
srvrLog.Errorf("unable to send message to %x, "+
"peer not found", target)
sMsg.errChan <- errors.New("peer not found")
return
}
s.peersMtx.RUnlock()
sMsg.errChan <- nil
for _, msg := range sMsg.msgs {
targetPeer.queueMsg(msg, nil)
}
}()
case query := <-s.queries:
switch msg := query.(type) {
case *disconnectPeerMsg:
s.handleDisconnectPeer(msg)
case *connectPeerMsg:
s.handleConnectPeer(msg)
case *openChanReq:
s.handleOpenChanReq(msg)
}
case <-s.quit:
break out
}
}
s.connMgr.Stop()
s.wg.Done()
}
// handleConnectPeer attempts to establish a connection to the address enclosed
// within the passed connectPeerMsg. This function is *async*, a goroutine will
// be spawned in order to finish the request, and respond to the caller.
func (s *server) handleConnectPeer(msg *connectPeerMsg) {
addr := msg.addr
targetPub := string(msg.addr.IdentityKey.SerializeCompressed())
// Ensure we're not already connected to this // Ensure we're not already connected to this
// peer. // peer.
s.peersMtx.RLock()
peer, ok := s.peersByPub[targetPub] peer, ok := s.peersByPub[targetPub]
if ok { if ok {
s.peersMtx.RUnlock() s.mu.Unlock()
msg.err <- fmt.Errorf("already connected to peer: %v", peer)
return return fmt.Errorf("already connected to peer: %v", peer)
} }
s.peersMtx.RUnlock()
// If there's already a pending connection request for this pubkey, // If there's already a pending connection request for this pubkey,
// then we ignore this request to ensure we don't create a redundant // then we ignore this request to ensure we don't create a redundant
// connection. // connection.
s.pendingConnMtx.RLock()
if _, ok := s.persistentConnReqs[targetPub]; ok { if _, ok := s.persistentConnReqs[targetPub]; ok {
s.pendingConnMtx.RUnlock() s.mu.Unlock()
msg.err <- fmt.Errorf("connection attempt to %v is pending",
addr) return fmt.Errorf("connection attempt to %v is pending", addr)
return
} }
s.pendingConnMtx.RUnlock()
// If there's not already a pending or active connection to this node, // If there's not already a pending or active connection to this node,
// then instruct the connection manager to attempt to establish a // then instruct the connection manager to attempt to establish a
// persistent connection to the peer. // persistent connection to the peer.
srvrLog.Debugf("Connecting to %v", addr) srvrLog.Debugf("Connecting to %v", addr)
if msg.persistent { if perm {
connReq := &connmgr.ConnReq{ connReq := &connmgr.ConnReq{
Addr: addr, Addr: addr,
Permanent: true, Permanent: true,
} }
s.pendingConnMtx.Lock()
s.persistentPeers[targetPub] = struct{}{} s.persistentPeers[targetPub] = struct{}{}
s.persistentConnReqs[targetPub] = append(s.persistentConnReqs[targetPub], s.persistentConnReqs[targetPub] = append(
connReq) s.persistentConnReqs[targetPub], connReq)
s.pendingConnMtx.Unlock() s.mu.Unlock()
go s.connMgr.Connect(connReq) go s.connMgr.Connect(connReq)
msg.err <- nil
} else {
// If we're not making a persistent connection, then we'll
// attempt to connect o the target peer, returning an error
// which indicates success of failure.
go func() {
// Attempt to connect to the remote node. If the we
// can't make the connection, or the crypto negotiation
// breaks down, then return an error to the caller.
conn, err := brontide.Dial(s.identityPriv, addr)
if err != nil {
msg.err <- err
return
}
s.outboundPeerConnected(nil, conn) return nil
msg.err <- nil
}()
} }
s.mu.Unlock()
// If we're not making a persistent connection, then we'll attempt to
// connect to the target peer. If the we can't make the connection, or
// the crypto negotiation breaks down, then return an error to the
// caller.
conn, err := brontide.Dial(s.identityPriv, addr)
if err != nil {
return err
}
// Once the connection has been made, we can notify the server of the
// new connection via our public endpoint, which will require the lock
// an add the peer to the server's internal state.
s.OutboundPeerConnected(nil, conn)
return nil
} }
// handleDisconnectPeer attempts to disconnect one peer from another // DisconnectPeer sends the request to server to close the connection with peer
func (s *server) handleDisconnectPeer(msg *disconnectPeerMsg) { // identified by public key.
pubBytes := msg.pubKey.SerializeCompressed() // NOTE: This function is safe for concurrent access.
func (s *server) DisconnectPeer(pubKey *btcec.PublicKey) error {
pubBytes := pubKey.SerializeCompressed()
pubStr := string(pubBytes) pubStr := string(pubBytes)
s.mu.Lock()
defer s.mu.Unlock()
// Check that were actually connected to this peer. If not, then we'll // Check that were actually connected to this peer. If not, then we'll
// exit in an error as we can't disconnect from a peer that we're not // exit in an error as we can't disconnect from a peer that we're not
// currently connected to. // currently connected to.
s.peersMtx.RLock()
peer, ok := s.peersByPub[pubStr] peer, ok := s.peersByPub[pubStr]
s.peersMtx.RUnlock()
if !ok { if !ok {
msg.err <- fmt.Errorf("unable to find peer %x", pubBytes) return fmt.Errorf("unable to find peer %x", pubBytes)
return
} }
// If this peer was formerly a persistent connection, then we'll remove // If this peer was formerly a persistent connection, then we'll remove
// them from this map so we don't attempt to re-connect after we // them from this map so we don't attempt to re-connect after we
// disconnect. // disconnect.
s.pendingConnMtx.Lock()
if _, ok := s.persistentPeers[pubStr]; ok { if _, ok := s.persistentPeers[pubStr]; ok {
delete(s.persistentPeers, pubStr) delete(s.persistentPeers, pubStr)
} }
s.pendingConnMtx.Unlock()
// Now that we know the peer is actually connected, we'll disconnect // Now that we know the peer is actually connected, we'll disconnect
// from the peer. // from the peer. The lock is held until after Disconnect to ensure
// that the peer's `peerTerminationWatcher` has fully exited before
// returning to the caller.
srvrLog.Infof("Disconnecting from %v", peer) srvrLog.Infof("Disconnecting from %v", peer)
peer.Disconnect(errors.New("received user command to disconnect the peer")) peer.Disconnect(
errors.New("received user command to disconnect the peer"),
)
msg.err <- nil return nil
} }
// handleOpenChanReq first locates the target peer, and if found hands off the // OpenChannel sends a request to the server to open a channel to the specified
// request to the funding manager allowing it to initiate the channel funding // peer identified by ID with the passed channel funding parameters.
// workflow. // NOTE: This function is safe for concurrent access.
func (s *server) handleOpenChanReq(req *openChanReq) { func (s *server) OpenChannel(
peerID int32,
nodeKey *btcec.PublicKey,
localAmt btcutil.Amount,
pushAmt btcutil.Amount) (chan *lnrpc.OpenStatusUpdate, chan error) {
updateChan := make(chan *lnrpc.OpenStatusUpdate, 1)
errChan := make(chan error, 1)
var ( var (
targetPeer *peer targetPeer *peer
pubKeyBytes []byte pubKeyBytes []byte
@ -1229,24 +1213,24 @@ func (s *server) handleOpenChanReq(req *openChanReq) {
// If the user is targeting the peer by public key, then we'll need to // If the user is targeting the peer by public key, then we'll need to
// convert that into a string for our map. Otherwise, we expect them to // convert that into a string for our map. Otherwise, we expect them to
// target by peer ID instead. // target by peer ID instead.
if req.targetPubkey != nil { if nodeKey != nil {
pubKeyBytes = req.targetPubkey.SerializeCompressed() pubKeyBytes = nodeKey.SerializeCompressed()
} }
// First attempt to locate the target peer to open a channel with, if // First attempt to locate the target peer to open a channel with, if
// we're unable to locate the peer then this request will fail. // we're unable to locate the peer then this request will fail.
s.peersMtx.RLock() s.mu.Lock()
if peer, ok := s.peersByID[req.targetPeerID]; ok { if peer, ok := s.peersByID[peerID]; ok {
targetPeer = peer targetPeer = peer
} else if peer, ok := s.peersByPub[string(pubKeyBytes)]; ok { } else if peer, ok := s.peersByPub[string(pubKeyBytes)]; ok {
targetPeer = peer targetPeer = peer
} }
s.peersMtx.RUnlock() s.mu.Unlock()
if targetPeer == nil { if targetPeer == nil {
req.err <- fmt.Errorf("unable to find peer nodeID(%x), "+ errChan <- fmt.Errorf("unable to find peer nodeID(%x), "+
"peerID(%v)", pubKeyBytes, req.targetPeerID) "peerID(%v)", pubKeyBytes, peerID)
return return updateChan, errChan
} }
// Spawn a goroutine to send the funding workflow request to the // Spawn a goroutine to send the funding workflow request to the
@ -1255,48 +1239,6 @@ func (s *server) handleOpenChanReq(req *openChanReq) {
// synchronous request to the outside world. // synchronous request to the outside world.
// TODO(roasbeef): pass in chan that's closed if/when funding succeeds // TODO(roasbeef): pass in chan that's closed if/when funding succeeds
// so can track as persistent peer? // so can track as persistent peer?
go s.fundingMgr.initFundingWorkflow(targetPeer.addr, req)
}
// ConnectToPeer requests that the server connect to a Lightning Network peer
// at the specified address. This function will *block* until either a
// connection is established, or the initial handshake process fails.
func (s *server) ConnectToPeer(addr *lnwire.NetAddress,
perm bool) error {
errChan := make(chan error, 1)
s.queries <- &connectPeerMsg{
addr: addr,
persistent: perm,
err: errChan,
}
return <-errChan
}
// DisconnectPeer sends the request to server to close the connection with peer
// identified by public key.
func (s *server) DisconnectPeer(pubKey *btcec.PublicKey) error {
errChan := make(chan error, 1)
s.queries <- &disconnectPeerMsg{
pubKey: pubKey,
err: errChan,
}
return <-errChan
}
// OpenChannel sends a request to the server to open a channel to the specified
// peer identified by ID with the passed channel funding parameters.
func (s *server) OpenChannel(peerID int32, nodeKey *btcec.PublicKey,
localAmt, pushAmt btcutil.Amount) (chan *lnrpc.OpenStatusUpdate, chan error) {
errChan := make(chan error, 1)
updateChan := make(chan *lnrpc.OpenStatusUpdate, 1)
req := &openChanReq{ req := &openChanReq{
targetPeerID: peerID, targetPeerID: peerID,
targetPubkey: nodeKey, targetPubkey: nodeKey,
@ -1307,21 +1249,21 @@ func (s *server) OpenChannel(peerID int32, nodeKey *btcec.PublicKey,
err: errChan, err: errChan,
} }
s.queries <- req go s.fundingMgr.initFundingWorkflow(targetPeer.addr, req)
return updateChan, errChan return updateChan, errChan
} }
// Peers returns a slice of all active peers. // Peers returns a slice of all active peers.
// NOTE: This function is safe for concurrent access.
func (s *server) Peers() []*peer { func (s *server) Peers() []*peer {
s.peersMtx.RLock() s.mu.Lock()
defer s.mu.Unlock()
peers := make([]*peer, 0, len(s.peersByID)) peers := make([]*peer, 0, len(s.peersByID))
for _, peer := range s.peersByID { for _, peer := range s.peersByID {
peers = append(peers, peer) peers = append(peers, peer)
} }
s.peersMtx.RUnlock()
return peers return peers
} }