brontide: exclude MAC length from cipher text packet length prefix
Pervasively we would include the length of the MAC in the length prefix for cipher text packets. As a result, the MAC would eat into the total payload size. To remedy this, we now exclude the MAC from the length prefix for cipher text packets, and instead account for the length of the MAC on the packet when reading messages.
This commit is contained in:
parent
387d41e5df
commit
d046efb502
@ -117,14 +117,13 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
|||||||
func (c *Conn) Write(b []byte) (n int, err error) {
|
func (c *Conn) Write(b []byte) (n int, err error) {
|
||||||
// If the message doesn't require any chunking, then we can go ahead
|
// If the message doesn't require any chunking, then we can go ahead
|
||||||
// with a single write.
|
// with a single write.
|
||||||
if len(b)+macSize <= math.MaxUint16 {
|
if len(b) <= math.MaxUint16 {
|
||||||
return len(b), c.noise.WriteMessage(c.conn, b)
|
return len(b), c.noise.WriteMessage(c.conn, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we need to split the message into fragments, then we'll write
|
// If we need to split the message into fragments, then we'll write
|
||||||
// chunks which maximize usage of the available payload. To do so, we
|
// chunks which maximize usage of the available payload.
|
||||||
// subtract the added overhead of the MAC at the end of the message.
|
chunkSize := math.MaxUint16
|
||||||
chunkSize := math.MaxUint16 - macSize
|
|
||||||
|
|
||||||
bytesToWrite := len(b)
|
bytesToWrite := len(b)
|
||||||
bytesWritten := 0
|
bytesWritten := 0
|
||||||
|
@ -641,12 +641,13 @@ func (b *BrontideMachine) WriteMessage(w io.Writer, p []byte) error {
|
|||||||
// The total length of each message payload including the MAC size
|
// The total length of each message payload including the MAC size
|
||||||
// payload exceed the largest number encodable within a 16-bit unsigned
|
// payload exceed the largest number encodable within a 16-bit unsigned
|
||||||
// integer.
|
// integer.
|
||||||
if len(p)+macSize > math.MaxUint16 {
|
if len(p) > math.MaxUint16 {
|
||||||
return ErrMaxMessageLengthExceeded
|
return ErrMaxMessageLengthExceeded
|
||||||
}
|
}
|
||||||
|
|
||||||
// The full length of the packet includes the 16 byte MAC.
|
// The full length of the packet is only the packet length, and does
|
||||||
fullLength := uint16(len(p) + macSize)
|
// NOT include the MAC.
|
||||||
|
fullLength := uint16(len(p))
|
||||||
|
|
||||||
var pktLen [2]byte
|
var pktLen [2]byte
|
||||||
binary.BigEndian.PutUint16(pktLen[:], fullLength)
|
binary.BigEndian.PutUint16(pktLen[:], fullLength)
|
||||||
@ -684,11 +685,11 @@ func (b *BrontideMachine) ReadMessage(r io.Reader) ([]byte, error) {
|
|||||||
|
|
||||||
// Next, using the length read from the packet header, read the
|
// Next, using the length read from the packet header, read the
|
||||||
// encrypted packet itself.
|
// encrypted packet itself.
|
||||||
pktLen := binary.BigEndian.Uint16(pktLenBytes)
|
pktLen := uint32(binary.BigEndian.Uint16(pktLenBytes)) + macSize
|
||||||
ciperText := make([]byte, pktLen)
|
cipherText := make([]byte, pktLen)
|
||||||
if _, err := io.ReadFull(r, ciperText[:]); err != nil {
|
if _, err := io.ReadFull(r, cipherText[:]); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return b.recvCipher.Decrypt(nil, nil, ciperText)
|
return b.recvCipher.Decrypt(nil, nil, cipherText)
|
||||||
}
|
}
|
||||||
|
@ -63,6 +63,7 @@ func establishTestConnection() (net.Conn, net.Conn, error) {
|
|||||||
|
|
||||||
return localConn, remoteConn, nil
|
return localConn, remoteConn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnectionCorrectness(t *testing.T) {
|
func TestConnectionCorrectness(t *testing.T) {
|
||||||
// Create a test connection, grabbing either side of the connection
|
// Create a test connection, grabbing either side of the connection
|
||||||
// into local variables. If the initial crypto handshake fails, then
|
// into local variables. If the initial crypto handshake fails, then
|
||||||
@ -130,9 +131,9 @@ func TestMaxPayloadLength(t *testing.T) {
|
|||||||
"should have been rejected")
|
"should have been rejected")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate another payload which with the MAC acounted for, should be
|
// Generate another payload which should be accepted as a valid
|
||||||
// accepted as a valid payload.
|
// payload.
|
||||||
payloadToAccept := make([]byte, math.MaxUint16-macSize)
|
payloadToAccept := make([]byte, math.MaxUint16-1)
|
||||||
if err := b.WriteMessage(&buf, payloadToAccept); err != nil {
|
if err := b.WriteMessage(&buf, payloadToAccept); err != nil {
|
||||||
t.Fatalf("write for payload was rejected, should have been " +
|
t.Fatalf("write for payload was rejected, should have been " +
|
||||||
"accepted")
|
"accepted")
|
||||||
@ -140,7 +141,7 @@ func TestMaxPayloadLength(t *testing.T) {
|
|||||||
|
|
||||||
// Generate a final payload which is juuust over the max payload length
|
// Generate a final payload which is juuust over the max payload length
|
||||||
// when the MAC is accounted for.
|
// when the MAC is accounted for.
|
||||||
payloadToReject = make([]byte, math.MaxUint16-macSize+1)
|
payloadToReject = make([]byte, math.MaxUint16+1)
|
||||||
|
|
||||||
// This payload should be rejected.
|
// This payload should be rejected.
|
||||||
err = b.WriteMessage(&buf, payloadToReject)
|
err = b.WriteMessage(&buf, payloadToReject)
|
||||||
@ -171,7 +172,7 @@ func TestWriteMessageChunking(t *testing.T) {
|
|||||||
go func() {
|
go func() {
|
||||||
bytesWritten, err := localConn.Write(largeMessage)
|
bytesWritten, err := localConn.Write(largeMessage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to write message")
|
t.Fatalf("unable to write message: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// The entire message should have been written out to the remote
|
// The entire message should have been written out to the remote
|
||||||
@ -186,7 +187,7 @@ func TestWriteMessageChunking(t *testing.T) {
|
|||||||
// Attempt to read the entirety of the message generated above.
|
// Attempt to read the entirety of the message generated above.
|
||||||
buf := make([]byte, len(largeMessage))
|
buf := make([]byte, len(largeMessage))
|
||||||
if _, err := io.ReadFull(remoteConn, buf); err != nil {
|
if _, err := io.ReadFull(remoteConn, buf); err != nil {
|
||||||
t.Fatalf("unable to read message")
|
t.Fatalf("unable to read message: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
Loading…
Reference in New Issue
Block a user