brontide: exclude MAC length from cipher text packet length prefix

Pervasively we would include the length of the MAC in the length prefix
for cipher text packets. As a result, the MAC would eat into the total
payload size. To remedy this, we now exclude the MAC from the length
prefix for cipher text packets, and instead account for the length of
the MAC on the packet when reading messages.
This commit is contained in:
Olaoluwa Osuntokun 2017-01-07 19:15:58 -08:00
parent 387d41e5df
commit d046efb502
No known key found for this signature in database
GPG Key ID: 9CC5B105D03521A2
3 changed files with 18 additions and 17 deletions

@ -117,14 +117,13 @@ func (c *Conn) Read(b []byte) (n int, err error) {
func (c *Conn) Write(b []byte) (n int, err error) { func (c *Conn) Write(b []byte) (n int, err error) {
// If the message doesn't require any chunking, then we can go ahead // If the message doesn't require any chunking, then we can go ahead
// with a single write. // with a single write.
if len(b)+macSize <= math.MaxUint16 { if len(b) <= math.MaxUint16 {
return len(b), c.noise.WriteMessage(c.conn, b) return len(b), c.noise.WriteMessage(c.conn, b)
} }
// If we need to split the message into fragments, then we'll write // If we need to split the message into fragments, then we'll write
// chunks which maximize usage of the available payload. To do so, we // chunks which maximize usage of the available payload.
// subtract the added overhead of the MAC at the end of the message. chunkSize := math.MaxUint16
chunkSize := math.MaxUint16 - macSize
bytesToWrite := len(b) bytesToWrite := len(b)
bytesWritten := 0 bytesWritten := 0

@ -641,12 +641,13 @@ func (b *BrontideMachine) WriteMessage(w io.Writer, p []byte) error {
// The total length of each message payload including the MAC size // The total length of each message payload including the MAC size
// payload exceed the largest number encodable within a 16-bit unsigned // payload exceed the largest number encodable within a 16-bit unsigned
// integer. // integer.
if len(p)+macSize > math.MaxUint16 { if len(p) > math.MaxUint16 {
return ErrMaxMessageLengthExceeded return ErrMaxMessageLengthExceeded
} }
// The full length of the packet includes the 16 byte MAC. // The full length of the packet is only the packet length, and does
fullLength := uint16(len(p) + macSize) // NOT include the MAC.
fullLength := uint16(len(p))
var pktLen [2]byte var pktLen [2]byte
binary.BigEndian.PutUint16(pktLen[:], fullLength) binary.BigEndian.PutUint16(pktLen[:], fullLength)
@ -684,11 +685,11 @@ func (b *BrontideMachine) ReadMessage(r io.Reader) ([]byte, error) {
// Next, using the length read from the packet header, read the // Next, using the length read from the packet header, read the
// encrypted packet itself. // encrypted packet itself.
pktLen := binary.BigEndian.Uint16(pktLenBytes) pktLen := uint32(binary.BigEndian.Uint16(pktLenBytes)) + macSize
ciperText := make([]byte, pktLen) cipherText := make([]byte, pktLen)
if _, err := io.ReadFull(r, ciperText[:]); err != nil { if _, err := io.ReadFull(r, cipherText[:]); err != nil {
return nil, err return nil, err
} }
return b.recvCipher.Decrypt(nil, nil, ciperText) return b.recvCipher.Decrypt(nil, nil, cipherText)
} }

@ -63,6 +63,7 @@ func establishTestConnection() (net.Conn, net.Conn, error) {
return localConn, remoteConn, nil return localConn, remoteConn, nil
} }
func TestConnectionCorrectness(t *testing.T) { func TestConnectionCorrectness(t *testing.T) {
// Create a test connection, grabbing either side of the connection // Create a test connection, grabbing either side of the connection
// into local variables. If the initial crypto handshake fails, then // into local variables. If the initial crypto handshake fails, then
@ -130,9 +131,9 @@ func TestMaxPayloadLength(t *testing.T) {
"should have been rejected") "should have been rejected")
} }
// Generate another payload which with the MAC acounted for, should be // Generate another payload which should be accepted as a valid
// accepted as a valid payload. // payload.
payloadToAccept := make([]byte, math.MaxUint16-macSize) payloadToAccept := make([]byte, math.MaxUint16-1)
if err := b.WriteMessage(&buf, payloadToAccept); err != nil { if err := b.WriteMessage(&buf, payloadToAccept); err != nil {
t.Fatalf("write for payload was rejected, should have been " + t.Fatalf("write for payload was rejected, should have been " +
"accepted") "accepted")
@ -140,7 +141,7 @@ func TestMaxPayloadLength(t *testing.T) {
// Generate a final payload which is juuust over the max payload length // Generate a final payload which is juuust over the max payload length
// when the MAC is accounted for. // when the MAC is accounted for.
payloadToReject = make([]byte, math.MaxUint16-macSize+1) payloadToReject = make([]byte, math.MaxUint16+1)
// This payload should be rejected. // This payload should be rejected.
err = b.WriteMessage(&buf, payloadToReject) err = b.WriteMessage(&buf, payloadToReject)
@ -171,7 +172,7 @@ func TestWriteMessageChunking(t *testing.T) {
go func() { go func() {
bytesWritten, err := localConn.Write(largeMessage) bytesWritten, err := localConn.Write(largeMessage)
if err != nil { if err != nil {
t.Fatalf("unable to write message") t.Fatalf("unable to write message: %v", err)
} }
// The entire message should have been written out to the remote // The entire message should have been written out to the remote
@ -186,7 +187,7 @@ func TestWriteMessageChunking(t *testing.T) {
// Attempt to read the entirety of the message generated above. // Attempt to read the entirety of the message generated above.
buf := make([]byte, len(largeMessage)) buf := make([]byte, len(largeMessage))
if _, err := io.ReadFull(remoteConn, buf); err != nil { if _, err := io.ReadFull(remoteConn, buf); err != nil {
t.Fatalf("unable to read message") t.Fatalf("unable to read message: %v", err)
} }
wg.Wait() wg.Wait()