// The code in this file is a heavily modified version of
// https://github.com/tmc/grpc-websocket-proxy/

package lnrpc

import (
	"bufio"
	"io"
	"net/http"
	"net/textproto"
	"regexp"
	"strings"
	"time"

	"github.com/btcsuite/btclog"
	"github.com/gorilla/websocket"
	"golang.org/x/net/context"
)

const (
	// MethodOverrideParam is the GET query parameter that specifies what
	// HTTP request method should be used for the forwarded REST request.
	// This is necessary because the WebSocket API specifies that a
	// handshake request must always be done through a GET request.
	MethodOverrideParam = "method"

	// HeaderWebSocketProtocol is the name of the WebSocket protocol
	// exchange header field that we use to transport additional header
	// fields.
	HeaderWebSocketProtocol = "Sec-Websocket-Protocol"

	// WebSocketProtocolDelimiter is the delimiter we use between the
	// additional header field and its value. We use the plus symbol because
	// the default delimiters aren't allowed in the protocol names.
	WebSocketProtocolDelimiter = "+"
)

var (
	// defaultHeadersToForward is a map of all HTTP header fields that are
	// forwarded by default. The keys must be in the canonical MIME header
	// format.
	defaultHeadersToForward = map[string]bool{
		"Origin":                 true,
		"Referer":                true,
		"Grpc-Metadata-Macaroon": true,
	}

	// defaultProtocolsToAllow are additional header fields that we allow
	// to be transported inside of the Sec-Websocket-Protocol field to be
	// forwarded to the backend.
	defaultProtocolsToAllow = map[string]bool{
		"Grpc-Metadata-Macaroon": true,
	}

	// DefaultPingInterval is the default number of seconds to wait between
	// sending ping requests.
	DefaultPingInterval = time.Second * 30

	// DefaultPongWait is the maximum duration we wait for a pong response
	// to a ping we sent before we assume the connection died.
	DefaultPongWait = time.Second * 5
)

// NewWebSocketProxy attempts to expose the underlying handler as a response-
// streaming WebSocket stream with newline-delimited JSON as the content
// encoding. If pingInterval is a non-zero duration, a ping message will be
// sent out periodically and a pong response message is expected from the
// client. The clientStreamingURIs parameter can hold a list of all patterns
// for URIs that are mapped to client-streaming RPC methods. We need to keep
// track of those to make sure we initialize the request body correctly for the
// underlying grpc-gateway library.
func NewWebSocketProxy(h http.Handler, logger btclog.Logger,
	pingInterval, pongWait time.Duration,
	clientStreamingURIs []*regexp.Regexp) http.Handler {

	p := &WebsocketProxy{
		backend: h,
		logger:  logger,
		upgrader: &websocket.Upgrader{
			ReadBufferSize:  1024,
			WriteBufferSize: 1024,
			CheckOrigin: func(r *http.Request) bool {
				return true
			},
		},
		clientStreamingURIs: clientStreamingURIs,
	}

	if pingInterval > 0 && pongWait > 0 {
		p.pingInterval = pingInterval
		p.pongWait = pongWait
	}

	return p
}

// WebsocketProxy provides websocket transport upgrade to compatible endpoints.
type WebsocketProxy struct {
	backend  http.Handler
	logger   btclog.Logger
	upgrader *websocket.Upgrader

	// clientStreamingURIs holds a list of all patterns for URIs that are
	// mapped to client-streaming RPC methods. We need to keep track of
	// those to make sure we initialize the request body correctly for the
	// underlying grpc-gateway library.
	clientStreamingURIs []*regexp.Regexp

	pingInterval time.Duration
	pongWait     time.Duration
}

// pingPongEnabled returns true if a ping interval is set to enable sending and
// expecting regular ping/pong messages.
func (p *WebsocketProxy) pingPongEnabled() bool {
	return p.pingInterval > 0 && p.pongWait > 0
}

// ServeHTTP handles the incoming HTTP request. If the request is an
// "upgradeable" WebSocket request (identified by header fields), then the
// WS proxy handles the request. Otherwise the request is passed directly to the
// underlying REST proxy.
func (p *WebsocketProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	if !websocket.IsWebSocketUpgrade(r) {
		p.backend.ServeHTTP(w, r)
		return
	}
	p.upgradeToWebSocketProxy(w, r)
}

