diff --git a/config.go b/config.go index 67abc61d..a9ab511d 100644 --- a/config.go +++ b/config.go @@ -31,6 +31,7 @@ import ( "github.com/lightningnetwork/lnd/htlcswitch/hodl" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lncfg" + "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc/routerrpc" "github.com/lightningnetwork/lnd/lnrpc/signrpc" "github.com/lightningnetwork/lnd/lnwallet" @@ -250,6 +251,8 @@ type Config struct { DisableListen bool `long:"nolisten" description:"Disable listening for incoming peer connections"` DisableRest bool `long:"norest" description:"Disable REST API"` DisableRestTLS bool `long:"no-rest-tls" description:"Disable TLS for REST connections"` + WSPingInterval time.Duration `long:"ws-ping-interval" description:"The ping interval for REST based WebSocket connections, set to 0 to disable sending ping messages from the server side"` + WSPongWait time.Duration `long:"ws-pong-wait" description:"The time we wait for a pong response message on REST based WebSocket connections before the connection is closed as inactive"` NAT bool `long:"nat" description:"Toggle NAT traversal support (using either UPnP or NAT-PMP) to automatically advertise your external IP address to the network -- NOTE this does not support devices behind multiple NATs"` MinBackoff time.Duration `long:"minbackoff" description:"Shortest backoff when reconnecting to persistent peers. Valid time units are {s, m, h}."` MaxBackoff time.Duration `long:"maxbackoff" description:"Longest backoff when reconnecting to persistent peers. Valid time units are {s, m, h}."` @@ -396,6 +399,8 @@ func DefaultConfig() Config { MaxLogFiles: defaultMaxLogFiles, MaxLogFileSize: defaultMaxLogFileSize, AcceptorTimeout: defaultAcceptorTimeout, + WSPingInterval: lnrpc.DefaultPingInterval, + WSPongWait: lnrpc.DefaultPongWait, Bitcoin: &lncfg.Chain{ MinHTLCIn: chainreg.DefaultBitcoinMinHTLCInMSat, MinHTLCOut: chainreg.DefaultBitcoinMinHTLCOutMSat, diff --git a/docs/rest/websockets.md b/docs/rest/websockets.md index 705a4c73..f2c85769 100644 --- a/docs/rest/websockets.md +++ b/docs/rest/websockets.md @@ -97,3 +97,55 @@ ws.on('message', function(body) { // "height": , // } ``` + +## Request-streaming RPCs + +Starting with `lnd v0.13.0-beta` all RPCs can be used through REST, even those +that are fully bi-directional (e.g. the client can also send multiple request +messages to the stream). + +**Example**: + +As an example we show how one can use the bi-directional channel acceptor RPC. +Through that RPC each incoming channel open request (another peer opening a +channel to our node) will be passed in for inspection. We can decide +programmatically whether to accept or reject the channel. + +```javascript +// -------------------------- +// Example with websockets: +// -------------------------- +const WebSocket = require('ws'); +const fs = require('fs'); +const macaroon = fs.readFileSync('LND_DIR/data/chain/bitcoin/simnet/admin.macaroon').toString('hex'); +let ws = new WebSocket('wss://localhost:8080/v1/channels/acceptor?method=POST', { + // Work-around for self-signed certificates. + rejectUnauthorized: false, + headers: { + 'Grpc-Metadata-Macaroon': macaroon, + }, +}); +ws.on('open', function() { + // We always _need_ to send an initial message to kickstart the request. + // This empty message will be ignored by the channel acceptor though, this + // is just for telling the grpc-gateway library that it can forward the + // request to the gRPC interface now. If this were an RPC where the client + // always sends the first message (for example the streaming payment RPC + // /v1/channels/transaction-stream), we'd simply send the first "real" + // message here when needed. + ws.send('{}'); +}); +ws.on('error', function(err) { + console.log('Error: ' + err); +}); +ws.on('ping', function ping(event) { + console.log('Received ping from server: ' + JSON.stringify(event)); +}); +ws.on('message', function incoming(event) { + console.log('New channel accept message: ' + event); + const result = JSON.parse(event).result; + + // Accept the channel after inspecting it. + ws.send(JSON.stringify({accept: true, pending_chan_id: result.pending_chan_id})); +}); +``` diff --git a/lnd.go b/lnd.go index 2559d0aa..ad03a4a4 100644 --- a/lnd.go +++ b/lnd.go @@ -1272,7 +1272,10 @@ func startRestProxy(cfg *Config, rpcServer *rpcServer, restDialOpts []grpc.DialO } // Wrap the default grpc-gateway handler with the WebSocket handler. - restHandler := lnrpc.NewWebSocketProxy(mux, rpcsLog) + restHandler := lnrpc.NewWebSocketProxy( + mux, rpcsLog, cfg.WSPingInterval, cfg.WSPongWait, + lnrpc.LndClientStreamingURIs, + ) // Use a WaitGroup so we can be sure the instructions on how to input the // password is the last thing to be printed to the console. diff --git a/lnrpc/metadata.go b/lnrpc/metadata.go new file mode 100644 index 00000000..fc54560b --- /dev/null +++ b/lnrpc/metadata.go @@ -0,0 +1,17 @@ +package lnrpc + +import "regexp" + +var ( + // LndClientStreamingURIs is a list of all lnd RPCs that use a request- + // streaming interface. Those request-streaming RPCs need to be handled + // differently in the WebsocketProxy because of how the request body + // parsing is implemented in the grpc-gateway library. Unfortunately + // there is no straightforward way of obtaining this information on + // runtime so we need to keep a hard coded list here. + LndClientStreamingURIs = []*regexp.Regexp{ + regexp.MustCompile("^/v1/channels/acceptor$"), + regexp.MustCompile("^/v1/channels/transaction-stream$"), + regexp.MustCompile("^/v2/router/htlcinterceptor$"), + } +) diff --git a/lnrpc/rest-annotations.yaml b/lnrpc/rest-annotations.yaml index c10b2605..7b388da0 100644 --- a/lnrpc/rest-annotations.yaml +++ b/lnrpc/rest-annotations.yaml @@ -1,6 +1,10 @@ type: google.api.Service config_version: 3 +# Mapping for the grpc-gateway REST proxy. +# Please make sure to also update the `metadata.go` file when editing this file +# and adding a new client-streaming RPC! + http: rules: # rpc.proto @@ -61,12 +65,15 @@ http: post: "/v1/funding/step" body: "*" - selector: lnrpc.Lightning.ChannelAcceptor - # request streaming RPC, REST not supported + post: "/v1/channels/acceptor" + body: "*" - selector: lnrpc.Lightning.CloseChannel delete: "/v1/channels/{channel_point.funding_txid_str}/{channel_point.output_index}" - selector: lnrpc.Lightning.AbandonChannel delete: "/v1/channels/abandon/{channel_point.funding_txid_str}/{channel_point.output_index}" - selector: lnrpc.Lightning.SendPayment + post: "/v1/channels/transaction-stream" + body: "*" - selector: lnrpc.Lightning.SendPaymentSync post: "/v1/channels/transactions" body: "*" @@ -228,7 +235,8 @@ http: - selector: routerrpc.Router.TrackPayment # deprecated, no REST endpoint - selector: routerrpc.HtlcInterceptor - # request streaming RPC, REST not supported + post: "/v2/router/htlcinterceptor" + body: "*" - selector: routerrpc.UpdateChanStatus post: "/v2/router/updatechanstatus" body: "*" diff --git a/lnrpc/rpc.pb.gw.go b/lnrpc/rpc.pb.gw.go index 41021334..ce69a3d6 100644 --- a/lnrpc/rpc.pb.gw.go +++ b/lnrpc/rpc.pb.gw.go @@ -731,6 +731,58 @@ func local_request_Lightning_FundingStateStep_0(ctx context.Context, marshaler r } +func request_Lightning_ChannelAcceptor_0(ctx context.Context, marshaler runtime.Marshaler, client LightningClient, req *http.Request, pathParams map[string]string) (Lightning_ChannelAcceptorClient, runtime.ServerMetadata, error) { + var metadata runtime.ServerMetadata + stream, err := client.ChannelAcceptor(ctx) + if err != nil { + grpclog.Infof("Failed to start streaming: %v", err) + return nil, metadata, err + } + dec := marshaler.NewDecoder(req.Body) + handleSend := func() error { + var protoReq ChannelAcceptResponse + err := dec.Decode(&protoReq) + if err == io.EOF { + return err + } + if err != nil { + grpclog.Infof("Failed to decode request: %v", err) + return err + } + if err := stream.Send(&protoReq); err != nil { + grpclog.Infof("Failed to send request: %v", err) + return err + } + return nil + } + if err := handleSend(); err != nil { + if cerr := stream.CloseSend(); cerr != nil { + grpclog.Infof("Failed to terminate client stream: %v", cerr) + } + if err == io.EOF { + return stream, metadata, nil + } + return nil, metadata, err + } + go func() { + for { + if err := handleSend(); err != nil { + break + } + } + if err := stream.CloseSend(); err != nil { + grpclog.Infof("Failed to terminate client stream: %v", err) + } + }() + header, err := stream.Header() + if err != nil { + grpclog.Infof("Failed to get header from client: %v", err) + return nil, metadata, err + } + metadata.HeaderMD = header + return stream, metadata, nil +} + var ( filter_Lightning_CloseChannel_0 = &utilities.DoubleArray{Encoding: map[string]int{"channel_point": 0, "funding_txid_str": 1, "output_index": 2}, Base: []int{1, 1, 1, 2, 0, 0}, Check: []int{0, 1, 2, 2, 3, 4}} ) @@ -879,6 +931,58 @@ func local_request_Lightning_AbandonChannel_0(ctx context.Context, marshaler run } +func request_Lightning_SendPayment_0(ctx context.Context, marshaler runtime.Marshaler, client LightningClient, req *http.Request, pathParams map[string]string) (Lightning_SendPaymentClient, runtime.ServerMetadata, error) { + var metadata runtime.ServerMetadata + stream, err := client.SendPayment(ctx) + if err != nil { + grpclog.Infof("Failed to start streaming: %v", err) + return nil, metadata, err + } + dec := marshaler.NewDecoder(req.Body) + handleSend := func() error { + var protoReq SendRequest + err := dec.Decode(&protoReq) + if err == io.EOF { + return err + } + if err != nil { + grpclog.Infof("Failed to decode request: %v", err) + return err + } + if err := stream.Send(&protoReq); err != nil { + grpclog.Infof("Failed to send request: %v", err) + return err + } + return nil + } + if err := handleSend(); err != nil { + if cerr := stream.CloseSend(); cerr != nil { + grpclog.Infof("Failed to terminate client stream: %v", cerr) + } + if err == io.EOF { + return stream, metadata, nil + } + return nil, metadata, err + } + go func() { + for { + if err := handleSend(); err != nil { + break + } + } + if err := stream.CloseSend(); err != nil { + grpclog.Infof("Failed to terminate client stream: %v", err) + } + }() + header, err := stream.Header() + if err != nil { + grpclog.Infof("Failed to get header from client: %v", err) + return nil, metadata, err + } + metadata.HeaderMD = header + return stream, metadata, nil +} + func request_Lightning_SendPaymentSync_0(ctx context.Context, marshaler runtime.Marshaler, client LightningClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { var protoReq SendRequest var metadata runtime.ServerMetadata @@ -2451,6 +2555,13 @@ func RegisterLightningHandlerServer(ctx context.Context, mux *runtime.ServeMux, }) + mux.Handle("POST", pattern_Lightning_ChannelAcceptor_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + err := status.Error(codes.Unimplemented, "streaming calls are not yet supported in the in-process transport") + _, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + }) + mux.Handle("DELETE", pattern_Lightning_CloseChannel_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { err := status.Error(codes.Unimplemented, "streaming calls are not yet supported in the in-process transport") _, outboundMarshaler := runtime.MarshalerForRequest(mux, req) @@ -2478,6 +2589,13 @@ func RegisterLightningHandlerServer(ctx context.Context, mux *runtime.ServeMux, }) + mux.Handle("POST", pattern_Lightning_SendPayment_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + err := status.Error(codes.Unimplemented, "streaming calls are not yet supported in the in-process transport") + _, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + }) + mux.Handle("POST", pattern_Lightning_SendPaymentSync_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() @@ -3560,6 +3678,26 @@ func RegisterLightningHandlerClient(ctx context.Context, mux *runtime.ServeMux, }) + mux.Handle("POST", pattern_Lightning_ChannelAcceptor_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_Lightning_ChannelAcceptor_0(rctx, inboundMarshaler, client, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_Lightning_ChannelAcceptor_0(ctx, mux, outboundMarshaler, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) + + }) + mux.Handle("DELETE", pattern_Lightning_CloseChannel_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() @@ -3600,6 +3738,26 @@ func RegisterLightningHandlerClient(ctx context.Context, mux *runtime.ServeMux, }) + mux.Handle("POST", pattern_Lightning_SendPayment_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_Lightning_SendPayment_0(rctx, inboundMarshaler, client, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_Lightning_SendPayment_0(ctx, mux, outboundMarshaler, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) + + }) + mux.Handle("POST", pattern_Lightning_SendPaymentSync_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() @@ -4252,10 +4410,14 @@ var ( pattern_Lightning_FundingStateStep_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v1", "funding", "step"}, "", runtime.AssumeColonVerbOpt(true))) + pattern_Lightning_ChannelAcceptor_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v1", "channels", "acceptor"}, "", runtime.AssumeColonVerbOpt(true))) + pattern_Lightning_CloseChannel_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 1, 0, 4, 1, 5, 2, 1, 0, 4, 1, 5, 3}, []string{"v1", "channels", "channel_point.funding_txid_str", "channel_point.output_index"}, "", runtime.AssumeColonVerbOpt(true))) pattern_Lightning_AbandonChannel_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 1, 0, 4, 1, 5, 4}, []string{"v1", "channels", "abandon", "channel_point.funding_txid_str", "channel_point.output_index"}, "", runtime.AssumeColonVerbOpt(true))) + pattern_Lightning_SendPayment_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v1", "channels", "transaction-stream"}, "", runtime.AssumeColonVerbOpt(true))) + pattern_Lightning_SendPaymentSync_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v1", "channels", "transactions"}, "", runtime.AssumeColonVerbOpt(true))) pattern_Lightning_SendToRouteSync_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"v1", "channels", "transactions", "route"}, "", runtime.AssumeColonVerbOpt(true))) @@ -4366,10 +4528,14 @@ var ( forward_Lightning_FundingStateStep_0 = runtime.ForwardResponseMessage + forward_Lightning_ChannelAcceptor_0 = runtime.ForwardResponseStream + forward_Lightning_CloseChannel_0 = runtime.ForwardResponseStream forward_Lightning_AbandonChannel_0 = runtime.ForwardResponseMessage + forward_Lightning_SendPayment_0 = runtime.ForwardResponseStream + forward_Lightning_SendPaymentSync_0 = runtime.ForwardResponseMessage forward_Lightning_SendToRouteSync_0 = runtime.ForwardResponseMessage diff --git a/lnrpc/rpc.swagger.json b/lnrpc/rpc.swagger.json index 8cbe8838..433eb887 100644 --- a/lnrpc/rpc.swagger.json +++ b/lnrpc/rpc.swagger.json @@ -204,6 +204,49 @@ ] } }, + "/v1/channels/acceptor": { + "post": { + "summary": "ChannelAcceptor dispatches a bi-directional streaming RPC in which\nOpenChannel requests are sent to the client and the client responds with\na boolean that tells LND whether or not to accept the channel. This allows\nnode operators to specify their own criteria for accepting inbound channels\nthrough a single persistent connection.", + "operationId": "ChannelAcceptor", + "responses": { + "200": { + "description": "A successful response.(streaming responses)", + "schema": { + "type": "object", + "properties": { + "result": { + "$ref": "#/definitions/lnrpcChannelAcceptRequest" + }, + "error": { + "$ref": "#/definitions/runtimeStreamError" + } + }, + "title": "Stream result of lnrpcChannelAcceptRequest" + } + }, + "default": { + "description": "An unexpected error response", + "schema": { + "$ref": "#/definitions/runtimeError" + } + } + }, + "parameters": [ + { + "name": "body", + "description": " (streaming inputs)", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/lnrpcChannelAcceptResponse" + } + } + ], + "tags": [ + "Lightning" + ] + } + }, "/v1/channels/backup": { "get": { "summary": "ExportAllChannelBackups returns static channel backups for all existing\nchannels known to lnd. A set of regular singular static channel backups for\neach channel are returned. Additionally, a multi-channel backup is returned\nas well, which contains a single encrypted blob containing the backups of\neach channel.", @@ -537,6 +580,49 @@ ] } }, + "/v1/channels/transaction-stream": { + "post": { + "summary": "lncli: `sendpayment`\nDeprecated, use routerrpc.SendPaymentV2. SendPayment dispatches a\nbi-directional streaming RPC for sending payments through the Lightning\nNetwork. A single RPC invocation creates a persistent bi-directional\nstream allowing clients to rapidly send payments through the Lightning\nNetwork with a single persistent connection.", + "operationId": "SendPayment", + "responses": { + "200": { + "description": "A successful response.(streaming responses)", + "schema": { + "type": "object", + "properties": { + "result": { + "$ref": "#/definitions/lnrpcSendResponse" + }, + "error": { + "$ref": "#/definitions/runtimeStreamError" + } + }, + "title": "Stream result of lnrpcSendResponse" + } + }, + "default": { + "description": "An unexpected error response", + "schema": { + "$ref": "#/definitions/runtimeError" + } + } + }, + "parameters": [ + { + "name": "body", + "description": " (streaming inputs)", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/lnrpcSendRequest" + } + } + ], + "tags": [ + "Lightning" + ] + } + }, "/v1/channels/transactions": { "post": { "summary": "SendPaymentSync is the synchronous non-streaming version of SendPayment.\nThis RPC is intended to be consumed by clients of the REST proxy.\nAdditionally, this RPC expects the destination's public key and the payment\nhash (if any) to be encoded as hex strings.", @@ -2929,6 +3015,59 @@ } } }, + "lnrpcChannelAcceptResponse": { + "type": "object", + "properties": { + "accept": { + "type": "boolean", + "format": "boolean", + "description": "Whether or not the client accepts the channel." + }, + "pending_chan_id": { + "type": "string", + "format": "byte", + "description": "The pending channel id to which this response applies." + }, + "error": { + "type": "string", + "description": "An optional error to send the initiating party to indicate why the channel\nwas rejected. This field *should not* contain sensitive information, it will\nbe sent to the initiating party. This field should only be set if accept is\nfalse, the channel will be rejected if an error is set with accept=true\nbecause the meaning of this response is ambiguous. Limited to 500\ncharacters." + }, + "upfront_shutdown": { + "type": "string", + "description": "The upfront shutdown address to use if the initiating peer supports option\nupfront shutdown script (see ListPeers for the features supported). Note\nthat the channel open will fail if this value is set for a peer that does\nnot support this feature bit." + }, + "csv_delay": { + "type": "integer", + "format": "int64", + "description": "The csv delay (in blocks) that we require for the remote party." + }, + "reserve_sat": { + "type": "string", + "format": "uint64", + "description": "The reserve amount in satoshis that we require the remote peer to adhere to.\nWe require that the remote peer always have some reserve amount allocated to\nthem so that there is always a disincentive to broadcast old state (if they\nhold 0 sats on their side of the channel, there is nothing to lose)." + }, + "in_flight_max_msat": { + "type": "string", + "format": "uint64", + "description": "The maximum amount of funds in millisatoshis that we allow the remote peer\nto have in outstanding htlcs." + }, + "max_htlc_count": { + "type": "integer", + "format": "int64", + "description": "The maximum number of htlcs that the remote peer can offer us." + }, + "min_htlc_in": { + "type": "string", + "format": "uint64", + "description": "The minimum value in millisatoshis for incoming htlcs on the channel." + }, + "min_accept_depth": { + "type": "integer", + "format": "int64", + "description": "The number of confirmations we require before we consider the channel open." + } + } + }, "lnrpcChannelBackup": { "type": "object", "properties": { diff --git a/lnrpc/websocket_proxy.go b/lnrpc/websocket_proxy.go index 3cb701be..8c5001da 100644 --- a/lnrpc/websocket_proxy.go +++ b/lnrpc/websocket_proxy.go @@ -8,7 +8,9 @@ import ( "io" "net/http" "net/textproto" + "regexp" "strings" + "time" "github.com/btcsuite/btclog" "github.com/gorilla/websocket" @@ -49,12 +51,28 @@ var ( defaultProtocolsToAllow = map[string]bool{ "Grpc-Metadata-Macaroon": true, } + + // DefaultPingInterval is the default number of seconds to wait between + // sending ping requests. + DefaultPingInterval = time.Second * 30 + + // DefaultPongWait is the maximum duration we wait for a pong response + // to a ping we sent before we assume the connection died. + DefaultPongWait = time.Second * 5 ) // NewWebSocketProxy attempts to expose the underlying handler as a response- // streaming WebSocket stream with newline-delimited JSON as the content -// encoding. -func NewWebSocketProxy(h http.Handler, logger btclog.Logger) http.Handler { +// encoding. If pingInterval is a non-zero duration, a ping message will be +// sent out periodically and a pong response message is expected from the +// client. The clientStreamingURIs parameter can hold a list of all patterns +// for URIs that are mapped to client-streaming RPC methods. We need to keep +// track of those to make sure we initialize the request body correctly for the +// underlying grpc-gateway library. +func NewWebSocketProxy(h http.Handler, logger btclog.Logger, + pingInterval, pongWait time.Duration, + clientStreamingURIs []*regexp.Regexp) http.Handler { + p := &WebsocketProxy{ backend: h, logger: logger, @@ -65,7 +83,14 @@ func NewWebSocketProxy(h http.Handler, logger btclog.Logger) http.Handler { return true }, }, + clientStreamingURIs: clientStreamingURIs, } + + if pingInterval > 0 && pongWait > 0 { + p.pingInterval = pingInterval + p.pongWait = pongWait + } + return p } @@ -74,6 +99,21 @@ type WebsocketProxy struct { backend http.Handler logger btclog.Logger upgrader *websocket.Upgrader + + // clientStreamingURIs holds a list of all patterns for URIs that are + // mapped to client-streaming RPC methods. We need to keep track of + // those to make sure we initialize the request body correctly for the + // underlying grpc-gateway library. + clientStreamingURIs []*regexp.Regexp + + pingInterval time.Duration + pongWait time.Duration +} + +// pingPongEnabled returns true if a ping interval is set to enable sending and +// expecting regular ping/pong messages. +func (p *WebsocketProxy) pingPongEnabled() bool { + return p.pingInterval > 0 && p.pongWait > 0 } // ServeHTTP handles the incoming HTTP request. If the request is an @@ -129,6 +169,14 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter, request.Method = m } + // Is this a call to a client-streaming RPC method? + clientStreaming := false + for _, pattern := range p.clientStreamingURIs { + if pattern.MatchString(r.URL.Path) { + clientStreaming = true + } + } + responseForwarder := newResponseForwardingWriter() go func() { <-ctx.Done() @@ -169,13 +217,68 @@ func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter, } _, _ = requestForwarder.Write([]byte{'\n'}) - // We currently only support server-streaming messages. - // Therefore we close the request body after the first - // incoming message to trigger a response. - requestForwarder.CloseWriter() + // The grpc-gateway library uses a different request + // reader depending on whether it is a client streaming + // RPC or not. For a non-streaming request we need to + // close with EOF to signal the request was completed. + if !clientStreaming { + requestForwarder.CloseWriter() + } } }() + // Ping write loop: Send a ping message regularly if ping/pong is + // enabled. + if p.pingPongEnabled() { + // We'll send out our first ping in pingInterval. So the initial + // deadline is that interval plus the time we allow for a + // response to be sent. + initialDeadline := time.Now().Add(p.pingInterval + p.pongWait) + _ = conn.SetReadDeadline(initialDeadline) + + // Whenever a pong message comes in, we extend the deadline + // until the next read is expected by the interval plus pong + // wait time. + conn.SetPongHandler(func(appData string) error { + nextDeadline := time.Now().Add( + p.pingInterval + p.pongWait, + ) + _ = conn.SetReadDeadline(nextDeadline) + return nil + }) + go func() { + ticker := time.NewTicker(p.pingInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + p.logger.Debug("WS: ping loop done") + return + + case <-ticker.C: + // Writing the ping shouldn't take any + // longer than we'll wait for a response + // in the first place. + writeDeadline := time.Now().Add( + p.pongWait, + ) + _ = conn.SetWriteDeadline(writeDeadline) + + err := conn.WriteMessage( + websocket.PingMessage, nil, + ) + if err != nil { + p.logger.Warnf("WS: could not "+ + "send ping message: %v", + err) + return + } + } + } + }() + } + // Write loop: Take messages from the response forwarder and write them // to the WebSocket. for responseForwarder.Scan() { @@ -348,6 +451,9 @@ func IsClosedConnError(err error) bool { if strings.Contains(str, "broken pipe") { return true } + if strings.Contains(str, "connection reset by peer") { + return true + } return websocket.IsCloseError( err, websocket.CloseNormalClosure, websocket.CloseGoingAway, ) diff --git a/lntest/itest/lnd_rest_api_test.go b/lntest/itest/lnd_rest_api_test.go index daa5cef3..f8a56c9a 100644 --- a/lntest/itest/lnd_rest_api_test.go +++ b/lntest/itest/lnd_rest_api_test.go @@ -47,6 +47,9 @@ var ( TLSClientConfig: insecureTransport.TLSClientConfig, } resultPattern = regexp.MustCompile("{\"result\":(.*)}") + closeMsg = websocket.FormatCloseMessage( + websocket.CloseNormalClosure, "done", + ) ) // testRestAPI tests that the most important features of the REST API work @@ -185,204 +188,19 @@ func testRestAPI(net *lntest.NetworkHarness, ht *harnessTest) { ) assert.Equal(t, 0, len(body)) }, - }, { + }} + wsTestCases := []struct { + name string + run func(ht *harnessTest, net *lntest.NetworkHarness) + }{{ name: "websocket subscription", - run: func(t *testing.T, a, b *lntest.HarnessNode) { - // Find out the current best block so we can subscribe - // to the next one. - hash, height, err := net.Miner.Client.GetBestBlock() - require.Nil(t, err, "get best block") - - // Create a new subscription to get block epoch events. - req := &chainrpc.BlockEpoch{ - Hash: hash.CloneBytes(), - Height: uint32(height), - } - url := "/v2/chainnotifier/register/blocks" - c, err := openWebSocket(a, url, "POST", req, nil) - require.Nil(t, err, "websocket") - defer func() { - _ = c.WriteMessage( - websocket.CloseMessage, - websocket.FormatCloseMessage( - websocket.CloseNormalClosure, - "done", - ), - ) - _ = c.Close() - }() - - msgChan := make(chan *chainrpc.BlockEpoch) - errChan := make(chan error) - timeout := time.After(defaultTimeout) - - // We want to read exactly one message. - go func() { - defer close(msgChan) - - _, msg, err := c.ReadMessage() - if err != nil { - errChan <- err - return - } - - // The chunked/streamed responses come wrapped - // in either a {"result":{}} or {"error":{}} - // wrapper which we'll get rid of here. - msgStr := string(msg) - if !strings.Contains(msgStr, "\"result\":") { - errChan <- fmt.Errorf("invalid msg: %s", - msgStr) - return - } - msgStr = resultPattern.ReplaceAllString( - msgStr, "${1}", - ) - - // Make sure we can parse the unwrapped message - // into the expected proto message. - protoMsg := &chainrpc.BlockEpoch{} - err = jsonpb.UnmarshalString( - msgStr, protoMsg, - ) - if err != nil { - errChan <- err - return - } - - select { - case msgChan <- protoMsg: - case <-timeout: - } - }() - - // Mine a block and make sure we get a message for it. - blockHashes, err := net.Miner.Client.Generate(1) - require.Nil(t, err, "generate blocks") - assert.Equal(t, 1, len(blockHashes), "num blocks") - select { - case msg := <-msgChan: - assert.Equal( - t, blockHashes[0].CloneBytes(), - msg.Hash, "block hash", - ) - - case err := <-errChan: - t.Fatalf("Received error from WS: %v", err) - - case <-timeout: - t.Fatalf("Timeout before message was received") - } - }, + run: wsTestCaseSubscription, }, { name: "websocket subscription with macaroon in protocol", - run: func(t *testing.T, a, b *lntest.HarnessNode) { - // Find out the current best block so we can subscribe - // to the next one. - hash, height, err := net.Miner.Client.GetBestBlock() - require.Nil(t, err, "get best block") - - // Create a new subscription to get block epoch events. - req := &chainrpc.BlockEpoch{ - Hash: hash.CloneBytes(), - Height: uint32(height), - } - url := "/v2/chainnotifier/register/blocks" - - // This time we send the macaroon in the special header - // Sec-Websocket-Protocol which is the only header field - // available to browsers when opening a WebSocket. - mac, err := a.ReadMacaroon( - a.AdminMacPath(), defaultTimeout, - ) - require.NoError(t, err, "read admin mac") - macBytes, err := mac.MarshalBinary() - require.NoError(t, err, "marshal admin mac") - - customHeader := make(http.Header) - customHeader.Set( - lnrpc.HeaderWebSocketProtocol, fmt.Sprintf( - "Grpc-Metadata-Macaroon+%s", - hex.EncodeToString(macBytes), - ), - ) - c, err := openWebSocket( - a, url, "POST", req, customHeader, - ) - require.Nil(t, err, "websocket") - defer func() { - _ = c.WriteMessage( - websocket.CloseMessage, - websocket.FormatCloseMessage( - websocket.CloseNormalClosure, - "done", - ), - ) - _ = c.Close() - }() - - msgChan := make(chan *chainrpc.BlockEpoch) - errChan := make(chan error) - timeout := time.After(defaultTimeout) - - // We want to read exactly one message. - go func() { - defer close(msgChan) - - _, msg, err := c.ReadMessage() - if err != nil { - errChan <- err - return - } - - // The chunked/streamed responses come wrapped - // in either a {"result":{}} or {"error":{}} - // wrapper which we'll get rid of here. - msgStr := string(msg) - if !strings.Contains(msgStr, "\"result\":") { - errChan <- fmt.Errorf("invalid msg: %s", - msgStr) - return - } - msgStr = resultPattern.ReplaceAllString( - msgStr, "${1}", - ) - - // Make sure we can parse the unwrapped message - // into the expected proto message. - protoMsg := &chainrpc.BlockEpoch{} - err = jsonpb.UnmarshalString( - msgStr, protoMsg, - ) - if err != nil { - errChan <- err - return - } - - select { - case msgChan <- protoMsg: - case <-timeout: - } - }() - - // Mine a block and make sure we get a message for it. - blockHashes, err := net.Miner.Client.Generate(1) - require.Nil(t, err, "generate blocks") - assert.Equal(t, 1, len(blockHashes), "num blocks") - select { - case msg := <-msgChan: - assert.Equal( - t, blockHashes[0].CloneBytes(), - msg.Hash, "block hash", - ) - - case err := <-errChan: - t.Fatalf("Received error from WS: %v", err) - - case <-timeout: - t.Fatalf("Timeout before message was received") - } - }, + run: wsTestCaseSubscriptionMacaroon, + }, { + name: "websocket bi-directional subscription", + run: wsTestCaseBiDirectionalSubscription, }} // Make sure Alice allows all CORS origins. Bob will keep the default. @@ -401,6 +219,310 @@ func testRestAPI(net *lntest.NetworkHarness, ht *harnessTest) { tc.run(t, net.Alice, net.Bob) }) } + + for _, tc := range wsTestCases { + tc := tc + ht.t.Run(tc.name, func(t *testing.T) { + ht := &harnessTest{ + t: t, testCase: ht.testCase, lndHarness: net, + } + tc.run(ht, net) + }) + } +} + +func wsTestCaseSubscription(ht *harnessTest, net *lntest.NetworkHarness) { + // Find out the current best block so we can subscribe to the next one. + hash, height, err := net.Miner.Client.GetBestBlock() + require.Nil(ht.t, err, "get best block") + + // Create a new subscription to get block epoch events. + req := &chainrpc.BlockEpoch{ + Hash: hash.CloneBytes(), + Height: uint32(height), + } + url := "/v2/chainnotifier/register/blocks" + c, err := openWebSocket(net.Alice, url, "POST", req, nil) + require.Nil(ht.t, err, "websocket") + defer func() { + err := c.WriteMessage(websocket.CloseMessage, closeMsg) + require.NoError(ht.t, err) + _ = c.Close() + }() + + msgChan := make(chan *chainrpc.BlockEpoch) + errChan := make(chan error) + timeout := time.After(defaultTimeout) + + // We want to read exactly one message. + go func() { + defer close(msgChan) + + _, msg, err := c.ReadMessage() + if err != nil { + errChan <- err + return + } + + // The chunked/streamed responses come wrapped in either a + // {"result":{}} or {"error":{}} wrapper which we'll get rid of + // here. + msgStr := string(msg) + if !strings.Contains(msgStr, "\"result\":") { + errChan <- fmt.Errorf("invalid msg: %s", msgStr) + return + } + msgStr = resultPattern.ReplaceAllString(msgStr, "${1}") + + // Make sure we can parse the unwrapped message into the + // expected proto message. + protoMsg := &chainrpc.BlockEpoch{} + err = jsonpb.UnmarshalString(msgStr, protoMsg) + if err != nil { + errChan <- err + return + } + + select { + case msgChan <- protoMsg: + case <-timeout: + } + }() + + // Mine a block and make sure we get a message for it. + blockHashes, err := net.Miner.Client.Generate(1) + require.Nil(ht.t, err, "generate blocks") + assert.Equal(ht.t, 1, len(blockHashes), "num blocks") + select { + case msg := <-msgChan: + assert.Equal( + ht.t, blockHashes[0].CloneBytes(), msg.Hash, + "block hash", + ) + + case err := <-errChan: + ht.t.Fatalf("Received error from WS: %v", err) + + case <-timeout: + ht.t.Fatalf("Timeout before message was received") + } +} + +func wsTestCaseSubscriptionMacaroon(ht *harnessTest, + net *lntest.NetworkHarness) { + + // Find out the current best block so we can subscribe to the next one. + hash, height, err := net.Miner.Client.GetBestBlock() + require.Nil(ht.t, err, "get best block") + + // Create a new subscription to get block epoch events. + req := &chainrpc.BlockEpoch{ + Hash: hash.CloneBytes(), + Height: uint32(height), + } + url := "/v2/chainnotifier/register/blocks" + + // This time we send the macaroon in the special header + // Sec-Websocket-Protocol which is the only header field available to + // browsers when opening a WebSocket. + mac, err := net.Alice.ReadMacaroon( + net.Alice.AdminMacPath(), defaultTimeout, + ) + require.NoError(ht.t, err, "read admin mac") + macBytes, err := mac.MarshalBinary() + require.NoError(ht.t, err, "marshal admin mac") + + customHeader := make(http.Header) + customHeader.Set(lnrpc.HeaderWebSocketProtocol, fmt.Sprintf( + "Grpc-Metadata-Macaroon+%s", hex.EncodeToString(macBytes), + )) + c, err := openWebSocket(net.Alice, url, "POST", req, customHeader) + require.Nil(ht.t, err, "websocket") + defer func() { + err := c.WriteMessage(websocket.CloseMessage, closeMsg) + require.NoError(ht.t, err) + _ = c.Close() + }() + + msgChan := make(chan *chainrpc.BlockEpoch) + errChan := make(chan error) + timeout := time.After(defaultTimeout) + + // We want to read exactly one message. + go func() { + defer close(msgChan) + + _, msg, err := c.ReadMessage() + if err != nil { + errChan <- err + return + } + + // The chunked/streamed responses come wrapped in either a + // {"result":{}} or {"error":{}} wrapper which we'll get rid of + // here. + msgStr := string(msg) + if !strings.Contains(msgStr, "\"result\":") { + errChan <- fmt.Errorf("invalid msg: %s", msgStr) + return + } + msgStr = resultPattern.ReplaceAllString(msgStr, "${1}") + + // Make sure we can parse the unwrapped message into the + // expected proto message. + protoMsg := &chainrpc.BlockEpoch{} + err = jsonpb.UnmarshalString(msgStr, protoMsg) + if err != nil { + errChan <- err + return + } + + select { + case msgChan <- protoMsg: + case <-timeout: + } + }() + + // Mine a block and make sure we get a message for it. + blockHashes, err := net.Miner.Client.Generate(1) + require.Nil(ht.t, err, "generate blocks") + assert.Equal(ht.t, 1, len(blockHashes), "num blocks") + select { + case msg := <-msgChan: + assert.Equal( + ht.t, blockHashes[0].CloneBytes(), msg.Hash, + "block hash", + ) + + case err := <-errChan: + ht.t.Fatalf("Received error from WS: %v", err) + + case <-timeout: + ht.t.Fatalf("Timeout before message was received") + } +} + +func wsTestCaseBiDirectionalSubscription(ht *harnessTest, + net *lntest.NetworkHarness) { + + initialRequest := &lnrpc.ChannelAcceptResponse{} + url := "/v1/channels/acceptor" + + // This time we send the macaroon in the special header + // Sec-Websocket-Protocol which is the only header field available to + // browsers when opening a WebSocket. + mac, err := net.Alice.ReadMacaroon( + net.Alice.AdminMacPath(), defaultTimeout, + ) + require.NoError(ht.t, err, "read admin mac") + macBytes, err := mac.MarshalBinary() + require.NoError(ht.t, err, "marshal admin mac") + + customHeader := make(http.Header) + customHeader.Set(lnrpc.HeaderWebSocketProtocol, fmt.Sprintf( + "Grpc-Metadata-Macaroon+%s", hex.EncodeToString(macBytes), + )) + conn, err := openWebSocket( + net.Alice, url, "POST", initialRequest, customHeader, + ) + require.Nil(ht.t, err, "websocket") + defer func() { + err := conn.WriteMessage(websocket.CloseMessage, closeMsg) + require.NoError(ht.t, err) + _ = conn.Close() + }() + + msgChan := make(chan *lnrpc.ChannelAcceptResponse) + errChan := make(chan error) + done := make(chan struct{}) + timeout := time.After(defaultTimeout) + + // We want to read messages over and over again. We just accept any + // channels that are opened. + go func() { + for { + _, msg, err := conn.ReadMessage() + if err != nil { + errChan <- err + return + } + + // The chunked/streamed responses come wrapped in either + // a {"result":{}} or {"error":{}} wrapper which we'll + // get rid of here. + msgStr := string(msg) + if !strings.Contains(msgStr, "\"result\":") { + errChan <- fmt.Errorf("invalid msg: %s", msgStr) + return + } + msgStr = resultPattern.ReplaceAllString(msgStr, "${1}") + + // Make sure we can parse the unwrapped message into the + // expected proto message. + protoMsg := &lnrpc.ChannelAcceptRequest{} + err = jsonpb.UnmarshalString(msgStr, protoMsg) + if err != nil { + errChan <- err + return + } + + // Send the response that we accept the channel. + res := &lnrpc.ChannelAcceptResponse{ + Accept: true, + PendingChanId: protoMsg.PendingChanId, + } + resMsg, err := jsonMarshaler.MarshalToString(res) + if err != nil { + errChan <- err + return + } + err = conn.WriteMessage( + websocket.TextMessage, []byte(resMsg), + ) + if err != nil { + errChan <- err + return + } + + // Also send the message on our message channel to make + // sure we count it as successful. + msgChan <- res + + // Are we done or should there be more messages? + select { + case <-done: + return + default: + } + } + }() + + // Before we start opening channels, make sure the two nodes are + // connected. + err = net.EnsureConnected(context.Background(), net.Alice, net.Bob) + require.NoError(ht.t, err) + + // Open 3 channels to make sure multiple requests and responses can be + // sent over the web socket. + const numChannels = 3 + for i := 0; i < numChannels; i++ { + openChannelAndAssert( + context.Background(), ht, net, net.Bob, net.Alice, + lntest.OpenChannelParams{ + Amt: 500000, + }, + ) + + select { + case <-msgChan: + case err := <-errChan: + ht.t.Fatalf("Received error from WS: %v", err) + + case <-timeout: + ht.t.Fatalf("Timeout before message was received") + } + } + close(done) } // invokeGET calls the given URL with the GET method and appropriate macaroon diff --git a/lntest/itest/log_error_whitelist.txt b/lntest/itest/log_error_whitelist.txt index ec164c9d..87ac6f50 100644 --- a/lntest/itest/log_error_whitelist.txt +++ b/lntest/itest/log_error_whitelist.txt @@ -11,6 +11,7 @@