diff --git a/brontide/conn.go b/brontide/conn.go index 643ff860..0ebed66f 100644 --- a/brontide/conn.go +++ b/brontide/conn.go @@ -166,7 +166,11 @@ func (c *Conn) Write(b []byte) (n int, err error) { // If the message doesn't require any chunking, then we can go ahead // with a single write. if len(b) <= math.MaxUint16 { - return len(b), c.noise.WriteMessage(c.conn, b) + err = c.noise.WriteMessage(b) + if err != nil { + return 0, err + } + return c.noise.Flush(c.conn) } // If we need to split the message into fragments, then we'll write @@ -185,16 +189,43 @@ func (c *Conn) Write(b []byte) (n int, err error) { // Slice off the next chunk to be written based on our running // counter and next chunk size. chunk := b[bytesWritten : bytesWritten+chunkSize] - if err := c.noise.WriteMessage(c.conn, chunk); err != nil { + if err := c.noise.WriteMessage(chunk); err != nil { return bytesWritten, err } - bytesWritten += len(chunk) + n, err := c.noise.Flush(c.conn) + bytesWritten += n + if err != nil { + return bytesWritten, err + } } return bytesWritten, nil } +// WriteMessage encrypts and buffers the next message p for the connection. The +// ciphertext of the message is prepended with an encrypt+auth'd length which +// must be used as the AD to the AEAD construction when being decrypted by the +// other side. +// +// NOTE: This DOES NOT write the message to the wire, it should be followed by a +// call to Flush to ensure the message is written. +func (c *Conn) WriteMessage(b []byte) error { + return c.noise.WriteMessage(b) +} + +// Flush attempts to write a message buffered using WriteMessage to the +// underlying connection. If no buffered message exists, this will result in a +// NOP. Otherwise, it will continue to write the remaining bytes, picking up +// where the byte stream left off in the event of a partial write. The number of +// bytes returned reflects the number of plaintext bytes in the payload, and +// does not account for the overhead of the header or MACs. +// +// NOTE: It is safe to call this method again iff a timeout error is returned. +func (c *Conn) Flush() (int, error) { + return c.noise.Flush(c.conn) +} + // Close closes the connection. Any blocked Read or Write operations will be // unblocked and return errors. // diff --git a/brontide/noise.go b/brontide/noise.go index 7b6bdb43..7d5f42d4 100644 --- a/brontide/noise.go +++ b/brontide/noise.go @@ -31,6 +31,10 @@ const ( // length of a message payload. lengthHeaderSize = 2 + // encHeaderSize is the number of bytes required to hold an encrypted + // header and it's MAC. + encHeaderSize = lengthHeaderSize + macSize + // keyRotationInterval is the number of messages sent on a single // cipher stream before the keys are rotated forwards. keyRotationInterval = 1000 @@ -48,6 +52,10 @@ var ( ErrMaxMessageLengthExceeded = errors.New("the generated payload exceeds " + "the max allowed message length of (2^16)-1") + // ErrMessageNotFlushed signals that the connection cannot accept a new + // message because the prior message has not been fully flushed. + ErrMessageNotFlushed = errors.New("prior message not flushed") + // lightningPrologue is the noise prologue that is used to initialize // the brontide noise handshake. lightningPrologue = []byte("lightning") @@ -366,7 +374,17 @@ type Machine struct { // nextCipherHeader is a static buffer that we'll use to read in the // next ciphertext header from the wire. The header is a 2 byte length // (of the next ciphertext), followed by a 16 byte MAC. - nextCipherHeader [lengthHeaderSize + macSize]byte + nextCipherHeader [encHeaderSize]byte + + // nextHeaderSend holds a reference to the remaining header bytes to + // write out for a pending message. This allows us to tolerate timeout + // errors that cause partial writes. + nextHeaderSend []byte + + // nextHeaderBody holds a reference to the remaining body bytes to write + // out for a pending message. This allows us to tolerate timeout errors + // that cause partial writes. + nextBodySend []byte } // NewBrontideMachine creates a new instance of the brontide state-machine. If @@ -682,11 +700,13 @@ func (b *Machine) split() { } } -// WriteMessage writes the next message p to the passed io.Writer. The -// ciphertext of the message is prepended with an encrypt+auth'd length which -// must be used as the AD to the AEAD construction when being decrypted by the -// other side. -func (b *Machine) WriteMessage(w io.Writer, p []byte) error { +// WriteMessage encrypts and buffers the next message p. The ciphertext of the +// message is prepended with an encrypt+auth'd length which must be used as the +// AD to the AEAD construction when being decrypted by the other side. +// +// NOTE: This DOES NOT write the message to the wire, it should be followed by a +// call to Flush to ensure the message is written. +func (b *Machine) WriteMessage(p []byte) error { // The total length of each message payload including the MAC size // payload exceed the largest number encodable within a 16-bit unsigned // integer. @@ -694,6 +714,13 @@ func (b *Machine) WriteMessage(w io.Writer, p []byte) error { return ErrMaxMessageLengthExceeded } + // If a prior message was written but it hasn't been fully flushed, + // return an error as we only support buffering of one message at a + // time. + if len(b.nextHeaderSend) > 0 || len(b.nextBodySend) > 0 { + return ErrMessageNotFlushed + } + // The full length of the packet is only the packet length, and does // NOT include the MAC. fullLength := uint16(len(p)) @@ -701,18 +728,88 @@ func (b *Machine) WriteMessage(w io.Writer, p []byte) error { var pktLen [2]byte binary.BigEndian.PutUint16(pktLen[:], fullLength) - // First, write out the encrypted+MAC'd length prefix for the packet. - cipherLen := b.sendCipher.Encrypt(nil, nil, pktLen[:]) - if _, err := w.Write(cipherLen); err != nil { - return err + // First, generate the encrypted+MAC'd length prefix for the packet. + b.nextHeaderSend = b.sendCipher.Encrypt(nil, nil, pktLen[:]) + + // Finally, generate the encrypted packet itself. + b.nextBodySend = b.sendCipher.Encrypt(nil, nil, p) + + return nil +} + +// Flush attempts to write a message buffered using WriteMessage to the provided +// io.Writer. If no buffered message exists, this will result in a NOP. +// Otherwise, it will continue to write the remaining bytes, picking up where +// the byte stream left off in the event of a partial write. The number of bytes +// returned reflects the number of plaintext bytes in the payload, and does not +// account for the overhead of the header or MACs. +// +// NOTE: It is safe to call this method again iff a timeout error is returned. +func (b *Machine) Flush(w io.Writer) (int, error) { + // First, write out the pending header bytes, if any exist. Any header + // bytes written will not count towards the total amount flushed. + if len(b.nextHeaderSend) > 0 { + // Write any remaining header bytes and shift the slice to point + // to the next segment of unwritten bytes. If an error is + // encountered, we can continue to write the header from where + // we left off on a subsequent call to Flush. + n, err := w.Write(b.nextHeaderSend) + b.nextHeaderSend = b.nextHeaderSend[n:] + if err != nil { + return 0, err + } } - // Finally, write out the encrypted packet itself. We only write out a - // single packet, as any fragmentation should have taken place at a - // higher level. - cipherText := b.sendCipher.Encrypt(nil, nil, p) - _, err := w.Write(cipherText) - return err + // Next, write the pending body bytes, if any exist. Only the number of + // bytes written that correspond to the ciphertext will be included in + // the total bytes written, bytes written as part of the MAC will not be + // counted. + var nn int + if len(b.nextBodySend) > 0 { + // Write out all bytes excluding the mac and shift the body + // slice depending on the number of actual bytes written. + n, err := w.Write(b.nextBodySend) + b.nextBodySend = b.nextBodySend[n:] + + // If we partially or fully wrote any of the body's MAC, we'll + // subtract that contribution from the total amount flushed to + // preserve the abstraction of returning the number of plaintext + // bytes written by the connection. + // + // There are three possible scenarios we must handle to ensure + // the returned value is correct. In the first case, the write + // straddles both payload and MAC bytes, and we must subtract + // the number of MAC bytes written from n. In the second, only + // payload bytes are written, thus we can return n unmodified. + // The final scenario pertains to the case where only MAC bytes + // are written, none of which count towards the total. + // + // |-----------Payload------------|----MAC----| + // Straddle: S---------------------------------E--------0 + // Payload-only: S------------------------E-----------------0 + // MAC-only: S-------E-0 + start, end := n+len(b.nextBodySend), len(b.nextBodySend) + switch { + + // Straddles payload and MAC bytes, subtract number of MAC bytes + // written from the actual number written. + case start > macSize && end <= macSize: + nn = n - (macSize - end) + + // Only payload bytes are written, return n directly. + case start > macSize && end > macSize: + nn = n + + // Only MAC bytes are written, return 0 bytes written. + default: + } + + if err != nil { + return nn, err + } + } + + return nn, nil } // ReadMessage attempts to read the next message from the passed io.Reader. In diff --git a/brontide/noise_test.go b/brontide/noise_test.go index 1918a646..50de346c 100644 --- a/brontide/noise_test.go +++ b/brontide/noise_test.go @@ -8,6 +8,7 @@ import ( "net" "sync" "testing" + "testing/iotest" "github.com/btcsuite/btcd/btcec" "github.com/lightningnetwork/lnd/lnwire" @@ -220,11 +221,9 @@ func TestMaxPayloadLength(t *testing.T) { // payload length. payloadToReject := make([]byte, math.MaxUint16+1) - var buf bytes.Buffer - // A write of the payload generated above to the state machine should // be rejected as it's over the max payload length. - err := b.WriteMessage(&buf, payloadToReject) + err := b.WriteMessage(payloadToReject) if err != ErrMaxMessageLengthExceeded { t.Fatalf("payload is over the max allowed length, the write " + "should have been rejected") @@ -233,7 +232,7 @@ func TestMaxPayloadLength(t *testing.T) { // Generate another payload which should be accepted as a valid // payload. payloadToAccept := make([]byte, math.MaxUint16-1) - if err := b.WriteMessage(&buf, payloadToAccept); err != nil { + if err := b.WriteMessage(payloadToAccept); err != nil { t.Fatalf("write for payload was rejected, should have been " + "accepted") } @@ -243,7 +242,7 @@ func TestMaxPayloadLength(t *testing.T) { payloadToReject = make([]byte, math.MaxUint16+1) // This payload should be rejected. - err = b.WriteMessage(&buf, payloadToReject) + err = b.WriteMessage(payloadToReject) if err != ErrMaxMessageLengthExceeded { t.Fatalf("payload is over the max allowed length, the write " + "should have been rejected") @@ -270,6 +269,8 @@ func TestWriteMessageChunking(t *testing.T) { var wg sync.WaitGroup wg.Add(1) go func() { + defer wg.Done() + bytesWritten, err := localConn.Write(largeMessage) if err != nil { t.Fatalf("unable to write message: %v", err) @@ -281,7 +282,6 @@ func TestWriteMessageChunking(t *testing.T) { t.Fatalf("bytes not fully written!") } - wg.Done() }() // Attempt to read the entirety of the message generated above. @@ -502,10 +502,14 @@ func TestBolt0008TestVectors(t *testing.T) { var buf bytes.Buffer for i := 0; i < 1002; i++ { - err = initiator.WriteMessage(&buf, payload) + err = initiator.WriteMessage(payload) if err != nil { t.Fatalf("could not write message %s", payload) } + _, err = initiator.Flush(&buf) + if err != nil { + t.Fatalf("could not flush message: %v", err) + } if val, ok := transportMessageVectors[i]; ok { binaryVal, err := hex.DecodeString(val) if err != nil { @@ -534,3 +538,176 @@ func TestBolt0008TestVectors(t *testing.T) { buf.Reset() } } + +// timeoutWriter wraps an io.Writer and throws an iotest.ErrTimeout after +// writing n bytes. +type timeoutWriter struct { + w io.Writer + n int64 +} + +func NewTimeoutWriter(w io.Writer, n int64) io.Writer { + return &timeoutWriter{w, n} +} + +func (t *timeoutWriter) Write(p []byte) (int, error) { + n := len(p) + if int64(n) > t.n { + n = int(t.n) + } + n, err := t.w.Write(p[:n]) + t.n -= int64(n) + if err == nil && t.n == 0 { + return n, iotest.ErrTimeout + } + return n, err +} + +const payloadSize = 10 + +type flushChunk struct { + errAfter int64 + expN int + expErr error +} + +type flushTest struct { + name string + chunks []flushChunk +} + +var flushTests = []flushTest{ + { + name: "partial header write", + chunks: []flushChunk{ + // Write 18-byte header in two parts, 16 then 2. + { + errAfter: encHeaderSize - 2, + expN: 0, + expErr: iotest.ErrTimeout, + }, + { + errAfter: 2, + expN: 0, + expErr: iotest.ErrTimeout, + }, + // Write payload and MAC in one go. + { + errAfter: -1, + expN: payloadSize, + }, + }, + }, + { + name: "full payload then full mac", + chunks: []flushChunk{ + // Write entire header and entire payload w/o MAC. + { + errAfter: encHeaderSize + payloadSize, + expN: payloadSize, + expErr: iotest.ErrTimeout, + }, + // Write the entire MAC. + { + errAfter: -1, + expN: 0, + }, + }, + }, + { + name: "payload-only, straddle, mac-only", + chunks: []flushChunk{ + // Write header and all but last byte of payload. + { + errAfter: encHeaderSize + payloadSize - 1, + expN: payloadSize - 1, + expErr: iotest.ErrTimeout, + }, + // Write last byte of payload and first byte of MAC. + { + errAfter: 2, + expN: 1, + expErr: iotest.ErrTimeout, + }, + // Write 10 bytes of the MAC. + { + errAfter: 10, + expN: 0, + expErr: iotest.ErrTimeout, + }, + // Write the remaining 5 MAC bytes. + { + errAfter: -1, + expN: 0, + }, + }, + }, +} + +// TestFlush asserts a Machine's ability to handle timeouts during Flush that +// cause partial writes, and that the machine can properly resume writes on +// subsequent calls to Flush. +func TestFlush(t *testing.T) { + // Run each test individually, to assert that they pass in isolation. + for _, test := range flushTests { + t.Run(test.name, func(t *testing.T) { + var ( + w bytes.Buffer + b Machine + ) + b.split() + testFlush(t, test, &b, &w) + }) + } + + // Finally, run the tests serially as if all on one connection. + t.Run("flush serial", func(t *testing.T) { + var ( + w bytes.Buffer + b Machine + ) + b.split() + for _, test := range flushTests { + testFlush(t, test, &b, &w) + } + }) +} + +// testFlush buffers a message on the Machine, then flushes it to the io.Writer +// in chunks. Once complete, a final call to flush is made to assert that Write +// is not called again. +func testFlush(t *testing.T, test flushTest, b *Machine, w io.Writer) { + payload := make([]byte, payloadSize) + if err := b.WriteMessage(payload); err != nil { + t.Fatalf("unable to write message: %v", err) + } + + for _, chunk := range test.chunks { + assertFlush(t, b, w, chunk.errAfter, chunk.expN, chunk.expErr) + } + + // We should always be able to call Flush after a message has been + // successfully written, and it should result in a NOP. + assertFlush(t, b, w, 0, 0, nil) +} + +// assertFlush flushes a chunk to the passed io.Writer. If n >= 0, a +// timeoutWriter will be used the flush should stop with iotest.ErrTimeout after +// n bytes. The method asserts that the returned error matches expErr and that +// the number of bytes written by Flush matches expN. +func assertFlush(t *testing.T, b *Machine, w io.Writer, n int64, expN int, + expErr error) { + + t.Helper() + + if n >= 0 { + w = NewTimeoutWriter(w, n) + } + nn, err := b.Flush(w) + if err != expErr { + t.Fatalf("expected flush err: %v, got: %v", expErr, err) + } + if nn != expN { + t.Fatalf("expected n: %d, got: %d", expN, nn) + } +} diff --git a/lncfg/workers.go b/lncfg/workers.go index bbd9960b..136b38d7 100644 --- a/lncfg/workers.go +++ b/lncfg/workers.go @@ -9,7 +9,7 @@ const ( // DefaultWriteWorkers is the default maximum number of concurrent // workers used by the daemon's write pool. - DefaultWriteWorkers = 100 + DefaultWriteWorkers = 8 // DefaultSigWorkers is the default maximum number of concurrent workers // used by the daemon's sig pool. @@ -20,13 +20,13 @@ const ( // pools. type Workers struct { // Read is the maximum number of concurrent read pool workers. - Read int `long:"read" description:"Maximum number of concurrent read pool workers."` + Read int `long:"read" description:"Maximum number of concurrent read pool workers. This number should be proportional to the number of peers."` // Write is the maximum number of concurrent write pool workers. - Write int `long:"write" description:"Maximum number of concurrent write pool workers."` + Write int `long:"write" description:"Maximum number of concurrent write pool workers. This number should be proportional to the number of CPUs on the host. "` // Sig is the maximum number of concurrent sig pool workers. - Sig int `long:"sig" description:"Maximum number of concurrent sig pool workers."` + Sig int `long:"sig" description:"Maximum number of concurrent sig pool workers. This number should be proportional to the number of CPUs on the host."` } // Validate checks the Workers configuration to ensure that the input values are diff --git a/peer.go b/peer.go index d7fcf2e1..0e79088f 100644 --- a/peer.go +++ b/peer.go @@ -1403,16 +1403,58 @@ func (p *peer) logWireMessage(msg lnwire.Message, read bool) { })) } -// writeMessage writes the target lnwire.Message to the remote peer. +// writeMessage writes and flushes the target lnwire.Message to the remote peer. +// If the passed message is nil, this method will only try to flush an existing +// message buffered on the connection. It is safe to recall this method with a +// nil message iff a timeout error is returned. This will continue to flush the +// pending message to the wire. func (p *peer) writeMessage(msg lnwire.Message) error { // Simply exit if we're shutting down. if atomic.LoadInt32(&p.disconnect) != 0 { return ErrPeerExiting } - p.logWireMessage(msg, false) + // Only log the message on the first attempt. + if msg != nil { + p.logWireMessage(msg, false) + } - var n int + noiseConn, ok := p.conn.(*brontide.Conn) + if !ok { + return fmt.Errorf("brontide.Conn required to write messages") + } + + flushMsg := func() error { + // Ensure the write deadline is set before we attempt to send + // the message. + writeDeadline := time.Now().Add(writeMessageTimeout) + err := noiseConn.SetWriteDeadline(writeDeadline) + if err != nil { + return err + } + + // Flush the pending message to the wire. If an error is + // encountered, e.g. write timeout, the number of bytes written + // so far will be returned. + n, err := noiseConn.Flush() + + // Record the number of bytes written on the wire, if any. + if n > 0 { + atomic.AddUint64(&p.bytesSent, uint64(n)) + } + + return err + } + + // If the current message has already been serialized, encrypted, and + // buffered on the underlying connection we will skip straight to + // flushing it to the wire. + if msg == nil { + return flushMsg() + } + + // Otherwise, this is a new message. We'll acquire a write buffer to + // serialize the message and buffer the ciphertext on the connection. err := p.writePool.Submit(func(buf *bytes.Buffer) error { // Using a buffer allocated by the write pool, encode the // message directly into the buffer. @@ -1421,25 +1463,17 @@ func (p *peer) writeMessage(msg lnwire.Message) error { return writeErr } - // Ensure the write deadline is set before we attempt to send - // the message. - writeDeadline := time.Now().Add(writeMessageTimeout) - writeErr = p.conn.SetWriteDeadline(writeDeadline) - if writeErr != nil { - return writeErr - } - - // Finally, write the message itself in a single swoop. - n, writeErr = p.conn.Write(buf.Bytes()) - return writeErr + // Finally, write the message itself in a single swoop. This + // will buffer the ciphertext on the underlying connection. We + // will defer flushing the message until the write pool has been + // released. + return noiseConn.WriteMessage(buf.Bytes()) }) - - // Record the number of bytes written on the wire, if any. - if n > 0 { - atomic.AddUint64(&p.bytesSent, uint64(n)) + if err != nil { + return err } - return err + return flushMsg() } // writeHandler is a goroutine dedicated to reading messages off of an incoming @@ -1459,38 +1493,10 @@ func (p *peer) writeHandler() { var exitErr error - const ( - minRetryDelay = 5 * time.Second - maxRetryDelay = time.Minute - ) - out: for { select { case outMsg := <-p.sendQueue: - // Record the time at which we first attempt to send the - // message. - startTime := time.Now() - - // Initialize a retry delay of zero, which will be - // increased if we encounter a write timeout on the - // send. - var retryDelay time.Duration - retryWithDelay: - if retryDelay > 0 { - select { - case <-time.After(retryDelay): - case <-p.quit: - // Inform synchronous writes that the - // peer is exiting. - if outMsg.errChan != nil { - outMsg.errChan <- ErrPeerExiting - } - exitErr = ErrPeerExiting - break out - } - } - // If we're about to send a ping message, then log the // exact time in which we send the message so we can // use the delay as a rough estimate of latency to the @@ -1502,33 +1508,32 @@ out: atomic.StoreInt64(&p.pingLastSend, now) } + // Record the time at which we first attempt to send the + // message. + startTime := time.Now() + + retry: // Write out the message to the socket. If a timeout // error is encountered, we will catch this and retry // after backing off in case the remote peer is just // slow to process messages from the wire. err := p.writeMessage(outMsg.msg) if nerr, ok := err.(net.Error); ok && nerr.Timeout() { - // Increase the retry delay in the event of a - // timeout error, this prevents us from - // disconnecting if the remote party is slow to - // pull messages off the wire. We back off - // exponentially up to our max delay to prevent - // blocking the write pool. - if retryDelay == 0 { - retryDelay = minRetryDelay - } else { - retryDelay *= 2 - if retryDelay > maxRetryDelay { - retryDelay = maxRetryDelay - } - } - peerLog.Debugf("Write timeout detected for "+ - "peer %s, retrying after %v, "+ - "first attempted %v ago", p, retryDelay, + "peer %s, first write for message "+ + "attempted %v ago", p, time.Since(startTime)) - goto retryWithDelay + // If we received a timeout error, this implies + // that the message was buffered on the + // connection successfully and that a flush was + // attempted. We'll set the message to nil so + // that on a subsequent pass we only try to + // flush the buffered message, and forgo + // reserializing or reencrypting it. + outMsg.msg = nil + + goto retry } // The write succeeded, reset the idle timer to prevent