brontide/conn: migrate to WriteMessage + Flush
This commit modifies WriteMessage to only perform encryption on the passed plaintext, and buffer the ciphertext within the connection object. We then modify internal uses of WriteMessage to follow with a call to Flush, which actually writes the message to the wire. Additionally, since WriteMessage does not actually perform the write itself, the io.Writer argument is removed from the function signature and all call sites.
This commit is contained in:
parent
ed8fe4bc82
commit
73cf352daa
@ -166,7 +166,11 @@ func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
// If the message doesn't require any chunking, then we can go ahead
|
||||
// with a single write.
|
||||
if len(b) <= math.MaxUint16 {
|
||||
return len(b), c.noise.WriteMessage(c.conn, b)
|
||||
err = c.noise.WriteMessage(b)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return c.noise.Flush(c.conn)
|
||||
}
|
||||
|
||||
// If we need to split the message into fragments, then we'll write
|
||||
@ -185,11 +189,15 @@ func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
// Slice off the next chunk to be written based on our running
|
||||
// counter and next chunk size.
|
||||
chunk := b[bytesWritten : bytesWritten+chunkSize]
|
||||
if err := c.noise.WriteMessage(c.conn, chunk); err != nil {
|
||||
if err := c.noise.WriteMessage(chunk); err != nil {
|
||||
return bytesWritten, err
|
||||
}
|
||||
|
||||
bytesWritten += len(chunk)
|
||||
n, err := c.noise.Flush(c.conn)
|
||||
bytesWritten += n
|
||||
if err != nil {
|
||||
return bytesWritten, err
|
||||
}
|
||||
}
|
||||
|
||||
return bytesWritten, nil
|
||||
|
@ -48,6 +48,10 @@ var (
|
||||
ErrMaxMessageLengthExceeded = errors.New("the generated payload exceeds " +
|
||||
"the max allowed message length of (2^16)-1")
|
||||
|
||||
// ErrMessageNotFlushed signals that the connection cannot accept a new
|
||||
// message because the prior message has not been fully flushed.
|
||||
ErrMessageNotFlushed = errors.New("prior message not flushed")
|
||||
|
||||
// lightningPrologue is the noise prologue that is used to initialize
|
||||
// the brontide noise handshake.
|
||||
lightningPrologue = []byte("lightning")
|
||||
@ -692,11 +696,13 @@ func (b *Machine) split() {
|
||||
}
|
||||
}
|
||||
|
||||
// WriteMessage writes the next message p to the passed io.Writer. The
|
||||
// ciphertext of the message is prepended with an encrypt+auth'd length which
|
||||
// must be used as the AD to the AEAD construction when being decrypted by the
|
||||
// other side.
|
||||
func (b *Machine) WriteMessage(w io.Writer, p []byte) error {
|
||||
// WriteMessage encrypts and buffers the next message p. The ciphertext of the
|
||||
// message is prepended with an encrypt+auth'd length which must be used as the
|
||||
// AD to the AEAD construction when being decrypted by the other side.
|
||||
//
|
||||
// NOTE: This DOES NOT write the message to the wire, it should be followed by a
|
||||
// call to Flush to ensure the message is written.
|
||||
func (b *Machine) WriteMessage(p []byte) error {
|
||||
// The total length of each message payload including the MAC size
|
||||
// payload exceed the largest number encodable within a 16-bit unsigned
|
||||
// integer.
|
||||
@ -704,6 +710,13 @@ func (b *Machine) WriteMessage(w io.Writer, p []byte) error {
|
||||
return ErrMaxMessageLengthExceeded
|
||||
}
|
||||
|
||||
// If a prior message was written but it hasn't been fully flushed,
|
||||
// return an error as we only support buffering of one message at a
|
||||
// time.
|
||||
if len(b.nextHeaderSend) > 0 || len(b.nextBodySend) > 0 {
|
||||
return ErrMessageNotFlushed
|
||||
}
|
||||
|
||||
// The full length of the packet is only the packet length, and does
|
||||
// NOT include the MAC.
|
||||
fullLength := uint16(len(p))
|
||||
@ -711,18 +724,13 @@ func (b *Machine) WriteMessage(w io.Writer, p []byte) error {
|
||||
var pktLen [2]byte
|
||||
binary.BigEndian.PutUint16(pktLen[:], fullLength)
|
||||
|
||||
// First, write out the encrypted+MAC'd length prefix for the packet.
|
||||
cipherLen := b.sendCipher.Encrypt(nil, nil, pktLen[:])
|
||||
if _, err := w.Write(cipherLen); err != nil {
|
||||
return err
|
||||
}
|
||||
// First, generate the encrypted+MAC'd length prefix for the packet.
|
||||
b.nextHeaderSend = b.sendCipher.Encrypt(nil, nil, pktLen[:])
|
||||
|
||||
// Finally, write out the encrypted packet itself. We only write out a
|
||||
// single packet, as any fragmentation should have taken place at a
|
||||
// higher level.
|
||||
cipherText := b.sendCipher.Encrypt(nil, nil, p)
|
||||
_, err := w.Write(cipherText)
|
||||
return err
|
||||
// Finally, generate the encrypted packet itself.
|
||||
b.nextBodySend = b.sendCipher.Encrypt(nil, nil, p)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush attempts to write a message buffered using WriteMessage to the provided
|
||||
|
@ -220,11 +220,9 @@ func TestMaxPayloadLength(t *testing.T) {
|
||||
// payload length.
|
||||
payloadToReject := make([]byte, math.MaxUint16+1)
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
// A write of the payload generated above to the state machine should
|
||||
// be rejected as it's over the max payload length.
|
||||
err := b.WriteMessage(&buf, payloadToReject)
|
||||
err := b.WriteMessage(payloadToReject)
|
||||
if err != ErrMaxMessageLengthExceeded {
|
||||
t.Fatalf("payload is over the max allowed length, the write " +
|
||||
"should have been rejected")
|
||||
@ -233,7 +231,7 @@ func TestMaxPayloadLength(t *testing.T) {
|
||||
// Generate another payload which should be accepted as a valid
|
||||
// payload.
|
||||
payloadToAccept := make([]byte, math.MaxUint16-1)
|
||||
if err := b.WriteMessage(&buf, payloadToAccept); err != nil {
|
||||
if err := b.WriteMessage(payloadToAccept); err != nil {
|
||||
t.Fatalf("write for payload was rejected, should have been " +
|
||||
"accepted")
|
||||
}
|
||||
@ -243,7 +241,7 @@ func TestMaxPayloadLength(t *testing.T) {
|
||||
payloadToReject = make([]byte, math.MaxUint16+1)
|
||||
|
||||
// This payload should be rejected.
|
||||
err = b.WriteMessage(&buf, payloadToReject)
|
||||
err = b.WriteMessage(payloadToReject)
|
||||
if err != ErrMaxMessageLengthExceeded {
|
||||
t.Fatalf("payload is over the max allowed length, the write " +
|
||||
"should have been rejected")
|
||||
@ -502,10 +500,14 @@ func TestBolt0008TestVectors(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
for i := 0; i < 1002; i++ {
|
||||
err = initiator.WriteMessage(&buf, payload)
|
||||
err = initiator.WriteMessage(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("could not write message %s", payload)
|
||||
}
|
||||
_, err = initiator.Flush(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("could not flush message: %v", err)
|
||||
}
|
||||
if val, ok := transportMessageVectors[i]; ok {
|
||||
binaryVal, err := hex.DecodeString(val)
|
||||
if err != nil {
|
||||
|
Loading…
Reference in New Issue
Block a user