Merge pull request #2474 from cfromknecht/read-and-write-pools

lnpeer+brontide: reduce memory footprint using read/write pools for message encode/decode
This commit is contained in:
Olaoluwa Osuntokun 2019-02-24 16:39:32 -03:00 committed by GitHub
commit a6ba965bc4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 953 additions and 91 deletions

@ -104,13 +104,34 @@ func Dial(localPriv *btcec.PrivateKey, netAddr *lnwire.NetAddress,
return b, nil
}
// ReadNextMessage uses the connection in a message-oriented instructing it to
// read the next _full_ message with the brontide stream. This function will
// block until the read succeeds.
// ReadNextMessage uses the connection in a message-oriented manner, instructing
// it to read the next _full_ message with the brontide stream. This function
// will block until the read of the header and body succeeds.
//
// NOTE: This method SHOULD NOT be used in the case that the connection may be
// adversarial and induce long delays. If the caller needs to set read deadlines
// appropriately, it is preferred that they use the split ReadNextHeader and
// ReadNextBody methods so that the deadlines can be set appropriately on each.
func (c *Conn) ReadNextMessage() ([]byte, error) {
return c.noise.ReadMessage(c.conn)
}
// ReadNextHeader uses the connection to read the next header from the brontide
// stream. This function will block until the read of the header succeeds and
// return the packet length (including MAC overhead) that is expected from the
// subsequent call to ReadNextBody.
func (c *Conn) ReadNextHeader() (uint32, error) {
return c.noise.ReadHeader(c.conn)
}
// ReadNextBody uses the connection to read the next message body from the
// brontide stream. This function will block until the read of the body succeeds
// and return the decrypted payload. The provided buffer MUST be the packet
// length returned by the preceding call to ReadNextHeader.
func (c *Conn) ReadNextBody(buf []byte) ([]byte, error) {
return c.noise.ReadBody(c.conn, buf)
}
// Read reads data from the connection. Read can be made to time out and
// return an Error with Timeout() == true after a fixed time limit; see
// SetDeadline and SetReadDeadline.

