From 3f944dd3371ddeebecffe07732f27c831afe07f2 Mon Sep 17 00:00:00 2001 From: Graham Krizek Date: Wed, 19 Aug 2020 23:06:28 -0500 Subject: [PATCH] lnd: Add CORS support to the WalletUnlocker proxy This commit adds the same CORS functionality that's currently in the main gRPC proxy to the WalletUnlocker proxy. This ensures the CORS configuration is carried through all API endpoints --- lnd.go | 2 +- rpcserver.go | 15 ++++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/lnd.go b/lnd.go index 21a6923d..1722061b 100644 --- a/lnd.go +++ b/lnd.go @@ -1011,7 +1011,7 @@ func waitForWalletPassword(cfg *Config, restEndpoints []net.Addr, return nil, err } - srv := &http.Server{Handler: mux} + srv := &http.Server{Handler: allowCORS(mux, cfg.RestCORS)} for _, restEndpoint := range restEndpoints { lis, err := lncfg.TLSListenOnAddress(restEndpoint, tlsConf) diff --git a/rpcserver.go b/rpcserver.go index 088f4a86..9348088d 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -810,12 +810,6 @@ func (r *rpcServer) Start() error { // Wrap the default grpc-gateway handler with the WebSocket handler. restHandler := lnrpc.NewWebSocketProxy(restMux, rpcsLog) - // Set the CORS headers if configured. This wraps the HTTP handler with - // another handler. - if len(r.cfg.RestCORS) > 0 { - restHandler = allowCORS(restHandler, r.cfg.RestCORS) - } - // With our custom REST proxy mux created, register our main RPC and // give all subservers a chance to register as well. err := lnrpc.RegisterLightningHandlerFromEndpoint( @@ -871,7 +865,8 @@ func (r *rpcServer) Start() error { // through the following chain: // req ---> CORS handler --> WS proxy ---> // REST proxy --> gRPC endpoint - err := http.Serve(lis, restHandler) + corsHandler := allowCORS(restHandler, r.cfg.RestCORS) + err := http.Serve(lis, corsHandler) if err != nil && !lnrpc.IsClosedConnError(err) { rpcsLog.Error(err) } @@ -944,6 +939,12 @@ func allowCORS(handler http.Handler, origins []string) http.Handler { allowMethods := "Access-Control-Allow-Methods" allowOrigin := "Access-Control-Allow-Origin" + // If the user didn't supply any origins that means CORS is disabled + // and we should return the original handler. + if len(origins) == 0 { + return handler + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin")