Merge pull request #4141 from guggero/rest-api-improvements

REST saga 3/3: REST API for subservers, websocket for streaming responses
This commit is contained in:
Conner Fromknecht 2020-06-17 10:30:29 -07:00 committed by GitHub
commit 87880c0d56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 825 additions and 5 deletions

@ -159,6 +159,7 @@ type Config struct {
RawExternalIPs []string `long:"externalip" description:"Add an ip:port to the list of local addresses we claim to listen on to peers. If a port is not specified, the default (9735) will be used regardless of other parameters"`
RPCListeners []net.Addr
RESTListeners []net.Addr
RestCORS []string `long:"restcors" description:"Add an ip:port/hostname to allow cross origin access from. To allow all origins, set as \"*\"."`
Listeners []net.Addr
ExternalIPs []net.Addr
DisableListen bool `long:"nolisten" description:"Disable listening for incoming peer connections"`

2
go.mod

@ -26,7 +26,7 @@ require (
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect
github.com/golang/protobuf v1.3.2
github.com/google/btree v1.0.0 // indirect
github.com/gorilla/websocket v1.4.1 // indirect
github.com/gorilla/websocket v1.4.2
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
github.com/grpc-ecosystem/grpc-gateway v1.14.3

4
go.sum

@ -132,8 +132,8 @@ github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0 h1:Iju5GlWwrvL6UBg4zJJt3btmonfrMlCDdsejg4CZE7c=
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs=
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 h1:Ovs26xHkKqVztRpIrF/92BcuyuQ/YW4NSIpoGtfXNho=

@ -12367,6 +12367,11 @@ type LightningClient interface {
// lncli: `estimatefee`
//EstimateFee asks the chain backend to estimate the fee rate and total fees
//for a transaction that pays to multiple specified outputs.
//
//When using REST, the `AddrToAmount` map type can be set by appending
//`&AddrToAmount[<address>]=<amount_to_send>` to the URL. Unfortunately this
//map type doesn't appear in the REST API documentation because of a bug in
//the grpc-gateway library.
EstimateFee(ctx context.Context, in *EstimateFeeRequest, opts ...grpc.CallOption) (*EstimateFeeResponse, error)
// lncli: `sendcoins`
//SendCoins executes a request to send coins to a particular address. Unlike
@ -12590,6 +12595,11 @@ type LightningClient interface {
//satoshis. The returned route contains the full details required to craft and
//send an HTLC, also including the necessary information that should be
//present within the Sphinx packet encapsulated within the HTLC.
//
//When using REST, the `dest_custom_records` map type can be set by appending
//`&dest_custom_records[<record_number>]=<record_data_base64_url_encoded>`
//to the URL. Unfortunately this map type doesn't appear in the REST API
//documentation because of a bug in the grpc-gateway library.
QueryRoutes(ctx context.Context, in *QueryRoutesRequest, opts ...grpc.CallOption) (*QueryRoutesResponse, error)
// lncli: `getnetworkinfo`
//GetNetworkInfo returns some basic stats about the known channel graph from
@ -13448,6 +13458,11 @@ type LightningServer interface {
// lncli: `estimatefee`
//EstimateFee asks the chain backend to estimate the fee rate and total fees
//for a transaction that pays to multiple specified outputs.
//
//When using REST, the `AddrToAmount` map type can be set by appending
//`&AddrToAmount[<address>]=<amount_to_send>` to the URL. Unfortunately this
//map type doesn't appear in the REST API documentation because of a bug in
//the grpc-gateway library.
EstimateFee(context.Context, *EstimateFeeRequest) (*EstimateFeeResponse, error)
// lncli: `sendcoins`
//SendCoins executes a request to send coins to a particular address. Unlike
@ -13671,6 +13686,11 @@ type LightningServer interface {
//satoshis. The returned route contains the full details required to craft and
//send an HTLC, also including the necessary information that should be
//present within the Sphinx packet encapsulated within the HTLC.
//
//When using REST, the `dest_custom_records` map type can be set by appending
//`&dest_custom_records[<record_number>]=<record_data_base64_url_encoded>`
//to the URL. Unfortunately this map type doesn't appear in the REST API
//documentation because of a bug in the grpc-gateway library.
QueryRoutes(context.Context, *QueryRoutesRequest) (*QueryRoutesResponse, error)
// lncli: `getnetworkinfo`
//GetNetworkInfo returns some basic stats about the known channel graph from

@ -46,6 +46,11 @@ service Lightning {
/* lncli: `estimatefee`
EstimateFee asks the chain backend to estimate the fee rate and total fees
for a transaction that pays to multiple specified outputs.
When using REST, the `AddrToAmount` map type can be set by appending
`&AddrToAmount[<address>]=<amount_to_send>` to the URL. Unfortunately this
map type doesn't appear in the REST API documentation because of a bug in
the grpc-gateway library.
*/
rpc EstimateFee (EstimateFeeRequest) returns (EstimateFeeResponse);
@ -355,6 +360,11 @@ service Lightning {
satoshis. The returned route contains the full details required to craft and
send an HTLC, also including the necessary information that should be
present within the Sphinx packet encapsulated within the HTLC.
When using REST, the `dest_custom_records` map type can be set by appending
`&dest_custom_records[<record_number>]=<record_data_base64_url_encoded>`
to the URL. Unfortunately this map type doesn't appear in the REST API
documentation because of a bug in the grpc-gateway library.
*/
rpc QueryRoutes (QueryRoutesRequest) returns (QueryRoutesResponse);

@ -1001,6 +1001,7 @@
"/v1/graph/routes/{pub_key}/{amt}": {
"get": {
"summary": "lncli: `queryroutes`\nQueryRoutes attempts to query the daemon's Channel Router for a possible\nroute to a target destination capable of carrying a specific amount of\nsatoshis. The returned route contains the full details required to craft and\nsend an HTLC, also including the necessary information that should be\npresent within the Sphinx packet encapsulated within the HTLC.",
"description": "When using REST, the `dest_custom_records` map type can be set by appending\n`\u0026dest_custom_records[\u003crecord_number\u003e]=\u003crecord_data_base64_url_encoded\u003e`\nto the URL. Unfortunately this map type doesn't appear in the REST API\ndocumentation because of a bug in the grpc-gateway library.",
"operationId": "QueryRoutes",
"responses": {
"200": {
@ -1854,6 +1855,7 @@
"/v1/transactions/fee": {
"get": {
"summary": "lncli: `estimatefee`\nEstimateFee asks the chain backend to estimate the fee rate and total fees\nfor a transaction that pays to multiple specified outputs.",
"description": "When using REST, the `AddrToAmount` map type can be set by appending\n`\u0026AddrToAmount[\u003caddress\u003e]=\u003camount_to_send\u003e` to the URL. Unfortunately this\nmap type doesn't appear in the REST API documentation because of a bug in\nthe grpc-gateway library.",
"operationId": "EstimateFee",
"responses": {
"200": {

305
lnrpc/websocket_proxy.go Normal file

@ -0,0 +1,305 @@
// The code in this file is a heavily modified version of
// https://github.com/tmc/grpc-websocket-proxy/
package lnrpc
import (
"bufio"
"io"
"net/http"
"net/textproto"
"strings"
"github.com/btcsuite/btclog"
"github.com/gorilla/websocket"
"golang.org/x/net/context"
)
const (
// MethodOverrideParam is the GET query parameter that specifies what
// HTTP request method should be used for the forwarded REST request.
// This is necessary because the WebSocket API specifies that a
// handshake request must always be done through a GET request.
MethodOverrideParam = "method"
)
var (
// defaultHeadersToForward is a map of all HTTP header fields that are
// forwarded by default. The keys must be in the canonical MIME header
// format.
defaultHeadersToForward = map[string]bool{
"Origin": true,
"Referer": true,
"Grpc-Metadata-Macaroon": true,
}
)
// NewWebSocketProxy attempts to expose the underlying handler as a response-
// streaming WebSocket stream with newline-delimited JSON as the content
// encoding.
func NewWebSocketProxy(h http.Handler, logger btclog.Logger) http.Handler {
p := &WebsocketProxy{
backend: h,
logger: logger,
upgrader: &websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true
},
},
}
return p
}
// WebsocketProxy provides websocket transport upgrade to compatible endpoints.
type WebsocketProxy struct {
backend http.Handler
logger btclog.Logger
upgrader *websocket.Upgrader
}
// ServeHTTP handles the incoming HTTP request. If the request is an
// "upgradeable" WebSocket request (identified by header fields), then the
// WS proxy handles the request. Otherwise the request is passed directly to the
// underlying REST proxy.
func (p *WebsocketProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !websocket.IsWebSocketUpgrade(r) {
p.backend.ServeHTTP(w, r)
return
}
p.upgradeToWebSocketProxy(w, r)
}
// upgradeToWebSocketProxy upgrades the incoming request to a WebSocket, reads
// one incoming message then streams all responses until either the client or
// server quit the connection.
func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
r *http.Request) {
conn, err := p.upgrader.Upgrade(w, r, nil)
if err != nil {
p.logger.Errorf("error upgrading websocket:", err)
return
}
defer func() {
err := conn.Close()
if err != nil && !IsClosedConnError(err) {
p.logger.Errorf("WS: error closing upgraded conn: %v",
err)
}
}()
ctx, cancelFn := context.WithCancel(context.Background())
defer cancelFn()
requestForwarder := newRequestForwardingReader()
request, err := http.NewRequestWithContext(
r.Context(), r.Method, r.URL.String(), requestForwarder,
)
if err != nil {
p.logger.Errorf("WS: error preparing request:", err)
return
}
for header := range r.Header {
headerName := textproto.CanonicalMIMEHeaderKey(header)
forward, ok := defaultHeadersToForward[headerName]
if ok && forward {
request.Header.Set(headerName, r.Header.Get(header))
}
}
if m := r.URL.Query().Get(MethodOverrideParam); m != "" {
request.Method = m
}
responseForwarder := newResponseForwardingWriter()
go func() {
<-ctx.Done()
responseForwarder.Close()
}()
go func() {
defer cancelFn()
p.backend.ServeHTTP(responseForwarder, request)
}()
// Read loop: Take messages from websocket and write to http request.
go func() {
defer cancelFn()
for {
select {
case <-ctx.Done():
return
default:
}
_, payload, err := conn.ReadMessage()
if err != nil {
if IsClosedConnError(err) {
p.logger.Tracef("WS: socket "+
"closed: %v", err)
return
}
p.logger.Errorf("error reading message: %v",
err)
return
}
_, err = requestForwarder.Write(payload)
if err != nil {
p.logger.Errorf("WS: error writing message "+
"to upstream http server: %v", err)
return
}
_, _ = requestForwarder.Write([]byte{'\n'})
// We currently only support server-streaming messages.
// Therefore we close the request body after the first
// incoming message to trigger a response.
requestForwarder.CloseWriter()
}
}()
// Write loop: Take messages from the response forwarder and write them
// to the WebSocket.
for responseForwarder.Scan() {
if len(responseForwarder.Bytes()) == 0 {
p.logger.Errorf("WS: empty scan: %v",
responseForwarder.Err())
continue
}
err = conn.WriteMessage(
websocket.TextMessage, responseForwarder.Bytes(),
)
if err != nil {
p.logger.Errorf("WS: error writing message: %v", err)
return
}
}
if err := responseForwarder.Err(); err != nil && !IsClosedConnError(err) {
p.logger.Errorf("WS: scanner err: %v", err)
}
}
// newRequestForwardingReader creates a new request forwarding pipe.
func newRequestForwardingReader() *requestForwardingReader {
r, w := io.Pipe()
return &requestForwardingReader{
Reader: r,
Writer: w,
pipeR: r,
pipeW: w,
}
}
// requestForwardingReader is a wrapper around io.Pipe that embeds both the
// io.Reader and io.Writer interface and can be closed.
type requestForwardingReader struct {
io.Reader
io.Writer
pipeR *io.PipeReader
pipeW *io.PipeWriter
}
// CloseWriter closes the underlying pipe writer.
func (r *requestForwardingReader) CloseWriter() {
_ = r.pipeW.CloseWithError(io.EOF)
}
// newResponseForwardingWriter creates a new http.ResponseWriter that intercepts
// what's written to it and presents it through a bufio.Scanner interface.
func newResponseForwardingWriter() *responseForwardingWriter {
r, w := io.Pipe()
return &responseForwardingWriter{
Writer: w,
Scanner: bufio.NewScanner(r),
pipeR: r,
pipeW: w,
header: http.Header{},
closed: make(chan bool, 1),
}
}
// responseForwardingWriter is a type that implements the http.ResponseWriter
// interface but internally forwards what's written to the writer through a pipe
// so it can easily be read again through the bufio.Scanner interface.
type responseForwardingWriter struct {
io.Writer
*bufio.Scanner
pipeR *io.PipeReader
pipeW *io.PipeWriter
header http.Header
code int
closed chan bool
}
// Write writes the given bytes to the internal pipe.
//
// NOTE: This is part of the http.ResponseWriter interface.
func (w *responseForwardingWriter) Write(b []byte) (int, error) {
return w.Writer.Write(b)
}
// Header returns the HTTP header fields intercepted so far.
//
// NOTE: This is part of the http.ResponseWriter interface.
func (w *responseForwardingWriter) Header() http.Header {
return w.header
}
// WriteHeader indicates that the header part of the response is now finished
// and sets the response code.
//
// NOTE: This is part of the http.ResponseWriter interface.
func (w *responseForwardingWriter) WriteHeader(code int) {
w.code = code
}
// CloseNotify returns a channel that indicates if a connection was closed.
//
// NOTE: This is part of the http.CloseNotifier interface.
func (w *responseForwardingWriter) CloseNotify() <-chan bool {
return w.closed
}
// Flush empties all buffers. We implement this to indicate to our backend that
// we support flushing our content. There is no actual implementation because
// all writes happen immediately, there is no internal buffering.
//
// NOTE: This is part of the http.Flusher interface.
func (w *responseForwardingWriter) Flush() {}
func (w *responseForwardingWriter) Close() {
_ = w.pipeR.CloseWithError(io.EOF)
_ = w.pipeW.CloseWithError(io.EOF)
w.closed <- true
}
// IsClosedConnError is a helper function that returns true if the given error
// is an error indicating we are using a closed connection.
func IsClosedConnError(err error) bool {
if err == nil {
return false
}
if err == http.ErrServerClosed {
return true
}
str := err.Error()
if strings.Contains(str, "use of closed network connection") {
return true
}
if strings.Contains(str, "closed pipe") {
return true
}
if strings.Contains(str, "broken pipe") {
return true
}
return websocket.IsCloseError(
err, websocket.CloseNormalClosure, websocket.CloseGoingAway,
)
}

@ -14665,6 +14665,10 @@ var testsCases = []*testCase{
name: "send multi path payment",
test: testSendMultiPathPayment,
},
{
name: "REST API",
test: testRestApi,
},
}
// TestLightningNetworkDaemon performs a series of integration tests amongst a

@ -130,6 +130,8 @@
<time> [ERR] PEER: unable to send msg to remote peer: peer exiting
<time> [ERR] PEER: unable to send msg to remote peer: write tcp <ip>-><ip>: write: broken pipe
<time> [ERR] PEER: unable to send msg to remote peer: write tcp <ip>-><ip>: write: connection reset by peer
<time> [ERR] RPCS: [/chainrpc.ChainNotifier/RegisterBlockEpochNtfn]: chain notifier shutting down
<time> [ERR] RPCS: [/chainrpc.ChainNotifier/RegisterBlockEpochNtfn]: context canceled
<time> [ERR] RPCS: [/invoicesrpc.Invoices/SubscribeSingleInvoice]: rpc error: code = Canceled desc = context canceled
<time> [ERR] RPCS: [/lnrpc.Lightning/CloseChannel]: cannot co-op close frozen channel as initiator until height=<height>, (current_height=<height>)
<time> [ERR] RPCS: [/lnrpc.Lightning/CloseChannel]: cannot co-op close frozen channel as initiator until height=3059, (current_height=3055)
@ -185,6 +187,6 @@
<time> [ERR] NTFN: unable to get hash from block with height 790
<time> [ERR] CRTR: Payment with hash <hex> failed: timeout
<time> [ERR] RPCS: [/routerrpc.Route<time> [INF] LTND: Listening on the p2p interface is disabled!
<time> [ERR] FNDG: Unable to advance state(<chan_point>): failed adding to router graph: error sending channel announcement: gossiper is shutting down
<time> [ERR] FNDG: Unable to advance state(<chan_point>): failed adding to router graph: error sending channel announcement: gossiper is shutting down
<time> [ERR] PEER: unable to close channel, ChannelID(<hex>) is unknown
<time> [ERR] HSWC: ChannelLink(<chan>): unable to update signals

412
lntest/itest/rest_api.go Normal file

@ -0,0 +1,412 @@
// +build rpctest
package itest
import (
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"io/ioutil"
"net/http"
"regexp"
"strings"
"testing"
"time"
"github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/proto"
"github.com/gorilla/websocket"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lnrpc/autopilotrpc"
"github.com/lightningnetwork/lnd/lnrpc/chainrpc"
"github.com/lightningnetwork/lnd/lnrpc/routerrpc"
"github.com/lightningnetwork/lnd/lnrpc/verrpc"
"github.com/lightningnetwork/lnd/lnrpc/walletrpc"
"github.com/lightningnetwork/lnd/lntest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
insecureTransport = &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
restClient = &http.Client{
Transport: insecureTransport,
}
jsonMarshaler = &jsonpb.Marshaler{
EmitDefaults: true,
OrigName: true,
Indent: " ",
}
urlEnc = base64.URLEncoding
webSocketDialer = &websocket.Dialer{
HandshakeTimeout: 45 * time.Second,
TLSClientConfig: insecureTransport.TLSClientConfig,
}
resultPattern = regexp.MustCompile("{\"result\":(.*)}")
)
// testRestApi tests that the most important features of the REST API work
// correctly.
func testRestApi(net *lntest.NetworkHarness, ht *harnessTest) {
testCases := []struct {
name string
run func(*testing.T, *lntest.HarnessNode, *lntest.HarnessNode)
}{{
name: "simple GET",
run: func(t *testing.T, a, b *lntest.HarnessNode) {
// Check that the parsing into the response proto
// message works.
resp := &lnrpc.GetInfoResponse{}
err := invokeGET(a, "/v1/getinfo", resp)
require.Nil(t, err, "getinfo")
assert.Equal(t, "#3399ff", resp.Color, "node color")
// Make sure we get the correct field names (snake
// case).
_, resp2, err := makeRequest(
a, "/v1/getinfo", "GET", nil, nil,
)
require.Nil(t, err, "getinfo")
assert.Contains(
t, string(resp2), "best_header_timestamp",
"getinfo",
)
},
}, {
name: "simple POST and GET with query param",
run: func(t *testing.T, a, b *lntest.HarnessNode) {
// Add an invoice, testing POST in the process.
req := &lnrpc.Invoice{Value: 1234}
resp := &lnrpc.AddInvoiceResponse{}
err := invokePOST(a, "/v1/invoices", req, resp)
require.Nil(t, err, "add invoice")
assert.Equal(t, 32, len(resp.RHash), "invoice rhash")
// Make sure we can call a GET endpoint with a hex
// encoded URL part.
url := fmt.Sprintf("/v1/invoice/%x", resp.RHash)
resp2 := &lnrpc.Invoice{}
err = invokeGET(a, url, resp2)
require.Nil(t, err, "query invoice")
assert.Equal(t, int64(1234), resp2.Value, "invoice amt")
},
}, {
name: "GET with base64 encoded byte slice in path",
run: func(t *testing.T, a, b *lntest.HarnessNode) {
url := "/v2/router/mc/probability/%s/%s/%d"
url = fmt.Sprintf(
url, urlEnc.EncodeToString(a.PubKey[:]),
urlEnc.EncodeToString(b.PubKey[:]), 1234,
)
resp := &routerrpc.QueryProbabilityResponse{}
err := invokeGET(a, url, resp)
require.Nil(t, err, "query probability")
assert.Greater(t, resp.Probability, 0.5, "probability")
},
}, {
name: "GET with map type query param",
run: func(t *testing.T, a, b *lntest.HarnessNode) {
// Get a new wallet address from Alice.
ctxb := context.Background()
newAddrReq := &lnrpc.NewAddressRequest{
Type: lnrpc.AddressType_WITNESS_PUBKEY_HASH,
}
addrRes, err := a.NewAddress(ctxb, newAddrReq)
require.Nil(t, err, "get address")
// Create the full URL with the map query param.
url := "/v1/transactions/fee?target_conf=%d&" +
"AddrToAmount[%s]=%d"
url = fmt.Sprintf(url, 2, addrRes.Address, 50000)
resp := &lnrpc.EstimateFeeResponse{}
err = invokeGET(a, url, resp)
require.Nil(t, err, "estimate fee")
assert.Greater(t, resp.FeeSat, int64(253), "fee")
},
}, {
name: "sub RPC servers REST support",
run: func(t *testing.T, a, b *lntest.HarnessNode) {
// Query autopilot status.
res1 := &autopilotrpc.StatusResponse{}
err := invokeGET(a, "/v2/autopilot/status", res1)
require.Nil(t, err, "autopilot status")
assert.Equal(t, false, res1.Active, "autopilot status")
// Query the version RPC.
res2 := &verrpc.Version{}
err = invokeGET(a, "/v2/versioner/version", res2)
require.Nil(t, err, "version")
assert.Greater(
t, res2.AppMinor, uint32(0), "lnd minor version",
)
// Request a new external address from the wallet kit.
req1 := &walletrpc.AddrRequest{}
res3 := &walletrpc.AddrResponse{}
err = invokePOST(
a, "/v2/wallet/address/next", req1, res3,
)
require.Nil(t, err, "address")
assert.NotEmpty(t, res3.Addr, "address")
},
}, {
name: "CORS headers",
run: func(t *testing.T, a, b *lntest.HarnessNode) {
// Alice allows all origins. Make sure we get the same
// value back in the CORS header that we send in the
// Origin header.
reqHeaders := make(http.Header)
reqHeaders.Add("Origin", "https://foo.bar:9999")
resHeaders, body, err := makeRequest(
a, "/v1/getinfo", "OPTIONS", nil, reqHeaders,
)
require.Nil(t, err, "getinfo")
assert.Equal(
t, "https://foo.bar:9999",
resHeaders.Get("Access-Control-Allow-Origin"),
"CORS header",
)
assert.Equal(t, 0, len(body))
// Make sure that we don't get a value set for Bob which
// doesn't allow any CORS origin.
resHeaders, body, err = makeRequest(
b, "/v1/getinfo", "OPTIONS", nil, reqHeaders,
)
require.Nil(t, err, "getinfo")
assert.Equal(
t, "",
resHeaders.Get("Access-Control-Allow-Origin"),
"CORS header",
)
assert.Equal(t, 0, len(body))
},
}, {
name: "websocket subscription",
run: func(t *testing.T, a, b *lntest.HarnessNode) {
// Find out the current best block so we can subscribe
// to the next one.
hash, height, err := net.Miner.Node.GetBestBlock()
require.Nil(t, err, "get best block")
// Create a new subscription to get block epoch events.
req := &chainrpc.BlockEpoch{
Hash: hash.CloneBytes(),
Height: uint32(height),
}
url := "/v2/chainnotifier/register/blocks"
c, err := openWebSocket(a, url, "POST", req)
require.Nil(t, err, "websocket")
defer func() {
_ = c.WriteMessage(
websocket.CloseMessage,
websocket.FormatCloseMessage(
websocket.CloseNormalClosure,
"done",
),
)
_ = c.Close()
}()
msgChan := make(chan *chainrpc.BlockEpoch)
errChan := make(chan error)
timeout := time.After(defaultTimeout)
// We want to read exactly one message.
go func() {
defer close(msgChan)
_, msg, err := c.ReadMessage()
if err != nil {
errChan <- err
return
}
// The chunked/streamed responses come wrapped
// in either a {"result":{}} or {"error":{}}
// wrapper which we'll get rid of here.
msgStr := string(msg)
if !strings.Contains(msgStr, "\"result\":") {
errChan <- fmt.Errorf("invalid msg: %s",
msgStr)
return
}
msgStr = resultPattern.ReplaceAllString(
msgStr, "${1}",
)
// Make sure we can parse the unwrapped message
// into the expected proto message.
protoMsg := &chainrpc.BlockEpoch{}
err = jsonpb.UnmarshalString(
msgStr, protoMsg,
)
if err != nil {
errChan <- err
return
}
select {
case msgChan <- protoMsg:
case <-timeout:
}
}()
// Mine a block and make sure we get a message for it.
blockHashes, err := net.Miner.Node.Generate(1)
require.Nil(t, err, "generate blocks")
assert.Equal(t, 1, len(blockHashes), "num blocks")
select {
case msg := <-msgChan:
assert.Equal(
t, blockHashes[0].CloneBytes(),
msg.Hash, "block hash",
)
case err := <-errChan:
t.Fatalf("Received error from WS: %v", err)
case <-timeout:
t.Fatalf("Timeout before message was received")
}
},
}}
// Make sure Alice allows all CORS origins. Bob will keep the default.
net.Alice.Cfg.ExtraArgs = append(
net.Alice.Cfg.ExtraArgs, "--restcors=\"*\"",
)
err := net.RestartNode(net.Alice, nil)
if err != nil {
ht.t.Fatalf("Could not restart Alice to set CORS config: %v",
err)
}
for _, tc := range testCases {
ht.t.Run(tc.name, func(t *testing.T) {
tc.run(t, net.Alice, net.Bob)
})
}
}
// invokeGET calls the given URL with the GET method and appropriate macaroon
// header fields then tries to unmarshal the response into the given response
// proto message.
func invokeGET(node *lntest.HarnessNode, url string, resp proto.Message) error {
_, rawResp, err := makeRequest(node, url, "GET", nil, nil)
if err != nil {
return err
}
return jsonpb.Unmarshal(bytes.NewReader(rawResp), resp)
}
// invokePOST calls the given URL with the POST method, request body and
// appropriate macaroon header fields then tries to unmarshal the response into
// the given response proto message.
func invokePOST(node *lntest.HarnessNode, url string, req,
resp proto.Message) error {
// Marshal the request to JSON using the jsonpb marshaler to get correct
// field names.
var buf bytes.Buffer
if err := jsonMarshaler.Marshal(&buf, req); err != nil {
return err
}
_, rawResp, err := makeRequest(node, url, "POST", &buf, nil)
if err != nil {
return err
}
return jsonpb.Unmarshal(bytes.NewReader(rawResp), resp)
}
// makeRequest calls the given URL with the given method, request body and
// appropriate macaroon header fields and returns the raw response body.
func makeRequest(node *lntest.HarnessNode, url, method string,
request io.Reader, additionalHeaders http.Header) (http.Header, []byte,
error) {
// Assemble the full URL from the node's listening address then create
// the request so we can set the macaroon on it.
fullURL := fmt.Sprintf("https://%s%s", node.Cfg.RESTAddr(), url)
req, err := http.NewRequest(method, fullURL, request)
if err != nil {
return nil, nil, err
}
if err := addAdminMacaroon(node, req.Header); err != nil {
return nil, nil, err
}
for key, values := range additionalHeaders {
for _, value := range values {
req.Header.Add(key, value)
}
}
// Do the actual call with the completed request object now.
resp, err := restClient.Do(req)
if err != nil {
return nil, nil, err
}
defer func() { _ = resp.Body.Close() }()
data, err := ioutil.ReadAll(resp.Body)
return resp.Header, data, err
}
// openWebSocket opens a new WebSocket connection to the given URL with the
// appropriate macaroon headers and sends the request message over the socket.
func openWebSocket(node *lntest.HarnessNode, url, method string,
req proto.Message) (*websocket.Conn, error) {
// Prepare our macaroon headers and assemble the full URL from the
// node's listening address. WebSockets always work over GET so we need
// to append the target request method as a query parameter.
header := make(http.Header)
if err := addAdminMacaroon(node, header); err != nil {
return nil, err
}
fullURL := fmt.Sprintf(
"wss://%s%s?method=%s", node.Cfg.RESTAddr(), url, method,
)
conn, _, err := webSocketDialer.Dial(fullURL, header)
if err != nil {
return nil, err
}
// Send the given request message as the first message on the socket.
reqMsg, err := jsonMarshaler.MarshalToString(req)
if err != nil {
return nil, err
}
err = conn.WriteMessage(websocket.TextMessage, []byte(reqMsg))
if err != nil {
return nil, err
}
return conn, nil
}
// addAdminMacaroon reads the admin macaroon from the node and appends it to
// the HTTP header fields.
func addAdminMacaroon(node *lntest.HarnessNode, header http.Header) error {
mac, err := node.ReadMacaroon(node.AdminMacPath(), defaultTimeout)
if err != nil {
return err
}
macBytes, err := mac.MarshalBinary()
if err != nil {
return err
}
header.Set("Grpc-Metadata-Macaroon", hex.EncodeToString(macBytes))
return nil
}

@ -206,6 +206,7 @@ func (cfg NodeConfig) genArgs() []string {
args = append(args, fmt.Sprintf("--bitcoin.defaultremotedelay=%v", DefaultCSV))
args = append(args, fmt.Sprintf("--rpclisten=%v", cfg.RPCAddr()))
args = append(args, fmt.Sprintf("--restlisten=%v", cfg.RESTAddr()))
args = append(args, fmt.Sprintf("--restcors=https://%v", cfg.RESTAddr()))
args = append(args, fmt.Sprintf("--listen=%v", cfg.P2PAddr()))
args = append(args, fmt.Sprintf("--externalip=%v", cfg.P2PAddr()))
args = append(args, fmt.Sprintf("--logdir=%v", cfg.LogDir))

@ -799,6 +799,15 @@ func (r *rpcServer) Start() error {
restCtx, restCancel := context.WithCancel(context.Background())
r.listenerCleanUp = append(r.listenerCleanUp, restCancel)
// Wrap the default grpc-gateway handler with the WebSocket handler.
restHandler := lnrpc.NewWebSocketProxy(restMux, rpcsLog)
// Set the CORS headers if configured. This wraps the HTTP handler with
// another handler.
if len(r.cfg.RestCORS) > 0 {
restHandler = allowCORS(restHandler, r.cfg.RestCORS)
}
// With our custom REST proxy mux created, register our main RPC and
// give all subservers a chance to register as well.
err := lnrpc.RegisterLightningHandlerFromEndpoint(
@ -849,7 +858,15 @@ func (r *rpcServer) Start() error {
go func() {
rpcsLog.Infof("gRPC proxy started at %s", lis.Addr())
_ = http.Serve(lis, restMux)
// Create our proxy chain now. A request will pass
// through the following chain:
// req ---> CORS handler --> WS proxy --->
// REST proxy --> gRPC endpoint
err := http.Serve(lis, restHandler)
if err != nil && !lnrpc.IsClosedConnError(err) {
rpcsLog.Error(err)
}
}()
}
@ -912,6 +929,52 @@ func addrPairsToOutputs(addrPairs map[string]int64) ([]*wire.TxOut, error) {
return outputs, nil
}
// allowCORS wraps the given http.Handler with a function that adds the
// Access-Control-Allow-Origin header to the response.
func allowCORS(handler http.Handler, origins []string) http.Handler {
allowHeaders := "Access-Control-Allow-Headers"
allowMethods := "Access-Control-Allow-Methods"
allowOrigin := "Access-Control-Allow-Origin"
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
// Skip everything if the browser doesn't send the Origin field.
if origin == "" {
handler.ServeHTTP(w, r)
return
}
// Set the static header fields first.
w.Header().Set(
allowHeaders,
"Content-Type, Accept, Grpc-Metadata-Macaroon",
)
w.Header().Set(allowMethods, "GET, POST, DELETE")
// Either we allow all origins or the incoming request matches
// a specific origin in our list of allowed origins.
for _, allowedOrigin := range origins {
if allowedOrigin == "*" || origin == allowedOrigin {
// Only set allowed origin to requested origin.
w.Header().Set(allowOrigin, origin)
break
}
}
// For a pre-flight request we only need to send the headers
// back. No need to call the rest of the chain.
if r.Method == "OPTIONS" {
return
}
// Everything's prepared now, we can pass the request along the
// chain of handlers.
handler.ServeHTTP(w, r)
})
}
// sendCoinsOnChain makes an on-chain transaction in or to send coins to one or
// more addresses specified in the passed payment map. The payment map maps an
// address to a specified output value to be sent to that address.