diff --git a/brontide/conn.go b/brontide/conn.go index 643ff860..7c6a16e2 100644 --- a/brontide/conn.go +++ b/brontide/conn.go @@ -166,7 +166,11 @@ func (c *Conn) Write(b []byte) (n int, err error) { // If the message doesn't require any chunking, then we can go ahead // with a single write. if len(b) <= math.MaxUint16 { - return len(b), c.noise.WriteMessage(c.conn, b) + err = c.noise.WriteMessage(b) + if err != nil { + return 0, err + } + return c.noise.Flush(c.conn) } // If we need to split the message into fragments, then we'll write @@ -185,11 +189,15 @@ func (c *Conn) Write(b []byte) (n int, err error) { // Slice off the next chunk to be written based on our running // counter and next chunk size. chunk := b[bytesWritten : bytesWritten+chunkSize] - if err := c.noise.WriteMessage(c.conn, chunk); err != nil { + if err := c.noise.WriteMessage(chunk); err != nil { return bytesWritten, err } - bytesWritten += len(chunk) + n, err := c.noise.Flush(c.conn) + bytesWritten += n + if err != nil { + return bytesWritten, err + } } return bytesWritten, nil diff --git a/brontide/noise.go b/brontide/noise.go index e6b9045c..ef5ecf29 100644 --- a/brontide/noise.go +++ b/brontide/noise.go @@ -48,6 +48,10 @@ var ( ErrMaxMessageLengthExceeded = errors.New("the generated payload exceeds " + "the max allowed message length of (2^16)-1") + // ErrMessageNotFlushed signals that the connection cannot accept a new + // message because the prior message has not been fully flushed. + ErrMessageNotFlushed = errors.New("prior message not flushed") + // lightningPrologue is the noise prologue that is used to initialize // the brontide noise handshake. lightningPrologue = []byte("lightning") @@ -692,11 +696,13 @@ func (b *Machine) split() { } } -// WriteMessage writes the next message p to the passed io.Writer. The -// ciphertext of the message is prepended with an encrypt+auth'd length which -// must be used as the AD to the AEAD construction when being decrypted by the -// other side. -func (b *Machine) WriteMessage(w io.Writer, p []byte) error { +// WriteMessage encrypts and buffers the next message p. The ciphertext of the +// message is prepended with an encrypt+auth'd length which must be used as the +// AD to the AEAD construction when being decrypted by the other side. +// +// NOTE: This DOES NOT write the message to the wire, it should be followed by a +// call to Flush to ensure the message is written. +func (b *Machine) WriteMessage(p []byte) error { // The total length of each message payload including the MAC size // payload exceed the largest number encodable within a 16-bit unsigned // integer. @@ -704,6 +710,13 @@ func (b *Machine) WriteMessage(w io.Writer, p []byte) error { return ErrMaxMessageLengthExceeded } + // If a prior message was written but it hasn't been fully flushed, + // return an error as we only support buffering of one message at a + // time. + if len(b.nextHeaderSend) > 0 || len(b.nextBodySend) > 0 { + return ErrMessageNotFlushed + } + // The full length of the packet is only the packet length, and does // NOT include the MAC. fullLength := uint16(len(p)) @@ -711,18 +724,13 @@ func (b *Machine) WriteMessage(w io.Writer, p []byte) error { var pktLen [2]byte binary.BigEndian.PutUint16(pktLen[:], fullLength) - // First, write out the encrypted+MAC'd length prefix for the packet. - cipherLen := b.sendCipher.Encrypt(nil, nil, pktLen[:]) - if _, err := w.Write(cipherLen); err != nil { - return err - } + // First, generate the encrypted+MAC'd length prefix for the packet. + b.nextHeaderSend = b.sendCipher.Encrypt(nil, nil, pktLen[:]) - // Finally, write out the encrypted packet itself. We only write out a - // single packet, as any fragmentation should have taken place at a - // higher level. - cipherText := b.sendCipher.Encrypt(nil, nil, p) - _, err := w.Write(cipherText) - return err + // Finally, generate the encrypted packet itself. + b.nextBodySend = b.sendCipher.Encrypt(nil, nil, p) + + return nil } // Flush attempts to write a message buffered using WriteMessage to the provided diff --git a/brontide/noise_test.go b/brontide/noise_test.go index 1918a646..1349e653 100644 --- a/brontide/noise_test.go +++ b/brontide/noise_test.go @@ -220,11 +220,9 @@ func TestMaxPayloadLength(t *testing.T) { // payload length. payloadToReject := make([]byte, math.MaxUint16+1) - var buf bytes.Buffer - // A write of the payload generated above to the state machine should // be rejected as it's over the max payload length. - err := b.WriteMessage(&buf, payloadToReject) + err := b.WriteMessage(payloadToReject) if err != ErrMaxMessageLengthExceeded { t.Fatalf("payload is over the max allowed length, the write " + "should have been rejected") @@ -233,7 +231,7 @@ func TestMaxPayloadLength(t *testing.T) { // Generate another payload which should be accepted as a valid // payload. payloadToAccept := make([]byte, math.MaxUint16-1) - if err := b.WriteMessage(&buf, payloadToAccept); err != nil { + if err := b.WriteMessage(payloadToAccept); err != nil { t.Fatalf("write for payload was rejected, should have been " + "accepted") } @@ -243,7 +241,7 @@ func TestMaxPayloadLength(t *testing.T) { payloadToReject = make([]byte, math.MaxUint16+1) // This payload should be rejected. - err = b.WriteMessage(&buf, payloadToReject) + err = b.WriteMessage(payloadToReject) if err != ErrMaxMessageLengthExceeded { t.Fatalf("payload is over the max allowed length, the write " + "should have been rejected") @@ -502,10 +500,14 @@ func TestBolt0008TestVectors(t *testing.T) { var buf bytes.Buffer for i := 0; i < 1002; i++ { - err = initiator.WriteMessage(&buf, payload) + err = initiator.WriteMessage(payload) if err != nil { t.Fatalf("could not write message %s", payload) } + _, err = initiator.Flush(&buf) + if err != nil { + t.Fatalf("could not flush message: %v", err) + } if val, ok := transportMessageVectors[i]; ok { binaryVal, err := hex.DecodeString(val) if err != nil {