// upgradeToWebSocketProxy upgrades the incoming request to a WebSocket, reads
// one incoming message then streams all responses until either the client or
// server quit the connection.
func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
	r *http.Request) {

	conn, err := p.upgrader.Upgrade(w, r, nil)
	if err != nil {
		p.logger.Errorf("error upgrading websocket:", err)
		return
	}
	defer func() {
		err := conn.Close()
		if err != nil && !IsClosedConnError(err) {
			p.logger.Errorf("WS: error closing upgraded conn: %v",
				err)
		}
	}()

	ctx, cancelFn := context.WithCancel(context.Background())
	defer cancelFn()

	requestForwarder := newRequestForwardingReader()
	request, err := http.NewRequestWithContext(
		r.Context(), r.Method, r.URL.String(), requestForwarder,
	)
	if err != nil {
		p.logger.Errorf("WS: error preparing request:", err)
		return
	}

	// Allow certain headers to be forwarded, either from source headers
	// or the special Sec-Websocket-Protocol header field.
	forwardHeaders(r.Header, request.Header)

	// Also allow the target request method to be overwritten, as all
	// WebSocket establishment calls MUST be GET requests.
	if m := r.URL.Query().Get(MethodOverrideParam); m != "" {
		request.Method = m
	}

	// Is this a call to a client-streaming RPC method?
	clientStreaming := false
	for _, pattern := range p.clientStreamingURIs {
		if pattern.MatchString(r.URL.Path) {
			clientStreaming = true
		}
	}

	responseForwarder := newResponseForwardingWriter()
	go func() {
		<-ctx.Done()
		responseForwarder.Close()
	}()

	go func() {
		defer cancelFn()
		p.backend.ServeHTTP(responseForwarder, request)
	}()

	// Read loop: Take messages from websocket and write to http request.
	go func() {
		defer cancelFn()
		for {
			select {
			case <-ctx.Done():
				return
			default:
			}

			_, payload, err := conn.ReadMessage()
			if err != nil {
				if IsClosedConnError(err) {
					p.logger.Tracef("WS: socket "+
						"closed: %v", err)
					return
				}
				p.logger.Errorf("error reading message: %v",
					err)
				return
			}
			_, err = requestForwarder.Write(payload)
			if err != nil {
				p.logger.Errorf("WS: error writing message "+
					"to upstream http server: %v", err)
				return
			}
			_, _ = requestForwarder.Write([]byte{'\n'})

			// The grpc-gateway library uses a different request
			// reader depending on whether it is a client streaming
			// RPC or not. For a non-streaming request we need to
			// close with EOF to signal the request was completed.
			if !clientStreaming {
				requestForwarder.CloseWriter()
			}
		}
	}()

	// Ping write loop: Send a ping message regularly if ping/pong is
	// enabled.
	if p.pingPongEnabled() {
		// We'll send out our first ping in pingInterval. So the initial
		// deadline is that interval plus the time we allow for a
		// response to be sent.
		initialDeadline := time.Now().Add(p.pingInterval + p.pongWait)
		_ = conn.SetReadDeadline(initialDeadline)

		// Whenever a pong message comes in, we extend the deadline
		// until the next read is expected by the interval plus pong
		// wait time.
		conn.SetPongHandler(func(appData string) error {
			nextDeadline := time.Now().Add(
				p.pingInterval + p.pongWait,
			)
			_ = conn.SetReadDeadline(nextDeadline)
			return nil
		})
		go func() {
			ticker := time.NewTicker(p.pingInterval)
			defer ticker.Stop()

			for {
				select {
				case <-ctx.Done():
					p.logger.Debug("WS: ping loop done")
					return

				case <-ticker.C:
					// Writing the ping shouldn't take any
					// longer than we'll wait for a response
					// in the first place.
					writeDeadline := time.Now().Add(
						p.pongWait,
					)
					_ = conn.SetWriteDeadline(writeDeadline)

					err := conn.WriteMessage(
						websocket.PingMessage, nil,
					)
					if err != nil {
						p.logger.Warnf("WS: could not "+
							"send ping message: %v",
							err)
						return
					}
				}
			}
		}()
	}

	// Write loop: Take messages from the response forwarder and write them
	// to the WebSocket.
	for responseForwarder.Scan() {
		if len(responseForwarder.Bytes()) == 0 {
			p.logger.Errorf("WS: empty scan: %v",
				responseForwarder.Err())

			continue
		}

		err = conn.WriteMessage(
			websocket.TextMessage, responseForwarder.Bytes(),
		)
		if err != nil {
			p.logger.Errorf("WS: error writing message: %v", err)
			return
		}
	}
	if err := responseForwarder.Err(); err != nil && !IsClosedConnError(err) {
		p.logger.Errorf("WS: scanner err: %v", err)
	}
}

