Merge pull request #2924 from cfromknecht/write-and-flush

peer: resume partial writes due to timeouts
This commit is contained in:
Olaoluwa Osuntokun 2019-04-26 18:25:37 -07:00 committed by GitHub
commit 42b081bb37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 405 additions and 95 deletions

@ -166,7 +166,11 @@ func (c *Conn) Write(b []byte) (n int, err error) {
// If the message doesn't require any chunking, then we can go ahead
// with a single write.
if len(b) <= math.MaxUint16 {
return len(b), c.noise.WriteMessage(c.conn, b)
err = c.noise.WriteMessage(b)
if err != nil {
return 0, err
}
return c.noise.Flush(c.conn)
}
// If we need to split the message into fragments, then we'll write
@ -185,16 +189,43 @@ func (c *Conn) Write(b []byte) (n int, err error) {
// Slice off the next chunk to be written based on our running
// counter and next chunk size.
chunk := b[bytesWritten : bytesWritten+chunkSize]
if err := c.noise.WriteMessage(c.conn, chunk); err != nil {
if err := c.noise.WriteMessage(chunk); err != nil {
return bytesWritten, err
}
bytesWritten += len(chunk)
n, err := c.noise.Flush(c.conn)
bytesWritten += n
if err != nil {
return bytesWritten, err
}
}
return bytesWritten, nil
}
// WriteMessage encrypts and buffers the next message p for the connection. The
// ciphertext of the message is prepended with an encrypt+auth'd length which
// must be used as the AD to the AEAD construction when being decrypted by the
// other side.
//
// NOTE: This DOES NOT write the message to the wire, it should be followed by a
// call to Flush to ensure the message is written.
func (c *Conn) WriteMessage(b []byte) error {
return c.noise.WriteMessage(b)
}
// Flush attempts to write a message buffered using WriteMessage to the
// underlying connection. If no buffered message exists, this will result in a
// NOP. Otherwise, it will continue to write the remaining bytes, picking up
// where the byte stream left off in the event of a partial write. The number of
// bytes returned reflects the number of plaintext bytes in the payload, and
// does not account for the overhead of the header or MACs.
//
// NOTE: It is safe to call this method again iff a timeout error is returned.
func (c *Conn) Flush() (int, error) {
return c.noise.Flush(c.conn)
}
// Close closes the connection. Any blocked Read or Write operations will be
// unblocked and return errors.
//

@ -31,6 +31,10 @@ const (
// length of a message payload.
lengthHeaderSize = 2
// encHeaderSize is the number of bytes required to hold an encrypted
// header and it's MAC.
encHeaderSize = lengthHeaderSize + macSize
// keyRotationInterval is the number of messages sent on a single
// cipher stream before the keys are rotated forwards.
keyRotationInterval = 1000
@ -48,6 +52,10 @@ var (
ErrMaxMessageLengthExceeded = errors.New("the generated payload exceeds " +
"the max allowed message length of (2^16)-1")
// ErrMessageNotFlushed signals that the connection cannot accept a new
// message because the prior message has not been fully flushed.
ErrMessageNotFlushed = errors.New("prior message not flushed")
// lightningPrologue is the noise prologue that is used to initialize
// the brontide noise handshake.
lightningPrologue = []byte("lightning")
@ -366,7 +374,17 @@ type Machine struct {
// nextCipherHeader is a static buffer that we'll use to read in the
// next ciphertext header from the wire. The header is a 2 byte length
// (of the next ciphertext), followed by a 16 byte MAC.
nextCipherHeader [lengthHeaderSize + macSize]byte
nextCipherHeader [encHeaderSize]byte
// nextHeaderSend holds a reference to the remaining header bytes to
// write out for a pending message. This allows us to tolerate timeout
// errors that cause partial writes.
nextHeaderSend []byte
// nextHeaderBody holds a reference to the remaining body bytes to write
// out for a pending message. This allows us to tolerate timeout errors
// that cause partial writes.
nextBodySend []byte
}
// NewBrontideMachine creates a new instance of the brontide state-machine. If
@ -682,11 +700,13 @@ func (b *Machine) split() {
}
}
// WriteMessage writes the next message p to the passed io.Writer. The
// ciphertext of the message is prepended with an encrypt+auth'd length which
// must be used as the AD to the AEAD construction when being decrypted by the
// other side.
func (b *Machine) WriteMessage(w io.Writer, p []byte) error {
// WriteMessage encrypts and buffers the next message p. The ciphertext of the
// message is prepended with an encrypt+auth'd length which must be used as the
// AD to the AEAD construction when being decrypted by the other side.
//
// NOTE: This DOES NOT write the message to the wire, it should be followed by a
// call to Flush to ensure the message is written.
func (b *Machine) WriteMessage(p []byte) error {
// The total length of each message payload including the MAC size
// payload exceed the largest number encodable within a 16-bit unsigned
// integer.
@ -694,6 +714,13 @@ func (b *Machine) WriteMessage(w io.Writer, p []byte) error {
return ErrMaxMessageLengthExceeded
}
// If a prior message was written but it hasn't been fully flushed,
// return an error as we only support buffering of one message at a
// time.
if len(b.nextHeaderSend) > 0 || len(b.nextBodySend) > 0 {
return ErrMessageNotFlushed
}
// The full length of the packet is only the packet length, and does
// NOT include the MAC.
fullLength := uint16(len(p))
@ -701,18 +728,88 @@ func (b *Machine) WriteMessage(w io.Writer, p []byte) error {
var pktLen [2]byte
binary.BigEndian.PutUint16(pktLen[:], fullLength)
// First, write out the encrypted+MAC'd length prefix for the packet.
cipherLen := b.sendCipher.Encrypt(nil, nil, pktLen[:])
if _, err := w.Write(cipherLen); err != nil {
return err
// First, generate the encrypted+MAC'd length prefix for the packet.
b.nextHeaderSend = b.sendCipher.Encrypt(nil, nil, pktLen[:])
// Finally, generate the encrypted packet itself.
b.nextBodySend = b.sendCipher.Encrypt(nil, nil, p)
return nil
}
// Flush attempts to write a message buffered using WriteMessage to the provided
// io.Writer. If no buffered message exists, this will result in a NOP.
// Otherwise, it will continue to write the remaining bytes, picking up where
// the byte stream left off in the event of a partial write. The number of bytes
// returned reflects the number of plaintext bytes in the payload, and does not
// account for the overhead of the header or MACs.
//
// NOTE: It is safe to call this method again iff a timeout error is returned.
func (b *Machine) Flush(w io.Writer) (int, error) {
// First, write out the pending header bytes, if any exist. Any header
// bytes written will not count towards the total amount flushed.
if len(b.nextHeaderSend) > 0 {
// Write any remaining header bytes and shift the slice to point
// to the next segment of unwritten bytes. If an error is
// encountered, we can continue to write the header from where
// we left off on a subsequent call to Flush.
n, err := w.Write(b.nextHeaderSend)
b.nextHeaderSend = b.nextHeaderSend[n:]
if err != nil {
return 0, err
}
}
// Finally, write out the encrypted packet itself. We only write out a
// single packet, as any fragmentation should have taken place at a
// higher level.
cipherText := b.sendCipher.Encrypt(nil, nil, p)
_, err := w.Write(cipherText)
return err
// Next, write the pending body bytes, if any exist. Only the number of
// bytes written that correspond to the ciphertext will be included in
// the total bytes written, bytes written as part of the MAC will not be
// counted.
var nn int
if len(b.nextBodySend) > 0 {
// Write out all bytes excluding the mac and shift the body
// slice depending on the number of actual bytes written.
n, err := w.Write(b.nextBodySend)
b.nextBodySend = b.nextBodySend[n:]
// If we partially or fully wrote any of the body's MAC, we'll
// subtract that contribution from the total amount flushed to
// preserve the abstraction of returning the number of plaintext
// bytes written by the connection.
//
// There are three possible scenarios we must handle to ensure
// the returned value is correct. In the first case, the write
// straddles both payload and MAC bytes, and we must subtract
// the number of MAC bytes written from n. In the second, only
// payload bytes are written, thus we can return n unmodified.
// The final scenario pertains to the case where only MAC bytes
// are written, none of which count towards the total.
//
// |-----------Payload------------|----MAC----|
// Straddle: S---------------------------------E--------0
// Payload-only: S------------------------E-----------------0
// MAC-only: S-------E-0
start, end := n+len(b.nextBodySend), len(b.nextBodySend)
switch {
// Straddles payload and MAC bytes, subtract number of MAC bytes
// written from the actual number written.
case start > macSize && end <= macSize:
nn = n - (macSize - end)
// Only payload bytes are written, return n directly.
case start > macSize && end > macSize:
nn = n
// Only MAC bytes are written, return 0 bytes written.
default:
}
if err != nil {
return nn, err
}
}
return nn, nil
}
// ReadMessage attempts to read the next message from the passed io.Reader. In

@ -8,6 +8,7 @@ import (
"net"
"sync"
"testing"
"testing/iotest"
"github.com/btcsuite/btcd/btcec"
"github.com/lightningnetwork/lnd/lnwire"
@ -220,11 +221,9 @@ func TestMaxPayloadLength(t *testing.T) {
// payload length.
payloadToReject := make([]byte, math.MaxUint16+1)
var buf bytes.Buffer
// A write of the payload generated above to the state machine should
// be rejected as it's over the max payload length.
err := b.WriteMessage(&buf, payloadToReject)
err := b.WriteMessage(payloadToReject)
if err != ErrMaxMessageLengthExceeded {
t.Fatalf("payload is over the max allowed length, the write " +
"should have been rejected")
@ -233,7 +232,7 @@ func TestMaxPayloadLength(t *testing.T) {
// Generate another payload which should be accepted as a valid
// payload.
payloadToAccept := make([]byte, math.MaxUint16-1)
if err := b.WriteMessage(&buf, payloadToAccept); err != nil {
if err := b.WriteMessage(payloadToAccept); err != nil {
t.Fatalf("write for payload was rejected, should have been " +
"accepted")
}
@ -243,7 +242,7 @@ func TestMaxPayloadLength(t *testing.T) {
payloadToReject = make([]byte, math.MaxUint16+1)
// This payload should be rejected.
err = b.WriteMessage(&buf, payloadToReject)
err = b.WriteMessage(payloadToReject)
if err != ErrMaxMessageLengthExceeded {
t.Fatalf("payload is over the max allowed length, the write " +
"should have been rejected")
@ -270,6 +269,8 @@ func TestWriteMessageChunking(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
bytesWritten, err := localConn.Write(largeMessage)
if err != nil {
t.Fatalf("unable to write message: %v", err)
@ -281,7 +282,6 @@ func TestWriteMessageChunking(t *testing.T) {
t.Fatalf("bytes not fully written!")
}
wg.Done()
}()
// Attempt to read the entirety of the message generated above.
@ -502,10 +502,14 @@ func TestBolt0008TestVectors(t *testing.T) {
var buf bytes.Buffer
for i := 0; i < 1002; i++ {
err = initiator.WriteMessage(&buf, payload)
err = initiator.WriteMessage(payload)
if err != nil {
t.Fatalf("could not write message %s", payload)
}
_, err = initiator.Flush(&buf)
if err != nil {
t.Fatalf("could not flush message: %v", err)
}
if val, ok := transportMessageVectors[i]; ok {
binaryVal, err := hex.DecodeString(val)
if err != nil {
@ -534,3 +538,176 @@ func TestBolt0008TestVectors(t *testing.T) {
buf.Reset()
}
}
// timeoutWriter wraps an io.Writer and throws an iotest.ErrTimeout after
// writing n bytes.
type timeoutWriter struct {
w io.Writer
n int64
}
func NewTimeoutWriter(w io.Writer, n int64) io.Writer {
return &timeoutWriter{w, n}
}
func (t *timeoutWriter) Write(p []byte) (int, error) {
n := len(p)
if int64(n) > t.n {
n = int(t.n)
}
n, err := t.w.Write(p[:n])
t.n -= int64(n)
if err == nil && t.n == 0 {
return n, iotest.ErrTimeout
}
return n, err
}
const payloadSize = 10
type flushChunk struct {
errAfter int64
expN int
expErr error
}
type flushTest struct {
name string
chunks []flushChunk
}
var flushTests = []flushTest{
{
name: "partial header write",
chunks: []flushChunk{
// Write 18-byte header in two parts, 16 then 2.
{
errAfter: encHeaderSize - 2,
expN: 0,
expErr: iotest.ErrTimeout,
},
{
errAfter: 2,
expN: 0,
expErr: iotest.ErrTimeout,
},
// Write payload and MAC in one go.
{
errAfter: -1,
expN: payloadSize,
},
},
},
{
name: "full payload then full mac",
chunks: []flushChunk{
// Write entire header and entire payload w/o MAC.
{
errAfter: encHeaderSize + payloadSize,
expN: payloadSize,
expErr: iotest.ErrTimeout,
},
// Write the entire MAC.
{
errAfter: -1,
expN: 0,
},
},
},
{
name: "payload-only, straddle, mac-only",
chunks: []flushChunk{
// Write header and all but last byte of payload.
{
errAfter: encHeaderSize + payloadSize - 1,
expN: payloadSize - 1,
expErr: iotest.ErrTimeout,
},
// Write last byte of payload and first byte of MAC.
{
errAfter: 2,
expN: 1,
expErr: iotest.ErrTimeout,
},
// Write 10 bytes of the MAC.
{
errAfter: 10,
expN: 0,
expErr: iotest.ErrTimeout,
},
// Write the remaining 5 MAC bytes.
{
errAfter: -1,
expN: 0,
},
},
},
}
// TestFlush asserts a Machine's ability to handle timeouts during Flush that
// cause partial writes, and that the machine can properly resume writes on
// subsequent calls to Flush.
func TestFlush(t *testing.T) {
// Run each test individually, to assert that they pass in isolation.
for _, test := range flushTests {
t.Run(test.name, func(t *testing.T) {
var (
w bytes.Buffer
b Machine
)
b.split()
testFlush(t, test, &b, &w)
})
}
// Finally, run the tests serially as if all on one connection.
t.Run("flush serial", func(t *testing.T) {
var (
w bytes.Buffer
b Machine
)
b.split()
for _, test := range flushTests {
testFlush(t, test, &b, &w)
}
})
}
// testFlush buffers a message on the Machine, then flushes it to the io.Writer
// in chunks. Once complete, a final call to flush is made to assert that Write
// is not called again.
func testFlush(t *testing.T, test flushTest, b *Machine, w io.Writer) {
payload := make([]byte, payloadSize)
if err := b.WriteMessage(payload); err != nil {
t.Fatalf("unable to write message: %v", err)
}
for _, chunk := range test.chunks {
assertFlush(t, b, w, chunk.errAfter, chunk.expN, chunk.expErr)
}
// We should always be able to call Flush after a message has been
// successfully written, and it should result in a NOP.
assertFlush(t, b, w, 0, 0, nil)
}
// assertFlush flushes a chunk to the passed io.Writer. If n >= 0, a
// timeoutWriter will be used the flush should stop with iotest.ErrTimeout after
// n bytes. The method asserts that the returned error matches expErr and that
// the number of bytes written by Flush matches expN.
func assertFlush(t *testing.T, b *Machine, w io.Writer, n int64, expN int,
expErr error) {
t.Helper()
if n >= 0 {
w = NewTimeoutWriter(w, n)
}
nn, err := b.Flush(w)
if err != expErr {
t.Fatalf("expected flush err: %v, got: %v", expErr, err)
}
if nn != expN {
t.Fatalf("expected n: %d, got: %d", expN, nn)
}
}

@ -9,7 +9,7 @@ const (
// DefaultWriteWorkers is the default maximum number of concurrent
// workers used by the daemon's write pool.
DefaultWriteWorkers = 100
DefaultWriteWorkers = 8
// DefaultSigWorkers is the default maximum number of concurrent workers
// used by the daemon's sig pool.
@ -20,13 +20,13 @@ const (
// pools.
type Workers struct {
// Read is the maximum number of concurrent read pool workers.
Read int `long:"read" description:"Maximum number of concurrent read pool workers."`
Read int `long:"read" description:"Maximum number of concurrent read pool workers. This number should be proportional to the number of peers."`
// Write is the maximum number of concurrent write pool workers.
Write int `long:"write" description:"Maximum number of concurrent write pool workers."`
Write int `long:"write" description:"Maximum number of concurrent write pool workers. This number should be proportional to the number of CPUs on the host. "`
// Sig is the maximum number of concurrent sig pool workers.
Sig int `long:"sig" description:"Maximum number of concurrent sig pool workers."`
Sig int `long:"sig" description:"Maximum number of concurrent sig pool workers. This number should be proportional to the number of CPUs on the host."`
}
// Validate checks the Workers configuration to ensure that the input values are

135
peer.go

@ -1403,16 +1403,58 @@ func (p *peer) logWireMessage(msg lnwire.Message, read bool) {
}))
}
// writeMessage writes the target lnwire.Message to the remote peer.
// writeMessage writes and flushes the target lnwire.Message to the remote peer.
// If the passed message is nil, this method will only try to flush an existing
// message buffered on the connection. It is safe to recall this method with a
// nil message iff a timeout error is returned. This will continue to flush the
// pending message to the wire.
func (p *peer) writeMessage(msg lnwire.Message) error {
// Simply exit if we're shutting down.
if atomic.LoadInt32(&p.disconnect) != 0 {
return ErrPeerExiting
}
p.logWireMessage(msg, false)
// Only log the message on the first attempt.
if msg != nil {
p.logWireMessage(msg, false)
}
var n int
noiseConn, ok := p.conn.(*brontide.Conn)
if !ok {
return fmt.Errorf("brontide.Conn required to write messages")
}
flushMsg := func() error {
// Ensure the write deadline is set before we attempt to send
// the message.
writeDeadline := time.Now().Add(writeMessageTimeout)
err := noiseConn.SetWriteDeadline(writeDeadline)
if err != nil {
return err
}
// Flush the pending message to the wire. If an error is
// encountered, e.g. write timeout, the number of bytes written
// so far will be returned.
n, err := noiseConn.Flush()
// Record the number of bytes written on the wire, if any.
if n > 0 {
atomic.AddUint64(&p.bytesSent, uint64(n))
}
return err
}
// If the current message has already been serialized, encrypted, and
// buffered on the underlying connection we will skip straight to
// flushing it to the wire.
if msg == nil {
return flushMsg()
}
// Otherwise, this is a new message. We'll acquire a write buffer to
// serialize the message and buffer the ciphertext on the connection.
err := p.writePool.Submit(func(buf *bytes.Buffer) error {
// Using a buffer allocated by the write pool, encode the
// message directly into the buffer.
@ -1421,25 +1463,17 @@ func (p *peer) writeMessage(msg lnwire.Message) error {
return writeErr
}
// Ensure the write deadline is set before we attempt to send
// the message.
writeDeadline := time.Now().Add(writeMessageTimeout)
writeErr = p.conn.SetWriteDeadline(writeDeadline)
if writeErr != nil {
return writeErr
}
// Finally, write the message itself in a single swoop.
n, writeErr = p.conn.Write(buf.Bytes())
return writeErr
// Finally, write the message itself in a single swoop. This
// will buffer the ciphertext on the underlying connection. We
// will defer flushing the message until the write pool has been
// released.
return noiseConn.WriteMessage(buf.Bytes())
})
// Record the number of bytes written on the wire, if any.
if n > 0 {
atomic.AddUint64(&p.bytesSent, uint64(n))
if err != nil {
return err
}
return err
return flushMsg()
}
// writeHandler is a goroutine dedicated to reading messages off of an incoming
@ -1459,38 +1493,10 @@ func (p *peer) writeHandler() {
var exitErr error
const (
minRetryDelay = 5 * time.Second
maxRetryDelay = time.Minute
)
out:
for {
select {
case outMsg := <-p.sendQueue:
// Record the time at which we first attempt to send the
// message.
startTime := time.Now()
// Initialize a retry delay of zero, which will be
// increased if we encounter a write timeout on the
// send.
var retryDelay time.Duration
retryWithDelay:
if retryDelay > 0 {
select {
case <-time.After(retryDelay):
case <-p.quit:
// Inform synchronous writes that the
// peer is exiting.
if outMsg.errChan != nil {
outMsg.errChan <- ErrPeerExiting
}
exitErr = ErrPeerExiting
break out
}
}
// If we're about to send a ping message, then log the
// exact time in which we send the message so we can
// use the delay as a rough estimate of latency to the
@ -1502,33 +1508,32 @@ out:
atomic.StoreInt64(&p.pingLastSend, now)
}
// Record the time at which we first attempt to send the
// message.
startTime := time.Now()
retry:
// Write out the message to the socket. If a timeout
// error is encountered, we will catch this and retry
// after backing off in case the remote peer is just
// slow to process messages from the wire.
err := p.writeMessage(outMsg.msg)
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
// Increase the retry delay in the event of a
// timeout error, this prevents us from
// disconnecting if the remote party is slow to
// pull messages off the wire. We back off
// exponentially up to our max delay to prevent
// blocking the write pool.
if retryDelay == 0 {
retryDelay = minRetryDelay
} else {
retryDelay *= 2
if retryDelay > maxRetryDelay {
retryDelay = maxRetryDelay
}
}
peerLog.Debugf("Write timeout detected for "+
"peer %s, retrying after %v, "+
"first attempted %v ago", p, retryDelay,
"peer %s, first write for message "+
"attempted %v ago", p,
time.Since(startTime))
goto retryWithDelay
// If we received a timeout error, this implies
// that the message was buffered on the
// connection successfully and that a flush was
// attempted. We'll set the message to nil so
// that on a subsequent pass we only try to
// flush the buffered message, and forgo
// reserializing or reencrypting it.
outMsg.msg = nil
goto retry
}
// The write succeeded, reset the idle timer to prevent