diff --git a/lntest/itest/rest_api.go b/lntest/itest/rest_api.go index 5eee3d85..c296e3a0 100644 --- a/lntest/itest/rest_api.go +++ b/lntest/itest/rest_api.go @@ -201,7 +201,116 @@ func testRestApi(net *lntest.NetworkHarness, ht *harnessTest) { Height: uint32(height), } url := "/v2/chainnotifier/register/blocks" - c, err := openWebSocket(a, url, "POST", req) + c, err := openWebSocket(a, url, "POST", req, nil) + require.Nil(t, err, "websocket") + defer func() { + _ = c.WriteMessage( + websocket.CloseMessage, + websocket.FormatCloseMessage( + websocket.CloseNormalClosure, + "done", + ), + ) + _ = c.Close() + }() + + msgChan := make(chan *chainrpc.BlockEpoch) + errChan := make(chan error) + timeout := time.After(defaultTimeout) + + // We want to read exactly one message. + go func() { + defer close(msgChan) + + _, msg, err := c.ReadMessage() + if err != nil { + errChan <- err + return + } + + // The chunked/streamed responses come wrapped + // in either a {"result":{}} or {"error":{}} + // wrapper which we'll get rid of here. + msgStr := string(msg) + if !strings.Contains(msgStr, "\"result\":") { + errChan <- fmt.Errorf("invalid msg: %s", + msgStr) + return + } + msgStr = resultPattern.ReplaceAllString( + msgStr, "${1}", + ) + + // Make sure we can parse the unwrapped message + // into the expected proto message. + protoMsg := &chainrpc.BlockEpoch{} + err = jsonpb.UnmarshalString( + msgStr, protoMsg, + ) + if err != nil { + errChan <- err + return + } + + select { + case msgChan <- protoMsg: + case <-timeout: + } + }() + + // Mine a block and make sure we get a message for it. + blockHashes, err := net.Miner.Node.Generate(1) + require.Nil(t, err, "generate blocks") + assert.Equal(t, 1, len(blockHashes), "num blocks") + select { + case msg := <-msgChan: + assert.Equal( + t, blockHashes[0].CloneBytes(), + msg.Hash, "block hash", + ) + + case err := <-errChan: + t.Fatalf("Received error from WS: %v", err) + + case <-timeout: + t.Fatalf("Timeout before message was received") + } + }, + }, { + name: "websocket subscription with macaroon in protocol", + run: func(t *testing.T, a, b *lntest.HarnessNode) { + // Find out the current best block so we can subscribe + // to the next one. + hash, height, err := net.Miner.Node.GetBestBlock() + require.Nil(t, err, "get best block") + + // Create a new subscription to get block epoch events. + req := &chainrpc.BlockEpoch{ + Hash: hash.CloneBytes(), + Height: uint32(height), + } + url := "/v2/chainnotifier/register/blocks" + + // This time we send the macaroon in the special header + // Sec-Websocket-Protocol which is the only header field + // available to browsers when opening a WebSocket. + mac, err := a.ReadMacaroon( + a.AdminMacPath(), defaultTimeout, + ) + require.NoError(t, err, "read admin mac") + macBytes, err := mac.MarshalBinary() + require.NoError(t, err, "marshal admin mac") + + customHeader := make(http.Header) + customHeader.Set( + lnrpc.HeaderWebSocketProtocol, fmt.Sprintf( + "Grpc-Metadata-Macaroon+%s", + hex.EncodeToString(macBytes), + ), + ) + c, err := openWebSocket( + a, url, "POST", req, customHeader, + ) require.Nil(t, err, "websocket") defer func() { _ = c.WriteMessage( @@ -364,14 +473,17 @@ func makeRequest(node *lntest.HarnessNode, url, method string, // openWebSocket opens a new WebSocket connection to the given URL with the // appropriate macaroon headers and sends the request message over the socket. func openWebSocket(node *lntest.HarnessNode, url, method string, - req proto.Message) (*websocket.Conn, error) { + req proto.Message, customHeader http.Header) (*websocket.Conn, error) { // Prepare our macaroon headers and assemble the full URL from the // node's listening address. WebSockets always work over GET so we need // to append the target request method as a query parameter. - header := make(http.Header) - if err := addAdminMacaroon(node, header); err != nil { - return nil, err + header := customHeader + if header == nil { + header = make(http.Header) + if err := addAdminMacaroon(node, header); err != nil { + return nil, err + } } fullURL := fmt.Sprintf( "wss://%s%s?method=%s", node.Cfg.RESTAddr(), url, method,