Merge pull request #2924 from cfromknecht/write-and-flush
peer: resume partial writes due to timeouts
This commit is contained in:
commit
42b081bb37
@ -166,7 +166,11 @@ func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
// If the message doesn't require any chunking, then we can go ahead
|
||||
// with a single write.
|
||||
if len(b) <= math.MaxUint16 {
|
||||
return len(b), c.noise.WriteMessage(c.conn, b)
|
||||
err = c.noise.WriteMessage(b)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return c.noise.Flush(c.conn)
|
||||
}
|
||||
|
||||
// If we need to split the message into fragments, then we'll write
|
||||
@ -185,16 +189,43 @@ func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
// Slice off the next chunk to be written based on our running
|
||||
// counter and next chunk size.
|
||||
chunk := b[bytesWritten : bytesWritten+chunkSize]
|
||||
if err := c.noise.WriteMessage(c.conn, chunk); err != nil {
|
||||
if err := c.noise.WriteMessage(chunk); err != nil {
|
||||
return bytesWritten, err
|
||||
}
|
||||
|
||||
bytesWritten += len(chunk)
|
||||
n, err := c.noise.Flush(c.conn)
|
||||
bytesWritten += n
|
||||
if err != nil {
|
||||
return bytesWritten, err
|
||||
}
|
||||
}
|
||||
|
||||
return bytesWritten, nil
|
||||
}
|
||||
|
||||
// WriteMessage encrypts and buffers the next message p for the connection. The
|
||||
// ciphertext of the message is prepended with an encrypt+auth'd length which
|
||||
// must be used as the AD to the AEAD construction when being decrypted by the
|
||||
// other side.
|
||||
//
|
||||
// NOTE: This DOES NOT write the message to the wire, it should be followed by a
|
||||
// call to Flush to ensure the message is written.
|
||||
func (c *Conn) WriteMessage(b []byte) error {
|
||||
return c.noise.WriteMessage(b)
|
||||
}
|
||||
|
||||
// Flush attempts to write a message buffered using WriteMessage to the
|
||||
// underlying connection. If no buffered message exists, this will result in a
|
||||
// NOP. Otherwise, it will continue to write the remaining bytes, picking up
|
||||
// where the byte stream left off in the event of a partial write. The number of
|
||||
// bytes returned reflects the number of plaintext bytes in the payload, and
|
||||
// does not account for the overhead of the header or MACs.
|
||||
//
|
||||
// NOTE: It is safe to call this method again iff a timeout error is returned.
|
||||
func (c *Conn) Flush() (int, error) {
|
||||
return c.noise.Flush(c.conn)
|
||||
}
|
||||
|
||||
// Close closes the connection. Any blocked Read or Write operations will be
|
||||
// unblocked and return errors.
|
||||
//
|
||||
|
@ -31,6 +31,10 @@ const (
|
||||
// length of a message payload.
|
||||
lengthHeaderSize = 2
|
||||
|
||||
// encHeaderSize is the number of bytes required to hold an encrypted
|
||||
// header and it's MAC.
|
||||
encHeaderSize = lengthHeaderSize + macSize
|
||||
|
||||
// keyRotationInterval is the number of messages sent on a single
|
||||
// cipher stream before the keys are rotated forwards.
|
||||
keyRotationInterval = 1000
|
||||
@ -48,6 +52,10 @@ var (
|
||||
ErrMaxMessageLengthExceeded = errors.New("the generated payload exceeds " +
|
||||
"the max allowed message length of (2^16)-1")
|
||||
|
||||
// ErrMessageNotFlushed signals that the connection cannot accept a new
|
||||
// message because the prior message has not been fully flushed.
|
||||
ErrMessageNotFlushed = errors.New("prior message not flushed")
|
||||
|
||||
// lightningPrologue is the noise prologue that is used to initialize
|
||||
// the brontide noise handshake.
|
||||
lightningPrologue = []byte("lightning")
|
||||
@ -366,7 +374,17 @@ type Machine struct {
|
||||
// nextCipherHeader is a static buffer that we'll use to read in the
|
||||
// next ciphertext header from the wire. The header is a 2 byte length
|
||||
// (of the next ciphertext), followed by a 16 byte MAC.
|
||||
nextCipherHeader [lengthHeaderSize + macSize]byte
|
||||
nextCipherHeader [encHeaderSize]byte
|
||||
|
||||
// nextHeaderSend holds a reference to the remaining header bytes to
|
||||
// write out for a pending message. This allows us to tolerate timeout
|
||||
// errors that cause partial writes.
|
||||
nextHeaderSend []byte
|
||||
|
||||
// nextHeaderBody holds a reference to the remaining body bytes to write
|
||||
// out for a pending message. This allows us to tolerate timeout errors
|
||||
// that cause partial writes.
|
||||
nextBodySend []byte
|
||||
}
|
||||
|
||||
// NewBrontideMachine creates a new instance of the brontide state-machine. If
|
||||
@ -682,11 +700,13 @@ func (b *Machine) split() {
|
||||
}
|
||||
}
|
||||
|
||||
// WriteMessage writes the next message p to the passed io.Writer. The
|
||||
// ciphertext of the message is prepended with an encrypt+auth'd length which
|
||||
// must be used as the AD to the AEAD construction when being decrypted by the
|
||||
// other side.
|
||||
func (b *Machine) WriteMessage(w io.Writer, p []byte) error {
|
||||
// WriteMessage encrypts and buffers the next message p. The ciphertext of the
|
||||
// message is prepended with an encrypt+auth'd length which must be used as the
|
||||
// AD to the AEAD construction when being decrypted by the other side.
|
||||
//
|
||||
// NOTE: This DOES NOT write the message to the wire, it should be followed by a
|
||||
// call to Flush to ensure the message is written.
|
||||
func (b *Machine) WriteMessage(p []byte) error {
|
||||
// The total length of each message payload including the MAC size
|
||||
// payload exceed the largest number encodable within a 16-bit unsigned
|
||||
// integer.
|
||||
@ -694,6 +714,13 @@ func (b *Machine) WriteMessage(w io.Writer, p []byte) error {
|
||||
return ErrMaxMessageLengthExceeded
|
||||
}
|
||||
|
||||
// If a prior message was written but it hasn't been fully flushed,
|
||||
// return an error as we only support buffering of one message at a
|
||||
// time.
|
||||
if len(b.nextHeaderSend) > 0 || len(b.nextBodySend) > 0 {
|
||||
return ErrMessageNotFlushed
|
||||
}
|
||||
|
||||
// The full length of the packet is only the packet length, and does
|
||||
// NOT include the MAC.
|
||||
fullLength := uint16(len(p))
|
||||
@ -701,18 +728,88 @@ func (b *Machine) WriteMessage(w io.Writer, p []byte) error {
|
||||
var pktLen [2]byte
|
||||
binary.BigEndian.PutUint16(pktLen[:], fullLength)
|
||||
|
||||
// First, write out the encrypted+MAC'd length prefix for the packet.
|
||||
cipherLen := b.sendCipher.Encrypt(nil, nil, pktLen[:])
|
||||
if _, err := w.Write(cipherLen); err != nil {
|
||||
return err
|
||||
// First, generate the encrypted+MAC'd length prefix for the packet.
|
||||
b.nextHeaderSend = b.sendCipher.Encrypt(nil, nil, pktLen[:])
|
||||
|
||||
// Finally, generate the encrypted packet itself.
|
||||
b.nextBodySend = b.sendCipher.Encrypt(nil, nil, p)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush attempts to write a message buffered using WriteMessage to the provided
|
||||
// io.Writer. If no buffered message exists, this will result in a NOP.
|
||||
// Otherwise, it will continue to write the remaining bytes, picking up where
|
||||
// the byte stream left off in the event of a partial write. The number of bytes
|
||||
// returned reflects the number of plaintext bytes in the payload, and does not
|
||||
// account for the overhead of the header or MACs.
|
||||
//
|
||||
// NOTE: It is safe to call this method again iff a timeout error is returned.
|
||||
func (b *Machine) Flush(w io.Writer) (int, error) {
|
||||
// First, write out the pending header bytes, if any exist. Any header
|
||||
// bytes written will not count towards the total amount flushed.
|
||||
if len(b.nextHeaderSend) > 0 {
|
||||
// Write any remaining header bytes and shift the slice to point
|
||||
// to the next segment of unwritten bytes. If an error is
|
||||
// encountered, we can continue to write the header from where
|
||||
// we left off on a subsequent call to Flush.
|
||||
n, err := w.Write(b.nextHeaderSend)
|
||||
b.nextHeaderSend = b.nextHeaderSend[n:]
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, write out the encrypted packet itself. We only write out a
|
||||
// single packet, as any fragmentation should have taken place at a
|
||||
// higher level.
|
||||
cipherText := b.sendCipher.Encrypt(nil, nil, p)
|
||||
_, err := w.Write(cipherText)
|
||||
return err
|
||||
// Next, write the pending body bytes, if any exist. Only the number of
|
||||
// bytes written that correspond to the ciphertext will be included in
|
||||
// the total bytes written, bytes written as part of the MAC will not be
|
||||
// counted.
|
||||
var nn int
|
||||
if len(b.nextBodySend) > 0 {
|
||||
// Write out all bytes excluding the mac and shift the body
|
||||
// slice depending on the number of actual bytes written.
|
||||
n, err := w.Write(b.nextBodySend)
|
||||
b.nextBodySend = b.nextBodySend[n:]
|
||||
|
||||
// If we partially or fully wrote any of the body's MAC, we'll
|
||||
// subtract that contribution from the total amount flushed to
|
||||
// preserve the abstraction of returning the number of plaintext
|
||||
// bytes written by the connection.
|
||||
//
|
||||
// There are three possible scenarios we must handle to ensure
|
||||
// the returned value is correct. In the first case, the write
|
||||
// straddles both payload and MAC bytes, and we must subtract
|
||||
// the number of MAC bytes written from n. In the second, only
|
||||
// payload bytes are written, thus we can return n unmodified.
|
||||
// The final scenario pertains to the case where only MAC bytes
|
||||
// are written, none of which count towards the total.
|
||||
//
|
||||
// |-----------Payload------------|----MAC----|
|
||||
// Straddle: S---------------------------------E--------0
|
||||
// Payload-only: S------------------------E-----------------0
|
||||
// MAC-only: S-------E-0
|
||||
start, end := n+len(b.nextBodySend), len(b.nextBodySend)
|
||||
switch {
|
||||
|
||||
// Straddles payload and MAC bytes, subtract number of MAC bytes
|
||||
// written from the actual number written.
|
||||
case start > macSize && end <= macSize:
|
||||
nn = n - (macSize - end)
|
||||
|
||||
// Only payload bytes are written, return n directly.
|
||||
case start > macSize && end > macSize:
|
||||
nn = n
|
||||
|
||||
// Only MAC bytes are written, return 0 bytes written.
|
||||
default:
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nn, err
|
||||
}
|
||||
}
|
||||
|
||||
return nn, nil
|
||||
}
|
||||
|
||||
// ReadMessage attempts to read the next message from the passed io.Reader. In
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"testing/iotest"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
@ -220,11 +221,9 @@ func TestMaxPayloadLength(t *testing.T) {
|
||||
// payload length.
|
||||
payloadToReject := make([]byte, math.MaxUint16+1)
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
// A write of the payload generated above to the state machine should
|
||||
// be rejected as it's over the max payload length.
|
||||
err := b.WriteMessage(&buf, payloadToReject)
|
||||
err := b.WriteMessage(payloadToReject)
|
||||
if err != ErrMaxMessageLengthExceeded {
|
||||
t.Fatalf("payload is over the max allowed length, the write " +
|
||||
"should have been rejected")
|
||||
@ -233,7 +232,7 @@ func TestMaxPayloadLength(t *testing.T) {
|
||||
// Generate another payload which should be accepted as a valid
|
||||
// payload.
|
||||
payloadToAccept := make([]byte, math.MaxUint16-1)
|
||||
if err := b.WriteMessage(&buf, payloadToAccept); err != nil {
|
||||
if err := b.WriteMessage(payloadToAccept); err != nil {
|
||||
t.Fatalf("write for payload was rejected, should have been " +
|
||||
"accepted")
|
||||
}
|
||||
@ -243,7 +242,7 @@ func TestMaxPayloadLength(t *testing.T) {
|
||||
payloadToReject = make([]byte, math.MaxUint16+1)
|
||||
|
||||
// This payload should be rejected.
|
||||
err = b.WriteMessage(&buf, payloadToReject)
|
||||
err = b.WriteMessage(payloadToReject)
|
||||
if err != ErrMaxMessageLengthExceeded {
|
||||
t.Fatalf("payload is over the max allowed length, the write " +
|
||||
"should have been rejected")
|
||||
@ -270,6 +269,8 @@ func TestWriteMessageChunking(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
bytesWritten, err := localConn.Write(largeMessage)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to write message: %v", err)
|
||||
@ -281,7 +282,6 @@ func TestWriteMessageChunking(t *testing.T) {
|
||||
t.Fatalf("bytes not fully written!")
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
// Attempt to read the entirety of the message generated above.
|
||||
@ -502,10 +502,14 @@ func TestBolt0008TestVectors(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
for i := 0; i < 1002; i++ {
|
||||
err = initiator.WriteMessage(&buf, payload)
|
||||
err = initiator.WriteMessage(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("could not write message %s", payload)
|
||||
}
|
||||
_, err = initiator.Flush(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("could not flush message: %v", err)
|
||||
}
|
||||
if val, ok := transportMessageVectors[i]; ok {
|
||||
binaryVal, err := hex.DecodeString(val)
|
||||
if err != nil {
|
||||
@ -534,3 +538,176 @@ func TestBolt0008TestVectors(t *testing.T) {
|
||||
buf.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
// timeoutWriter wraps an io.Writer and throws an iotest.ErrTimeout after
|
||||
// writing n bytes.
|
||||
type timeoutWriter struct {
|
||||
w io.Writer
|
||||
n int64
|
||||
}
|
||||
|
||||
func NewTimeoutWriter(w io.Writer, n int64) io.Writer {
|
||||
return &timeoutWriter{w, n}
|
||||
}
|
||||
|
||||
func (t *timeoutWriter) Write(p []byte) (int, error) {
|
||||
n := len(p)
|
||||
if int64(n) > t.n {
|
||||
n = int(t.n)
|
||||
}
|
||||
n, err := t.w.Write(p[:n])
|
||||
t.n -= int64(n)
|
||||
if err == nil && t.n == 0 {
|
||||
return n, iotest.ErrTimeout
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
const payloadSize = 10
|
||||
|
||||
type flushChunk struct {
|
||||
errAfter int64
|
||||
expN int
|
||||
expErr error
|
||||
}
|
||||
|
||||
type flushTest struct {
|
||||
name string
|
||||
chunks []flushChunk
|
||||
}
|
||||
|
||||
var flushTests = []flushTest{
|
||||
{
|
||||
name: "partial header write",
|
||||
chunks: []flushChunk{
|
||||
// Write 18-byte header in two parts, 16 then 2.
|
||||
{
|
||||
errAfter: encHeaderSize - 2,
|
||||
expN: 0,
|
||||
expErr: iotest.ErrTimeout,
|
||||
},
|
||||
{
|
||||
errAfter: 2,
|
||||
expN: 0,
|
||||
expErr: iotest.ErrTimeout,
|
||||
},
|
||||
// Write payload and MAC in one go.
|
||||
{
|
||||
errAfter: -1,
|
||||
expN: payloadSize,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "full payload then full mac",
|
||||
chunks: []flushChunk{
|
||||
// Write entire header and entire payload w/o MAC.
|
||||
{
|
||||
errAfter: encHeaderSize + payloadSize,
|
||||
expN: payloadSize,
|
||||
expErr: iotest.ErrTimeout,
|
||||
},
|
||||
// Write the entire MAC.
|
||||
{
|
||||
errAfter: -1,
|
||||
expN: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "payload-only, straddle, mac-only",
|
||||
chunks: []flushChunk{
|
||||
// Write header and all but last byte of payload.
|
||||
{
|
||||
errAfter: encHeaderSize + payloadSize - 1,
|
||||
expN: payloadSize - 1,
|
||||
expErr: iotest.ErrTimeout,
|
||||
},
|
||||
// Write last byte of payload and first byte of MAC.
|
||||
{
|
||||
errAfter: 2,
|
||||
expN: 1,
|
||||
expErr: iotest.ErrTimeout,
|
||||
},
|
||||
// Write 10 bytes of the MAC.
|
||||
{
|
||||
errAfter: 10,
|
||||
expN: 0,
|
||||
expErr: iotest.ErrTimeout,
|
||||
},
|
||||
// Write the remaining 5 MAC bytes.
|
||||
{
|
||||
errAfter: -1,
|
||||
expN: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// TestFlush asserts a Machine's ability to handle timeouts during Flush that
|
||||
// cause partial writes, and that the machine can properly resume writes on
|
||||
// subsequent calls to Flush.
|
||||
func TestFlush(t *testing.T) {
|
||||
// Run each test individually, to assert that they pass in isolation.
|
||||
for _, test := range flushTests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var (
|
||||
w bytes.Buffer
|
||||
b Machine
|
||||
)
|
||||
b.split()
|
||||
testFlush(t, test, &b, &w)
|
||||
})
|
||||
}
|
||||
|
||||
// Finally, run the tests serially as if all on one connection.
|
||||
t.Run("flush serial", func(t *testing.T) {
|
||||
var (
|
||||
w bytes.Buffer
|
||||
b Machine
|
||||
)
|
||||
b.split()
|
||||
for _, test := range flushTests {
|
||||
testFlush(t, test, &b, &w)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// testFlush buffers a message on the Machine, then flushes it to the io.Writer
|
||||
// in chunks. Once complete, a final call to flush is made to assert that Write
|
||||
// is not called again.
|
||||
func testFlush(t *testing.T, test flushTest, b *Machine, w io.Writer) {
|
||||
payload := make([]byte, payloadSize)
|
||||
if err := b.WriteMessage(payload); err != nil {
|
||||
t.Fatalf("unable to write message: %v", err)
|
||||
}
|
||||
|
||||
for _, chunk := range test.chunks {
|
||||
assertFlush(t, b, w, chunk.errAfter, chunk.expN, chunk.expErr)
|
||||
}
|
||||
|
||||
// We should always be able to call Flush after a message has been
|
||||
// successfully written, and it should result in a NOP.
|
||||
assertFlush(t, b, w, 0, 0, nil)
|
||||
}
|
||||
|
||||
// assertFlush flushes a chunk to the passed io.Writer. If n >= 0, a
|
||||
// timeoutWriter will be used the flush should stop with iotest.ErrTimeout after
|
||||
// n bytes. The method asserts that the returned error matches expErr and that
|
||||
// the number of bytes written by Flush matches expN.
|
||||
func assertFlush(t *testing.T, b *Machine, w io.Writer, n int64, expN int,
|
||||
expErr error) {
|
||||
|
||||
t.Helper()
|
||||
|
||||
if n >= 0 {
|
||||
w = NewTimeoutWriter(w, n)
|
||||
}
|
||||
nn, err := b.Flush(w)
|
||||
if err != expErr {
|
||||
t.Fatalf("expected flush err: %v, got: %v", expErr, err)
|
||||
}
|
||||
if nn != expN {
|
||||
t.Fatalf("expected n: %d, got: %d", expN, nn)
|
||||
}
|
||||
}
|
||||
|
@ -9,7 +9,7 @@ const (
|
||||
|
||||
// DefaultWriteWorkers is the default maximum number of concurrent
|
||||
// workers used by the daemon's write pool.
|
||||
DefaultWriteWorkers = 100
|
||||
DefaultWriteWorkers = 8
|
||||
|
||||
// DefaultSigWorkers is the default maximum number of concurrent workers
|
||||
// used by the daemon's sig pool.
|
||||
@ -20,13 +20,13 @@ const (
|
||||
// pools.
|
||||
type Workers struct {
|
||||
// Read is the maximum number of concurrent read pool workers.
|
||||
Read int `long:"read" description:"Maximum number of concurrent read pool workers."`
|
||||
Read int `long:"read" description:"Maximum number of concurrent read pool workers. This number should be proportional to the number of peers."`
|
||||
|
||||
// Write is the maximum number of concurrent write pool workers.
|
||||
Write int `long:"write" description:"Maximum number of concurrent write pool workers."`
|
||||
Write int `long:"write" description:"Maximum number of concurrent write pool workers. This number should be proportional to the number of CPUs on the host. "`
|
||||
|
||||
// Sig is the maximum number of concurrent sig pool workers.
|
||||
Sig int `long:"sig" description:"Maximum number of concurrent sig pool workers."`
|
||||
Sig int `long:"sig" description:"Maximum number of concurrent sig pool workers. This number should be proportional to the number of CPUs on the host."`
|
||||
}
|
||||
|
||||
// Validate checks the Workers configuration to ensure that the input values are
|
||||
|
135
peer.go
135
peer.go
@ -1403,16 +1403,58 @@ func (p *peer) logWireMessage(msg lnwire.Message, read bool) {
|
||||
}))
|
||||
}
|
||||
|
||||
// writeMessage writes the target lnwire.Message to the remote peer.
|
||||
// writeMessage writes and flushes the target lnwire.Message to the remote peer.
|
||||
// If the passed message is nil, this method will only try to flush an existing
|
||||
// message buffered on the connection. It is safe to recall this method with a
|
||||
// nil message iff a timeout error is returned. This will continue to flush the
|
||||
// pending message to the wire.
|
||||
func (p *peer) writeMessage(msg lnwire.Message) error {
|
||||
// Simply exit if we're shutting down.
|
||||
if atomic.LoadInt32(&p.disconnect) != 0 {
|
||||
return ErrPeerExiting
|
||||
}
|
||||
|
||||
p.logWireMessage(msg, false)
|
||||
// Only log the message on the first attempt.
|
||||
if msg != nil {
|
||||
p.logWireMessage(msg, false)
|
||||
}
|
||||
|
||||
var n int
|
||||
noiseConn, ok := p.conn.(*brontide.Conn)
|
||||
if !ok {
|
||||
return fmt.Errorf("brontide.Conn required to write messages")
|
||||
}
|
||||
|
||||
flushMsg := func() error {
|
||||
// Ensure the write deadline is set before we attempt to send
|
||||
// the message.
|
||||
writeDeadline := time.Now().Add(writeMessageTimeout)
|
||||
err := noiseConn.SetWriteDeadline(writeDeadline)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Flush the pending message to the wire. If an error is
|
||||
// encountered, e.g. write timeout, the number of bytes written
|
||||
// so far will be returned.
|
||||
n, err := noiseConn.Flush()
|
||||
|
||||
// Record the number of bytes written on the wire, if any.
|
||||
if n > 0 {
|
||||
atomic.AddUint64(&p.bytesSent, uint64(n))
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// If the current message has already been serialized, encrypted, and
|
||||
// buffered on the underlying connection we will skip straight to
|
||||
// flushing it to the wire.
|
||||
if msg == nil {
|
||||
return flushMsg()
|
||||
}
|
||||
|
||||
// Otherwise, this is a new message. We'll acquire a write buffer to
|
||||
// serialize the message and buffer the ciphertext on the connection.
|
||||
err := p.writePool.Submit(func(buf *bytes.Buffer) error {
|
||||
// Using a buffer allocated by the write pool, encode the
|
||||
// message directly into the buffer.
|
||||
@ -1421,25 +1463,17 @@ func (p *peer) writeMessage(msg lnwire.Message) error {
|
||||
return writeErr
|
||||
}
|
||||
|
||||
// Ensure the write deadline is set before we attempt to send
|
||||
// the message.
|
||||
writeDeadline := time.Now().Add(writeMessageTimeout)
|
||||
writeErr = p.conn.SetWriteDeadline(writeDeadline)
|
||||
if writeErr != nil {
|
||||
return writeErr
|
||||
}
|
||||
|
||||
// Finally, write the message itself in a single swoop.
|
||||
n, writeErr = p.conn.Write(buf.Bytes())
|
||||
return writeErr
|
||||
// Finally, write the message itself in a single swoop. This
|
||||
// will buffer the ciphertext on the underlying connection. We
|
||||
// will defer flushing the message until the write pool has been
|
||||
// released.
|
||||
return noiseConn.WriteMessage(buf.Bytes())
|
||||
})
|
||||
|
||||
// Record the number of bytes written on the wire, if any.
|
||||
if n > 0 {
|
||||
atomic.AddUint64(&p.bytesSent, uint64(n))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
return flushMsg()
|
||||
}
|
||||
|
||||
// writeHandler is a goroutine dedicated to reading messages off of an incoming
|
||||
@ -1459,38 +1493,10 @@ func (p *peer) writeHandler() {
|
||||
|
||||
var exitErr error
|
||||
|
||||
const (
|
||||
minRetryDelay = 5 * time.Second
|
||||
maxRetryDelay = time.Minute
|
||||
)
|
||||
|
||||
out:
|
||||
for {
|
||||
select {
|
||||
case outMsg := <-p.sendQueue:
|
||||
// Record the time at which we first attempt to send the
|
||||
// message.
|
||||
startTime := time.Now()
|
||||
|
||||
// Initialize a retry delay of zero, which will be
|
||||
// increased if we encounter a write timeout on the
|
||||
// send.
|
||||
var retryDelay time.Duration
|
||||
retryWithDelay:
|
||||
if retryDelay > 0 {
|
||||
select {
|
||||
case <-time.After(retryDelay):
|
||||
case <-p.quit:
|
||||
// Inform synchronous writes that the
|
||||
// peer is exiting.
|
||||
if outMsg.errChan != nil {
|
||||
outMsg.errChan <- ErrPeerExiting
|
||||
}
|
||||
exitErr = ErrPeerExiting
|
||||
break out
|
||||
}
|
||||
}
|
||||
|
||||
// If we're about to send a ping message, then log the
|
||||
// exact time in which we send the message so we can
|
||||
// use the delay as a rough estimate of latency to the
|
||||
@ -1502,33 +1508,32 @@ out:
|
||||
atomic.StoreInt64(&p.pingLastSend, now)
|
||||
}
|
||||
|
||||
// Record the time at which we first attempt to send the
|
||||
// message.
|
||||
startTime := time.Now()
|
||||
|
||||
retry:
|
||||
// Write out the message to the socket. If a timeout
|
||||
// error is encountered, we will catch this and retry
|
||||
// after backing off in case the remote peer is just
|
||||
// slow to process messages from the wire.
|
||||
err := p.writeMessage(outMsg.msg)
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
// Increase the retry delay in the event of a
|
||||
// timeout error, this prevents us from
|
||||
// disconnecting if the remote party is slow to
|
||||
// pull messages off the wire. We back off
|
||||
// exponentially up to our max delay to prevent
|
||||
// blocking the write pool.
|
||||
if retryDelay == 0 {
|
||||
retryDelay = minRetryDelay
|
||||
} else {
|
||||
retryDelay *= 2
|
||||
if retryDelay > maxRetryDelay {
|
||||
retryDelay = maxRetryDelay
|
||||
}
|
||||
}
|
||||
|
||||
peerLog.Debugf("Write timeout detected for "+
|
||||
"peer %s, retrying after %v, "+
|
||||
"first attempted %v ago", p, retryDelay,
|
||||
"peer %s, first write for message "+
|
||||
"attempted %v ago", p,
|
||||
time.Since(startTime))
|
||||
|
||||
goto retryWithDelay
|
||||
// If we received a timeout error, this implies
|
||||
// that the message was buffered on the
|
||||
// connection successfully and that a flush was
|
||||
// attempted. We'll set the message to nil so
|
||||
// that on a subsequent pass we only try to
|
||||
// flush the buffered message, and forgo
|
||||
// reserializing or reencrypting it.
|
||||
outMsg.msg = nil
|
||||
|
||||
goto retry
|
||||
}
|
||||
|
||||
// The write succeeded, reset the idle timer to prevent
|
||||
|
Loading…
Reference in New Issue
Block a user