diff --git a/brontide/noise.go b/brontide/noise.go index 7b6bdb43..e6b9045c 100644 --- a/brontide/noise.go +++ b/brontide/noise.go @@ -367,6 +367,16 @@ type Machine struct { // next ciphertext header from the wire. The header is a 2 byte length // (of the next ciphertext), followed by a 16 byte MAC. nextCipherHeader [lengthHeaderSize + macSize]byte + + // nextHeaderSend holds a reference to the remaining header bytes to + // write out for a pending message. This allows us to tolerate timeout + // errors that cause partial writes. + nextHeaderSend []byte + + // nextHeaderBody holds a reference to the remaining body bytes to write + // out for a pending message. This allows us to tolerate timeout errors + // that cause partial writes. + nextBodySend []byte } // NewBrontideMachine creates a new instance of the brontide state-machine. If @@ -715,6 +725,81 @@ func (b *Machine) WriteMessage(w io.Writer, p []byte) error { return err } +// Flush attempts to write a message buffered using WriteMessage to the provided +// io.Writer. If no buffered message exists, this will result in a NOP. +// Otherwise, it will continue to write the remaining bytes, picking up where +// the byte stream left off in the event of a partial write. The number of bytes +// returned reflects the number of plaintext bytes in the payload, and does not +// account for the overhead of the header or MACs. +// +// NOTE: It is safe to call this method again iff a timeout error is returned. +func (b *Machine) Flush(w io.Writer) (int, error) { + // First, write out the pending header bytes, if any exist. Any header + // bytes written will not count towards the total amount flushed. + if len(b.nextHeaderSend) > 0 { + // Write any remaining header bytes and shift the slice to point + // to the next segment of unwritten bytes. If an error is + // encountered, we can continue to write the header from where + // we left off on a subsequent call to Flush. + n, err := w.Write(b.nextHeaderSend) + b.nextHeaderSend = b.nextHeaderSend[n:] + if err != nil { + return 0, err + } + } + + // Next, write the pending body bytes, if any exist. Only the number of + // bytes written that correspond to the ciphertext will be included in + // the total bytes written, bytes written as part of the MAC will not be + // counted. + var nn int + if len(b.nextBodySend) > 0 { + // Write out all bytes excluding the mac and shift the body + // slice depending on the number of actual bytes written. + n, err := w.Write(b.nextBodySend) + b.nextBodySend = b.nextBodySend[n:] + + // If we partially or fully wrote any of the body's MAC, we'll + // subtract that contribution from the total amount flushed to + // preserve the abstraction of returning the number of plaintext + // bytes written by the connection. + // + // There are three possible scenarios we must handle to ensure + // the returned value is correct. In the first case, the write + // straddles both payload and MAC bytes, and we must subtract + // the number of MAC bytes written from n. In the second, only + // payload bytes are written, thus we can return n unmodified. + // The final scenario pertains to the case where only MAC bytes + // are written, none of which count towards the total. + // + // |-----------Payload------------|----MAC----| + // Straddle: S---------------------------------E--------0 + // Payload-only: S------------------------E-----------------0 + // MAC-only: S-------E-0 + start, end := n+len(b.nextBodySend), len(b.nextBodySend) + switch { + + // Straddles payload and MAC bytes, subtract number of MAC bytes + // written from the actual number written. + case start > macSize && end <= macSize: + nn = n - (macSize - end) + + // Only payload bytes are written, return n directly. + case start > macSize && end > macSize: + nn = n + + // Only MAC bytes are written, return 0 bytes written. + default: + } + + if err != nil { + return nn, err + } + } + + return nn, nil +} + // ReadMessage attempts to read the next message from the passed io.Reader. In // the case of an authentication error, a non-nil error is returned. func (b *Machine) ReadMessage(r io.Reader) ([]byte, error) {