brontide/noise_test: add TestFlush
This commit is contained in:
parent
333caac09c
commit
e3728da478
@ -31,6 +31,10 @@ const (
|
|||||||
// length of a message payload.
|
// length of a message payload.
|
||||||
lengthHeaderSize = 2
|
lengthHeaderSize = 2
|
||||||
|
|
||||||
|
// encHeaderSize is the number of bytes required to hold an encrypted
|
||||||
|
// header and it's MAC.
|
||||||
|
encHeaderSize = lengthHeaderSize + macSize
|
||||||
|
|
||||||
// keyRotationInterval is the number of messages sent on a single
|
// keyRotationInterval is the number of messages sent on a single
|
||||||
// cipher stream before the keys are rotated forwards.
|
// cipher stream before the keys are rotated forwards.
|
||||||
keyRotationInterval = 1000
|
keyRotationInterval = 1000
|
||||||
@ -370,7 +374,7 @@ type Machine struct {
|
|||||||
// nextCipherHeader is a static buffer that we'll use to read in the
|
// nextCipherHeader is a static buffer that we'll use to read in the
|
||||||
// next ciphertext header from the wire. The header is a 2 byte length
|
// next ciphertext header from the wire. The header is a 2 byte length
|
||||||
// (of the next ciphertext), followed by a 16 byte MAC.
|
// (of the next ciphertext), followed by a 16 byte MAC.
|
||||||
nextCipherHeader [lengthHeaderSize + macSize]byte
|
nextCipherHeader [encHeaderSize]byte
|
||||||
|
|
||||||
// nextHeaderSend holds a reference to the remaining header bytes to
|
// nextHeaderSend holds a reference to the remaining header bytes to
|
||||||
// write out for a pending message. This allows us to tolerate timeout
|
// write out for a pending message. This allows us to tolerate timeout
|
||||||
|
@ -561,3 +561,152 @@ func (t *timeoutWriter) Write(p []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const payloadSize = 10
|
||||||
|
|
||||||
|
type flushChunk struct {
|
||||||
|
errAfter int64
|
||||||
|
expN int
|
||||||
|
expErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
type flushTest struct {
|
||||||
|
name string
|
||||||
|
chunks []flushChunk
|
||||||
|
}
|
||||||
|
|
||||||
|
var flushTests = []flushTest{
|
||||||
|
{
|
||||||
|
name: "partial header write",
|
||||||
|
chunks: []flushChunk{
|
||||||
|
// Write 18-byte header in two parts, 16 then 2.
|
||||||
|
{
|
||||||
|
errAfter: encHeaderSize - 2,
|
||||||
|
expN: 0,
|
||||||
|
expErr: iotest.ErrTimeout,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
errAfter: 2,
|
||||||
|
expN: 0,
|
||||||
|
expErr: iotest.ErrTimeout,
|
||||||
|
},
|
||||||
|
// Write payload and MAC in one go.
|
||||||
|
{
|
||||||
|
errAfter: -1,
|
||||||
|
expN: payloadSize,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "full payload then full mac",
|
||||||
|
chunks: []flushChunk{
|
||||||
|
// Write entire header and entire payload w/o MAC.
|
||||||
|
{
|
||||||
|
errAfter: encHeaderSize + payloadSize,
|
||||||
|
expN: payloadSize,
|
||||||
|
expErr: iotest.ErrTimeout,
|
||||||
|
},
|
||||||
|
// Write the entire MAC.
|
||||||
|
{
|
||||||
|
errAfter: -1,
|
||||||
|
expN: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "payload-only, straddle, mac-only",
|
||||||
|
chunks: []flushChunk{
|
||||||
|
// Write header and all but last byte of payload.
|
||||||
|
{
|
||||||
|
errAfter: encHeaderSize + payloadSize - 1,
|
||||||
|
expN: payloadSize - 1,
|
||||||
|
expErr: iotest.ErrTimeout,
|
||||||
|
},
|
||||||
|
// Write last byte of payload and first byte of MAC.
|
||||||
|
{
|
||||||
|
errAfter: 2,
|
||||||
|
expN: 1,
|
||||||
|
expErr: iotest.ErrTimeout,
|
||||||
|
},
|
||||||
|
// Write 10 bytes of the MAC.
|
||||||
|
{
|
||||||
|
errAfter: 10,
|
||||||
|
expN: 0,
|
||||||
|
expErr: iotest.ErrTimeout,
|
||||||
|
},
|
||||||
|
// Write the remaining 5 MAC bytes.
|
||||||
|
{
|
||||||
|
errAfter: -1,
|
||||||
|
expN: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFlush asserts a Machine's ability to handle timeouts during Flush that
|
||||||
|
// cause partial writes, and that the machine can properly resume writes on
|
||||||
|
// subsequent calls to Flush.
|
||||||
|
func TestFlush(t *testing.T) {
|
||||||
|
// Run each test individually, to assert that they pass in isolation.
|
||||||
|
for _, test := range flushTests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
w bytes.Buffer
|
||||||
|
b Machine
|
||||||
|
)
|
||||||
|
b.split()
|
||||||
|
testFlush(t, test, &b, &w)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, run the tests serially as if all on one connection.
|
||||||
|
t.Run("flush serial", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
w bytes.Buffer
|
||||||
|
b Machine
|
||||||
|
)
|
||||||
|
b.split()
|
||||||
|
for _, test := range flushTests {
|
||||||
|
testFlush(t, test, &b, &w)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// testFlush buffers a message on the Machine, then flushes it to the io.Writer
|
||||||
|
// in chunks. Once complete, a final call to flush is made to assert that Write
|
||||||
|
// is not called again.
|
||||||
|
func testFlush(t *testing.T, test flushTest, b *Machine, w io.Writer) {
|
||||||
|
payload := make([]byte, payloadSize)
|
||||||
|
if err := b.WriteMessage(payload); err != nil {
|
||||||
|
t.Fatalf("unable to write message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, chunk := range test.chunks {
|
||||||
|
assertFlush(t, b, w, chunk.errAfter, chunk.expN, chunk.expErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// We should always be able to call Flush after a message has been
|
||||||
|
// successfully written, and it should result in a NOP.
|
||||||
|
assertFlush(t, b, w, 0, 0, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// assertFlush flushes a chunk to the passed io.Writer. If n >= 0, a
|
||||||
|
// timeoutWriter will be used the flush should stop with iotest.ErrTimeout after
|
||||||
|
// n bytes. The method asserts that the returned error matches expErr and that
|
||||||
|
// the number of bytes written by Flush matches expN.
|
||||||
|
func assertFlush(t *testing.T, b *Machine, w io.Writer, n int64, expN int,
|
||||||
|
expErr error) {
|
||||||
|
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if n >= 0 {
|
||||||
|
w = NewTimeoutWriter(w, n)
|
||||||
|
}
|
||||||
|
nn, err := b.Flush(w)
|
||||||
|
if err != expErr {
|
||||||
|
t.Fatalf("expected flush err: %v, got: %v", expErr, err)
|
||||||
|
}
|
||||||
|
if nn != expN {
|
||||||
|
t.Fatalf("expected n: %d, got: %d", expN, nn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user