brontide/noise_test: add TestFlush

This commit is contained in:
Conner Fromknecht 2019-04-22 16:04:24 -07:00
parent 333caac09c
commit e3728da478
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7
2 changed files with 154 additions and 1 deletions

View File

@ -31,6 +31,10 @@ const (
// length of a message payload.
lengthHeaderSize = 2
// encHeaderSize is the number of bytes required to hold an encrypted
// header and it's MAC.
encHeaderSize = lengthHeaderSize + macSize
// keyRotationInterval is the number of messages sent on a single
// cipher stream before the keys are rotated forwards.
keyRotationInterval = 1000
@ -370,7 +374,7 @@ type Machine struct {
// nextCipherHeader is a static buffer that we'll use to read in the
// next ciphertext header from the wire. The header is a 2 byte length
// (of the next ciphertext), followed by a 16 byte MAC.
nextCipherHeader [lengthHeaderSize + macSize]byte
nextCipherHeader [encHeaderSize]byte
// nextHeaderSend holds a reference to the remaining header bytes to
// write out for a pending message. This allows us to tolerate timeout

View File

@ -561,3 +561,152 @@ func (t *timeoutWriter) Write(p []byte) (int, error) {
}
return n, err
}
const payloadSize = 10
type flushChunk struct {
errAfter int64
expN int
expErr error
}
type flushTest struct {
name string
chunks []flushChunk
}
var flushTests = []flushTest{
{
name: "partial header write",
chunks: []flushChunk{
// Write 18-byte header in two parts, 16 then 2.
{
errAfter: encHeaderSize - 2,
expN: 0,
expErr: iotest.ErrTimeout,
},
{
errAfter: 2,
expN: 0,
expErr: iotest.ErrTimeout,
},
// Write payload and MAC in one go.
{
errAfter: -1,
expN: payloadSize,
},
},
},
{
name: "full payload then full mac",
chunks: []flushChunk{
// Write entire header and entire payload w/o MAC.
{
errAfter: encHeaderSize + payloadSize,
expN: payloadSize,
expErr: iotest.ErrTimeout,
},
// Write the entire MAC.
{
errAfter: -1,
expN: 0,
},
},
},
{
name: "payload-only, straddle, mac-only",
chunks: []flushChunk{
// Write header and all but last byte of payload.
{
errAfter: encHeaderSize + payloadSize - 1,
expN: payloadSize - 1,
expErr: iotest.ErrTimeout,
},
// Write last byte of payload and first byte of MAC.
{
errAfter: 2,
expN: 1,
expErr: iotest.ErrTimeout,
},
// Write 10 bytes of the MAC.
{
errAfter: 10,
expN: 0,
expErr: iotest.ErrTimeout,
},
// Write the remaining 5 MAC bytes.
{
errAfter: -1,
expN: 0,
},
},
},
}
// TestFlush asserts a Machine's ability to handle timeouts during Flush that
// cause partial writes, and that the machine can properly resume writes on
// subsequent calls to Flush.
func TestFlush(t *testing.T) {
// Run each test individually, to assert that they pass in isolation.
for _, test := range flushTests {
t.Run(test.name, func(t *testing.T) {
var (
w bytes.Buffer
b Machine
)
b.split()
testFlush(t, test, &b, &w)
})
}
// Finally, run the tests serially as if all on one connection.
t.Run("flush serial", func(t *testing.T) {
var (
w bytes.Buffer
b Machine
)
b.split()
for _, test := range flushTests {
testFlush(t, test, &b, &w)
}
})
}
// testFlush buffers a message on the Machine, then flushes it to the io.Writer
// in chunks. Once complete, a final call to flush is made to assert that Write
// is not called again.
func testFlush(t *testing.T, test flushTest, b *Machine, w io.Writer) {
payload := make([]byte, payloadSize)
if err := b.WriteMessage(payload); err != nil {
t.Fatalf("unable to write message: %v", err)
}
for _, chunk := range test.chunks {
assertFlush(t, b, w, chunk.errAfter, chunk.expN, chunk.expErr)
}
// We should always be able to call Flush after a message has been
// successfully written, and it should result in a NOP.
assertFlush(t, b, w, 0, 0, nil)
}
// assertFlush flushes a chunk to the passed io.Writer. If n >= 0, a
// timeoutWriter will be used the flush should stop with iotest.ErrTimeout after
// n bytes. The method asserts that the returned error matches expErr and that
// the number of bytes written by Flush matches expN.
func assertFlush(t *testing.T, b *Machine, w io.Writer, n int64, expN int,
expErr error) {
t.Helper()
if n >= 0 {
w = NewTimeoutWriter(w, n)
}
nn, err := b.Flush(w)
if err != expErr {
t.Fatalf("expected flush err: %v, got: %v", expErr, err)
}
if nn != expN {
t.Fatalf("expected n: %d, got: %d", expN, nn)
}
}