Merge pull request #2924 from cfromknecht/write-and-flush
peer: resume partial writes due to timeouts
This commit is contained in:
commit
42b081bb37
@ -166,7 +166,11 @@ func (c *Conn) Write(b []byte) (n int, err error) {
|
|||||||
// If the message doesn't require any chunking, then we can go ahead
|
// If the message doesn't require any chunking, then we can go ahead
|
||||||
// with a single write.
|
// with a single write.
|
||||||
if len(b) <= math.MaxUint16 {
|
if len(b) <= math.MaxUint16 {
|
||||||
return len(b), c.noise.WriteMessage(c.conn, b)
|
err = c.noise.WriteMessage(b)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return c.noise.Flush(c.conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we need to split the message into fragments, then we'll write
|
// If we need to split the message into fragments, then we'll write
|
||||||
@ -185,16 +189,43 @@ func (c *Conn) Write(b []byte) (n int, err error) {
|
|||||||
// Slice off the next chunk to be written based on our running
|
// Slice off the next chunk to be written based on our running
|
||||||
// counter and next chunk size.
|
// counter and next chunk size.
|
||||||
chunk := b[bytesWritten : bytesWritten+chunkSize]
|
chunk := b[bytesWritten : bytesWritten+chunkSize]
|
||||||
if err := c.noise.WriteMessage(c.conn, chunk); err != nil {
|
if err := c.noise.WriteMessage(chunk); err != nil {
|
||||||
return bytesWritten, err
|
return bytesWritten, err
|
||||||
}
|
}
|
||||||
|
|
||||||
bytesWritten += len(chunk)
|
n, err := c.noise.Flush(c.conn)
|
||||||
|
bytesWritten += n
|
||||||
|
if err != nil {
|
||||||
|
return bytesWritten, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return bytesWritten, nil
|
return bytesWritten, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WriteMessage encrypts and buffers the next message p for the connection. The
|
||||||
|
// ciphertext of the message is prepended with an encrypt+auth'd length which
|
||||||
|
// must be used as the AD to the AEAD construction when being decrypted by the
|
||||||
|
// other side.
|
||||||
|
//
|
||||||
|
// NOTE: This DOES NOT write the message to the wire, it should be followed by a
|
||||||
|
// call to Flush to ensure the message is written.
|
||||||
|
func (c *Conn) WriteMessage(b []byte) error {
|
||||||
|
return c.noise.WriteMessage(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush attempts to write a message buffered using WriteMessage to the
|
||||||
|
// underlying connection. If no buffered message exists, this will result in a
|
||||||
|
// NOP. Otherwise, it will continue to write the remaining bytes, picking up
|
||||||
|
// where the byte stream left off in the event of a partial write. The number of
|
||||||
|
// bytes returned reflects the number of plaintext bytes in the payload, and
|
||||||
|
// does not account for the overhead of the header or MACs.
|
||||||
|
//
|
||||||
|
// NOTE: It is safe to call this method again iff a timeout error is returned.
|
||||||
|
func (c *Conn) Flush() (int, error) {
|
||||||
|
return c.noise.Flush(c.conn)
|
||||||
|
}
|
||||||
|
|
||||||
// Close closes the connection. Any blocked Read or Write operations will be
|
// Close closes the connection. Any blocked Read or Write operations will be
|
||||||
// unblocked and return errors.
|
// unblocked and return errors.
|
||||||
//
|
//
|
||||||
|
@ -31,6 +31,10 @@ const (
|
|||||||
// length of a message payload.
|
// length of a message payload.
|
||||||
lengthHeaderSize = 2
|
lengthHeaderSize = 2
|
||||||
|
|
||||||
|
// encHeaderSize is the number of bytes required to hold an encrypted
|
||||||
|
// header and it's MAC.
|
||||||
|
encHeaderSize = lengthHeaderSize + macSize
|
||||||
|
|
||||||
// keyRotationInterval is the number of messages sent on a single
|
// keyRotationInterval is the number of messages sent on a single
|
||||||
// cipher stream before the keys are rotated forwards.
|
// cipher stream before the keys are rotated forwards.
|
||||||
keyRotationInterval = 1000
|
keyRotationInterval = 1000
|
||||||
@ -48,6 +52,10 @@ var (
|
|||||||
ErrMaxMessageLengthExceeded = errors.New("the generated payload exceeds " +
|
ErrMaxMessageLengthExceeded = errors.New("the generated payload exceeds " +
|
||||||
"the max allowed message length of (2^16)-1")
|
"the max allowed message length of (2^16)-1")
|
||||||
|
|
||||||
|
// ErrMessageNotFlushed signals that the connection cannot accept a new
|
||||||
|
// message because the prior message has not been fully flushed.
|
||||||
|
ErrMessageNotFlushed = errors.New("prior message not flushed")
|
||||||
|
|
||||||
// lightningPrologue is the noise prologue that is used to initialize
|
// lightningPrologue is the noise prologue that is used to initialize
|
||||||
// the brontide noise handshake.
|
// the brontide noise handshake.
|
||||||
lightningPrologue = []byte("lightning")
|
lightningPrologue = []byte("lightning")
|
||||||
@ -366,7 +374,17 @@ type Machine struct {
|
|||||||
// nextCipherHeader is a static buffer that we'll use to read in the
|
// nextCipherHeader is a static buffer that we'll use to read in the
|
||||||
// next ciphertext header from the wire. The header is a 2 byte length
|
// next ciphertext header from the wire. The header is a 2 byte length
|
||||||
// (of the next ciphertext), followed by a 16 byte MAC.
|
// (of the next ciphertext), followed by a 16 byte MAC.
|
||||||
nextCipherHeader [lengthHeaderSize + macSize]byte
|
nextCipherHeader [encHeaderSize]byte
|
||||||
|
|
||||||
|
// nextHeaderSend holds a reference to the remaining header bytes to
|
||||||
|
// write out for a pending message. This allows us to tolerate timeout
|
||||||
|
// errors that cause partial writes.
|
||||||
|
nextHeaderSend []byte
|
||||||
|
|
||||||
|
// nextHeaderBody holds a reference to the remaining body bytes to write
|
||||||
|
// out for a pending message. This allows us to tolerate timeout errors
|
||||||
|
// that cause partial writes.
|
||||||
|
nextBodySend []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewBrontideMachine creates a new instance of the brontide state-machine. If
|
// NewBrontideMachine creates a new instance of the brontide state-machine. If
|
||||||
@ -682,11 +700,13 @@ func (b *Machine) split() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteMessage writes the next message p to the passed io.Writer. The
|
// WriteMessage encrypts and buffers the next message p. The ciphertext of the
|
||||||
// ciphertext of the message is prepended with an encrypt+auth'd length which
|
// message is prepended with an encrypt+auth'd length which must be used as the
|
||||||
// must be used as the AD to the AEAD construction when being decrypted by the
|
// AD to the AEAD construction when being decrypted by the other side.
|
||||||
// other side.
|
//
|
||||||
func (b *Machine) WriteMessage(w io.Writer, p []byte) error {
|
// NOTE: This DOES NOT write the message to the wire, it should be followed by a
|
||||||
|
// call to Flush to ensure the message is written.
|
||||||
|
func (b *Machine) WriteMessage(p []byte) error {
|
||||||
// The total length of each message payload including the MAC size
|
// The total length of each message payload including the MAC size
|
||||||
// payload exceed the largest number encodable within a 16-bit unsigned
|
// payload exceed the largest number encodable within a 16-bit unsigned
|
||||||
// integer.
|
// integer.
|
||||||
@ -694,6 +714,13 @@ func (b *Machine) WriteMessage(w io.Writer, p []byte) error {
|
|||||||
return ErrMaxMessageLengthExceeded
|
return ErrMaxMessageLengthExceeded
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If a prior message was written but it hasn't been fully flushed,
|
||||||
|
// return an error as we only support buffering of one message at a
|
||||||
|
// time.
|
||||||
|
if len(b.nextHeaderSend) > 0 || len(b.nextBodySend) > 0 {
|
||||||
|
return ErrMessageNotFlushed
|
||||||
|
}
|
||||||
|
|
||||||
// The full length of the packet is only the packet length, and does
|
// The full length of the packet is only the packet length, and does
|
||||||
// NOT include the MAC.
|
// NOT include the MAC.
|
||||||
fullLength := uint16(len(p))
|
fullLength := uint16(len(p))
|
||||||
@ -701,18 +728,88 @@ func (b *Machine) WriteMessage(w io.Writer, p []byte) error {
|
|||||||
var pktLen [2]byte
|
var pktLen [2]byte
|
||||||
binary.BigEndian.PutUint16(pktLen[:], fullLength)
|
binary.BigEndian.PutUint16(pktLen[:], fullLength)
|
||||||
|
|
||||||
// First, write out the encrypted+MAC'd length prefix for the packet.
|
// First, generate the encrypted+MAC'd length prefix for the packet.
|
||||||
cipherLen := b.sendCipher.Encrypt(nil, nil, pktLen[:])
|
b.nextHeaderSend = b.sendCipher.Encrypt(nil, nil, pktLen[:])
|
||||||
if _, err := w.Write(cipherLen); err != nil {
|
|
||||||
return err
|
// Finally, generate the encrypted packet itself.
|
||||||
|
b.nextBodySend = b.sendCipher.Encrypt(nil, nil, p)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finally, write out the encrypted packet itself. We only write out a
|
// Flush attempts to write a message buffered using WriteMessage to the provided
|
||||||
// single packet, as any fragmentation should have taken place at a
|
// io.Writer. If no buffered message exists, this will result in a NOP.
|
||||||
// higher level.
|
// Otherwise, it will continue to write the remaining bytes, picking up where
|
||||||
cipherText := b.sendCipher.Encrypt(nil, nil, p)
|
// the byte stream left off in the event of a partial write. The number of bytes
|
||||||
_, err := w.Write(cipherText)
|
// returned reflects the number of plaintext bytes in the payload, and does not
|
||||||
return err
|
// account for the overhead of the header or MACs.
|
||||||
|
//
|
||||||
|
// NOTE: It is safe to call this method again iff a timeout error is returned.
|
||||||
|
func (b *Machine) Flush(w io.Writer) (int, error) {
|
||||||
|
// First, write out the pending header bytes, if any exist. Any header
|
||||||
|
// bytes written will not count towards the total amount flushed.
|
||||||
|
if len(b.nextHeaderSend) > 0 {
|
||||||
|
// Write any remaining header bytes and shift the slice to point
|
||||||
|
// to the next segment of unwritten bytes. If an error is
|
||||||
|
// encountered, we can continue to write the header from where
|
||||||
|
// we left off on a subsequent call to Flush.
|
||||||
|
n, err := w.Write(b.nextHeaderSend)
|
||||||
|
b.nextHeaderSend = b.nextHeaderSend[n:]
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next, write the pending body bytes, if any exist. Only the number of
|
||||||
|
// bytes written that correspond to the ciphertext will be included in
|
||||||
|
// the total bytes written, bytes written as part of the MAC will not be
|
||||||
|
// counted.
|
||||||
|
var nn int
|
||||||
|
if len(b.nextBodySend) > 0 {
|
||||||
|
// Write out all bytes excluding the mac and shift the body
|
||||||
|
// slice depending on the number of actual bytes written.
|
||||||
|
n, err := w.Write(b.nextBodySend)
|
||||||
|
b.nextBodySend = b.nextBodySend[n:]
|
||||||
|
|
||||||
|
// If we partially or fully wrote any of the body's MAC, we'll
|
||||||
|
// subtract that contribution from the total amount flushed to
|
||||||
|
// preserve the abstraction of returning the number of plaintext
|
||||||
|
// bytes written by the connection.
|
||||||
|
//
|
||||||
|
// There are three possible scenarios we must handle to ensure
|
||||||
|
// the returned value is correct. In the first case, the write
|
||||||
|
// straddles both payload and MAC bytes, and we must subtract
|
||||||
|
// the number of MAC bytes written from n. In the second, only
|
||||||
|
// payload bytes are written, thus we can return n unmodified.
|
||||||
|
// The final scenario pertains to the case where only MAC bytes
|
||||||
|
// are written, none of which count towards the total.
|
||||||
|
//
|
||||||
|
// |-----------Payload------------|----MAC----|
|
||||||
|
// Straddle: S---------------------------------E--------0
|
||||||
|
// Payload-only: S------------------------E-----------------0
|
||||||
|
// MAC-only: S-------E-0
|
||||||
|
start, end := n+len(b.nextBodySend), len(b.nextBodySend)
|
||||||
|
switch {
|
||||||
|
|
||||||
|
// Straddles payload and MAC bytes, subtract number of MAC bytes
|
||||||
|
// written from the actual number written.
|
||||||
|
case start > macSize && end <= macSize:
|
||||||
|
nn = n - (macSize - end)
|
||||||
|
|
||||||
|
// Only payload bytes are written, return n directly.
|
||||||
|
case start > macSize && end > macSize:
|
||||||
|
nn = n
|
||||||
|
|
||||||
|
// Only MAC bytes are written, return 0 bytes written.
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nn, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadMessage attempts to read the next message from the passed io.Reader. In
|
// ReadMessage attempts to read the next message from the passed io.Reader. In
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"testing/iotest"
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec"
|
"github.com/btcsuite/btcd/btcec"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
@ -220,11 +221,9 @@ func TestMaxPayloadLength(t *testing.T) {
|
|||||||
// payload length.
|
// payload length.
|
||||||
payloadToReject := make([]byte, math.MaxUint16+1)
|
payloadToReject := make([]byte, math.MaxUint16+1)
|
||||||
|
|
||||||
var buf bytes.Buffer
|
|
||||||
|
|
||||||
// A write of the payload generated above to the state machine should
|
// A write of the payload generated above to the state machine should
|
||||||
// be rejected as it's over the max payload length.
|
// be rejected as it's over the max payload length.
|
||||||
err := b.WriteMessage(&buf, payloadToReject)
|
err := b.WriteMessage(payloadToReject)
|
||||||
if err != ErrMaxMessageLengthExceeded {
|
if err != ErrMaxMessageLengthExceeded {
|
||||||
t.Fatalf("payload is over the max allowed length, the write " +
|
t.Fatalf("payload is over the max allowed length, the write " +
|
||||||
"should have been rejected")
|
"should have been rejected")
|
||||||
@ -233,7 +232,7 @@ func TestMaxPayloadLength(t *testing.T) {
|
|||||||
// Generate another payload which should be accepted as a valid
|
// Generate another payload which should be accepted as a valid
|
||||||
// payload.
|
// payload.
|
||||||
payloadToAccept := make([]byte, math.MaxUint16-1)
|
payloadToAccept := make([]byte, math.MaxUint16-1)
|
||||||
if err := b.WriteMessage(&buf, payloadToAccept); err != nil {
|
if err := b.WriteMessage(payloadToAccept); err != nil {
|
||||||
t.Fatalf("write for payload was rejected, should have been " +
|
t.Fatalf("write for payload was rejected, should have been " +
|
||||||
"accepted")
|
"accepted")
|
||||||
}
|
}
|
||||||
@ -243,7 +242,7 @@ func TestMaxPayloadLength(t *testing.T) {
|
|||||||
payloadToReject = make([]byte, math.MaxUint16+1)
|
payloadToReject = make([]byte, math.MaxUint16+1)
|
||||||
|
|
||||||
// This payload should be rejected.
|
// This payload should be rejected.
|
||||||
err = b.WriteMessage(&buf, payloadToReject)
|
err = b.WriteMessage(payloadToReject)
|
||||||
if err != ErrMaxMessageLengthExceeded {
|
if err != ErrMaxMessageLengthExceeded {
|
||||||
t.Fatalf("payload is over the max allowed length, the write " +
|
t.Fatalf("payload is over the max allowed length, the write " +
|
||||||
"should have been rejected")
|
"should have been rejected")
|
||||||
@ -270,6 +269,8 @@ func TestWriteMessageChunking(t *testing.T) {
|
|||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
bytesWritten, err := localConn.Write(largeMessage)
|
bytesWritten, err := localConn.Write(largeMessage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to write message: %v", err)
|
t.Fatalf("unable to write message: %v", err)
|
||||||
@ -281,7 +282,6 @@ func TestWriteMessageChunking(t *testing.T) {
|
|||||||
t.Fatalf("bytes not fully written!")
|
t.Fatalf("bytes not fully written!")
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Done()
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Attempt to read the entirety of the message generated above.
|
// Attempt to read the entirety of the message generated above.
|
||||||
@ -502,10 +502,14 @@ func TestBolt0008TestVectors(t *testing.T) {
|
|||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
|
|
||||||
for i := 0; i < 1002; i++ {
|
for i := 0; i < 1002; i++ {
|
||||||
err = initiator.WriteMessage(&buf, payload)
|
err = initiator.WriteMessage(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("could not write message %s", payload)
|
t.Fatalf("could not write message %s", payload)
|
||||||
}
|
}
|
||||||
|
_, err = initiator.Flush(&buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("could not flush message: %v", err)
|
||||||
|
}
|
||||||
if val, ok := transportMessageVectors[i]; ok {
|
if val, ok := transportMessageVectors[i]; ok {
|
||||||
binaryVal, err := hex.DecodeString(val)
|
binaryVal, err := hex.DecodeString(val)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -534,3 +538,176 @@ func TestBolt0008TestVectors(t *testing.T) {
|
|||||||
buf.Reset()
|
buf.Reset()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// timeoutWriter wraps an io.Writer and throws an iotest.ErrTimeout after
|
||||||
|
// writing n bytes.
|
||||||
|
type timeoutWriter struct {
|
||||||
|
w io.Writer
|
||||||
|
n int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTimeoutWriter(w io.Writer, n int64) io.Writer {
|
||||||
|
return &timeoutWriter{w, n}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *timeoutWriter) Write(p []byte) (int, error) {
|
||||||
|
n := len(p)
|
||||||
|
if int64(n) > t.n {
|
||||||
|
n = int(t.n)
|
||||||
|
}
|
||||||
|
n, err := t.w.Write(p[:n])
|
||||||
|
t.n -= int64(n)
|
||||||
|
if err == nil && t.n == 0 {
|
||||||
|
return n, iotest.ErrTimeout
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const payloadSize = 10
|
||||||
|
|
||||||
|
type flushChunk struct {
|
||||||
|
errAfter int64
|
||||||
|
expN int
|
||||||
|
expErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
type flushTest struct {
|
||||||
|
name string
|
||||||
|
chunks []flushChunk
|
||||||
|
}
|
||||||
|
|
||||||
|
var flushTests = []flushTest{
|
||||||
|
{
|
||||||
|
name: "partial header write",
|
||||||
|
chunks: []flushChunk{
|
||||||
|
// Write 18-byte header in two parts, 16 then 2.
|
||||||
|
{
|
||||||
|
errAfter: encHeaderSize - 2,
|
||||||
|
expN: 0,
|
||||||
|
expErr: iotest.ErrTimeout,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
errAfter: 2,
|
||||||
|
expN: 0,
|
||||||
|
expErr: iotest.ErrTimeout,
|
||||||
|
},
|
||||||
|
// Write payload and MAC in one go.
|
||||||
|
{
|
||||||
|
errAfter: -1,
|
||||||
|
expN: payloadSize,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "full payload then full mac",
|
||||||
|
chunks: []flushChunk{
|
||||||
|
// Write entire header and entire payload w/o MAC.
|
||||||
|
{
|
||||||
|
errAfter: encHeaderSize + payloadSize,
|
||||||
|
expN: payloadSize,
|
||||||
|
expErr: iotest.ErrTimeout,
|
||||||
|
},
|
||||||
|
// Write the entire MAC.
|
||||||
|
{
|
||||||
|
errAfter: -1,
|
||||||
|
expN: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "payload-only, straddle, mac-only",
|
||||||
|
chunks: []flushChunk{
|
||||||
|
// Write header and all but last byte of payload.
|
||||||
|
{
|
||||||
|
errAfter: encHeaderSize + payloadSize - 1,
|
||||||
|
expN: payloadSize - 1,
|
||||||
|
expErr: iotest.ErrTimeout,
|
||||||
|
},
|
||||||
|
// Write last byte of payload and first byte of MAC.
|
||||||
|
{
|
||||||
|
errAfter: 2,
|
||||||
|
expN: 1,
|
||||||
|
expErr: iotest.ErrTimeout,
|
||||||
|
},
|
||||||
|
// Write 10 bytes of the MAC.
|
||||||
|
{
|
||||||
|
errAfter: 10,
|
||||||
|
expN: 0,
|
||||||
|
expErr: iotest.ErrTimeout,
|
||||||
|
},
|
||||||
|
// Write the remaining 5 MAC bytes.
|
||||||
|
{
|
||||||
|
errAfter: -1,
|
||||||
|
expN: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFlush asserts a Machine's ability to handle timeouts during Flush that
|
||||||
|
// cause partial writes, and that the machine can properly resume writes on
|
||||||
|
// subsequent calls to Flush.
|
||||||
|
func TestFlush(t *testing.T) {
|
||||||
|
// Run each test individually, to assert that they pass in isolation.
|
||||||
|
for _, test := range flushTests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
w bytes.Buffer
|
||||||
|
b Machine
|
||||||
|
)
|
||||||
|
b.split()
|
||||||
|
testFlush(t, test, &b, &w)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, run the tests serially as if all on one connection.
|
||||||
|
t.Run("flush serial", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
w bytes.Buffer
|
||||||
|
b Machine
|
||||||
|
)
|
||||||
|
b.split()
|
||||||
|
for _, test := range flushTests {
|
||||||
|
testFlush(t, test, &b, &w)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// testFlush buffers a message on the Machine, then flushes it to the io.Writer
|
||||||
|
// in chunks. Once complete, a final call to flush is made to assert that Write
|
||||||
|
// is not called again.
|
||||||
|
func testFlush(t *testing.T, test flushTest, b *Machine, w io.Writer) {
|
||||||
|
payload := make([]byte, payloadSize)
|
||||||
|
if err := b.WriteMessage(payload); err != nil {
|
||||||
|
t.Fatalf("unable to write message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, chunk := range test.chunks {
|
||||||
|
assertFlush(t, b, w, chunk.errAfter, chunk.expN, chunk.expErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// We should always be able to call Flush after a message has been
|
||||||
|
// successfully written, and it should result in a NOP.
|
||||||
|
assertFlush(t, b, w, 0, 0, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// assertFlush flushes a chunk to the passed io.Writer. If n >= 0, a
|
||||||
|
// timeoutWriter will be used the flush should stop with iotest.ErrTimeout after
|
||||||
|
// n bytes. The method asserts that the returned error matches expErr and that
|
||||||
|
// the number of bytes written by Flush matches expN.
|
||||||
|
func assertFlush(t *testing.T, b *Machine, w io.Writer, n int64, expN int,
|
||||||
|
expErr error) {
|
||||||
|
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if n >= 0 {
|
||||||
|
w = NewTimeoutWriter(w, n)
|
||||||
|
}
|
||||||
|
nn, err := b.Flush(w)
|
||||||
|
if err != expErr {
|
||||||
|
t.Fatalf("expected flush err: %v, got: %v", expErr, err)
|
||||||
|
}
|
||||||
|
if nn != expN {
|
||||||
|
t.Fatalf("expected n: %d, got: %d", expN, nn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -9,7 +9,7 @@ const (
|
|||||||
|
|
||||||
// DefaultWriteWorkers is the default maximum number of concurrent
|
// DefaultWriteWorkers is the default maximum number of concurrent
|
||||||
// workers used by the daemon's write pool.
|
// workers used by the daemon's write pool.
|
||||||
DefaultWriteWorkers = 100
|
DefaultWriteWorkers = 8
|
||||||
|
|
||||||
// DefaultSigWorkers is the default maximum number of concurrent workers
|
// DefaultSigWorkers is the default maximum number of concurrent workers
|
||||||
// used by the daemon's sig pool.
|
// used by the daemon's sig pool.
|
||||||
@ -20,13 +20,13 @@ const (
|
|||||||
// pools.
|
// pools.
|
||||||
type Workers struct {
|
type Workers struct {
|
||||||
// Read is the maximum number of concurrent read pool workers.
|
// Read is the maximum number of concurrent read pool workers.
|
||||||
Read int `long:"read" description:"Maximum number of concurrent read pool workers."`
|
Read int `long:"read" description:"Maximum number of concurrent read pool workers. This number should be proportional to the number of peers."`
|
||||||
|
|
||||||
// Write is the maximum number of concurrent write pool workers.
|
// Write is the maximum number of concurrent write pool workers.
|
||||||
Write int `long:"write" description:"Maximum number of concurrent write pool workers."`
|
Write int `long:"write" description:"Maximum number of concurrent write pool workers. This number should be proportional to the number of CPUs on the host. "`
|
||||||
|
|
||||||
// Sig is the maximum number of concurrent sig pool workers.
|
// Sig is the maximum number of concurrent sig pool workers.
|
||||||
Sig int `long:"sig" description:"Maximum number of concurrent sig pool workers."`
|
Sig int `long:"sig" description:"Maximum number of concurrent sig pool workers. This number should be proportional to the number of CPUs on the host."`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate checks the Workers configuration to ensure that the input values are
|
// Validate checks the Workers configuration to ensure that the input values are
|
||||||
|
133
peer.go
133
peer.go
@ -1403,16 +1403,58 @@ func (p *peer) logWireMessage(msg lnwire.Message, read bool) {
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeMessage writes the target lnwire.Message to the remote peer.
|
// writeMessage writes and flushes the target lnwire.Message to the remote peer.
|
||||||
|
// If the passed message is nil, this method will only try to flush an existing
|
||||||
|
// message buffered on the connection. It is safe to recall this method with a
|
||||||
|
// nil message iff a timeout error is returned. This will continue to flush the
|
||||||
|
// pending message to the wire.
|
||||||
func (p *peer) writeMessage(msg lnwire.Message) error {
|
func (p *peer) writeMessage(msg lnwire.Message) error {
|
||||||
// Simply exit if we're shutting down.
|
// Simply exit if we're shutting down.
|
||||||
if atomic.LoadInt32(&p.disconnect) != 0 {
|
if atomic.LoadInt32(&p.disconnect) != 0 {
|
||||||
return ErrPeerExiting
|
return ErrPeerExiting
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Only log the message on the first attempt.
|
||||||
|
if msg != nil {
|
||||||
p.logWireMessage(msg, false)
|
p.logWireMessage(msg, false)
|
||||||
|
}
|
||||||
|
|
||||||
var n int
|
noiseConn, ok := p.conn.(*brontide.Conn)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("brontide.Conn required to write messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
flushMsg := func() error {
|
||||||
|
// Ensure the write deadline is set before we attempt to send
|
||||||
|
// the message.
|
||||||
|
writeDeadline := time.Now().Add(writeMessageTimeout)
|
||||||
|
err := noiseConn.SetWriteDeadline(writeDeadline)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush the pending message to the wire. If an error is
|
||||||
|
// encountered, e.g. write timeout, the number of bytes written
|
||||||
|
// so far will be returned.
|
||||||
|
n, err := noiseConn.Flush()
|
||||||
|
|
||||||
|
// Record the number of bytes written on the wire, if any.
|
||||||
|
if n > 0 {
|
||||||
|
atomic.AddUint64(&p.bytesSent, uint64(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the current message has already been serialized, encrypted, and
|
||||||
|
// buffered on the underlying connection we will skip straight to
|
||||||
|
// flushing it to the wire.
|
||||||
|
if msg == nil {
|
||||||
|
return flushMsg()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, this is a new message. We'll acquire a write buffer to
|
||||||
|
// serialize the message and buffer the ciphertext on the connection.
|
||||||
err := p.writePool.Submit(func(buf *bytes.Buffer) error {
|
err := p.writePool.Submit(func(buf *bytes.Buffer) error {
|
||||||
// Using a buffer allocated by the write pool, encode the
|
// Using a buffer allocated by the write pool, encode the
|
||||||
// message directly into the buffer.
|
// message directly into the buffer.
|
||||||
@ -1421,25 +1463,17 @@ func (p *peer) writeMessage(msg lnwire.Message) error {
|
|||||||
return writeErr
|
return writeErr
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the write deadline is set before we attempt to send
|
// Finally, write the message itself in a single swoop. This
|
||||||
// the message.
|
// will buffer the ciphertext on the underlying connection. We
|
||||||
writeDeadline := time.Now().Add(writeMessageTimeout)
|
// will defer flushing the message until the write pool has been
|
||||||
writeErr = p.conn.SetWriteDeadline(writeDeadline)
|
// released.
|
||||||
if writeErr != nil {
|
return noiseConn.WriteMessage(buf.Bytes())
|
||||||
return writeErr
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, write the message itself in a single swoop.
|
|
||||||
n, writeErr = p.conn.Write(buf.Bytes())
|
|
||||||
return writeErr
|
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
// Record the number of bytes written on the wire, if any.
|
return err
|
||||||
if n > 0 {
|
|
||||||
atomic.AddUint64(&p.bytesSent, uint64(n))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return flushMsg()
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeHandler is a goroutine dedicated to reading messages off of an incoming
|
// writeHandler is a goroutine dedicated to reading messages off of an incoming
|
||||||
@ -1459,38 +1493,10 @@ func (p *peer) writeHandler() {
|
|||||||
|
|
||||||
var exitErr error
|
var exitErr error
|
||||||
|
|
||||||
const (
|
|
||||||
minRetryDelay = 5 * time.Second
|
|
||||||
maxRetryDelay = time.Minute
|
|
||||||
)
|
|
||||||
|
|
||||||
out:
|
out:
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case outMsg := <-p.sendQueue:
|
case outMsg := <-p.sendQueue:
|
||||||
// Record the time at which we first attempt to send the
|
|
||||||
// message.
|
|
||||||
startTime := time.Now()
|
|
||||||
|
|
||||||
// Initialize a retry delay of zero, which will be
|
|
||||||
// increased if we encounter a write timeout on the
|
|
||||||
// send.
|
|
||||||
var retryDelay time.Duration
|
|
||||||
retryWithDelay:
|
|
||||||
if retryDelay > 0 {
|
|
||||||
select {
|
|
||||||
case <-time.After(retryDelay):
|
|
||||||
case <-p.quit:
|
|
||||||
// Inform synchronous writes that the
|
|
||||||
// peer is exiting.
|
|
||||||
if outMsg.errChan != nil {
|
|
||||||
outMsg.errChan <- ErrPeerExiting
|
|
||||||
}
|
|
||||||
exitErr = ErrPeerExiting
|
|
||||||
break out
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we're about to send a ping message, then log the
|
// If we're about to send a ping message, then log the
|
||||||
// exact time in which we send the message so we can
|
// exact time in which we send the message so we can
|
||||||
// use the delay as a rough estimate of latency to the
|
// use the delay as a rough estimate of latency to the
|
||||||
@ -1502,33 +1508,32 @@ out:
|
|||||||
atomic.StoreInt64(&p.pingLastSend, now)
|
atomic.StoreInt64(&p.pingLastSend, now)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Record the time at which we first attempt to send the
|
||||||
|
// message.
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
retry:
|
||||||
// Write out the message to the socket. If a timeout
|
// Write out the message to the socket. If a timeout
|
||||||
// error is encountered, we will catch this and retry
|
// error is encountered, we will catch this and retry
|
||||||
// after backing off in case the remote peer is just
|
// after backing off in case the remote peer is just
|
||||||
// slow to process messages from the wire.
|
// slow to process messages from the wire.
|
||||||
err := p.writeMessage(outMsg.msg)
|
err := p.writeMessage(outMsg.msg)
|
||||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||||
// Increase the retry delay in the event of a
|
|
||||||
// timeout error, this prevents us from
|
|
||||||
// disconnecting if the remote party is slow to
|
|
||||||
// pull messages off the wire. We back off
|
|
||||||
// exponentially up to our max delay to prevent
|
|
||||||
// blocking the write pool.
|
|
||||||
if retryDelay == 0 {
|
|
||||||
retryDelay = minRetryDelay
|
|
||||||
} else {
|
|
||||||
retryDelay *= 2
|
|
||||||
if retryDelay > maxRetryDelay {
|
|
||||||
retryDelay = maxRetryDelay
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
peerLog.Debugf("Write timeout detected for "+
|
peerLog.Debugf("Write timeout detected for "+
|
||||||
"peer %s, retrying after %v, "+
|
"peer %s, first write for message "+
|
||||||
"first attempted %v ago", p, retryDelay,
|
"attempted %v ago", p,
|
||||||
time.Since(startTime))
|
time.Since(startTime))
|
||||||
|
|
||||||
goto retryWithDelay
|
// If we received a timeout error, this implies
|
||||||
|
// that the message was buffered on the
|
||||||
|
// connection successfully and that a flush was
|
||||||
|
// attempted. We'll set the message to nil so
|
||||||
|
// that on a subsequent pass we only try to
|
||||||
|
// flush the buffered message, and forgo
|
||||||
|
// reserializing or reencrypting it.
|
||||||
|
outMsg.msg = nil
|
||||||
|
|
||||||
|
goto retry
|
||||||
}
|
}
|
||||||
|
|
||||||
// The write succeeded, reset the idle timer to prevent
|
// The write succeeded, reset the idle timer to prevent
|
||||||
|
Loading…
Reference in New Issue
Block a user