lnd+lnrpc: enable WebSocket ping/pong messages

Fixes #4497 by sending out ping messages in a regular interval to make
sure the connection is still alive.
This commit is contained in:
Oliver Gugger 2021-04-27 15:47:32 +02:00
parent 5e7b905f19
commit 4b685e4d64
No known key found for this signature in database
GPG Key ID: 8E4256593F177720
2 changed files with 82 additions and 2 deletions

3
lnd.go

@ -1266,7 +1266,8 @@ func startRestProxy(cfg *Config, rpcServer *rpcServer, restDialOpts []grpc.DialO
// Wrap the default grpc-gateway handler with the WebSocket handler.
restHandler := lnrpc.NewWebSocketProxy(
mux, rpcsLog, lnrpc.LndClientStreamingURIs,
mux, rpcsLog, lnrpc.DefaultPingInterval, lnrpc.DefaultPongWait,
lnrpc.LndClientStreamingURIs,
)
// Use a WaitGroup so we can be sure the instructions on how to input the

@ -10,6 +10,7 @@ import (
"net/textproto"
"regexp"
"strings"
"time"
"github.com/btcsuite/btclog"
"github.com/gorilla/websocket"
@ -50,15 +51,26 @@ var (
defaultProtocolsToAllow = map[string]bool{
"Grpc-Metadata-Macaroon": true,
}
// DefaultPingInterval is the default number of seconds to wait between
// sending ping requests.
DefaultPingInterval = time.Second * 30
// DefaultPongWait is the maximum duration we wait for a pong response
// to a ping we sent before we assume the connection died.
DefaultPongWait = time.Second * 5
)
// NewWebSocketProxy attempts to expose the underlying handler as a response-
// streaming WebSocket stream with newline-delimited JSON as the content
// encoding. The clientStreamingURIs parameter can hold a list of all patterns
// encoding. If pingInterval is a non-zero duration, a ping message will be
// sent out periodically and a pong response message is expected from the
// client. The clientStreamingURIs parameter can hold a list of all patterns
// for URIs that are mapped to client-streaming RPC methods. We need to keep
// track of those to make sure we initialize the request body correctly for the
// underlying grpc-gateway library.
func NewWebSocketProxy(h http.Handler, logger btclog.Logger,
pingInterval, pongWait time.Duration,
clientStreamingURIs []*regexp.Regexp) http.Handler {
p := &WebsocketProxy{
@ -73,6 +85,12 @@ func NewWebSocketProxy(h http.Handler, logger btclog.Logger,
},
clientStreamingURIs: clientStreamingURIs,
}
if pingInterval > 0 && pongWait > 0 {
p.pingInterval = pingInterval
p.pongWait = pongWait
}
return p
}
@ -87,6 +105,15 @@ type WebsocketProxy struct {
// those to make sure we initialize the request body correctly for the
// underlying grpc-gateway library.
clientStreamingURIs []*regexp.Regexp
pingInterval time.Duration
pongWait time.Duration
}
// pingPongEnabled returns true if a ping interval is set to enable sending and
// expecting regular ping/pong messages.
func (p *WebsocketProxy) pingPongEnabled() bool {
return p.pingInterval > 0 && p.pongWait > 0
}
// ServeHTTP handles the incoming HTTP request. If the request is an
@ -200,6 +227,58 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
}
}()
// Ping write loop: Send a ping message regularly if ping/pong is
// enabled.
if p.pingPongEnabled() {
// We'll send out our first ping in pingInterval. So the initial
// deadline is that interval plus the time we allow for a
// response to be sent.
initialDeadline := time.Now().Add(p.pingInterval + p.pongWait)
_ = conn.SetReadDeadline(initialDeadline)
// Whenever a pong message comes in, we extend the deadline
// until the next read is expected by the interval plus pong
// wait time.
conn.SetPongHandler(func(appData string) error {
nextDeadline := time.Now().Add(
p.pingInterval + p.pongWait,
)
_ = conn.SetReadDeadline(nextDeadline)
return nil
})
go func() {
ticker := time.NewTicker(p.pingInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
p.logger.Debug("WS: ping loop done")
return
case <-ticker.C:
// Writing the ping shouldn't take any
// longer than we'll wait for a response
// in the first place.
writeDeadline := time.Now().Add(
p.pongWait,
)
_ = conn.SetWriteDeadline(writeDeadline)
err := conn.WriteMessage(
websocket.PingMessage, nil,
)
if err != nil {
p.logger.Warnf("WS: could not "+
"send ping message: %v",
err)
return
}
}
}
}()
}
// Write loop: Take messages from the response forwarder and write them
// to the WebSocket.
for responseForwarder.Scan() {