// forwardHeaders forwards certain allowed header fields from the source request
// to the target request. Because browsers are limited in what header fields
// they can send on the WebSocket setup call, we also allow additional fields to
// be transported in the special Sec-Websocket-Protocol field.
func forwardHeaders(source, target http.Header) {
	// Forward allowed header fields directly.
	for header := range source {
		headerName := textproto.CanonicalMIMEHeaderKey(header)
		forward, ok := defaultHeadersToForward[headerName]
		if ok && forward {
			target.Set(headerName, source.Get(header))
		}
	}

	// Browser aren't allowed to set custom header fields on WebSocket
	// requests. We need to allow them to submit the macaroon as a WS
	// protocol, which is the only allowed header. Set any "protocols" we
	// declare valid as header fields on the forwarded request.
	protocol := source.Get(HeaderWebSocketProtocol)
	for key := range defaultProtocolsToAllow {
		if strings.HasPrefix(protocol, key) {
			// The format is "<protocol name>+<value>". We know the
			// protocol string starts with the name so we only need
			// to set the value.
			values := strings.Split(
				protocol, WebSocketProtocolDelimiter,
			)
			target.Set(key, values[1])
		}
	}
}

// newRequestForwardingReader creates a new request forwarding pipe.
func newRequestForwardingReader() *requestForwardingReader {
	r, w := io.Pipe()
	return &requestForwardingReader{
		Reader: r,
		Writer: w,
		pipeR:  r,
		pipeW:  w,
	}
}

// requestForwardingReader is a wrapper around io.Pipe that embeds both the
// io.Reader and io.Writer interface and can be closed.
type requestForwardingReader struct {
	io.Reader
	io.Writer

	pipeR *io.PipeReader
	pipeW *io.PipeWriter
}

// CloseWriter closes the underlying pipe writer.
func (r *requestForwardingReader) CloseWriter() {
	_ = r.pipeW.CloseWithError(io.EOF)
}

// newResponseForwardingWriter creates a new http.ResponseWriter that intercepts
// what's written to it and presents it through a bufio.Scanner interface.
func newResponseForwardingWriter() *responseForwardingWriter {
	r, w := io.Pipe()
	return &responseForwardingWriter{
		Writer:  w,
		Scanner: bufio.NewScanner(r),
		pipeR:   r,
		pipeW:   w,
		header:  http.Header{},
		closed:  make(chan bool, 1),
	}
}

// responseForwardingWriter is a type that implements the http.ResponseWriter
// interface but internally forwards what's written to the writer through a pipe
// so it can easily be read again through the bufio.Scanner interface.
type responseForwardingWriter struct {
	io.Writer
	*bufio.Scanner

	pipeR *io.PipeReader
	pipeW *io.PipeWriter

	header http.Header
	code   int
	closed chan bool
}

// Write writes the given bytes to the internal pipe.
//
// NOTE: This is part of the http.ResponseWriter interface.
func (w *responseForwardingWriter) Write(b []byte) (int, error) {
	return w.Writer.Write(b)
}

// Header returns the HTTP header fields intercepted so far.
//
// NOTE: This is part of the http.ResponseWriter interface.
func (w *responseForwardingWriter) Header() http.Header {
	return w.header
}

// WriteHeader indicates that the header part of the response is now finished
// and sets the response code.
//
// NOTE: This is part of the http.ResponseWriter interface.
func (w *responseForwardingWriter) WriteHeader(code int) {
	w.code = code
}

// CloseNotify returns a channel that indicates if a connection was closed.
//
// NOTE: This is part of the http.CloseNotifier interface.
func (w *responseForwardingWriter) CloseNotify() <-chan bool {
	return w.closed
}

// Flush empties all buffers. We implement this to indicate to our backend that
// we support flushing our content. There is no actual implementation because
// all writes happen immediately, there is no internal buffering.
//
// NOTE: This is part of the http.Flusher interface.
func (w *responseForwardingWriter) Flush() {}

func (w *responseForwardingWriter) Close() {
	_ = w.pipeR.CloseWithError(io.EOF)
	_ = w.pipeW.CloseWithError(io.EOF)
	w.closed <- true
}

// IsClosedConnError is a helper function that returns true if the given error
// is an error indicating we are using a closed connection.
func IsClosedConnError(err error) bool {
	if err == nil {
		return false
	}
	if err == http.ErrServerClosed {
		return true
	}

	str := err.Error()
	if strings.Contains(str, "use of closed network connection") {
		return true
	}
	if strings.Contains(str, "closed pipe") {
		return true
	}
	if strings.Contains(str, "broken pipe") {
		return true
	}
	if strings.Contains(str, "connection reset by peer") {
		return true
	}
	return websocket.IsCloseError(
		err, websocket.CloseNormalClosure, websocket.CloseGoingAway,
	)
}