diff --git a/brontide/noise.go b/brontide/noise.go index ef5ecf29..7d5f42d4 100644 --- a/brontide/noise.go +++ b/brontide/noise.go @@ -31,6 +31,10 @@ const ( // length of a message payload. lengthHeaderSize = 2 + // encHeaderSize is the number of bytes required to hold an encrypted + // header and it's MAC. + encHeaderSize = lengthHeaderSize + macSize + // keyRotationInterval is the number of messages sent on a single // cipher stream before the keys are rotated forwards. keyRotationInterval = 1000 @@ -370,7 +374,7 @@ type Machine struct { // nextCipherHeader is a static buffer that we'll use to read in the // next ciphertext header from the wire. The header is a 2 byte length // (of the next ciphertext), followed by a 16 byte MAC. - nextCipherHeader [lengthHeaderSize + macSize]byte + nextCipherHeader [encHeaderSize]byte // nextHeaderSend holds a reference to the remaining header bytes to // write out for a pending message. This allows us to tolerate timeout diff --git a/brontide/noise_test.go b/brontide/noise_test.go index c0614f50..55ea4659 100644 --- a/brontide/noise_test.go +++ b/brontide/noise_test.go @@ -561,3 +561,152 @@ func (t *timeoutWriter) Write(p []byte) (int, error) { } return n, err } + +const payloadSize = 10 + +type flushChunk struct { + errAfter int64 + expN int + expErr error +} + +type flushTest struct { + name string + chunks []flushChunk +} + +var flushTests = []flushTest{ + { + name: "partial header write", + chunks: []flushChunk{ + // Write 18-byte header in two parts, 16 then 2. + { + errAfter: encHeaderSize - 2, + expN: 0, + expErr: iotest.ErrTimeout, + }, + { + errAfter: 2, + expN: 0, + expErr: iotest.ErrTimeout, + }, + // Write payload and MAC in one go. + { + errAfter: -1, + expN: payloadSize, + }, + }, + }, + { + name: "full payload then full mac", + chunks: []flushChunk{ + // Write entire header and entire payload w/o MAC. + { + errAfter: encHeaderSize + payloadSize, + expN: payloadSize, + expErr: iotest.ErrTimeout, + }, + // Write the entire MAC. + { + errAfter: -1, + expN: 0, + }, + }, + }, + { + name: "payload-only, straddle, mac-only", + chunks: []flushChunk{ + // Write header and all but last byte of payload. + { + errAfter: encHeaderSize + payloadSize - 1, + expN: payloadSize - 1, + expErr: iotest.ErrTimeout, + }, + // Write last byte of payload and first byte of MAC. + { + errAfter: 2, + expN: 1, + expErr: iotest.ErrTimeout, + }, + // Write 10 bytes of the MAC. + { + errAfter: 10, + expN: 0, + expErr: iotest.ErrTimeout, + }, + // Write the remaining 5 MAC bytes. + { + errAfter: -1, + expN: 0, + }, + }, + }, +} + +// TestFlush asserts a Machine's ability to handle timeouts during Flush that +// cause partial writes, and that the machine can properly resume writes on +// subsequent calls to Flush. +func TestFlush(t *testing.T) { + // Run each test individually, to assert that they pass in isolation. + for _, test := range flushTests { + t.Run(test.name, func(t *testing.T) { + var ( + w bytes.Buffer + b Machine + ) + b.split() + testFlush(t, test, &b, &w) + }) + } + + // Finally, run the tests serially as if all on one connection. + t.Run("flush serial", func(t *testing.T) { + var ( + w bytes.Buffer + b Machine + ) + b.split() + for _, test := range flushTests { + testFlush(t, test, &b, &w) + } + }) +} + +// testFlush buffers a message on the Machine, then flushes it to the io.Writer +// in chunks. Once complete, a final call to flush is made to assert that Write +// is not called again. +func testFlush(t *testing.T, test flushTest, b *Machine, w io.Writer) { + payload := make([]byte, payloadSize) + if err := b.WriteMessage(payload); err != nil { + t.Fatalf("unable to write message: %v", err) + } + + for _, chunk := range test.chunks { + assertFlush(t, b, w, chunk.errAfter, chunk.expN, chunk.expErr) + } + + // We should always be able to call Flush after a message has been + // successfully written, and it should result in a NOP. + assertFlush(t, b, w, 0, 0, nil) +} + +// assertFlush flushes a chunk to the passed io.Writer. If n >= 0, a +// timeoutWriter will be used the flush should stop with iotest.ErrTimeout after +// n bytes. The method asserts that the returned error matches expErr and that +// the number of bytes written by Flush matches expN. +func assertFlush(t *testing.T, b *Machine, w io.Writer, n int64, expN int, + expErr error) { + + t.Helper() + + if n >= 0 { + w = NewTimeoutWriter(w, n) + } + nn, err := b.Flush(w) + if err != expErr { + t.Fatalf("expected flush err: %v, got: %v", expErr, err) + } + if nn != expN { + t.Fatalf("expected n: %d, got: %d", expN, nn) + } +}