brontide/conn: migrate to WriteMessage + Flush

This commit modifies WriteMessage to only perform encryption on the
passed plaintext, and buffer the ciphertext within the connection
object. We then modify internal uses of WriteMessage to follow with a
call to Flush, which actually writes the message to the wire.
Additionally, since WriteMessage does not actually perform the write
itself, the io.Writer argument is removed from the function signature
and all call sites.
This commit is contained in:
Conner Fromknecht 2019-04-22 16:03:56 -07:00
parent ed8fe4bc82
commit 73cf352daa
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7
3 changed files with 43 additions and 25 deletions

View File

@ -166,7 +166,11 @@ func (c *Conn) Write(b []byte) (n int, err error) {
// If the message doesn't require any chunking, then we can go ahead // If the message doesn't require any chunking, then we can go ahead
// with a single write. // with a single write.
if len(b) <= math.MaxUint16 { if len(b) <= math.MaxUint16 {
return len(b), c.noise.WriteMessage(c.conn, b) err = c.noise.WriteMessage(b)
if err != nil {
return 0, err
}
return c.noise.Flush(c.conn)
} }
// If we need to split the message into fragments, then we'll write // If we need to split the message into fragments, then we'll write
@ -185,11 +189,15 @@ func (c *Conn) Write(b []byte) (n int, err error) {
// Slice off the next chunk to be written based on our running // Slice off the next chunk to be written based on our running
// counter and next chunk size. // counter and next chunk size.
chunk := b[bytesWritten : bytesWritten+chunkSize] chunk := b[bytesWritten : bytesWritten+chunkSize]
if err := c.noise.WriteMessage(c.conn, chunk); err != nil { if err := c.noise.WriteMessage(chunk); err != nil {
return bytesWritten, err return bytesWritten, err
} }
bytesWritten += len(chunk) n, err := c.noise.Flush(c.conn)
bytesWritten += n
if err != nil {
return bytesWritten, err
}
} }
return bytesWritten, nil return bytesWritten, nil

View File

@ -48,6 +48,10 @@ var (
ErrMaxMessageLengthExceeded = errors.New("the generated payload exceeds " + ErrMaxMessageLengthExceeded = errors.New("the generated payload exceeds " +
"the max allowed message length of (2^16)-1") "the max allowed message length of (2^16)-1")
// ErrMessageNotFlushed signals that the connection cannot accept a new
// message because the prior message has not been fully flushed.
ErrMessageNotFlushed = errors.New("prior message not flushed")
// lightningPrologue is the noise prologue that is used to initialize // lightningPrologue is the noise prologue that is used to initialize
// the brontide noise handshake. // the brontide noise handshake.
lightningPrologue = []byte("lightning") lightningPrologue = []byte("lightning")
@ -692,11 +696,13 @@ func (b *Machine) split() {
} }
} }
// WriteMessage writes the next message p to the passed io.Writer. The // WriteMessage encrypts and buffers the next message p. The ciphertext of the
// ciphertext of the message is prepended with an encrypt+auth'd length which // message is prepended with an encrypt+auth'd length which must be used as the
// must be used as the AD to the AEAD construction when being decrypted by the // AD to the AEAD construction when being decrypted by the other side.
// other side. //
func (b *Machine) WriteMessage(w io.Writer, p []byte) error { // NOTE: This DOES NOT write the message to the wire, it should be followed by a
// call to Flush to ensure the message is written.
func (b *Machine) WriteMessage(p []byte) error {
// The total length of each message payload including the MAC size // The total length of each message payload including the MAC size
// payload exceed the largest number encodable within a 16-bit unsigned // payload exceed the largest number encodable within a 16-bit unsigned
// integer. // integer.
@ -704,6 +710,13 @@ func (b *Machine) WriteMessage(w io.Writer, p []byte) error {
return ErrMaxMessageLengthExceeded return ErrMaxMessageLengthExceeded
} }
// If a prior message was written but it hasn't been fully flushed,
// return an error as we only support buffering of one message at a
// time.
if len(b.nextHeaderSend) > 0 || len(b.nextBodySend) > 0 {
return ErrMessageNotFlushed
}
// The full length of the packet is only the packet length, and does // The full length of the packet is only the packet length, and does
// NOT include the MAC. // NOT include the MAC.
fullLength := uint16(len(p)) fullLength := uint16(len(p))
@ -711,18 +724,13 @@ func (b *Machine) WriteMessage(w io.Writer, p []byte) error {
var pktLen [2]byte var pktLen [2]byte
binary.BigEndian.PutUint16(pktLen[:], fullLength) binary.BigEndian.PutUint16(pktLen[:], fullLength)
// First, write out the encrypted+MAC'd length prefix for the packet. // First, generate the encrypted+MAC'd length prefix for the packet.
cipherLen := b.sendCipher.Encrypt(nil, nil, pktLen[:]) b.nextHeaderSend = b.sendCipher.Encrypt(nil, nil, pktLen[:])
if _, err := w.Write(cipherLen); err != nil {
return err
}
// Finally, write out the encrypted packet itself. We only write out a // Finally, generate the encrypted packet itself.
// single packet, as any fragmentation should have taken place at a b.nextBodySend = b.sendCipher.Encrypt(nil, nil, p)
// higher level.
cipherText := b.sendCipher.Encrypt(nil, nil, p) return nil
_, err := w.Write(cipherText)
return err
} }
// Flush attempts to write a message buffered using WriteMessage to the provided // Flush attempts to write a message buffered using WriteMessage to the provided

