brontide: implement message chunking for the net.Conn implementation
This commit implements message chunking within the implementation of net.Conn which implements our initial handshake, then uses the crypto to read/write messages. With this change it’s now possible to send message larger than 65535 bytes over a p2p crypto connection by properly chunking the messages on the side of the connection that’s writing.
This commit is contained in:
parent
49f9f496fb
commit
767c550d65
@ -3,6 +3,7 @@ package brontide
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
@ -108,7 +109,37 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
//
|
||||
// Part of the net.Conn interface.
|
||||
func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
return len(b), c.noise.WriteMessage(c.conn, b)
|
||||
// If the message doesn't require any chunking, then we can go ahead
|
||||
// with a single write.
|
||||
if len(b)+macSize <= math.MaxUint16 {
|
||||
return len(b), c.noise.WriteMessage(c.conn, b)
|
||||
}
|
||||
|
||||
// If we need to split the message into fragments, then we'll write
|
||||
// chunks which maximize usage of the available payload. To do so, we
|
||||
// subtract the added overhead of the MAC at the end of the message.
|
||||
chunkSize := math.MaxUint16 - macSize
|
||||
|
||||
bytesToWrite := len(b)
|
||||
bytesWritten := 0
|
||||
for bytesWritten < bytesToWrite {
|
||||
// If we're on the last chunk, then truncate the chunk size as
|
||||
// necessary to avoid an out-of-bounds array memory access.
|
||||
if bytesWritten+chunkSize > len(b) {
|
||||
chunkSize = len(b) - bytesWritten
|
||||
}
|
||||
|
||||
// Slice off the next chunk to be written based on our running
|
||||
// counter and next chunk size.
|
||||
chunk := b[bytesWritten : bytesWritten+chunkSize]
|
||||
if err := c.noise.WriteMessage(c.conn, chunk); err != nil {
|
||||
return bytesWritten, err
|
||||
}
|
||||
|
||||
bytesWritten += len(chunk)
|
||||
}
|
||||
|
||||
return bytesWritten, nil
|
||||
}
|
||||
|
||||
// Close closes the connection. Any blocked Read or Write operations will be
|
||||
|
@ -2,24 +2,26 @@ package brontide
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/roasbeef/btcd/btcec"
|
||||
)
|
||||
|
||||
func TestConnectionCorrectness(t *testing.T) {
|
||||
func establishTestConnection() (net.Conn, net.Conn, error) {
|
||||
// First, generate the long-term private keys both ends of the connection
|
||||
// within our test.
|
||||
localPriv, err := btcec.NewPrivateKey(btcec.S256())
|
||||
if err != nil {
|
||||
t.Fatalf("unable to generate local priv key: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
remotePriv, err := btcec.NewPrivateKey(btcec.S256())
|
||||
if err != nil {
|
||||
t.Fatalf("unable to generate remote priv key: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Having a port of ":0" means a random port, and interface will be
|
||||
@ -29,7 +31,7 @@ func TestConnectionCorrectness(t *testing.T) {
|
||||
// Our listener will be local, and the connection remote.
|
||||
listener, err := NewListener(localPriv, addr)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create listener: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
@ -50,24 +52,36 @@ func TestConnectionCorrectness(t *testing.T) {
|
||||
|
||||
localConn, listenErr := listener.Accept()
|
||||
if listenErr != nil {
|
||||
t.Fatalf("unable to accept connection: %v", listenErr)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if dialErr := <-errChan; err != nil {
|
||||
t.Fatalf("unable to establish connection: %v", dialErr)
|
||||
return nil, nil, dialErr
|
||||
}
|
||||
remoteConn := <-connChan
|
||||
|
||||
return localConn, remoteConn, nil
|
||||
}
|
||||
|
||||
func TestConnectionCorrectness(t *testing.T) {
|
||||
// Create a test connection, grabbing either side of the connection
|
||||
// into local variables. If the initial crypto handshake fails, then
|
||||
// we'll get a non-nil error here.
|
||||
localConn, remoteConn, err := establishTestConnection()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to establish test connection: %v", err)
|
||||
}
|
||||
conn := <-connChan
|
||||
|
||||
// Test out some message full-message reads.
|
||||
for i := 0; i < 10; i++ {
|
||||
msg := []byte("hello" + string(i))
|
||||
|
||||
if _, err := conn.Write(msg); err != nil {
|
||||
if _, err := localConn.Write(msg); err != nil {
|
||||
t.Fatalf("remote conn failed to write: %v", err)
|
||||
}
|
||||
|
||||
readBuf := make([]byte, len(msg))
|
||||
if _, err := localConn.Read(readBuf); err != nil {
|
||||
if _, err := remoteConn.Read(readBuf); err != nil {
|
||||
t.Fatalf("local conn failed to read: %v", err)
|
||||
}
|
||||
|
||||
@ -80,15 +94,15 @@ func TestConnectionCorrectness(t *testing.T) {
|
||||
// Now try incremental message reads. This simulates first writing a
|
||||
// message header, then a message body.
|
||||
outMsg := []byte("hello world")
|
||||
if _, err := conn.Write(outMsg); err != nil {
|
||||
if _, err := localConn.Write(outMsg); err != nil {
|
||||
t.Fatalf("remote conn failed to write: %v", err)
|
||||
}
|
||||
|
||||
readBuf := make([]byte, len(outMsg))
|
||||
if _, err := localConn.Read(readBuf[:len(outMsg)/2]); err != nil {
|
||||
if _, err := remoteConn.Read(readBuf[:len(outMsg)/2]); err != nil {
|
||||
t.Fatalf("local conn failed to read: %v", err)
|
||||
}
|
||||
if _, err := localConn.Read(readBuf[len(outMsg)/2:]); err != nil {
|
||||
if _, err := remoteConn.Read(readBuf[len(outMsg)/2:]); err != nil {
|
||||
t.Fatalf("local conn failed to read: %v", err)
|
||||
}
|
||||
|
||||
@ -136,6 +150,54 @@ func TestMaxPayloadLength(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteMessageChunking(t *testing.T) {
|
||||
// Create a test connection, grabbing either side of the connection
|
||||
// into local variables. If the initial crypto handshake fails, then
|
||||
// we'll get a non-nil error here.
|
||||
localConn, remoteConn, err := establishTestConnection()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to establish test connection: %v", err)
|
||||
}
|
||||
|
||||
// Attempt to write a message which is over 3x the max allowed payload
|
||||
// size.
|
||||
largeMessage := bytes.Repeat([]byte("kek"), math.MaxUint16*3)
|
||||
|
||||
// Launch a new goroutine to write the lerge message generated above in
|
||||
// chunks. We spawn a new goroutine because otherwise, we may block as
|
||||
// the kernal waits for the buffer to flush.
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
bytesWritten, err := localConn.Write(largeMessage)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to write message")
|
||||
}
|
||||
|
||||
// The entire message should have been written out to the remote
|
||||
// connection.
|
||||
if bytesWritten != len(largeMessage) {
|
||||
t.Fatalf("bytes not fully written!")
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
// Attempt to read the entirety of the message generated above.
|
||||
buf := make([]byte, len(largeMessage))
|
||||
if _, err := io.ReadFull(remoteConn, buf); err != nil {
|
||||
t.Fatalf("unable to read message")
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Finally, the message the remote end of the connection received
|
||||
// should be identical to what we sent from the local connection.
|
||||
if !bytes.Equal(buf, largeMessage) {
|
||||
t.Fatalf("bytes don't match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoiseIdentityHiding(t *testing.T) {
|
||||
// TODO(roasbeef): fin
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user