@ -8,15 +8,12 @@ import (
"fmt"
"io"
"math"
"runtime"
"time"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/hkdf"
"github.com/btcsuite/btcd/btcec"
"github.com/lightningnetwork/lnd/buffer"
"github.com/lightningnetwork/lnd/pool"
)
const (
@ -60,14 +57,6 @@ var (
ephemeralGen = func() (*btcec.PrivateKey, error) {
return btcec.NewPrivateKey(btcec.S256())
}
// readBufferPool is a singleton instance of a buffer pool, used to
// conserve memory allocations due to read buffers across the entire
// brontide package.
readBufferPool = pool.NewReadBuffer(
pool.DefaultReadBufferGCInterval,
pool.DefaultReadBufferExpiryInterval,
)
)
// TODO(roasbeef): free buffer pool?
@ -378,15 +367,6 @@ type Machine struct {
// next ciphertext header from the wire. The header is a 2 byte length
// (of the next ciphertext), followed by a 16 byte MAC.
nextCipherHeader [lengthHeaderSize + macSize]byte
// nextCipherText is a static buffer that we'll use to read in the
// bytes of the next cipher text message. As all messages in the
// protocol MUST be below 65KB plus our macSize, this will be
// sufficient to buffer all messages from the socket when we need to
// read the next one. Having a fixed buffer that's re-used also means
// that we save on allocations as we don't need to create a new one
// each time.
nextCipherText *buffer.Read
}
// NewBrontideMachine creates a new instance of the brontide state-machine. If
@ -738,43 +718,59 @@ func (b *Machine) WriteMessage(w io.Writer, p []byte) error {
// ReadMessage attempts to read the next message from the passed io.Reader. In
// the case of an authentication error, a non-nil error is returned.
func (b *Machine) ReadMessage(r io.Reader) ([]byte, error) {
if _, err := io.ReadFull(r, b.nextCipherHeader[:]); err != nil {
pktLen, err := b.ReadHeader(r)
if err != nil {
return nil, err
}
buf := make([]byte, pktLen)
return b.ReadBody(r, buf)
}
// ReadHeader attempts to read the next message header from the passed
// io.Reader. The header contains the length of the next body including
// additional overhead of the MAC. In the case of an authentication error, a
// non-nil error is returned.
//
// NOTE: This method SHOULD NOT be used in the case that the io.Reader may be
// adversarial and induce long delays. If the caller needs to set read deadlines
// appropriately, it is preferred that they use the split ReadHeader and
// ReadBody methods so that the deadlines can be set appropriately on each.
func (b *Machine) ReadHeader(r io.Reader) (uint32, error) {
_, err := io.ReadFull(r, b.nextCipherHeader[:])
if err != nil {
return 0, err
}
// Attempt to decrypt+auth the packet length present in the stream.
pktLenBytes, err := b.recvCipher.Decrypt(
nil, nil, b.nextCipherHeader[:],
)
if err != nil {
return nil, err
return 0, err
}
// If this is the first message being read, take a read buffer from the
// buffer pool. This is delayed until this point to avoid allocating
// read buffers until after the peer has successfully completed the
// handshake, and is ready to begin sending lnwire messages.
if b.nextCipherText == nil {
b.nextCipherText = readBufferPool.Take()
runtime.SetFinalizer(b, freeReadBuffer)
}
// Next, using the length read from the packet header, read the
// encrypted packet itself.
// Compute the packet length that we will need to read off the wire.
pktLen := uint32(binary.BigEndian.Uint16(pktLenBytes)) + macSize
if _, err := io.ReadFull(r, b.nextCipherText[:pktLen]); err != nil {
return pktLen, nil
}
// ReadBody attempts to ready the next message body from the passed io.Reader.
// The provided buffer MUST be the length indicated by the packet length
// returned by the preceding call to ReadHeader. In the case of an
// authentication eerror, a non-nil error is returned.
func (b *Machine) ReadBody(r io.Reader, buf []byte) ([]byte, error) {
// Next, using the length read from the packet header, read the
// encrypted packet itself into the buffer allocated by the read
// pool.
_, err := io.ReadFull(r, buf)
if err != nil {
return nil, err
}
// Finally, decrypt the message held in the buffer, and return a
// new byte slice containing the plaintext.
// TODO(roasbeef): modify to let pass in slice
return b.recvCipher.Decrypt(nil, nil, b.nextCipherText[:pktLen])
}
// freeReadBuffer returns the Machine's read buffer back to the package wide
// read buffer pool.
//
// NOTE: This method should only be called by a Machine's finalizer.
func freeReadBuffer(b *Machine) {
readBufferPool.Return(b.nextCipherText)
b.nextCipherText = nil
return b.recvCipher.Decrypt(nil, nil, buf)
}

105
peer.go

@ -26,6 +26,7 @@ import (
"github.com/lightningnetwork/lnd/lnpeer"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/pool"
"github.com/lightningnetwork/lnd/ticker"
)
@ -43,8 +44,12 @@ const (
// idleTimeout is the duration of inactivity before we time out a peer.
idleTimeout = 5 * time.Minute
// writeMessageTimeout is the timeout used when writing a message to peer.
writeMessageTimeout = 50 * time.Second
// writeMessageTimeout is the timeout used when writing a message to a peer.
writeMessageTimeout = 10 * time.Second
// readMessageTimeout is the timeout used when reading a message from a
// peer.
readMessageTimeout = 5 * time.Second
// handshakeTimeout is the timeout used when waiting for peer init message.
handshakeTimeout = 15 * time.Second
@ -209,11 +214,13 @@ type peer struct {
// TODO(halseth): remove when link failure is properly handled.
failedChannels map[lnwire.ChannelID]struct{}
// writeBuf is a buffer that we'll re-use in order to encode wire
// messages to write out directly on the socket. By re-using this
// buffer, we avoid needing to allocate more memory each time a new
// message is to be sent to a peer.
writeBuf *buffer.Write
// writePool is the task pool to that manages reuse of write buffers.
// Write tasks are submitted to the pool in order to conserve the total
// number of write buffers allocated at any one time, and decouple write
// buffer allocation from the peer life cycle.
writePool *pool.Write
readPool *pool.Read
queueQuit chan struct{}
quit chan struct{}
@ -258,7 +265,8 @@ func newPeer(conn net.Conn, connReq *connmgr.ConnReq, server *server,
chanActiveTimeout: chanActiveTimeout,
writeBuf: server.writeBufferPool.Take(),
writePool: server.writePool,
readPool: server.readPool,
queueQuit: make(chan struct{}),
quit: make(chan struct{}),
@ -608,11 +616,6 @@ func (p *peer) WaitForDisconnect(ready chan struct{}) {
}
p.wg.Wait()
// Now that we are certain all active goroutines which could have been
// modifying the write buffer have exited, return the buffer to the pool
// to be reused.
p.server.writeBufferPool.Return(p.writeBuf)
}
// Disconnect terminates the connection with the remote peer. Additionally, a
@ -644,11 +647,37 @@ func (p *peer) readNextMessage() (lnwire.Message, error) {
return nil, fmt.Errorf("brontide.Conn required to read messages")
}
err := noiseConn.SetReadDeadline(time.Time{})
if err != nil {
return nil, err
}
pktLen, err := noiseConn.ReadNextHeader()
if err != nil {
return nil, err
}
// First we'll read the next _full_ message. We do this rather than
// reading incrementally from the stream as the Lightning wire protocol
// is message oriented and allows nodes to pad on additional data to
// the message stream.
rawMsg, err := noiseConn.ReadNextMessage()
var rawMsg []byte
err = p.readPool.Submit(func(buf *buffer.Read) error {
// Before reading the body of the message, set the read timeout
// accordingly to ensure we don't block other readers using the
// pool. We do so only after the task has been scheduled to
// ensure the deadline doesn't expire while the message is in
// the process of being scheduled.
readDeadline := time.Now().Add(readMessageTimeout)
readErr := noiseConn.SetReadDeadline(readDeadline)
if readErr != nil {
return readErr
}
rawMsg, readErr = noiseConn.ReadNextBody(buf[:pktLen])
return readErr
})
atomic.AddUint64(&p.bytesReceived, uint64(len(rawMsg)))
if err != nil {
return nil, err
@ -1359,33 +1388,33 @@ func (p *peer) writeMessage(msg lnwire.Message) error {
p.logWireMessage(msg, false)
// We'll re-slice of static write buffer to allow this new message to
// utilize all available space. We also ensure we cap the capacity of
// this new buffer to the static buffer which is sized for the largest
// possible protocol message.
b := bytes.NewBuffer(p.writeBuf[0:0:len(p.writeBuf)])
var n int
err := p.writePool.Submit(func(buf *bytes.Buffer) error {
// Using a buffer allocated by the write pool, encode the
// message directly into the buffer.
_, writeErr := lnwire.WriteMessage(buf, msg, 0)
if writeErr != nil {
return writeErr
}
// With the temp buffer created and sliced properly (length zero, full
// capacity), we'll now encode the message directly into this buffer.
_, err := lnwire.WriteMessage(b, msg, 0)
if err != nil {
return err
// Ensure the write deadline is set before we attempt to send
// the message.
writeDeadline := time.Now().Add(writeMessageTimeout)
writeErr = p.conn.SetWriteDeadline(writeDeadline)
if writeErr != nil {
return writeErr
}
// Finally, write the message itself in a single swoop.
n, writeErr = p.conn.Write(buf.Bytes())
return writeErr
})
// Record the number of bytes written on the wire, if any.
if n > 0 {
atomic.AddUint64(&p.bytesSent, uint64(n))
}
// Compute and set the write deadline we will impose on the remote peer.
writeDeadline := time.Now().Add(writeMessageTimeout)
err = p.conn.SetWriteDeadline(writeDeadline)
if err != nil {
return err
}
// Finally, write the message itself in a single swoop.
n, err := p.conn.Write(b.Bytes())
// Regardless of the error returned, record how many bytes were written
// to the wire.
atomic.AddUint64(&p.bytesSent, uint64(n))
return err
}

87
pool/read.go Normal file

@ -0,0 +1,87 @@
package pool
import (
"time"
"github.com/lightningnetwork/lnd/buffer"
)
// Read is a worker pool specifically designed for sharing access to buffer.Read
// objects amongst a set of worker goroutines. This enables an application to
// limit the total number of buffer.Read objects allocated at any given time.
type Read struct {
workerPool *Worker
bufferPool *ReadBuffer
}
// NewRead creates a new Read pool, using an underlying ReadBuffer pool to
// recycle buffer.Read objects across the lifetime of the Read pool's workers.
func NewRead(readBufferPool *ReadBuffer, numWorkers int,
workerTimeout time.Duration) *Read {
r := &Read{
bufferPool: readBufferPool,
}
r.workerPool = NewWorker(&WorkerConfig{
NewWorkerState: r.newWorkerState,
NumWorkers: numWorkers,
WorkerTimeout: workerTimeout,
})
return r
}
// Start safely spins up the Read pool.
func (r *Read) Start() error {
return r.workerPool.Start()
}
// Stop safely shuts down the Read pool.
func (r *Read) Stop() error {
return r.workerPool.Stop()
}
// Submit accepts a function closure that provides access to the fresh
// buffer.Read object. The function's execution will be allocated to one of the
// underlying Worker pool's goroutines.
func (r *Read) Submit(inner func(*buffer.Read) error) error {
return r.workerPool.Submit(func(s WorkerState) error {
state := s.(*readWorkerState)
return inner(state.readBuf)
})
}
// readWorkerState is the per-goroutine state maintained by a Read pool's
// goroutines.
type readWorkerState struct {
// bufferPool is the pool to which the readBuf will be returned when the
// goroutine exits.
bufferPool *ReadBuffer
// readBuf is a buffer taken from the bufferPool on initialization,
// which will be cleaned and provided to any tasks that the goroutine
// processes before exiting.
readBuf *buffer.Read
}
// newWorkerState initializes a new readWorkerState, which will be called
// whenever a new goroutine is allocated to begin processing read tasks.
func (r *Read) newWorkerState() WorkerState {
return &readWorkerState{
bufferPool: r.bufferPool,
readBuf: r.bufferPool.Take(),
}
}
// Cleanup returns the readBuf to the underlying buffer pool, and removes the
// goroutine's reference to the readBuf.
func (r *readWorkerState) Cleanup() {
r.bufferPool.Return(r.readBuf)
r.readBuf = nil
}
// Reset recycles the readBuf to make it ready for any subsequent tasks the
// goroutine may process.
func (r *readWorkerState) Reset() {
r.readBuf.Recycle()
}

250
pool/worker.go Normal file

@ -0,0 +1,250 @@
package pool
import (
"errors"
"sync"
"time"
)
// ErrWorkerPoolExiting signals that a shutdown of the Worker has been
// requested.
var ErrWorkerPoolExiting = errors.New("worker pool exiting")
// DefaultWorkerTimeout is the default duration after which a worker goroutine
// will exit to free up resources after having received no newly submitted
// tasks.
const DefaultWorkerTimeout = 5 * time.Second
type (
// WorkerState is an interface used by the Worker to abstract the
// lifecycle of internal state used by a worker goroutine.
WorkerState interface {
// Reset clears any internal state that may have been dirtied in
// processing a prior task.
Reset()
// Cleanup releases any shared state before a worker goroutine
// exits.
Cleanup()
}
// WorkerConfig parameterizes the behavior of a Worker pool.
WorkerConfig struct {
// NewWorkerState allocates a new state for a worker goroutine.
// This method is called each time a new worker goroutine is
// spawned by the pool.
NewWorkerState func() WorkerState
// NumWorkers is the maximum number of workers the Worker pool
// will permit to be allocated. Once the maximum number is
// reached, any newly submitted tasks are forced to be processed
// by existing worker goroutines.
NumWorkers int
// WorkerTimeout is the duration after which a worker goroutine
// will exit after having received no newly submitted tasks.
WorkerTimeout time.Duration
}
// Worker maintains a pool of goroutines that process submitted function
// closures, and enable more efficient reuse of expensive state.
Worker struct {
started sync.Once
stopped sync.Once
cfg *WorkerConfig
// requests is a channel where new tasks are submitted. Tasks
// submitted through this channel may cause a new worker
// goroutine to be allocated.
requests chan *request
// work is a channel where new tasks are submitted, but is only
// read by active worker gorotuines.
work chan *request
// workerSem is a channel-based sempahore that is used to limit
// the total number of worker goroutines to the number
// prescribed by the WorkerConfig.
workerSem chan struct{}
wg sync.WaitGroup
quit chan struct{}
}
// request is a tuple of task closure and error channel that is used to
// both submit a task to the pool and respond with any errors
// encountered during the task's execution.
request struct {
fn func(WorkerState) error
errChan chan error
}
)
// NewWorker initializes a new Worker pool using the provided WorkerConfig.
func NewWorker(cfg *WorkerConfig) *Worker {
return &Worker{
cfg: cfg,
requests: make(chan *request),
workerSem: make(chan struct{}, cfg.NumWorkers),
work: make(chan *request),
quit: make(chan struct{}),
}
}
// Start safely spins up the Worker pool.
func (w *Worker) Start() error {
w.started.Do(func() {
w.wg.Add(1)
go w.requestHandler()
})
return nil
}
// Stop safely shuts down the Worker pool.
func (w *Worker) Stop() error {
w.stopped.Do(func() {
close(w.quit)
w.wg.Wait()
})
return nil
}
// Submit accepts a function closure to the worker pool. The returned error will
// be either the result of the closure's execution or an ErrWorkerPoolExiting if
// a shutdown is requested.
func (w *Worker) Submit(fn func(WorkerState) error) error {
req := &request{
fn: fn,
errChan: make(chan error, 1),
}
select {
// Send request to requestHandler, where either a new worker is spawned
// or the task will be handed to an existing worker.
case w.requests <- req:
// Fast path directly to existing worker.
case w.work <- req:
case <-w.quit:
return ErrWorkerPoolExiting
}
select {
// Wait for task to be processed.
case err := <-req.errChan:
return err
case <-w.quit:
return ErrWorkerPoolExiting
}
}
// requestHandler processes incoming tasks by either allocating new worker
// goroutines to process the incoming tasks, or by feeding a submitted task to
// an already running worker goroutine.
func (w *Worker) requestHandler() {
defer w.wg.Done()
for {
select {
case req := <-w.requests:
select {
// If we have not reached our maximum number of workers,
// spawn one to process the submitted request.
case w.workerSem <- struct{}{}:
w.wg.Add(1)
go w.spawnWorker(req)
// Otherwise, submit the task to any of the active
// workers.
case w.work <- req:
case <-w.quit:
return
}
case <-w.quit:
return
}
}
}
// spawnWorker is used when the Worker pool wishes to create a new worker
// goroutine. The worker's state is initialized by calling the config's
// NewWorkerState method, and will continue to process incoming tasks until the
// pool is shut down or no new tasks are received before the worker's timeout
// elapses.
//
// NOTE: This method MUST be run as a goroutine.
func (w *Worker) spawnWorker(req *request) {
defer w.wg.Done()
defer func() { <-w.workerSem }()
state := w.cfg.NewWorkerState()
defer state.Cleanup()
req.errChan <- req.fn(state)
// We'll use a timer to implement the worker timeouts, as this reduces
// the number of total allocations that would otherwise be necessary
// with time.After.
var t *time.Timer
for {
// Before processing another request, we'll reset the worker
// state to that each request is processed against a clean
// state.
state.Reset()
select {
// Process any new requests that get submitted. We use a
// non-blocking case first so that under high load we can spare
// allocating a timeout.
case req := <-w.work:
req.errChan <- req.fn(state)
continue
case <-w.quit:
return
default:
}
// There were no new requests that could be taken immediately
// from the work channel. Initialize or reset the timeout, which
// will fire if the worker doesn't receive a new task before
// needing to exit.
if t != nil {
t.Reset(w.cfg.WorkerTimeout)
} else {
t = time.NewTimer(w.cfg.WorkerTimeout)
}
select {
// Process any new requests that get submitted.
case req := <-w.work:
req.errChan <- req.fn(state)
// Stop the timer, draining the timer's channel if a
// notification was already delivered.
if !t.Stop() {
<-t.C
}
// The timeout has elapsed, meaning the worker did not receive
// any new tasks. Exit to allow the worker to return and free
// its resources.
case <-t.C:
return
case <-w.quit:
return
}
}
}

353
pool/worker_test.go Normal file

@ -0,0 +1,353 @@
package pool_test
import (
"bytes"
crand "crypto/rand"
"fmt"
"io"
"math/rand"
"testing"
"time"
"github.com/lightningnetwork/lnd/buffer"
"github.com/lightningnetwork/lnd/pool"
)
type workerPoolTest struct {
name string
newPool func() interface{}
numWorkers int
}
// TestConcreteWorkerPools asserts the behavior of any concrete implementations
// of worker pools provided by the pool package. Currently this tests the
// pool.Read and pool.Write instances.
func TestConcreteWorkerPools(t *testing.T) {
const (
gcInterval = time.Second
expiryInterval = 250 * time.Millisecond
numWorkers = 5
workerTimeout = 500 * time.Millisecond
)
tests := []workerPoolTest{
{
name: "write pool",
newPool: func() interface{} {
bp := pool.NewWriteBuffer(
gcInterval, expiryInterval,
)
return pool.NewWrite(
bp, numWorkers, workerTimeout,
)
},
numWorkers: numWorkers,
},
{
name: "read pool",
newPool: func() interface{} {
bp := pool.NewReadBuffer(
gcInterval, expiryInterval,
)
return pool.NewRead(
bp, numWorkers, workerTimeout,
)
},
numWorkers: numWorkers,
},
}
for _, test := range tests {
testWorkerPool(t, test)
}
}
func testWorkerPool(t *testing.T, test workerPoolTest) {
t.Run(test.name+" non blocking", func(t *testing.T) {
t.Parallel()
p := test.newPool()
startGeneric(t, p)
defer stopGeneric(t, p)
submitNonblockingGeneric(t, p, test.numWorkers)
})
t.Run(test.name+" blocking", func(t *testing.T) {
t.Parallel()
p := test.newPool()
startGeneric(t, p)
defer stopGeneric(t, p)
submitBlockingGeneric(t, p, test.numWorkers)
})
t.Run(test.name+" partial blocking", func(t *testing.T) {
t.Parallel()
p := test.newPool()
startGeneric(t, p)
defer stopGeneric(t, p)
submitPartialBlockingGeneric(t, p, test.numWorkers)
})
}
// submitNonblockingGeneric asserts that queueing tasks to the worker pool and
// allowing them all to unblock simultaneously results in all of the tasks being
// completed in a timely manner.
func submitNonblockingGeneric(t *testing.T, p interface{}, nWorkers int) {
// We'll submit 2*nWorkers tasks that will all be unblocked
// simultaneously.
nUnblocked := 2 * nWorkers
// First we'll queue all of the tasks for the pool.
errChan := make(chan error)
semChan := make(chan struct{})
for i := 0; i < nUnblocked; i++ {
go func() { errChan <- submitGeneric(p, semChan) }()
}
// Since we haven't signaled the semaphore, none of the them should
// complete.
pullNothing(t, errChan)
// Now, unblock them all simultaneously. All of the tasks should then be
// processed in parallel. Afterward, no more errors should come through.
close(semChan)
pullParllel(t, nUnblocked, errChan)
pullNothing(t, errChan)
}
// submitBlockingGeneric asserts that submitting blocking tasks to the pool and
// unblocking each sequentially results in a single task being processed at a
// time.
func submitBlockingGeneric(t *testing.T, p interface{}, nWorkers int) {
// We'll submit 2*nWorkers tasks that will be unblocked sequentially.
nBlocked := 2 * nWorkers
// First, queue all of the blocking tasks for the pool.
errChan := make(chan error)
semChan := make(chan struct{})
for i := 0; i < nBlocked; i++ {
go func() { errChan <- submitGeneric(p, semChan) }()
}
// Since we haven't signaled the semaphore, none of them should
// complete.
pullNothing(t, errChan)
// Now, pull each blocking task sequentially from the pool. Afterwards,
// no more errors should come through.
pullSequntial(t, nBlocked, errChan, semChan)
pullNothing(t, errChan)
}
// submitPartialBlockingGeneric tests that so long as one worker is not blocked,
// any other non-blocking submitted tasks can still be processed.
func submitPartialBlockingGeneric(t *testing.T, p interface{}, nWorkers int) {
// We'll submit nWorkers-1 tasks that will be initially blocked, the
// remainder will all be unblocked simultaneously. After the unblocked
// tasks have finished, we will sequentially unblock the nWorkers-1
// tasks that were first submitted.
nBlocked := nWorkers - 1
nUnblocked := 2*nWorkers - nBlocked
// First, submit all of the blocking tasks to the pool.
errChan := make(chan error)
semChan := make(chan struct{})
for i := 0; i < nBlocked; i++ {
go func() { errChan <- submitGeneric(p, semChan) }()
}
// Since these are all blocked, no errors should be returned yet.
pullNothing(t, errChan)
// Now, add all of the non-blocking task to the pool.
semChanNB := make(chan struct{})
for i := 0; i < nUnblocked; i++ {
go func() { errChan <- submitGeneric(p, semChanNB) }()
}
// Since we haven't unblocked the second batch, we again expect no tasks
// to finish.
pullNothing(t, errChan)
// Now, unblock the unblocked task and pull all of them. After they have
// been pulled, we should see no more tasks.
close(semChanNB)
pullParllel(t, nUnblocked, errChan)
pullNothing(t, errChan)
// Finally, unblock each the blocked tasks we added initially, and
// assert that no further errors come through.
pullSequntial(t, nBlocked, errChan, semChan)
pullNothing(t, errChan)
}
func pullNothing(t *testing.T, errChan chan error) {
t.Helper()
select {
case err := <-errChan:
t.Fatalf("received unexpected error before semaphore "+
"release: %v", err)
case <-time.After(time.Second):
}
}
func pullParllel(t *testing.T, n int, errChan chan error) {
t.Helper()
for i := 0; i < n; i++ {
select {
case err := <-errChan:
if err != nil {
t.Fatal(err)
}
case <-time.After(time.Second):
t.Fatalf("task %d was not processed in time", i)
}
}
}
func pullSequntial(t *testing.T, n int, errChan chan error, semChan chan struct{}) {
t.Helper()
for i := 0; i < n; i++ {
// Signal for another task to unblock.
select {
case semChan <- struct{}{}:
case <-time.After(time.Second):
t.Fatalf("task %d was not unblocked", i)
}
// Wait for the error to arrive, we expect it to be non-nil.
select {
case err := <-errChan:
if err != nil {
t.Fatal(err)
}
case <-time.After(time.Second):
t.Fatalf("task %d was not processed in time", i)
}
}
}
func startGeneric(t *testing.T, p interface{}) {
t.Helper()
var err error
switch pp := p.(type) {
case *pool.Write:
err = pp.Start()
case *pool.Read:
err = pp.Start()
default:
t.Fatalf("unknown worker pool type: %T", p)
}
if err != nil {
t.Fatalf("unable to start worker pool: %v", err)
}
}
func stopGeneric(t *testing.T, p interface{}) {
t.Helper()
var err error
switch pp := p.(type) {
case *pool.Write:
err = pp.Stop()
case *pool.Read:
err = pp.Stop()
default:
t.Fatalf("unknown worker pool type: %T", p)
}
if err != nil {
t.Fatalf("unable to stop worker pool: %v", err)
}
}
func submitGeneric(p interface{}, sem <-chan struct{}) error {
var err error
switch pp := p.(type) {
case *pool.Write:
err = pp.Submit(func(buf *bytes.Buffer) error {
// Verify that the provided buffer has been reset to be
// zero length.
if buf.Len() != 0 {
return fmt.Errorf("buf should be length zero, "+
"instead has length %d", buf.Len())
}
// Verify that the capacity of the buffer has the
// correct underlying size of a buffer.WriteSize.
if buf.Cap() != buffer.WriteSize {
return fmt.Errorf("buf should have capacity "+
"%d, instead has capacity %d",
buffer.WriteSize, buf.Cap())
}
// Sample some random bytes that we'll use to dirty the
// buffer.
b := make([]byte, rand.Intn(buf.Cap()))
_, err := io.ReadFull(crand.Reader, b)
if err != nil {
return err
}
// Write the random bytes the buffer.
_, err = buf.Write(b)
// Wait until this task is signaled to exit.
<-sem
return err
})
case *pool.Read:
err = pp.Submit(func(buf *buffer.Read) error {
// Assert that all of the bytes in the provided array
// are zero, indicating that the buffer was reset
// between uses.
for i := range buf[:] {
if buf[i] != 0x00 {
return fmt.Errorf("byte %d of "+
"buffer.Read should be "+
"0, instead is %d", i, buf[i])
}
}
// Sample some random bytes to read into the buffer.
_, err := io.ReadFull(crand.Reader, buf[:])
// Wait until this task is signaled to exit.
<-sem
return err
})
default:
return fmt.Errorf("unknown worker pool type: %T", p)
}
if err != nil {
return fmt.Errorf("unable to submit task: %v", err)
}
return nil
}

100
pool/write.go Normal file

@ -0,0 +1,100 @@
package pool
import (
"bytes"
"time"
"github.com/lightningnetwork/lnd/buffer"
)
// Write is a worker pool specifically designed for sharing access to
// buffer.Write objects amongst a set of worker goroutines. This enables an
// application to limit the total number of buffer.Write objects allocated at
// any given time.
type Write struct {
workerPool *Worker
bufferPool *WriteBuffer
}
// NewWrite creates a Write pool, using an underlying Writebuffer pool to
// recycle buffer.Write objects accross the lifetime of the Write pool's
// workers.
func NewWrite(writeBufferPool *WriteBuffer, numWorkers int,
workerTimeout time.Duration) *Write {
w := &Write{
bufferPool: writeBufferPool,
}
w.workerPool = NewWorker(&WorkerConfig{
NewWorkerState: w.newWorkerState,
NumWorkers: numWorkers,
WorkerTimeout: workerTimeout,
})
return w
}
// Start safely spins up the Write pool.
func (w *Write) Start() error {
return w.workerPool.Start()
}
// Stop safely shuts down the Write pool.
func (w *Write) Stop() error {
return w.workerPool.Stop()
}
// Submit accepts a function closure that provides access to a fresh
// bytes.Buffer backed by a buffer.Write object. The function's execution will
// be allocated to one of the underlying Worker pool's goroutines.
func (w *Write) Submit(inner func(*bytes.Buffer) error) error {
return w.workerPool.Submit(func(s WorkerState) error {
state := s.(*writeWorkerState)
return inner(state.buf)
})
}
// writeWorkerState is the per-goroutine state maintained by a Write pool's
// goroutines.
type writeWorkerState struct {
// bufferPool is the pool to which the writeBuf will be returned when
// the goroutine exits.
bufferPool *WriteBuffer
// writeBuf is the buffer taken from the bufferPool on initialization,
// which will be used to back the buf object provided to any tasks that
// the goroutine processes before exiting.
writeBuf *buffer.Write
// buf is a buffer backed by writeBuf, that can be written to by tasks
// submitted to the Write pool. The buf will be reset between each task
// processed by a goroutine before exiting, and allows the task
// submitters to interact with the writeBuf as if it were an io.Writer.
buf *bytes.Buffer
}
// newWorkerState initializes a new writeWorkerState, which will be called
// whenever a new goroutine is allocated to begin processing write tasks.
func (w *Write) newWorkerState() WorkerState {
writeBuf := w.bufferPool.Take()
return &writeWorkerState{
bufferPool: w.bufferPool,
writeBuf: writeBuf,
buf: bytes.NewBuffer(writeBuf[0:0:len(writeBuf)]),
}
}
// Cleanup returns the writeBuf to the underlying buffer pool, and removes the
// goroutine's reference to the readBuf and encapsulating buf.
func (w *writeWorkerState) Cleanup() {
w.bufferPool.Return(w.writeBuf)
w.writeBuf = nil
w.buf = nil
}
// Reset resets the bytes.Buffer so that it is zero-length and has the capacity
// of the underlying buffer.Write.k
func (w *writeWorkerState) Reset() {
w.buf.Reset()
}

@ -171,7 +171,9 @@ type server struct {
sigPool *lnwallet.SigPool
writeBufferPool *pool.WriteBuffer
writePool *pool.Write
readPool *pool.Read
// globalFeatures feature vector which affects HTLCs and thus are also
// advertised to other nodes.
@ -263,16 +265,31 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl,
sharedSecretPath := filepath.Join(graphDir, "sphinxreplay.db")
replayLog := htlcswitch.NewDecayedLog(sharedSecretPath, cc.chainNotifier)
sphinxRouter := sphinx.NewRouter(privKey, activeNetParams.Params, replayLog)
writeBufferPool := pool.NewWriteBuffer(
pool.DefaultWriteBufferGCInterval,
pool.DefaultWriteBufferExpiryInterval,
)
writePool := pool.NewWrite(
writeBufferPool, runtime.NumCPU(), pool.DefaultWorkerTimeout,
)
readBufferPool := pool.NewReadBuffer(
pool.DefaultReadBufferGCInterval,
pool.DefaultReadBufferExpiryInterval,
)
readPool := pool.NewRead(
readBufferPool, runtime.NumCPU(), pool.DefaultWorkerTimeout,
)
s := &server{
chanDB: chanDB,
cc: cc,
sigPool: lnwallet.NewSigPool(runtime.NumCPU()*2, cc.signer),
writeBufferPool: writeBufferPool,
chanDB: chanDB,
cc: cc,
sigPool: lnwallet.NewSigPool(runtime.NumCPU()*2, cc.signer),
writePool: writePool,
readPool: readPool,
invoices: invoices.NewRegistry(chanDB, activeNetParams.Params),
@ -1010,6 +1027,12 @@ func (s *server) Start() error {
if err := s.sigPool.Start(); err != nil {
return err
}
if err := s.writePool.Start(); err != nil {
return err
}
if err := s.readPool.Start(); err != nil {
return err
}
if err := s.cc.chainNotifier.Start(); err != nil {
return err
}
@ -1102,7 +1125,6 @@ func (s *server) Stop() error {
// Shutdown the wallet, funding manager, and the rpc server.
s.chanStatusMgr.Stop()
s.sigPool.Stop()
s.cc.chainNotifier.Stop()
s.chanRouter.Stop()
s.htlcSwitch.Stop()
@ -1129,6 +1151,10 @@ func (s *server) Stop() error {
// Wait for all lingering goroutines to quit.
s.wg.Wait()
s.sigPool.Stop()
s.writePool.Stop()
s.readPool.Stop()
return nil
}