brontide: implement message chunking for the net.Conn implementation

This commit implements message chunking within the implementation of
net.Conn which implements our initial handshake, then uses the crypto
to read/write messages.

With this change it’s now possible to send message larger than 65535
bytes over a p2p crypto connection by properly chunking the messages on
the side of the connection that’s writing.
This commit is contained in:
Olaoluwa Osuntokun 2016-11-07 19:45:00 -08:00
parent 49f9f496fb
commit 767c550d65
No known key found for this signature in database
GPG Key ID: 9CC5B105D03521A2
2 changed files with 106 additions and 13 deletions

@ -3,6 +3,7 @@ package brontide
import (
"bytes"
"io"
"math"
"net"
"time"
@ -108,7 +109,37 @@ func (c *Conn) Read(b []byte) (n int, err error) {
//
// Part of the net.Conn interface.
func (c *Conn) Write(b []byte) (n int, err error) {
return len(b), c.noise.WriteMessage(c.conn, b)
// If the message doesn't require any chunking, then we can go ahead
// with a single write.
if len(b)+macSize <= math.MaxUint16 {
return len(b), c.noise.WriteMessage(c.conn, b)
}
// If we need to split the message into fragments, then we'll write
// chunks which maximize usage of the available payload. To do so, we
// subtract the added overhead of the MAC at the end of the message.
chunkSize := math.MaxUint16 - macSize
bytesToWrite := len(b)
bytesWritten := 0
for bytesWritten < bytesToWrite {
// If we're on the last chunk, then truncate the chunk size as
// necessary to avoid an out-of-bounds array memory access.
if bytesWritten+chunkSize > len(b) {
chunkSize = len(b) - bytesWritten
}
// Slice off the next chunk to be written based on our running
// counter and next chunk size.
chunk := b[bytesWritten : bytesWritten+chunkSize]
if err := c.noise.WriteMessage(c.conn, chunk); err != nil {
return bytesWritten, err
}
bytesWritten += len(chunk)
}
return bytesWritten, nil
}
// Close closes the connection. Any blocked Read or Write operations will be

@ -2,24 +2,26 @@ package brontide
import (
"bytes"
"io"
"math"
"net"
"sync"
"testing"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/roasbeef/btcd/btcec"
)
func TestConnectionCorrectness(t *testing.T) {
func establishTestConnection() (net.Conn, net.Conn, error) {
// First, generate the long-term private keys both ends of the connection
// within our test.
localPriv, err := btcec.NewPrivateKey(btcec.S256())
if err != nil {
t.Fatalf("unable to generate local priv key: %v", err)
return nil, nil, err
}
remotePriv, err := btcec.NewPrivateKey(btcec.S256())
if err != nil {
t.Fatalf("unable to generate remote priv key: %v", err)
return nil, nil, err
}
// Having a port of ":0" means a random port, and interface will be
@ -29,7 +31,7 @@ func TestConnectionCorrectness(t *testing.T) {
// Our listener will be local, and the connection remote.
listener, err := NewListener(localPriv, addr)
if err != nil {
t.Fatalf("unable to create listener: %v", err)
return nil, nil, err
}
defer listener.Close()
@ -50,24 +52,36 @@ func TestConnectionCorrectness(t *testing.T) {
localConn, listenErr := listener.Accept()
if listenErr != nil {
t.Fatalf("unable to accept connection: %v", listenErr)
return nil, nil, err
}
if dialErr := <-errChan; err != nil {
t.Fatalf("unable to establish connection: %v", dialErr)
return nil, nil, dialErr
}
remoteConn := <-connChan
return localConn, remoteConn, nil
}
func TestConnectionCorrectness(t *testing.T) {
// Create a test connection, grabbing either side of the connection
// into local variables. If the initial crypto handshake fails, then
// we'll get a non-nil error here.
localConn, remoteConn, err := establishTestConnection()
if err != nil {
t.Fatalf("unable to establish test connection: %v", err)
}
conn := <-connChan
// Test out some message full-message reads.
for i := 0; i < 10; i++ {
msg := []byte("hello" + string(i))
if _, err := conn.Write(msg); err != nil {
if _, err := localConn.Write(msg); err != nil {
t.Fatalf("remote conn failed to write: %v", err)
}
readBuf := make([]byte, len(msg))
if _, err := localConn.Read(readBuf); err != nil {
if _, err := remoteConn.Read(readBuf); err != nil {
t.Fatalf("local conn failed to read: %v", err)
}
@ -80,15 +94,15 @@ func TestConnectionCorrectness(t *testing.T) {
// Now try incremental message reads. This simulates first writing a
// message header, then a message body.
outMsg := []byte("hello world")
if _, err := conn.Write(outMsg); err != nil {
if _, err := localConn.Write(outMsg); err != nil {
t.Fatalf("remote conn failed to write: %v", err)
}
readBuf := make([]byte, len(outMsg))
if _, err := localConn.Read(readBuf[:len(outMsg)/2]); err != nil {
if _, err := remoteConn.Read(readBuf[:len(outMsg)/2]); err != nil {
t.Fatalf("local conn failed to read: %v", err)
}
if _, err := localConn.Read(readBuf[len(outMsg)/2:]); err != nil {
if _, err := remoteConn.Read(readBuf[len(outMsg)/2:]); err != nil {
t.Fatalf("local conn failed to read: %v", err)
}
@ -136,6 +150,54 @@ func TestMaxPayloadLength(t *testing.T) {
}
}
func TestWriteMessageChunking(t *testing.T) {
// Create a test connection, grabbing either side of the connection
// into local variables. If the initial crypto handshake fails, then
// we'll get a non-nil error here.
localConn, remoteConn, err := establishTestConnection()
if err != nil {
t.Fatalf("unable to establish test connection: %v", err)
}
// Attempt to write a message which is over 3x the max allowed payload
// size.
largeMessage := bytes.Repeat([]byte("kek"), math.MaxUint16*3)
// Launch a new goroutine to write the lerge message generated above in
// chunks. We spawn a new goroutine because otherwise, we may block as
// the kernal waits for the buffer to flush.
var wg sync.WaitGroup
wg.Add(1)
go func() {
bytesWritten, err := localConn.Write(largeMessage)
if err != nil {
t.Fatalf("unable to write message")
}
// The entire message should have been written out to the remote
// connection.
if bytesWritten != len(largeMessage) {
t.Fatalf("bytes not fully written!")
}
wg.Done()
}()
// Attempt to read the entirety of the message generated above.
buf := make([]byte, len(largeMessage))
if _, err := io.ReadFull(remoteConn, buf); err != nil {
t.Fatalf("unable to read message")
}
wg.Wait()
// Finally, the message the remote end of the connection received
// should be identical to what we sent from the local connection.
if !bytes.Equal(buf, largeMessage) {
t.Fatalf("bytes don't match")
}
}
func TestNoiseIdentityHiding(t *testing.T) {
// TODO(roasbeef): fin
}