lnd: Add CORS support to the WalletUnlocker proxy

This commit adds the same CORS functionality that's currently in the main gRPC proxy to the WalletUnlocker proxy. This ensures the CORS configuration is carried through all API endpoints
This commit is contained in:
Graham Krizek 2020-08-19 23:06:28 -05:00
parent fc12656a1a
commit 3f944dd337
No known key found for this signature in database
GPG Key ID: CE60579CC185BC56
2 changed files with 9 additions and 8 deletions

2
lnd.go
View File

@ -1011,7 +1011,7 @@ func waitForWalletPassword(cfg *Config, restEndpoints []net.Addr,
return nil, err
}
srv := &http.Server{Handler: mux}
srv := &http.Server{Handler: allowCORS(mux, cfg.RestCORS)}
for _, restEndpoint := range restEndpoints {
lis, err := lncfg.TLSListenOnAddress(restEndpoint, tlsConf)

View File

@ -810,12 +810,6 @@ func (r *rpcServer) Start() error {
// Wrap the default grpc-gateway handler with the WebSocket handler.
restHandler := lnrpc.NewWebSocketProxy(restMux, rpcsLog)
// Set the CORS headers if configured. This wraps the HTTP handler with
// another handler.
if len(r.cfg.RestCORS) > 0 {
restHandler = allowCORS(restHandler, r.cfg.RestCORS)
}
// With our custom REST proxy mux created, register our main RPC and
// give all subservers a chance to register as well.
err := lnrpc.RegisterLightningHandlerFromEndpoint(
@ -871,7 +865,8 @@ func (r *rpcServer) Start() error {
// through the following chain:
// req ---> CORS handler --> WS proxy --->
// REST proxy --> gRPC endpoint
err := http.Serve(lis, restHandler)
corsHandler := allowCORS(restHandler, r.cfg.RestCORS)
err := http.Serve(lis, corsHandler)
if err != nil && !lnrpc.IsClosedConnError(err) {
rpcsLog.Error(err)
}
@ -944,6 +939,12 @@ func allowCORS(handler http.Handler, origins []string) http.Handler {
allowMethods := "Access-Control-Allow-Methods"
allowOrigin := "Access-Control-Allow-Origin"
// If the user didn't supply any origins that means CORS is disabled
// and we should return the original handler.
if len(origins) == 0 {
return handler
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")