diff --git a/brontide/conn.go b/brontide/conn.go index 1aa1b3a2..eb773548 100644 --- a/brontide/conn.go +++ b/brontide/conn.go @@ -3,6 +3,7 @@ package brontide import ( "bytes" "io" + "math" "net" "time" @@ -108,7 +109,37 @@ func (c *Conn) Read(b []byte) (n int, err error) { // // Part of the net.Conn interface. func (c *Conn) Write(b []byte) (n int, err error) { - return len(b), c.noise.WriteMessage(c.conn, b) + // If the message doesn't require any chunking, then we can go ahead + // with a single write. + if len(b)+macSize <= math.MaxUint16 { + return len(b), c.noise.WriteMessage(c.conn, b) + } + + // If we need to split the message into fragments, then we'll write + // chunks which maximize usage of the available payload. To do so, we + // subtract the added overhead of the MAC at the end of the message. + chunkSize := math.MaxUint16 - macSize + + bytesToWrite := len(b) + bytesWritten := 0 + for bytesWritten < bytesToWrite { + // If we're on the last chunk, then truncate the chunk size as + // necessary to avoid an out-of-bounds array memory access. + if bytesWritten+chunkSize > len(b) { + chunkSize = len(b) - bytesWritten + } + + // Slice off the next chunk to be written based on our running + // counter and next chunk size. + chunk := b[bytesWritten : bytesWritten+chunkSize] + if err := c.noise.WriteMessage(c.conn, chunk); err != nil { + return bytesWritten, err + } + + bytesWritten += len(chunk) + } + + return bytesWritten, nil } // Close closes the connection. Any blocked Read or Write operations will be diff --git a/brontide/noise_test.go b/brontide/noise_test.go index cf3f1ac2..6d277179 100644 --- a/brontide/noise_test.go +++ b/brontide/noise_test.go @@ -2,24 +2,26 @@ package brontide import ( "bytes" + "io" "math" "net" + "sync" "testing" "github.com/lightningnetwork/lnd/lnwire" "github.com/roasbeef/btcd/btcec" ) -func TestConnectionCorrectness(t *testing.T) { +func establishTestConnection() (net.Conn, net.Conn, error) { // First, generate the long-term private keys both ends of the connection // within our test. localPriv, err := btcec.NewPrivateKey(btcec.S256()) if err != nil { - t.Fatalf("unable to generate local priv key: %v", err) + return nil, nil, err } remotePriv, err := btcec.NewPrivateKey(btcec.S256()) if err != nil { - t.Fatalf("unable to generate remote priv key: %v", err) + return nil, nil, err } // Having a port of ":0" means a random port, and interface will be @@ -29,7 +31,7 @@ func TestConnectionCorrectness(t *testing.T) { // Our listener will be local, and the connection remote. listener, err := NewListener(localPriv, addr) if err != nil { - t.Fatalf("unable to create listener: %v", err) + return nil, nil, err } defer listener.Close() @@ -50,24 +52,36 @@ func TestConnectionCorrectness(t *testing.T) { localConn, listenErr := listener.Accept() if listenErr != nil { - t.Fatalf("unable to accept connection: %v", listenErr) + return nil, nil, err } if dialErr := <-errChan; err != nil { - t.Fatalf("unable to establish connection: %v", dialErr) + return nil, nil, dialErr + } + remoteConn := <-connChan + + return localConn, remoteConn, nil +} + +func TestConnectionCorrectness(t *testing.T) { + // Create a test connection, grabbing either side of the connection + // into local variables. If the initial crypto handshake fails, then + // we'll get a non-nil error here. + localConn, remoteConn, err := establishTestConnection() + if err != nil { + t.Fatalf("unable to establish test connection: %v", err) } - conn := <-connChan // Test out some message full-message reads. for i := 0; i < 10; i++ { msg := []byte("hello" + string(i)) - if _, err := conn.Write(msg); err != nil { + if _, err := localConn.Write(msg); err != nil { t.Fatalf("remote conn failed to write: %v", err) } readBuf := make([]byte, len(msg)) - if _, err := localConn.Read(readBuf); err != nil { + if _, err := remoteConn.Read(readBuf); err != nil { t.Fatalf("local conn failed to read: %v", err) } @@ -80,15 +94,15 @@ func TestConnectionCorrectness(t *testing.T) { // Now try incremental message reads. This simulates first writing a // message header, then a message body. outMsg := []byte("hello world") - if _, err := conn.Write(outMsg); err != nil { + if _, err := localConn.Write(outMsg); err != nil { t.Fatalf("remote conn failed to write: %v", err) } readBuf := make([]byte, len(outMsg)) - if _, err := localConn.Read(readBuf[:len(outMsg)/2]); err != nil { + if _, err := remoteConn.Read(readBuf[:len(outMsg)/2]); err != nil { t.Fatalf("local conn failed to read: %v", err) } - if _, err := localConn.Read(readBuf[len(outMsg)/2:]); err != nil { + if _, err := remoteConn.Read(readBuf[len(outMsg)/2:]); err != nil { t.Fatalf("local conn failed to read: %v", err) } @@ -136,6 +150,54 @@ func TestMaxPayloadLength(t *testing.T) { } } +func TestWriteMessageChunking(t *testing.T) { + // Create a test connection, grabbing either side of the connection + // into local variables. If the initial crypto handshake fails, then + // we'll get a non-nil error here. + localConn, remoteConn, err := establishTestConnection() + if err != nil { + t.Fatalf("unable to establish test connection: %v", err) + } + + // Attempt to write a message which is over 3x the max allowed payload + // size. + largeMessage := bytes.Repeat([]byte("kek"), math.MaxUint16*3) + + // Launch a new goroutine to write the lerge message generated above in + // chunks. We spawn a new goroutine because otherwise, we may block as + // the kernal waits for the buffer to flush. + var wg sync.WaitGroup + wg.Add(1) + go func() { + bytesWritten, err := localConn.Write(largeMessage) + if err != nil { + t.Fatalf("unable to write message") + } + + // The entire message should have been written out to the remote + // connection. + if bytesWritten != len(largeMessage) { + t.Fatalf("bytes not fully written!") + } + + wg.Done() + }() + + // Attempt to read the entirety of the message generated above. + buf := make([]byte, len(largeMessage)) + if _, err := io.ReadFull(remoteConn, buf); err != nil { + t.Fatalf("unable to read message") + } + + wg.Wait() + + // Finally, the message the remote end of the connection received + // should be identical to what we sent from the local connection. + if !bytes.Equal(buf, largeMessage) { + t.Fatalf("bytes don't match") + } +} + func TestNoiseIdentityHiding(t *testing.T) { // TODO(roasbeef): fin }