brontide/noise: add Flush method

This commit adds a Flush method to the brontide.Machine, which can write
out a buffered message to an io.Writer. This is a preliminary change
which will allow the encryption of the plaintext to be done in a
distinct method from actually writing the bytes to the wire.
This commit is contained in:
Conner Fromknecht 2019-04-22 16:03:39 -07:00
parent 6e3cf06bd4
commit ed8fe4bc82
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7

@ -367,6 +367,16 @@ type Machine struct {
// next ciphertext header from the wire. The header is a 2 byte length
// (of the next ciphertext), followed by a 16 byte MAC.
nextCipherHeader [lengthHeaderSize + macSize]byte
// nextHeaderSend holds a reference to the remaining header bytes to
// write out for a pending message. This allows us to tolerate timeout
// errors that cause partial writes.
nextHeaderSend []byte
// nextHeaderBody holds a reference to the remaining body bytes to write
// out for a pending message. This allows us to tolerate timeout errors
// that cause partial writes.
nextBodySend []byte
}
// NewBrontideMachine creates a new instance of the brontide state-machine. If
@ -715,6 +725,81 @@ func (b *Machine) WriteMessage(w io.Writer, p []byte) error {
return err
}
// Flush attempts to write a message buffered using WriteMessage to the provided
// io.Writer. If no buffered message exists, this will result in a NOP.
// Otherwise, it will continue to write the remaining bytes, picking up where
// the byte stream left off in the event of a partial write. The number of bytes
// returned reflects the number of plaintext bytes in the payload, and does not
// account for the overhead of the header or MACs.
//
// NOTE: It is safe to call this method again iff a timeout error is returned.
func (b *Machine) Flush(w io.Writer) (int, error) {
// First, write out the pending header bytes, if any exist. Any header
// bytes written will not count towards the total amount flushed.
if len(b.nextHeaderSend) > 0 {
// Write any remaining header bytes and shift the slice to point
// to the next segment of unwritten bytes. If an error is
// encountered, we can continue to write the header from where
// we left off on a subsequent call to Flush.
n, err := w.Write(b.nextHeaderSend)
b.nextHeaderSend = b.nextHeaderSend[n:]
if err != nil {
return 0, err
}
}
// Next, write the pending body bytes, if any exist. Only the number of
// bytes written that correspond to the ciphertext will be included in
// the total bytes written, bytes written as part of the MAC will not be
// counted.
var nn int
if len(b.nextBodySend) > 0 {
// Write out all bytes excluding the mac and shift the body
// slice depending on the number of actual bytes written.
n, err := w.Write(b.nextBodySend)
b.nextBodySend = b.nextBodySend[n:]
// If we partially or fully wrote any of the body's MAC, we'll
// subtract that contribution from the total amount flushed to
// preserve the abstraction of returning the number of plaintext
// bytes written by the connection.
//
// There are three possible scenarios we must handle to ensure
// the returned value is correct. In the first case, the write
// straddles both payload and MAC bytes, and we must subtract
// the number of MAC bytes written from n. In the second, only
// payload bytes are written, thus we can return n unmodified.
// The final scenario pertains to the case where only MAC bytes
// are written, none of which count towards the total.
//
// |-----------Payload------------|----MAC----|
// Straddle: S---------------------------------E--------0
// Payload-only: S------------------------E-----------------0
// MAC-only: S-------E-0
start, end := n+len(b.nextBodySend), len(b.nextBodySend)
switch {
// Straddles payload and MAC bytes, subtract number of MAC bytes
// written from the actual number written.
case start > macSize && end <= macSize:
nn = n - (macSize - end)
// Only payload bytes are written, return n directly.
case start > macSize && end > macSize:
nn = n
// Only MAC bytes are written, return 0 bytes written.
default:
}
if err != nil {
return nn, err
}
}
return nn, nil
}
// ReadMessage attempts to read the next message from the passed io.Reader. In
// the case of an authentication error, a non-nil error is returned.
func (b *Machine) ReadMessage(r io.Reader) ([]byte, error) {