View File

@ -220,11 +220,9 @@ func TestMaxPayloadLength(t *testing.T) {
// payload length. // payload length.
payloadToReject := make([]byte, math.MaxUint16+1) payloadToReject := make([]byte, math.MaxUint16+1)
var buf bytes.Buffer
// A write of the payload generated above to the state machine should // A write of the payload generated above to the state machine should
// be rejected as it's over the max payload length. // be rejected as it's over the max payload length.
err := b.WriteMessage(&buf, payloadToReject) err := b.WriteMessage(payloadToReject)
if err != ErrMaxMessageLengthExceeded { if err != ErrMaxMessageLengthExceeded {
t.Fatalf("payload is over the max allowed length, the write " + t.Fatalf("payload is over the max allowed length, the write " +
"should have been rejected") "should have been rejected")
@ -233,7 +231,7 @@ func TestMaxPayloadLength(t *testing.T) {
// Generate another payload which should be accepted as a valid // Generate another payload which should be accepted as a valid
// payload. // payload.
payloadToAccept := make([]byte, math.MaxUint16-1) payloadToAccept := make([]byte, math.MaxUint16-1)
if err := b.WriteMessage(&buf, payloadToAccept); err != nil { if err := b.WriteMessage(payloadToAccept); err != nil {
t.Fatalf("write for payload was rejected, should have been " + t.Fatalf("write for payload was rejected, should have been " +
"accepted") "accepted")
} }
@ -243,7 +241,7 @@ func TestMaxPayloadLength(t *testing.T) {
payloadToReject = make([]byte, math.MaxUint16+1) payloadToReject = make([]byte, math.MaxUint16+1)
// This payload should be rejected. // This payload should be rejected.
err = b.WriteMessage(&buf, payloadToReject) err = b.WriteMessage(payloadToReject)
if err != ErrMaxMessageLengthExceeded { if err != ErrMaxMessageLengthExceeded {
t.Fatalf("payload is over the max allowed length, the write " + t.Fatalf("payload is over the max allowed length, the write " +
"should have been rejected") "should have been rejected")
@ -502,10 +500,14 @@ func TestBolt0008TestVectors(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
for i := 0; i < 1002; i++ { for i := 0; i < 1002; i++ {
err = initiator.WriteMessage(&buf, payload) err = initiator.WriteMessage(payload)
if err != nil { if err != nil {
t.Fatalf("could not write message %s", payload) t.Fatalf("could not write message %s", payload)
} }
_, err = initiator.Flush(&buf)
if err != nil {
t.Fatalf("could not flush message: %v", err)
}
if val, ok := transportMessageVectors[i]; ok { if val, ok := transportMessageVectors[i]; ok {
binaryVal, err := hex.DecodeString(val) binaryVal, err := hex.DecodeString(val)
if err != nil { if err != nil {