Merge pull request #2474 from cfromknecht/read-and-write-pools
lnpeer+brontide: reduce memory footprint using read/write pools for message encode/decode
This commit is contained in:
commit
a6ba965bc4
@ -104,13 +104,34 @@ func Dial(localPriv *btcec.PrivateKey, netAddr *lnwire.NetAddress,
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// ReadNextMessage uses the connection in a message-oriented instructing it to
|
||||
// read the next _full_ message with the brontide stream. This function will
|
||||
// block until the read succeeds.
|
||||
// ReadNextMessage uses the connection in a message-oriented manner, instructing
|
||||
// it to read the next _full_ message with the brontide stream. This function
|
||||
// will block until the read of the header and body succeeds.
|
||||
//
|
||||
// NOTE: This method SHOULD NOT be used in the case that the connection may be
|
||||
// adversarial and induce long delays. If the caller needs to set read deadlines
|
||||
// appropriately, it is preferred that they use the split ReadNextHeader and
|
||||
// ReadNextBody methods so that the deadlines can be set appropriately on each.
|
||||
func (c *Conn) ReadNextMessage() ([]byte, error) {
|
||||
return c.noise.ReadMessage(c.conn)
|
||||
}
|
||||
|
||||
// ReadNextHeader uses the connection to read the next header from the brontide
|
||||
// stream. This function will block until the read of the header succeeds and
|
||||
// return the packet length (including MAC overhead) that is expected from the
|
||||
// subsequent call to ReadNextBody.
|
||||
func (c *Conn) ReadNextHeader() (uint32, error) {
|
||||
return c.noise.ReadHeader(c.conn)
|
||||
}
|
||||
|
||||
// ReadNextBody uses the connection to read the next message body from the
|
||||
// brontide stream. This function will block until the read of the body succeeds
|
||||
// and return the decrypted payload. The provided buffer MUST be the packet
|
||||
// length returned by the preceding call to ReadNextHeader.
|
||||
func (c *Conn) ReadNextBody(buf []byte) ([]byte, error) {
|
||||
return c.noise.ReadBody(c.conn, buf)
|
||||
}
|
||||
|
||||
// Read reads data from the connection. Read can be made to time out and
|
||||
// return an Error with Timeout() == true after a fixed time limit; see
|
||||
// SetDeadline and SetReadDeadline.
|
||||
|
@ -8,15 +8,12 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/lightningnetwork/lnd/buffer"
|
||||
"github.com/lightningnetwork/lnd/pool"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -60,14 +57,6 @@ var (
|
||||
ephemeralGen = func() (*btcec.PrivateKey, error) {
|
||||
return btcec.NewPrivateKey(btcec.S256())
|
||||
}
|
||||
|
||||
// readBufferPool is a singleton instance of a buffer pool, used to
|
||||
// conserve memory allocations due to read buffers across the entire
|
||||
// brontide package.
|
||||
readBufferPool = pool.NewReadBuffer(
|
||||
pool.DefaultReadBufferGCInterval,
|
||||
pool.DefaultReadBufferExpiryInterval,
|
||||
)
|
||||
)
|
||||
|
||||
// TODO(roasbeef): free buffer pool?
|
||||
@ -378,15 +367,6 @@ type Machine struct {
|
||||
// next ciphertext header from the wire. The header is a 2 byte length
|
||||
// (of the next ciphertext), followed by a 16 byte MAC.
|
||||
nextCipherHeader [lengthHeaderSize + macSize]byte
|
||||
|
||||
// nextCipherText is a static buffer that we'll use to read in the
|
||||
// bytes of the next cipher text message. As all messages in the
|
||||
// protocol MUST be below 65KB plus our macSize, this will be
|
||||
// sufficient to buffer all messages from the socket when we need to
|
||||
// read the next one. Having a fixed buffer that's re-used also means
|
||||
// that we save on allocations as we don't need to create a new one
|
||||
// each time.
|
||||
nextCipherText *buffer.Read
|
||||
}
|
||||
|
||||
// NewBrontideMachine creates a new instance of the brontide state-machine. If
|
||||
@ -738,43 +718,59 @@ func (b *Machine) WriteMessage(w io.Writer, p []byte) error {
|
||||
// ReadMessage attempts to read the next message from the passed io.Reader. In
|
||||
// the case of an authentication error, a non-nil error is returned.
|
||||
func (b *Machine) ReadMessage(r io.Reader) ([]byte, error) {
|
||||
if _, err := io.ReadFull(r, b.nextCipherHeader[:]); err != nil {
|
||||
pktLen, err := b.ReadHeader(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf := make([]byte, pktLen)
|
||||
return b.ReadBody(r, buf)
|
||||
}
|
||||
|
||||
// ReadHeader attempts to read the next message header from the passed
|
||||
// io.Reader. The header contains the length of the next body including
|
||||
// additional overhead of the MAC. In the case of an authentication error, a
|
||||
// non-nil error is returned.
|
||||
//
|
||||
// NOTE: This method SHOULD NOT be used in the case that the io.Reader may be
|
||||
// adversarial and induce long delays. If the caller needs to set read deadlines
|
||||
// appropriately, it is preferred that they use the split ReadHeader and
|
||||
// ReadBody methods so that the deadlines can be set appropriately on each.
|
||||
func (b *Machine) ReadHeader(r io.Reader) (uint32, error) {
|
||||
_, err := io.ReadFull(r, b.nextCipherHeader[:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Attempt to decrypt+auth the packet length present in the stream.
|
||||
pktLenBytes, err := b.recvCipher.Decrypt(
|
||||
nil, nil, b.nextCipherHeader[:],
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// If this is the first message being read, take a read buffer from the
|
||||
// buffer pool. This is delayed until this point to avoid allocating
|
||||
// read buffers until after the peer has successfully completed the
|
||||
// handshake, and is ready to begin sending lnwire messages.
|
||||
if b.nextCipherText == nil {
|
||||
b.nextCipherText = readBufferPool.Take()
|
||||
runtime.SetFinalizer(b, freeReadBuffer)
|
||||
}
|
||||
|
||||
// Next, using the length read from the packet header, read the
|
||||
// encrypted packet itself.
|
||||
// Compute the packet length that we will need to read off the wire.
|
||||
pktLen := uint32(binary.BigEndian.Uint16(pktLenBytes)) + macSize
|
||||
if _, err := io.ReadFull(r, b.nextCipherText[:pktLen]); err != nil {
|
||||
|
||||
return pktLen, nil
|
||||
}
|
||||
|
||||
// ReadBody attempts to ready the next message body from the passed io.Reader.
|
||||
// The provided buffer MUST be the length indicated by the packet length
|
||||
// returned by the preceding call to ReadHeader. In the case of an
|
||||
// authentication eerror, a non-nil error is returned.
|
||||
func (b *Machine) ReadBody(r io.Reader, buf []byte) ([]byte, error) {
|
||||
// Next, using the length read from the packet header, read the
|
||||
// encrypted packet itself into the buffer allocated by the read
|
||||
// pool.
|
||||
_, err := io.ReadFull(r, buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Finally, decrypt the message held in the buffer, and return a
|
||||
// new byte slice containing the plaintext.
|
||||
// TODO(roasbeef): modify to let pass in slice
|
||||
return b.recvCipher.Decrypt(nil, nil, b.nextCipherText[:pktLen])
|
||||
}
|
||||
|
||||
// freeReadBuffer returns the Machine's read buffer back to the package wide
|
||||
// read buffer pool.
|
||||
//
|
||||
// NOTE: This method should only be called by a Machine's finalizer.
|
||||
func freeReadBuffer(b *Machine) {
|
||||
readBufferPool.Return(b.nextCipherText)
|
||||
b.nextCipherText = nil
|
||||
return b.recvCipher.Decrypt(nil, nil, buf)
|
||||
}
|
||||
|
105
peer.go
105
peer.go
@ -26,6 +26,7 @@ import (
|
||||
"github.com/lightningnetwork/lnd/lnpeer"
|
||||
"github.com/lightningnetwork/lnd/lnwallet"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/pool"
|
||||
"github.com/lightningnetwork/lnd/ticker"
|
||||
)
|
||||
|
||||
@ -43,8 +44,12 @@ const (
|
||||
// idleTimeout is the duration of inactivity before we time out a peer.
|
||||
idleTimeout = 5 * time.Minute
|
||||
|
||||
// writeMessageTimeout is the timeout used when writing a message to peer.
|
||||
writeMessageTimeout = 50 * time.Second
|
||||
// writeMessageTimeout is the timeout used when writing a message to a peer.
|
||||
writeMessageTimeout = 10 * time.Second
|
||||
|
||||
// readMessageTimeout is the timeout used when reading a message from a
|
||||
// peer.
|
||||
readMessageTimeout = 5 * time.Second
|
||||
|
||||
// handshakeTimeout is the timeout used when waiting for peer init message.
|
||||
handshakeTimeout = 15 * time.Second
|
||||
@ -209,11 +214,13 @@ type peer struct {
|
||||
// TODO(halseth): remove when link failure is properly handled.
|
||||
failedChannels map[lnwire.ChannelID]struct{}
|
||||
|
||||
// writeBuf is a buffer that we'll re-use in order to encode wire
|
||||
// messages to write out directly on the socket. By re-using this
|
||||
// buffer, we avoid needing to allocate more memory each time a new
|
||||
// message is to be sent to a peer.
|
||||
writeBuf *buffer.Write
|
||||
// writePool is the task pool to that manages reuse of write buffers.
|
||||
// Write tasks are submitted to the pool in order to conserve the total
|
||||
// number of write buffers allocated at any one time, and decouple write
|
||||
// buffer allocation from the peer life cycle.
|
||||
writePool *pool.Write
|
||||
|
||||
readPool *pool.Read
|
||||
|
||||
queueQuit chan struct{}
|
||||
quit chan struct{}
|
||||
@ -258,7 +265,8 @@ func newPeer(conn net.Conn, connReq *connmgr.ConnReq, server *server,
|
||||
|
||||
chanActiveTimeout: chanActiveTimeout,
|
||||
|
||||
writeBuf: server.writeBufferPool.Take(),
|
||||
writePool: server.writePool,
|
||||
readPool: server.readPool,
|
||||
|
||||
queueQuit: make(chan struct{}),
|
||||
quit: make(chan struct{}),
|
||||
@ -608,11 +616,6 @@ func (p *peer) WaitForDisconnect(ready chan struct{}) {
|
||||
}
|
||||
|
||||
p.wg.Wait()
|
||||
|
||||
// Now that we are certain all active goroutines which could have been
|
||||
// modifying the write buffer have exited, return the buffer to the pool
|
||||
// to be reused.
|
||||
p.server.writeBufferPool.Return(p.writeBuf)
|
||||
}
|
||||
|
||||
// Disconnect terminates the connection with the remote peer. Additionally, a
|
||||
@ -644,11 +647,37 @@ func (p *peer) readNextMessage() (lnwire.Message, error) {
|
||||
return nil, fmt.Errorf("brontide.Conn required to read messages")
|
||||
}
|
||||
|
||||
err := noiseConn.SetReadDeadline(time.Time{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pktLen, err := noiseConn.ReadNextHeader()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// First we'll read the next _full_ message. We do this rather than
|
||||
// reading incrementally from the stream as the Lightning wire protocol
|
||||
// is message oriented and allows nodes to pad on additional data to
|
||||
// the message stream.
|
||||
rawMsg, err := noiseConn.ReadNextMessage()
|
||||
var rawMsg []byte
|
||||
err = p.readPool.Submit(func(buf *buffer.Read) error {
|
||||
// Before reading the body of the message, set the read timeout
|
||||
// accordingly to ensure we don't block other readers using the
|
||||
// pool. We do so only after the task has been scheduled to
|
||||
// ensure the deadline doesn't expire while the message is in
|
||||
// the process of being scheduled.
|
||||
readDeadline := time.Now().Add(readMessageTimeout)
|
||||
readErr := noiseConn.SetReadDeadline(readDeadline)
|
||||
if readErr != nil {
|
||||
return readErr
|
||||
}
|
||||
|
||||
rawMsg, readErr = noiseConn.ReadNextBody(buf[:pktLen])
|
||||
return readErr
|
||||
})
|
||||
|
||||
atomic.AddUint64(&p.bytesReceived, uint64(len(rawMsg)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -1359,33 +1388,33 @@ func (p *peer) writeMessage(msg lnwire.Message) error {
|
||||
|
||||
p.logWireMessage(msg, false)
|
||||
|
||||
// We'll re-slice of static write buffer to allow this new message to
|
||||
// utilize all available space. We also ensure we cap the capacity of
|
||||
// this new buffer to the static buffer which is sized for the largest
|
||||
// possible protocol message.
|
||||
b := bytes.NewBuffer(p.writeBuf[0:0:len(p.writeBuf)])
|
||||
var n int
|
||||
err := p.writePool.Submit(func(buf *bytes.Buffer) error {
|
||||
// Using a buffer allocated by the write pool, encode the
|
||||
// message directly into the buffer.
|
||||
_, writeErr := lnwire.WriteMessage(buf, msg, 0)
|
||||
if writeErr != nil {
|
||||
return writeErr
|
||||
}
|
||||
|
||||
// With the temp buffer created and sliced properly (length zero, full
|
||||
// capacity), we'll now encode the message directly into this buffer.
|
||||
_, err := lnwire.WriteMessage(b, msg, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
// Ensure the write deadline is set before we attempt to send
|
||||
// the message.
|
||||
writeDeadline := time.Now().Add(writeMessageTimeout)
|
||||
writeErr = p.conn.SetWriteDeadline(writeDeadline)
|
||||
if writeErr != nil {
|
||||
return writeErr
|
||||
}
|
||||
|
||||
// Finally, write the message itself in a single swoop.
|
||||
n, writeErr = p.conn.Write(buf.Bytes())
|
||||
return writeErr
|
||||
})
|
||||
|
||||
// Record the number of bytes written on the wire, if any.
|
||||
if n > 0 {
|
||||
atomic.AddUint64(&p.bytesSent, uint64(n))
|
||||
}
|
||||
|
||||
// Compute and set the write deadline we will impose on the remote peer.
|
||||
writeDeadline := time.Now().Add(writeMessageTimeout)
|
||||
err = p.conn.SetWriteDeadline(writeDeadline)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Finally, write the message itself in a single swoop.
|
||||
n, err := p.conn.Write(b.Bytes())
|
||||
|
||||
// Regardless of the error returned, record how many bytes were written
|
||||
// to the wire.
|
||||
atomic.AddUint64(&p.bytesSent, uint64(n))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
87
pool/read.go
Normal file
87
pool/read.go
Normal file
@ -0,0 +1,87 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lightningnetwork/lnd/buffer"
|
||||
)
|
||||
|
||||
// Read is a worker pool specifically designed for sharing access to buffer.Read
|
||||
// objects amongst a set of worker goroutines. This enables an application to
|
||||
// limit the total number of buffer.Read objects allocated at any given time.
|
||||
type Read struct {
|
||||
workerPool *Worker
|
||||
bufferPool *ReadBuffer
|
||||
}
|
||||
|
||||
// NewRead creates a new Read pool, using an underlying ReadBuffer pool to
|
||||
// recycle buffer.Read objects across the lifetime of the Read pool's workers.
|
||||
func NewRead(readBufferPool *ReadBuffer, numWorkers int,
|
||||
workerTimeout time.Duration) *Read {
|
||||
|
||||
r := &Read{
|
||||
bufferPool: readBufferPool,
|
||||
}
|
||||
r.workerPool = NewWorker(&WorkerConfig{
|
||||
NewWorkerState: r.newWorkerState,
|
||||
NumWorkers: numWorkers,
|
||||
WorkerTimeout: workerTimeout,
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// Start safely spins up the Read pool.
|
||||
func (r *Read) Start() error {
|
||||
return r.workerPool.Start()
|
||||
}
|
||||
|
||||
// Stop safely shuts down the Read pool.
|
||||
func (r *Read) Stop() error {
|
||||
return r.workerPool.Stop()
|
||||
}
|
||||
|
||||
// Submit accepts a function closure that provides access to the fresh
|
||||
// buffer.Read object. The function's execution will be allocated to one of the
|
||||
// underlying Worker pool's goroutines.
|
||||
func (r *Read) Submit(inner func(*buffer.Read) error) error {
|
||||
return r.workerPool.Submit(func(s WorkerState) error {
|
||||
state := s.(*readWorkerState)
|
||||
return inner(state.readBuf)
|
||||
})
|
||||
}
|
||||
|
||||
// readWorkerState is the per-goroutine state maintained by a Read pool's
|
||||
// goroutines.
|
||||
type readWorkerState struct {
|
||||
// bufferPool is the pool to which the readBuf will be returned when the
|
||||
// goroutine exits.
|
||||
bufferPool *ReadBuffer
|
||||
|
||||
// readBuf is a buffer taken from the bufferPool on initialization,
|
||||
// which will be cleaned and provided to any tasks that the goroutine
|
||||
// processes before exiting.
|
||||
readBuf *buffer.Read
|
||||
}
|
||||
|
||||
// newWorkerState initializes a new readWorkerState, which will be called
|
||||
// whenever a new goroutine is allocated to begin processing read tasks.
|
||||
func (r *Read) newWorkerState() WorkerState {
|
||||
return &readWorkerState{
|
||||
bufferPool: r.bufferPool,
|
||||
readBuf: r.bufferPool.Take(),
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup returns the readBuf to the underlying buffer pool, and removes the
|
||||
// goroutine's reference to the readBuf.
|
||||
func (r *readWorkerState) Cleanup() {
|
||||
r.bufferPool.Return(r.readBuf)
|
||||
r.readBuf = nil
|
||||
}
|
||||
|
||||
// Reset recycles the readBuf to make it ready for any subsequent tasks the
|
||||
// goroutine may process.
|
||||
func (r *readWorkerState) Reset() {
|
||||
r.readBuf.Recycle()
|
||||
}
|
250
pool/worker.go
Normal file
250
pool/worker.go
Normal file
@ -0,0 +1,250 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrWorkerPoolExiting signals that a shutdown of the Worker has been
|
||||
// requested.
|
||||
var ErrWorkerPoolExiting = errors.New("worker pool exiting")
|
||||
|
||||
// DefaultWorkerTimeout is the default duration after which a worker goroutine
|
||||
// will exit to free up resources after having received no newly submitted
|
||||
// tasks.
|
||||
const DefaultWorkerTimeout = 5 * time.Second
|
||||
|
||||
type (
|
||||
// WorkerState is an interface used by the Worker to abstract the
|
||||
// lifecycle of internal state used by a worker goroutine.
|
||||
WorkerState interface {
|
||||
// Reset clears any internal state that may have been dirtied in
|
||||
// processing a prior task.
|
||||
Reset()
|
||||
|
||||
// Cleanup releases any shared state before a worker goroutine
|
||||
// exits.
|
||||
Cleanup()
|
||||
}
|
||||
|
||||
// WorkerConfig parameterizes the behavior of a Worker pool.
|
||||
WorkerConfig struct {
|
||||
// NewWorkerState allocates a new state for a worker goroutine.
|
||||
// This method is called each time a new worker goroutine is
|
||||
// spawned by the pool.
|
||||
NewWorkerState func() WorkerState
|
||||
|
||||
// NumWorkers is the maximum number of workers the Worker pool
|
||||
// will permit to be allocated. Once the maximum number is
|
||||
// reached, any newly submitted tasks are forced to be processed
|
||||
// by existing worker goroutines.
|
||||
NumWorkers int
|
||||
|
||||
// WorkerTimeout is the duration after which a worker goroutine
|
||||
// will exit after having received no newly submitted tasks.
|
||||
WorkerTimeout time.Duration
|
||||
}
|
||||
|
||||
// Worker maintains a pool of goroutines that process submitted function
|
||||
// closures, and enable more efficient reuse of expensive state.
|
||||
Worker struct {
|
||||
started sync.Once
|
||||
stopped sync.Once
|
||||
|
||||
cfg *WorkerConfig
|
||||
|
||||
// requests is a channel where new tasks are submitted. Tasks
|
||||
// submitted through this channel may cause a new worker
|
||||
// goroutine to be allocated.
|
||||
requests chan *request
|
||||
|
||||
// work is a channel where new tasks are submitted, but is only
|
||||
// read by active worker gorotuines.
|
||||
work chan *request
|
||||
|
||||
// workerSem is a channel-based sempahore that is used to limit
|
||||
// the total number of worker goroutines to the number
|
||||
// prescribed by the WorkerConfig.
|
||||
workerSem chan struct{}
|
||||
|
||||
wg sync.WaitGroup
|
||||
quit chan struct{}
|
||||
}
|
||||
|
||||
// request is a tuple of task closure and error channel that is used to
|
||||
// both submit a task to the pool and respond with any errors
|
||||
// encountered during the task's execution.
|
||||
request struct {
|
||||
fn func(WorkerState) error
|
||||
errChan chan error
|
||||
}
|
||||
)
|
||||
|
||||
// NewWorker initializes a new Worker pool using the provided WorkerConfig.
|
||||
func NewWorker(cfg *WorkerConfig) *Worker {
|
||||
return &Worker{
|
||||
cfg: cfg,
|
||||
requests: make(chan *request),
|
||||
workerSem: make(chan struct{}, cfg.NumWorkers),
|
||||
work: make(chan *request),
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start safely spins up the Worker pool.
|
||||
func (w *Worker) Start() error {
|
||||
w.started.Do(func() {
|
||||
w.wg.Add(1)
|
||||
go w.requestHandler()
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop safely shuts down the Worker pool.
|
||||
func (w *Worker) Stop() error {
|
||||
w.stopped.Do(func() {
|
||||
close(w.quit)
|
||||
w.wg.Wait()
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Submit accepts a function closure to the worker pool. The returned error will
|
||||
// be either the result of the closure's execution or an ErrWorkerPoolExiting if
|
||||
// a shutdown is requested.
|
||||
func (w *Worker) Submit(fn func(WorkerState) error) error {
|
||||
req := &request{
|
||||
fn: fn,
|
||||
errChan: make(chan error, 1),
|
||||
}
|
||||
|
||||
select {
|
||||
|
||||
// Send request to requestHandler, where either a new worker is spawned
|
||||
// or the task will be handed to an existing worker.
|
||||
case w.requests <- req:
|
||||
|
||||
// Fast path directly to existing worker.
|
||||
case w.work <- req:
|
||||
|
||||
case <-w.quit:
|
||||
return ErrWorkerPoolExiting
|
||||
}
|
||||
|
||||
select {
|
||||
|
||||
// Wait for task to be processed.
|
||||
case err := <-req.errChan:
|
||||
return err
|
||||
|
||||
case <-w.quit:
|
||||
return ErrWorkerPoolExiting
|
||||
}
|
||||
}
|
||||
|
||||
// requestHandler processes incoming tasks by either allocating new worker
|
||||
// goroutines to process the incoming tasks, or by feeding a submitted task to
|
||||
// an already running worker goroutine.
|
||||
func (w *Worker) requestHandler() {
|
||||
defer w.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case req := <-w.requests:
|
||||
select {
|
||||
|
||||
// If we have not reached our maximum number of workers,
|
||||
// spawn one to process the submitted request.
|
||||
case w.workerSem <- struct{}{}:
|
||||
w.wg.Add(1)
|
||||
go w.spawnWorker(req)
|
||||
|
||||
// Otherwise, submit the task to any of the active
|
||||
// workers.
|
||||
case w.work <- req:
|
||||
|
||||
case <-w.quit:
|
||||
return
|
||||
}
|
||||
|
||||
case <-w.quit:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// spawnWorker is used when the Worker pool wishes to create a new worker
|
||||
// goroutine. The worker's state is initialized by calling the config's
|
||||
// NewWorkerState method, and will continue to process incoming tasks until the
|
||||
// pool is shut down or no new tasks are received before the worker's timeout
|
||||
// elapses.
|
||||
//
|
||||
// NOTE: This method MUST be run as a goroutine.
|
||||
func (w *Worker) spawnWorker(req *request) {
|
||||
defer w.wg.Done()
|
||||
defer func() { <-w.workerSem }()
|
||||
|
||||
state := w.cfg.NewWorkerState()
|
||||
defer state.Cleanup()
|
||||
|
||||
req.errChan <- req.fn(state)
|
||||
|
||||
// We'll use a timer to implement the worker timeouts, as this reduces
|
||||
// the number of total allocations that would otherwise be necessary
|
||||
// with time.After.
|
||||
var t *time.Timer
|
||||
for {
|
||||
// Before processing another request, we'll reset the worker
|
||||
// state to that each request is processed against a clean
|
||||
// state.
|
||||
state.Reset()
|
||||
|
||||
select {
|
||||
|
||||
// Process any new requests that get submitted. We use a
|
||||
// non-blocking case first so that under high load we can spare
|
||||
// allocating a timeout.
|
||||
case req := <-w.work:
|
||||
req.errChan <- req.fn(state)
|
||||
continue
|
||||
|
||||
case <-w.quit:
|
||||
return
|
||||
|
||||
default:
|
||||
}
|
||||
|
||||
// There were no new requests that could be taken immediately
|
||||
// from the work channel. Initialize or reset the timeout, which
|
||||
// will fire if the worker doesn't receive a new task before
|
||||
// needing to exit.
|
||||
if t != nil {
|
||||
t.Reset(w.cfg.WorkerTimeout)
|
||||
} else {
|
||||
t = time.NewTimer(w.cfg.WorkerTimeout)
|
||||
}
|
||||
|
||||
select {
|
||||
|
||||
// Process any new requests that get submitted.
|
||||
case req := <-w.work:
|
||||
req.errChan <- req.fn(state)
|
||||
|
||||
// Stop the timer, draining the timer's channel if a
|
||||
// notification was already delivered.
|
||||
if !t.Stop() {
|
||||
<-t.C
|
||||
}
|
||||
|
||||
// The timeout has elapsed, meaning the worker did not receive
|
||||
// any new tasks. Exit to allow the worker to return and free
|
||||
// its resources.
|
||||
case <-t.C:
|
||||
return
|
||||
|
||||
case <-w.quit:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
353
pool/worker_test.go
Normal file
353
pool/worker_test.go
Normal file
@ -0,0 +1,353 @@
|
||||
package pool_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
crand "crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lightningnetwork/lnd/buffer"
|
||||
"github.com/lightningnetwork/lnd/pool"
|
||||
)
|
||||
|
||||
type workerPoolTest struct {
|
||||
name string
|
||||
newPool func() interface{}
|
||||
numWorkers int
|
||||
}
|
||||
|
||||
// TestConcreteWorkerPools asserts the behavior of any concrete implementations
|
||||
// of worker pools provided by the pool package. Currently this tests the
|
||||
// pool.Read and pool.Write instances.
|
||||
func TestConcreteWorkerPools(t *testing.T) {
|
||||
const (
|
||||
gcInterval = time.Second
|
||||
expiryInterval = 250 * time.Millisecond
|
||||
numWorkers = 5
|
||||
workerTimeout = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
tests := []workerPoolTest{
|
||||
{
|
||||
name: "write pool",
|
||||
newPool: func() interface{} {
|
||||
bp := pool.NewWriteBuffer(
|
||||
gcInterval, expiryInterval,
|
||||
)
|
||||
|
||||
return pool.NewWrite(
|
||||
bp, numWorkers, workerTimeout,
|
||||
)
|
||||
},
|
||||
numWorkers: numWorkers,
|
||||
},
|
||||
{
|
||||
name: "read pool",
|
||||
newPool: func() interface{} {
|
||||
bp := pool.NewReadBuffer(
|
||||
gcInterval, expiryInterval,
|
||||
)
|
||||
|
||||
return pool.NewRead(
|
||||
bp, numWorkers, workerTimeout,
|
||||
)
|
||||
},
|
||||
numWorkers: numWorkers,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
testWorkerPool(t, test)
|
||||
}
|
||||
}
|
||||
|
||||
func testWorkerPool(t *testing.T, test workerPoolTest) {
|
||||
t.Run(test.name+" non blocking", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
p := test.newPool()
|
||||
startGeneric(t, p)
|
||||
defer stopGeneric(t, p)
|
||||
|
||||
submitNonblockingGeneric(t, p, test.numWorkers)
|
||||
})
|
||||
|
||||
t.Run(test.name+" blocking", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
p := test.newPool()
|
||||
startGeneric(t, p)
|
||||
defer stopGeneric(t, p)
|
||||
|
||||
submitBlockingGeneric(t, p, test.numWorkers)
|
||||
})
|
||||
|
||||
t.Run(test.name+" partial blocking", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
p := test.newPool()
|
||||
startGeneric(t, p)
|
||||
defer stopGeneric(t, p)
|
||||
|
||||
submitPartialBlockingGeneric(t, p, test.numWorkers)
|
||||
})
|
||||
}
|
||||
|
||||
// submitNonblockingGeneric asserts that queueing tasks to the worker pool and
|
||||
// allowing them all to unblock simultaneously results in all of the tasks being
|
||||
// completed in a timely manner.
|
||||
func submitNonblockingGeneric(t *testing.T, p interface{}, nWorkers int) {
|
||||
// We'll submit 2*nWorkers tasks that will all be unblocked
|
||||
// simultaneously.
|
||||
nUnblocked := 2 * nWorkers
|
||||
|
||||
// First we'll queue all of the tasks for the pool.
|
||||
errChan := make(chan error)
|
||||
semChan := make(chan struct{})
|
||||
for i := 0; i < nUnblocked; i++ {
|
||||
go func() { errChan <- submitGeneric(p, semChan) }()
|
||||
}
|
||||
|
||||
// Since we haven't signaled the semaphore, none of the them should
|
||||
// complete.
|
||||
pullNothing(t, errChan)
|
||||
|
||||
// Now, unblock them all simultaneously. All of the tasks should then be
|
||||
// processed in parallel. Afterward, no more errors should come through.
|
||||
close(semChan)
|
||||
pullParllel(t, nUnblocked, errChan)
|
||||
pullNothing(t, errChan)
|
||||
}
|
||||
|
||||
// submitBlockingGeneric asserts that submitting blocking tasks to the pool and
|
||||
// unblocking each sequentially results in a single task being processed at a
|
||||
// time.
|
||||
func submitBlockingGeneric(t *testing.T, p interface{}, nWorkers int) {
|
||||
// We'll submit 2*nWorkers tasks that will be unblocked sequentially.
|
||||
nBlocked := 2 * nWorkers
|
||||
|
||||
// First, queue all of the blocking tasks for the pool.
|
||||
errChan := make(chan error)
|
||||
semChan := make(chan struct{})
|
||||
for i := 0; i < nBlocked; i++ {
|
||||
go func() { errChan <- submitGeneric(p, semChan) }()
|
||||
}
|
||||
|
||||
// Since we haven't signaled the semaphore, none of them should
|
||||
// complete.
|
||||
pullNothing(t, errChan)
|
||||
|
||||
// Now, pull each blocking task sequentially from the pool. Afterwards,
|
||||
// no more errors should come through.
|
||||
pullSequntial(t, nBlocked, errChan, semChan)
|
||||
pullNothing(t, errChan)
|
||||
|
||||
}
|
||||
|
||||
// submitPartialBlockingGeneric tests that so long as one worker is not blocked,
|
||||
// any other non-blocking submitted tasks can still be processed.
|
||||
func submitPartialBlockingGeneric(t *testing.T, p interface{}, nWorkers int) {
|
||||
// We'll submit nWorkers-1 tasks that will be initially blocked, the
|
||||
// remainder will all be unblocked simultaneously. After the unblocked
|
||||
// tasks have finished, we will sequentially unblock the nWorkers-1
|
||||
// tasks that were first submitted.
|
||||
nBlocked := nWorkers - 1
|
||||
nUnblocked := 2*nWorkers - nBlocked
|
||||
|
||||
// First, submit all of the blocking tasks to the pool.
|
||||
errChan := make(chan error)
|
||||
semChan := make(chan struct{})
|
||||
for i := 0; i < nBlocked; i++ {
|
||||
go func() { errChan <- submitGeneric(p, semChan) }()
|
||||
}
|
||||
|
||||
// Since these are all blocked, no errors should be returned yet.
|
||||
pullNothing(t, errChan)
|
||||
|
||||
// Now, add all of the non-blocking task to the pool.
|
||||
semChanNB := make(chan struct{})
|
||||
for i := 0; i < nUnblocked; i++ {
|
||||
go func() { errChan <- submitGeneric(p, semChanNB) }()
|
||||
}
|
||||
|
||||
// Since we haven't unblocked the second batch, we again expect no tasks
|
||||
// to finish.
|
||||
pullNothing(t, errChan)
|
||||
|
||||
// Now, unblock the unblocked task and pull all of them. After they have
|
||||
// been pulled, we should see no more tasks.
|
||||
close(semChanNB)
|
||||
pullParllel(t, nUnblocked, errChan)
|
||||
pullNothing(t, errChan)
|
||||
|
||||
// Finally, unblock each the blocked tasks we added initially, and
|
||||
// assert that no further errors come through.
|
||||
pullSequntial(t, nBlocked, errChan, semChan)
|
||||
pullNothing(t, errChan)
|
||||
}
|
||||
|
||||
func pullNothing(t *testing.T, errChan chan error) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
t.Fatalf("received unexpected error before semaphore "+
|
||||
"release: %v", err)
|
||||
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
func pullParllel(t *testing.T, n int, errChan chan error) {
|
||||
t.Helper()
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
select {
|
||||
case err := <-errChan:
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("task %d was not processed in time", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func pullSequntial(t *testing.T, n int, errChan chan error, semChan chan struct{}) {
|
||||
t.Helper()
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
// Signal for another task to unblock.
|
||||
select {
|
||||
case semChan <- struct{}{}:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("task %d was not unblocked", i)
|
||||
}
|
||||
|
||||
// Wait for the error to arrive, we expect it to be non-nil.
|
||||
select {
|
||||
case err := <-errChan:
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("task %d was not processed in time", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func startGeneric(t *testing.T, p interface{}) {
|
||||
t.Helper()
|
||||
|
||||
var err error
|
||||
switch pp := p.(type) {
|
||||
case *pool.Write:
|
||||
err = pp.Start()
|
||||
|
||||
case *pool.Read:
|
||||
err = pp.Start()
|
||||
|
||||
default:
|
||||
t.Fatalf("unknown worker pool type: %T", p)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unable to start worker pool: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func stopGeneric(t *testing.T, p interface{}) {
|
||||
t.Helper()
|
||||
|
||||
var err error
|
||||
switch pp := p.(type) {
|
||||
case *pool.Write:
|
||||
err = pp.Stop()
|
||||
|
||||
case *pool.Read:
|
||||
err = pp.Stop()
|
||||
|
||||
default:
|
||||
t.Fatalf("unknown worker pool type: %T", p)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unable to stop worker pool: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func submitGeneric(p interface{}, sem <-chan struct{}) error {
|
||||
var err error
|
||||
switch pp := p.(type) {
|
||||
case *pool.Write:
|
||||
err = pp.Submit(func(buf *bytes.Buffer) error {
|
||||
// Verify that the provided buffer has been reset to be
|
||||
// zero length.
|
||||
if buf.Len() != 0 {
|
||||
return fmt.Errorf("buf should be length zero, "+
|
||||
"instead has length %d", buf.Len())
|
||||
}
|
||||
|
||||
// Verify that the capacity of the buffer has the
|
||||
// correct underlying size of a buffer.WriteSize.
|
||||
if buf.Cap() != buffer.WriteSize {
|
||||
return fmt.Errorf("buf should have capacity "+
|
||||
"%d, instead has capacity %d",
|
||||
buffer.WriteSize, buf.Cap())
|
||||
}
|
||||
|
||||
// Sample some random bytes that we'll use to dirty the
|
||||
// buffer.
|
||||
b := make([]byte, rand.Intn(buf.Cap()))
|
||||
_, err := io.ReadFull(crand.Reader, b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write the random bytes the buffer.
|
||||
_, err = buf.Write(b)
|
||||
|
||||
// Wait until this task is signaled to exit.
|
||||
<-sem
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
case *pool.Read:
|
||||
err = pp.Submit(func(buf *buffer.Read) error {
|
||||
// Assert that all of the bytes in the provided array
|
||||
// are zero, indicating that the buffer was reset
|
||||
// between uses.
|
||||
for i := range buf[:] {
|
||||
if buf[i] != 0x00 {
|
||||
return fmt.Errorf("byte %d of "+
|
||||
"buffer.Read should be "+
|
||||
"0, instead is %d", i, buf[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Sample some random bytes to read into the buffer.
|
||||
_, err := io.ReadFull(crand.Reader, buf[:])
|
||||
|
||||
// Wait until this task is signaled to exit.
|
||||
<-sem
|
||||
|
||||
return err
|
||||
|
||||
})
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unknown worker pool type: %T", p)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to submit task: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
100
pool/write.go
Normal file
100
pool/write.go
Normal file
@ -0,0 +1,100 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"time"
|
||||
|
||||
"github.com/lightningnetwork/lnd/buffer"
|
||||
)
|
||||
|
||||
// Write is a worker pool specifically designed for sharing access to
|
||||
// buffer.Write objects amongst a set of worker goroutines. This enables an
|
||||
// application to limit the total number of buffer.Write objects allocated at
|
||||
// any given time.
|
||||
type Write struct {
|
||||
workerPool *Worker
|
||||
bufferPool *WriteBuffer
|
||||
}
|
||||
|
||||
// NewWrite creates a Write pool, using an underlying Writebuffer pool to
|
||||
// recycle buffer.Write objects accross the lifetime of the Write pool's
|
||||
// workers.
|
||||
func NewWrite(writeBufferPool *WriteBuffer, numWorkers int,
|
||||
workerTimeout time.Duration) *Write {
|
||||
|
||||
w := &Write{
|
||||
bufferPool: writeBufferPool,
|
||||
}
|
||||
w.workerPool = NewWorker(&WorkerConfig{
|
||||
NewWorkerState: w.newWorkerState,
|
||||
NumWorkers: numWorkers,
|
||||
WorkerTimeout: workerTimeout,
|
||||
})
|
||||
|
||||
return w
|
||||
}
|
||||
|
||||
// Start safely spins up the Write pool.
|
||||
func (w *Write) Start() error {
|
||||
return w.workerPool.Start()
|
||||
}
|
||||
|
||||
// Stop safely shuts down the Write pool.
|
||||
func (w *Write) Stop() error {
|
||||
return w.workerPool.Stop()
|
||||
}
|
||||
|
||||
// Submit accepts a function closure that provides access to a fresh
|
||||
// bytes.Buffer backed by a buffer.Write object. The function's execution will
|
||||
// be allocated to one of the underlying Worker pool's goroutines.
|
||||
func (w *Write) Submit(inner func(*bytes.Buffer) error) error {
|
||||
return w.workerPool.Submit(func(s WorkerState) error {
|
||||
state := s.(*writeWorkerState)
|
||||
return inner(state.buf)
|
||||
})
|
||||
}
|
||||
|
||||
// writeWorkerState is the per-goroutine state maintained by a Write pool's
|
||||
// goroutines.
|
||||
type writeWorkerState struct {
|
||||
// bufferPool is the pool to which the writeBuf will be returned when
|
||||
// the goroutine exits.
|
||||
bufferPool *WriteBuffer
|
||||
|
||||
// writeBuf is the buffer taken from the bufferPool on initialization,
|
||||
// which will be used to back the buf object provided to any tasks that
|
||||
// the goroutine processes before exiting.
|
||||
writeBuf *buffer.Write
|
||||
|
||||
// buf is a buffer backed by writeBuf, that can be written to by tasks
|
||||
// submitted to the Write pool. The buf will be reset between each task
|
||||
// processed by a goroutine before exiting, and allows the task
|
||||
// submitters to interact with the writeBuf as if it were an io.Writer.
|
||||
buf *bytes.Buffer
|
||||
}
|
||||
|
||||
// newWorkerState initializes a new writeWorkerState, which will be called
|
||||
// whenever a new goroutine is allocated to begin processing write tasks.
|
||||
func (w *Write) newWorkerState() WorkerState {
|
||||
writeBuf := w.bufferPool.Take()
|
||||
|
||||
return &writeWorkerState{
|
||||
bufferPool: w.bufferPool,
|
||||
writeBuf: writeBuf,
|
||||
buf: bytes.NewBuffer(writeBuf[0:0:len(writeBuf)]),
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup returns the writeBuf to the underlying buffer pool, and removes the
|
||||
// goroutine's reference to the readBuf and encapsulating buf.
|
||||
func (w *writeWorkerState) Cleanup() {
|
||||
w.bufferPool.Return(w.writeBuf)
|
||||
w.writeBuf = nil
|
||||
w.buf = nil
|
||||
}
|
||||
|
||||
// Reset resets the bytes.Buffer so that it is zero-length and has the capacity
|
||||
// of the underlying buffer.Write.k
|
||||
func (w *writeWorkerState) Reset() {
|
||||
w.buf.Reset()
|
||||
}
|
38
server.go
38
server.go
@ -171,7 +171,9 @@ type server struct {
|
||||
|
||||
sigPool *lnwallet.SigPool
|
||||
|
||||
writeBufferPool *pool.WriteBuffer
|
||||
writePool *pool.Write
|
||||
|
||||
readPool *pool.Read
|
||||
|
||||
// globalFeatures feature vector which affects HTLCs and thus are also
|
||||
// advertised to other nodes.
|
||||
@ -263,16 +265,31 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl,
|
||||
sharedSecretPath := filepath.Join(graphDir, "sphinxreplay.db")
|
||||
replayLog := htlcswitch.NewDecayedLog(sharedSecretPath, cc.chainNotifier)
|
||||
sphinxRouter := sphinx.NewRouter(privKey, activeNetParams.Params, replayLog)
|
||||
|
||||
writeBufferPool := pool.NewWriteBuffer(
|
||||
pool.DefaultWriteBufferGCInterval,
|
||||
pool.DefaultWriteBufferExpiryInterval,
|
||||
)
|
||||
|
||||
writePool := pool.NewWrite(
|
||||
writeBufferPool, runtime.NumCPU(), pool.DefaultWorkerTimeout,
|
||||
)
|
||||
|
||||
readBufferPool := pool.NewReadBuffer(
|
||||
pool.DefaultReadBufferGCInterval,
|
||||
pool.DefaultReadBufferExpiryInterval,
|
||||
)
|
||||
|
||||
readPool := pool.NewRead(
|
||||
readBufferPool, runtime.NumCPU(), pool.DefaultWorkerTimeout,
|
||||
)
|
||||
|
||||
s := &server{
|
||||
chanDB: chanDB,
|
||||
cc: cc,
|
||||
sigPool: lnwallet.NewSigPool(runtime.NumCPU()*2, cc.signer),
|
||||
writeBufferPool: writeBufferPool,
|
||||
chanDB: chanDB,
|
||||
cc: cc,
|
||||
sigPool: lnwallet.NewSigPool(runtime.NumCPU()*2, cc.signer),
|
||||
writePool: writePool,
|
||||
readPool: readPool,
|
||||
|
||||
invoices: invoices.NewRegistry(chanDB, activeNetParams.Params),
|
||||
|
||||
@ -1010,6 +1027,12 @@ func (s *server) Start() error {
|
||||
if err := s.sigPool.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.writePool.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.readPool.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.cc.chainNotifier.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -1102,7 +1125,6 @@ func (s *server) Stop() error {
|
||||
|
||||
// Shutdown the wallet, funding manager, and the rpc server.
|
||||
s.chanStatusMgr.Stop()
|
||||
s.sigPool.Stop()
|
||||
s.cc.chainNotifier.Stop()
|
||||
s.chanRouter.Stop()
|
||||
s.htlcSwitch.Stop()
|
||||
@ -1129,6 +1151,10 @@ func (s *server) Stop() error {
|
||||
// Wait for all lingering goroutines to quit.
|
||||
s.wg.Wait()
|
||||
|
||||
s.sigPool.Stop()
|
||||
s.writePool.Stop()
|
||||
s.readPool.Stop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user