From 4b685e4d641f77b84efbf5f38e761e3a853d86f5 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Tue, 27 Apr 2021 15:47:32 +0200 Subject: [PATCH] lnd+lnrpc: enable WebSocket ping/pong messages Fixes #4497 by sending out ping messages in a regular interval to make sure the connection is still alive. --- lnd.go | 3 +- lnrpc/websocket_proxy.go | 81 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/lnd.go b/lnd.go index 589ec38c..53d60601 100644 --- a/lnd.go +++ b/lnd.go @@ -1266,7 +1266,8 @@ func startRestProxy(cfg *Config, rpcServer *rpcServer, restDialOpts []grpc.DialO // Wrap the default grpc-gateway handler with the WebSocket handler. restHandler := lnrpc.NewWebSocketProxy( - mux, rpcsLog, lnrpc.LndClientStreamingURIs, + mux, rpcsLog, lnrpc.DefaultPingInterval, lnrpc.DefaultPongWait, + lnrpc.LndClientStreamingURIs, ) // Use a WaitGroup so we can be sure the instructions on how to input the diff --git a/lnrpc/websocket_proxy.go b/lnrpc/websocket_proxy.go index fb43b923..8c5001da 100644 --- a/lnrpc/websocket_proxy.go +++ b/lnrpc/websocket_proxy.go @@ -10,6 +10,7 @@ import ( "net/textproto" "regexp" "strings" + "time" "github.com/btcsuite/btclog" "github.com/gorilla/websocket" @@ -50,15 +51,26 @@ var ( defaultProtocolsToAllow = map[string]bool{ "Grpc-Metadata-Macaroon": true, } + + // DefaultPingInterval is the default number of seconds to wait between + // sending ping requests. + DefaultPingInterval = time.Second * 30 + + // DefaultPongWait is the maximum duration we wait for a pong response + // to a ping we sent before we assume the connection died. + DefaultPongWait = time.Second * 5 ) // NewWebSocketProxy attempts to expose the underlying handler as a response- // streaming WebSocket stream with newline-delimited JSON as the content -// encoding. The clientStreamingURIs parameter can hold a list of all patterns +// encoding. If pingInterval is a non-zero duration, a ping message will be +// sent out periodically and a pong response message is expected from the +// client. The clientStreamingURIs parameter can hold a list of all patterns // for URIs that are mapped to client-streaming RPC methods. We need to keep // track of those to make sure we initialize the request body correctly for the // underlying grpc-gateway library. func NewWebSocketProxy(h http.Handler, logger btclog.Logger, + pingInterval, pongWait time.Duration, clientStreamingURIs []*regexp.Regexp) http.Handler { p := &WebsocketProxy{ @@ -73,6 +85,12 @@ func NewWebSocketProxy(h http.Handler, logger btclog.Logger, }, clientStreamingURIs: clientStreamingURIs, } + + if pingInterval > 0 && pongWait > 0 { + p.pingInterval = pingInterval + p.pongWait = pongWait + } + return p } @@ -87,6 +105,15 @@ type WebsocketProxy struct { // those to make sure we initialize the request body correctly for the // underlying grpc-gateway library. clientStreamingURIs []*regexp.Regexp + + pingInterval time.Duration + pongWait time.Duration +} + +// pingPongEnabled returns true if a ping interval is set to enable sending and +// expecting regular ping/pong messages. +func (p *WebsocketProxy) pingPongEnabled() bool { + return p.pingInterval > 0 && p.pongWait > 0 } // ServeHTTP handles the incoming HTTP request. If the request is an @@ -200,6 +227,58 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter, } }() + // Ping write loop: Send a ping message regularly if ping/pong is + // enabled. + if p.pingPongEnabled() { + // We'll send out our first ping in pingInterval. So the initial + // deadline is that interval plus the time we allow for a + // response to be sent. + initialDeadline := time.Now().Add(p.pingInterval + p.pongWait) + _ = conn.SetReadDeadline(initialDeadline) + + // Whenever a pong message comes in, we extend the deadline + // until the next read is expected by the interval plus pong + // wait time. + conn.SetPongHandler(func(appData string) error { + nextDeadline := time.Now().Add( + p.pingInterval + p.pongWait, + ) + _ = conn.SetReadDeadline(nextDeadline) + return nil + }) + go func() { + ticker := time.NewTicker(p.pingInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + p.logger.Debug("WS: ping loop done") + return + + case <-ticker.C: + // Writing the ping shouldn't take any + // longer than we'll wait for a response + // in the first place. + writeDeadline := time.Now().Add( + p.pongWait, + ) + _ = conn.SetWriteDeadline(writeDeadline) + + err := conn.WriteMessage( + websocket.PingMessage, nil, + ) + if err != nil { + p.logger.Warnf("WS: could not "+ + "send ping message: %v", + err) + return + } + } + } + }() + } + // Write loop: Take messages from the response forwarder and write them // to the WebSocket. for responseForwarder.Scan() {