diff --git a/brontide/conn.go b/brontide/conn.go index 05f17be3..643ff860 100644 --- a/brontide/conn.go +++ b/brontide/conn.go @@ -104,13 +104,34 @@ func Dial(localPriv *btcec.PrivateKey, netAddr *lnwire.NetAddress, return b, nil } -// ReadNextMessage uses the connection in a message-oriented instructing it to -// read the next _full_ message with the brontide stream. This function will -// block until the read succeeds. +// ReadNextMessage uses the connection in a message-oriented manner, instructing +// it to read the next _full_ message with the brontide stream. This function +// will block until the read of the header and body succeeds. +// +// NOTE: This method SHOULD NOT be used in the case that the connection may be +// adversarial and induce long delays. If the caller needs to set read deadlines +// appropriately, it is preferred that they use the split ReadNextHeader and +// ReadNextBody methods so that the deadlines can be set appropriately on each. func (c *Conn) ReadNextMessage() ([]byte, error) { return c.noise.ReadMessage(c.conn) } +// ReadNextHeader uses the connection to read the next header from the brontide +// stream. This function will block until the read of the header succeeds and +// return the packet length (including MAC overhead) that is expected from the +// subsequent call to ReadNextBody. +func (c *Conn) ReadNextHeader() (uint32, error) { + return c.noise.ReadHeader(c.conn) +} + +// ReadNextBody uses the connection to read the next message body from the +// brontide stream. This function will block until the read of the body succeeds +// and return the decrypted payload. The provided buffer MUST be the packet +// length returned by the preceding call to ReadNextHeader. +func (c *Conn) ReadNextBody(buf []byte) ([]byte, error) { + return c.noise.ReadBody(c.conn, buf) +} + // Read reads data from the connection. Read can be made to time out and // return an Error with Timeout() == true after a fixed time limit; see // SetDeadline and SetReadDeadline. diff --git a/brontide/noise.go b/brontide/noise.go index a01d518b..7b6bdb43 100644 --- a/brontide/noise.go +++ b/brontide/noise.go @@ -8,15 +8,12 @@ import ( "fmt" "io" "math" - "runtime" "time" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/hkdf" "github.com/btcsuite/btcd/btcec" - "github.com/lightningnetwork/lnd/buffer" - "github.com/lightningnetwork/lnd/pool" ) const ( @@ -60,14 +57,6 @@ var ( ephemeralGen = func() (*btcec.PrivateKey, error) { return btcec.NewPrivateKey(btcec.S256()) } - - // readBufferPool is a singleton instance of a buffer pool, used to - // conserve memory allocations due to read buffers across the entire - // brontide package. - readBufferPool = pool.NewReadBuffer( - pool.DefaultReadBufferGCInterval, - pool.DefaultReadBufferExpiryInterval, - ) ) // TODO(roasbeef): free buffer pool? @@ -378,15 +367,6 @@ type Machine struct { // next ciphertext header from the wire. The header is a 2 byte length // (of the next ciphertext), followed by a 16 byte MAC. nextCipherHeader [lengthHeaderSize + macSize]byte - - // nextCipherText is a static buffer that we'll use to read in the - // bytes of the next cipher text message. As all messages in the - // protocol MUST be below 65KB plus our macSize, this will be - // sufficient to buffer all messages from the socket when we need to - // read the next one. Having a fixed buffer that's re-used also means - // that we save on allocations as we don't need to create a new one - // each time. - nextCipherText *buffer.Read } // NewBrontideMachine creates a new instance of the brontide state-machine. If @@ -738,43 +718,59 @@ func (b *Machine) WriteMessage(w io.Writer, p []byte) error { // ReadMessage attempts to read the next message from the passed io.Reader. In // the case of an authentication error, a non-nil error is returned. func (b *Machine) ReadMessage(r io.Reader) ([]byte, error) { - if _, err := io.ReadFull(r, b.nextCipherHeader[:]); err != nil { + pktLen, err := b.ReadHeader(r) + if err != nil { return nil, err } + buf := make([]byte, pktLen) + return b.ReadBody(r, buf) +} + +// ReadHeader attempts to read the next message header from the passed +// io.Reader. The header contains the length of the next body including +// additional overhead of the MAC. In the case of an authentication error, a +// non-nil error is returned. +// +// NOTE: This method SHOULD NOT be used in the case that the io.Reader may be +// adversarial and induce long delays. If the caller needs to set read deadlines +// appropriately, it is preferred that they use the split ReadHeader and +// ReadBody methods so that the deadlines can be set appropriately on each. +func (b *Machine) ReadHeader(r io.Reader) (uint32, error) { + _, err := io.ReadFull(r, b.nextCipherHeader[:]) + if err != nil { + return 0, err + } + // Attempt to decrypt+auth the packet length present in the stream. pktLenBytes, err := b.recvCipher.Decrypt( nil, nil, b.nextCipherHeader[:], ) if err != nil { - return nil, err + return 0, err } - // If this is the first message being read, take a read buffer from the - // buffer pool. This is delayed until this point to avoid allocating - // read buffers until after the peer has successfully completed the - // handshake, and is ready to begin sending lnwire messages. - if b.nextCipherText == nil { - b.nextCipherText = readBufferPool.Take() - runtime.SetFinalizer(b, freeReadBuffer) - } - - // Next, using the length read from the packet header, read the - // encrypted packet itself. + // Compute the packet length that we will need to read off the wire. pktLen := uint32(binary.BigEndian.Uint16(pktLenBytes)) + macSize - if _, err := io.ReadFull(r, b.nextCipherText[:pktLen]); err != nil { + + return pktLen, nil +} + +// ReadBody attempts to ready the next message body from the passed io.Reader. +// The provided buffer MUST be the length indicated by the packet length +// returned by the preceding call to ReadHeader. In the case of an +// authentication eerror, a non-nil error is returned. +func (b *Machine) ReadBody(r io.Reader, buf []byte) ([]byte, error) { + // Next, using the length read from the packet header, read the + // encrypted packet itself into the buffer allocated by the read + // pool. + _, err := io.ReadFull(r, buf) + if err != nil { return nil, err } + // Finally, decrypt the message held in the buffer, and return a + // new byte slice containing the plaintext. // TODO(roasbeef): modify to let pass in slice - return b.recvCipher.Decrypt(nil, nil, b.nextCipherText[:pktLen]) -} - -// freeReadBuffer returns the Machine's read buffer back to the package wide -// read buffer pool. -// -// NOTE: This method should only be called by a Machine's finalizer. -func freeReadBuffer(b *Machine) { - readBufferPool.Return(b.nextCipherText) - b.nextCipherText = nil + return b.recvCipher.Decrypt(nil, nil, buf) } diff --git a/peer.go b/peer.go index 41a55892..4b8973dc 100644 --- a/peer.go +++ b/peer.go @@ -26,6 +26,7 @@ import ( "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/pool" "github.com/lightningnetwork/lnd/ticker" ) @@ -43,8 +44,12 @@ const ( // idleTimeout is the duration of inactivity before we time out a peer. idleTimeout = 5 * time.Minute - // writeMessageTimeout is the timeout used when writing a message to peer. - writeMessageTimeout = 50 * time.Second + // writeMessageTimeout is the timeout used when writing a message to a peer. + writeMessageTimeout = 10 * time.Second + + // readMessageTimeout is the timeout used when reading a message from a + // peer. + readMessageTimeout = 5 * time.Second // handshakeTimeout is the timeout used when waiting for peer init message. handshakeTimeout = 15 * time.Second @@ -209,11 +214,13 @@ type peer struct { // TODO(halseth): remove when link failure is properly handled. failedChannels map[lnwire.ChannelID]struct{} - // writeBuf is a buffer that we'll re-use in order to encode wire - // messages to write out directly on the socket. By re-using this - // buffer, we avoid needing to allocate more memory each time a new - // message is to be sent to a peer. - writeBuf *buffer.Write + // writePool is the task pool to that manages reuse of write buffers. + // Write tasks are submitted to the pool in order to conserve the total + // number of write buffers allocated at any one time, and decouple write + // buffer allocation from the peer life cycle. + writePool *pool.Write + + readPool *pool.Read queueQuit chan struct{} quit chan struct{} @@ -258,7 +265,8 @@ func newPeer(conn net.Conn, connReq *connmgr.ConnReq, server *server, chanActiveTimeout: chanActiveTimeout, - writeBuf: server.writeBufferPool.Take(), + writePool: server.writePool, + readPool: server.readPool, queueQuit: make(chan struct{}), quit: make(chan struct{}), @@ -608,11 +616,6 @@ func (p *peer) WaitForDisconnect(ready chan struct{}) { } p.wg.Wait() - - // Now that we are certain all active goroutines which could have been - // modifying the write buffer have exited, return the buffer to the pool - // to be reused. - p.server.writeBufferPool.Return(p.writeBuf) } // Disconnect terminates the connection with the remote peer. Additionally, a @@ -644,11 +647,37 @@ func (p *peer) readNextMessage() (lnwire.Message, error) { return nil, fmt.Errorf("brontide.Conn required to read messages") } + err := noiseConn.SetReadDeadline(time.Time{}) + if err != nil { + return nil, err + } + + pktLen, err := noiseConn.ReadNextHeader() + if err != nil { + return nil, err + } + // First we'll read the next _full_ message. We do this rather than // reading incrementally from the stream as the Lightning wire protocol // is message oriented and allows nodes to pad on additional data to // the message stream. - rawMsg, err := noiseConn.ReadNextMessage() + var rawMsg []byte + err = p.readPool.Submit(func(buf *buffer.Read) error { + // Before reading the body of the message, set the read timeout + // accordingly to ensure we don't block other readers using the + // pool. We do so only after the task has been scheduled to + // ensure the deadline doesn't expire while the message is in + // the process of being scheduled. + readDeadline := time.Now().Add(readMessageTimeout) + readErr := noiseConn.SetReadDeadline(readDeadline) + if readErr != nil { + return readErr + } + + rawMsg, readErr = noiseConn.ReadNextBody(buf[:pktLen]) + return readErr + }) + atomic.AddUint64(&p.bytesReceived, uint64(len(rawMsg))) if err != nil { return nil, err @@ -1359,33 +1388,33 @@ func (p *peer) writeMessage(msg lnwire.Message) error { p.logWireMessage(msg, false) - // We'll re-slice of static write buffer to allow this new message to - // utilize all available space. We also ensure we cap the capacity of - // this new buffer to the static buffer which is sized for the largest - // possible protocol message. - b := bytes.NewBuffer(p.writeBuf[0:0:len(p.writeBuf)]) + var n int + err := p.writePool.Submit(func(buf *bytes.Buffer) error { + // Using a buffer allocated by the write pool, encode the + // message directly into the buffer. + _, writeErr := lnwire.WriteMessage(buf, msg, 0) + if writeErr != nil { + return writeErr + } - // With the temp buffer created and sliced properly (length zero, full - // capacity), we'll now encode the message directly into this buffer. - _, err := lnwire.WriteMessage(b, msg, 0) - if err != nil { - return err + // Ensure the write deadline is set before we attempt to send + // the message. + writeDeadline := time.Now().Add(writeMessageTimeout) + writeErr = p.conn.SetWriteDeadline(writeDeadline) + if writeErr != nil { + return writeErr + } + + // Finally, write the message itself in a single swoop. + n, writeErr = p.conn.Write(buf.Bytes()) + return writeErr + }) + + // Record the number of bytes written on the wire, if any. + if n > 0 { + atomic.AddUint64(&p.bytesSent, uint64(n)) } - // Compute and set the write deadline we will impose on the remote peer. - writeDeadline := time.Now().Add(writeMessageTimeout) - err = p.conn.SetWriteDeadline(writeDeadline) - if err != nil { - return err - } - - // Finally, write the message itself in a single swoop. - n, err := p.conn.Write(b.Bytes()) - - // Regardless of the error returned, record how many bytes were written - // to the wire. - atomic.AddUint64(&p.bytesSent, uint64(n)) - return err } diff --git a/pool/read.go b/pool/read.go new file mode 100644 index 00000000..171a0d2a --- /dev/null +++ b/pool/read.go @@ -0,0 +1,87 @@ +package pool + +import ( + "time" + + "github.com/lightningnetwork/lnd/buffer" +) + +// Read is a worker pool specifically designed for sharing access to buffer.Read +// objects amongst a set of worker goroutines. This enables an application to +// limit the total number of buffer.Read objects allocated at any given time. +type Read struct { + workerPool *Worker + bufferPool *ReadBuffer +} + +// NewRead creates a new Read pool, using an underlying ReadBuffer pool to +// recycle buffer.Read objects across the lifetime of the Read pool's workers. +func NewRead(readBufferPool *ReadBuffer, numWorkers int, + workerTimeout time.Duration) *Read { + + r := &Read{ + bufferPool: readBufferPool, + } + r.workerPool = NewWorker(&WorkerConfig{ + NewWorkerState: r.newWorkerState, + NumWorkers: numWorkers, + WorkerTimeout: workerTimeout, + }) + + return r +} + +// Start safely spins up the Read pool. +func (r *Read) Start() error { + return r.workerPool.Start() +} + +// Stop safely shuts down the Read pool. +func (r *Read) Stop() error { + return r.workerPool.Stop() +} + +// Submit accepts a function closure that provides access to the fresh +// buffer.Read object. The function's execution will be allocated to one of the +// underlying Worker pool's goroutines. +func (r *Read) Submit(inner func(*buffer.Read) error) error { + return r.workerPool.Submit(func(s WorkerState) error { + state := s.(*readWorkerState) + return inner(state.readBuf) + }) +} + +// readWorkerState is the per-goroutine state maintained by a Read pool's +// goroutines. +type readWorkerState struct { + // bufferPool is the pool to which the readBuf will be returned when the + // goroutine exits. + bufferPool *ReadBuffer + + // readBuf is a buffer taken from the bufferPool on initialization, + // which will be cleaned and provided to any tasks that the goroutine + // processes before exiting. + readBuf *buffer.Read +} + +// newWorkerState initializes a new readWorkerState, which will be called +// whenever a new goroutine is allocated to begin processing read tasks. +func (r *Read) newWorkerState() WorkerState { + return &readWorkerState{ + bufferPool: r.bufferPool, + readBuf: r.bufferPool.Take(), + } +} + +// Cleanup returns the readBuf to the underlying buffer pool, and removes the +// goroutine's reference to the readBuf. +func (r *readWorkerState) Cleanup() { + r.bufferPool.Return(r.readBuf) + r.readBuf = nil +} + +// Reset recycles the readBuf to make it ready for any subsequent tasks the +// goroutine may process. +func (r *readWorkerState) Reset() { + r.readBuf.Recycle() +} diff --git a/pool/worker.go b/pool/worker.go new file mode 100644 index 00000000..49325a8b --- /dev/null +++ b/pool/worker.go @@ -0,0 +1,250 @@ +package pool + +import ( + "errors" + "sync" + "time" +) + +// ErrWorkerPoolExiting signals that a shutdown of the Worker has been +// requested. +var ErrWorkerPoolExiting = errors.New("worker pool exiting") + +// DefaultWorkerTimeout is the default duration after which a worker goroutine +// will exit to free up resources after having received no newly submitted +// tasks. +const DefaultWorkerTimeout = 5 * time.Second + +type ( + // WorkerState is an interface used by the Worker to abstract the + // lifecycle of internal state used by a worker goroutine. + WorkerState interface { + // Reset clears any internal state that may have been dirtied in + // processing a prior task. + Reset() + + // Cleanup releases any shared state before a worker goroutine + // exits. + Cleanup() + } + + // WorkerConfig parameterizes the behavior of a Worker pool. + WorkerConfig struct { + // NewWorkerState allocates a new state for a worker goroutine. + // This method is called each time a new worker goroutine is + // spawned by the pool. + NewWorkerState func() WorkerState + + // NumWorkers is the maximum number of workers the Worker pool + // will permit to be allocated. Once the maximum number is + // reached, any newly submitted tasks are forced to be processed + // by existing worker goroutines. + NumWorkers int + + // WorkerTimeout is the duration after which a worker goroutine + // will exit after having received no newly submitted tasks. + WorkerTimeout time.Duration + } + + // Worker maintains a pool of goroutines that process submitted function + // closures, and enable more efficient reuse of expensive state. + Worker struct { + started sync.Once + stopped sync.Once + + cfg *WorkerConfig + + // requests is a channel where new tasks are submitted. Tasks + // submitted through this channel may cause a new worker + // goroutine to be allocated. + requests chan *request + + // work is a channel where new tasks are submitted, but is only + // read by active worker gorotuines. + work chan *request + + // workerSem is a channel-based sempahore that is used to limit + // the total number of worker goroutines to the number + // prescribed by the WorkerConfig. + workerSem chan struct{} + + wg sync.WaitGroup + quit chan struct{} + } + + // request is a tuple of task closure and error channel that is used to + // both submit a task to the pool and respond with any errors + // encountered during the task's execution. + request struct { + fn func(WorkerState) error + errChan chan error + } +) + +// NewWorker initializes a new Worker pool using the provided WorkerConfig. +func NewWorker(cfg *WorkerConfig) *Worker { + return &Worker{ + cfg: cfg, + requests: make(chan *request), + workerSem: make(chan struct{}, cfg.NumWorkers), + work: make(chan *request), + quit: make(chan struct{}), + } +} + +// Start safely spins up the Worker pool. +func (w *Worker) Start() error { + w.started.Do(func() { + w.wg.Add(1) + go w.requestHandler() + }) + return nil +} + +// Stop safely shuts down the Worker pool. +func (w *Worker) Stop() error { + w.stopped.Do(func() { + close(w.quit) + w.wg.Wait() + }) + return nil +} + +// Submit accepts a function closure to the worker pool. The returned error will +// be either the result of the closure's execution or an ErrWorkerPoolExiting if +// a shutdown is requested. +func (w *Worker) Submit(fn func(WorkerState) error) error { + req := &request{ + fn: fn, + errChan: make(chan error, 1), + } + + select { + + // Send request to requestHandler, where either a new worker is spawned + // or the task will be handed to an existing worker. + case w.requests <- req: + + // Fast path directly to existing worker. + case w.work <- req: + + case <-w.quit: + return ErrWorkerPoolExiting + } + + select { + + // Wait for task to be processed. + case err := <-req.errChan: + return err + + case <-w.quit: + return ErrWorkerPoolExiting + } +} + +// requestHandler processes incoming tasks by either allocating new worker +// goroutines to process the incoming tasks, or by feeding a submitted task to +// an already running worker goroutine. +func (w *Worker) requestHandler() { + defer w.wg.Done() + + for { + select { + case req := <-w.requests: + select { + + // If we have not reached our maximum number of workers, + // spawn one to process the submitted request. + case w.workerSem <- struct{}{}: + w.wg.Add(1) + go w.spawnWorker(req) + + // Otherwise, submit the task to any of the active + // workers. + case w.work <- req: + + case <-w.quit: + return + } + + case <-w.quit: + return + } + } +} + +// spawnWorker is used when the Worker pool wishes to create a new worker +// goroutine. The worker's state is initialized by calling the config's +// NewWorkerState method, and will continue to process incoming tasks until the +// pool is shut down or no new tasks are received before the worker's timeout +// elapses. +// +// NOTE: This method MUST be run as a goroutine. +func (w *Worker) spawnWorker(req *request) { + defer w.wg.Done() + defer func() { <-w.workerSem }() + + state := w.cfg.NewWorkerState() + defer state.Cleanup() + + req.errChan <- req.fn(state) + + // We'll use a timer to implement the worker timeouts, as this reduces + // the number of total allocations that would otherwise be necessary + // with time.After. + var t *time.Timer + for { + // Before processing another request, we'll reset the worker + // state to that each request is processed against a clean + // state. + state.Reset() + + select { + + // Process any new requests that get submitted. We use a + // non-blocking case first so that under high load we can spare + // allocating a timeout. + case req := <-w.work: + req.errChan <- req.fn(state) + continue + + case <-w.quit: + return + + default: + } + + // There were no new requests that could be taken immediately + // from the work channel. Initialize or reset the timeout, which + // will fire if the worker doesn't receive a new task before + // needing to exit. + if t != nil { + t.Reset(w.cfg.WorkerTimeout) + } else { + t = time.NewTimer(w.cfg.WorkerTimeout) + } + + select { + + // Process any new requests that get submitted. + case req := <-w.work: + req.errChan <- req.fn(state) + + // Stop the timer, draining the timer's channel if a + // notification was already delivered. + if !t.Stop() { + <-t.C + } + + // The timeout has elapsed, meaning the worker did not receive + // any new tasks. Exit to allow the worker to return and free + // its resources. + case <-t.C: + return + + case <-w.quit: + return + } + } +} diff --git a/pool/worker_test.go b/pool/worker_test.go new file mode 100644 index 00000000..ee23e7a5 --- /dev/null +++ b/pool/worker_test.go @@ -0,0 +1,353 @@ +package pool_test + +import ( + "bytes" + crand "crypto/rand" + "fmt" + "io" + "math/rand" + "testing" + "time" + + "github.com/lightningnetwork/lnd/buffer" + "github.com/lightningnetwork/lnd/pool" +) + +type workerPoolTest struct { + name string + newPool func() interface{} + numWorkers int +} + +// TestConcreteWorkerPools asserts the behavior of any concrete implementations +// of worker pools provided by the pool package. Currently this tests the +// pool.Read and pool.Write instances. +func TestConcreteWorkerPools(t *testing.T) { + const ( + gcInterval = time.Second + expiryInterval = 250 * time.Millisecond + numWorkers = 5 + workerTimeout = 500 * time.Millisecond + ) + + tests := []workerPoolTest{ + { + name: "write pool", + newPool: func() interface{} { + bp := pool.NewWriteBuffer( + gcInterval, expiryInterval, + ) + + return pool.NewWrite( + bp, numWorkers, workerTimeout, + ) + }, + numWorkers: numWorkers, + }, + { + name: "read pool", + newPool: func() interface{} { + bp := pool.NewReadBuffer( + gcInterval, expiryInterval, + ) + + return pool.NewRead( + bp, numWorkers, workerTimeout, + ) + }, + numWorkers: numWorkers, + }, + } + + for _, test := range tests { + testWorkerPool(t, test) + } +} + +func testWorkerPool(t *testing.T, test workerPoolTest) { + t.Run(test.name+" non blocking", func(t *testing.T) { + t.Parallel() + + p := test.newPool() + startGeneric(t, p) + defer stopGeneric(t, p) + + submitNonblockingGeneric(t, p, test.numWorkers) + }) + + t.Run(test.name+" blocking", func(t *testing.T) { + t.Parallel() + + p := test.newPool() + startGeneric(t, p) + defer stopGeneric(t, p) + + submitBlockingGeneric(t, p, test.numWorkers) + }) + + t.Run(test.name+" partial blocking", func(t *testing.T) { + t.Parallel() + + p := test.newPool() + startGeneric(t, p) + defer stopGeneric(t, p) + + submitPartialBlockingGeneric(t, p, test.numWorkers) + }) +} + +// submitNonblockingGeneric asserts that queueing tasks to the worker pool and +// allowing them all to unblock simultaneously results in all of the tasks being +// completed in a timely manner. +func submitNonblockingGeneric(t *testing.T, p interface{}, nWorkers int) { + // We'll submit 2*nWorkers tasks that will all be unblocked + // simultaneously. + nUnblocked := 2 * nWorkers + + // First we'll queue all of the tasks for the pool. + errChan := make(chan error) + semChan := make(chan struct{}) + for i := 0; i < nUnblocked; i++ { + go func() { errChan <- submitGeneric(p, semChan) }() + } + + // Since we haven't signaled the semaphore, none of the them should + // complete. + pullNothing(t, errChan) + + // Now, unblock them all simultaneously. All of the tasks should then be + // processed in parallel. Afterward, no more errors should come through. + close(semChan) + pullParllel(t, nUnblocked, errChan) + pullNothing(t, errChan) +} + +// submitBlockingGeneric asserts that submitting blocking tasks to the pool and +// unblocking each sequentially results in a single task being processed at a +// time. +func submitBlockingGeneric(t *testing.T, p interface{}, nWorkers int) { + // We'll submit 2*nWorkers tasks that will be unblocked sequentially. + nBlocked := 2 * nWorkers + + // First, queue all of the blocking tasks for the pool. + errChan := make(chan error) + semChan := make(chan struct{}) + for i := 0; i < nBlocked; i++ { + go func() { errChan <- submitGeneric(p, semChan) }() + } + + // Since we haven't signaled the semaphore, none of them should + // complete. + pullNothing(t, errChan) + + // Now, pull each blocking task sequentially from the pool. Afterwards, + // no more errors should come through. + pullSequntial(t, nBlocked, errChan, semChan) + pullNothing(t, errChan) + +} + +// submitPartialBlockingGeneric tests that so long as one worker is not blocked, +// any other non-blocking submitted tasks can still be processed. +func submitPartialBlockingGeneric(t *testing.T, p interface{}, nWorkers int) { + // We'll submit nWorkers-1 tasks that will be initially blocked, the + // remainder will all be unblocked simultaneously. After the unblocked + // tasks have finished, we will sequentially unblock the nWorkers-1 + // tasks that were first submitted. + nBlocked := nWorkers - 1 + nUnblocked := 2*nWorkers - nBlocked + + // First, submit all of the blocking tasks to the pool. + errChan := make(chan error) + semChan := make(chan struct{}) + for i := 0; i < nBlocked; i++ { + go func() { errChan <- submitGeneric(p, semChan) }() + } + + // Since these are all blocked, no errors should be returned yet. + pullNothing(t, errChan) + + // Now, add all of the non-blocking task to the pool. + semChanNB := make(chan struct{}) + for i := 0; i < nUnblocked; i++ { + go func() { errChan <- submitGeneric(p, semChanNB) }() + } + + // Since we haven't unblocked the second batch, we again expect no tasks + // to finish. + pullNothing(t, errChan) + + // Now, unblock the unblocked task and pull all of them. After they have + // been pulled, we should see no more tasks. + close(semChanNB) + pullParllel(t, nUnblocked, errChan) + pullNothing(t, errChan) + + // Finally, unblock each the blocked tasks we added initially, and + // assert that no further errors come through. + pullSequntial(t, nBlocked, errChan, semChan) + pullNothing(t, errChan) +} + +func pullNothing(t *testing.T, errChan chan error) { + t.Helper() + + select { + case err := <-errChan: + t.Fatalf("received unexpected error before semaphore "+ + "release: %v", err) + + case <-time.After(time.Second): + } +} + +func pullParllel(t *testing.T, n int, errChan chan error) { + t.Helper() + + for i := 0; i < n; i++ { + select { + case err := <-errChan: + if err != nil { + t.Fatal(err) + } + + case <-time.After(time.Second): + t.Fatalf("task %d was not processed in time", i) + } + } +} + +func pullSequntial(t *testing.T, n int, errChan chan error, semChan chan struct{}) { + t.Helper() + + for i := 0; i < n; i++ { + // Signal for another task to unblock. + select { + case semChan <- struct{}{}: + case <-time.After(time.Second): + t.Fatalf("task %d was not unblocked", i) + } + + // Wait for the error to arrive, we expect it to be non-nil. + select { + case err := <-errChan: + if err != nil { + t.Fatal(err) + } + + case <-time.After(time.Second): + t.Fatalf("task %d was not processed in time", i) + } + } +} + +func startGeneric(t *testing.T, p interface{}) { + t.Helper() + + var err error + switch pp := p.(type) { + case *pool.Write: + err = pp.Start() + + case *pool.Read: + err = pp.Start() + + default: + t.Fatalf("unknown worker pool type: %T", p) + } + + if err != nil { + t.Fatalf("unable to start worker pool: %v", err) + } +} + +func stopGeneric(t *testing.T, p interface{}) { + t.Helper() + + var err error + switch pp := p.(type) { + case *pool.Write: + err = pp.Stop() + + case *pool.Read: + err = pp.Stop() + + default: + t.Fatalf("unknown worker pool type: %T", p) + } + + if err != nil { + t.Fatalf("unable to stop worker pool: %v", err) + } +} + +func submitGeneric(p interface{}, sem <-chan struct{}) error { + var err error + switch pp := p.(type) { + case *pool.Write: + err = pp.Submit(func(buf *bytes.Buffer) error { + // Verify that the provided buffer has been reset to be + // zero length. + if buf.Len() != 0 { + return fmt.Errorf("buf should be length zero, "+ + "instead has length %d", buf.Len()) + } + + // Verify that the capacity of the buffer has the + // correct underlying size of a buffer.WriteSize. + if buf.Cap() != buffer.WriteSize { + return fmt.Errorf("buf should have capacity "+ + "%d, instead has capacity %d", + buffer.WriteSize, buf.Cap()) + } + + // Sample some random bytes that we'll use to dirty the + // buffer. + b := make([]byte, rand.Intn(buf.Cap())) + _, err := io.ReadFull(crand.Reader, b) + if err != nil { + return err + } + + // Write the random bytes the buffer. + _, err = buf.Write(b) + + // Wait until this task is signaled to exit. + <-sem + + return err + }) + + case *pool.Read: + err = pp.Submit(func(buf *buffer.Read) error { + // Assert that all of the bytes in the provided array + // are zero, indicating that the buffer was reset + // between uses. + for i := range buf[:] { + if buf[i] != 0x00 { + return fmt.Errorf("byte %d of "+ + "buffer.Read should be "+ + "0, instead is %d", i, buf[i]) + } + } + + // Sample some random bytes to read into the buffer. + _, err := io.ReadFull(crand.Reader, buf[:]) + + // Wait until this task is signaled to exit. + <-sem + + return err + + }) + + default: + return fmt.Errorf("unknown worker pool type: %T", p) + } + + if err != nil { + return fmt.Errorf("unable to submit task: %v", err) + } + + return nil +} diff --git a/pool/write.go b/pool/write.go new file mode 100644 index 00000000..1322a289 --- /dev/null +++ b/pool/write.go @@ -0,0 +1,100 @@ +package pool + +import ( + "bytes" + "time" + + "github.com/lightningnetwork/lnd/buffer" +) + +// Write is a worker pool specifically designed for sharing access to +// buffer.Write objects amongst a set of worker goroutines. This enables an +// application to limit the total number of buffer.Write objects allocated at +// any given time. +type Write struct { + workerPool *Worker + bufferPool *WriteBuffer +} + +// NewWrite creates a Write pool, using an underlying Writebuffer pool to +// recycle buffer.Write objects accross the lifetime of the Write pool's +// workers. +func NewWrite(writeBufferPool *WriteBuffer, numWorkers int, + workerTimeout time.Duration) *Write { + + w := &Write{ + bufferPool: writeBufferPool, + } + w.workerPool = NewWorker(&WorkerConfig{ + NewWorkerState: w.newWorkerState, + NumWorkers: numWorkers, + WorkerTimeout: workerTimeout, + }) + + return w +} + +// Start safely spins up the Write pool. +func (w *Write) Start() error { + return w.workerPool.Start() +} + +// Stop safely shuts down the Write pool. +func (w *Write) Stop() error { + return w.workerPool.Stop() +} + +// Submit accepts a function closure that provides access to a fresh +// bytes.Buffer backed by a buffer.Write object. The function's execution will +// be allocated to one of the underlying Worker pool's goroutines. +func (w *Write) Submit(inner func(*bytes.Buffer) error) error { + return w.workerPool.Submit(func(s WorkerState) error { + state := s.(*writeWorkerState) + return inner(state.buf) + }) +} + +// writeWorkerState is the per-goroutine state maintained by a Write pool's +// goroutines. +type writeWorkerState struct { + // bufferPool is the pool to which the writeBuf will be returned when + // the goroutine exits. + bufferPool *WriteBuffer + + // writeBuf is the buffer taken from the bufferPool on initialization, + // which will be used to back the buf object provided to any tasks that + // the goroutine processes before exiting. + writeBuf *buffer.Write + + // buf is a buffer backed by writeBuf, that can be written to by tasks + // submitted to the Write pool. The buf will be reset between each task + // processed by a goroutine before exiting, and allows the task + // submitters to interact with the writeBuf as if it were an io.Writer. + buf *bytes.Buffer +} + +// newWorkerState initializes a new writeWorkerState, which will be called +// whenever a new goroutine is allocated to begin processing write tasks. +func (w *Write) newWorkerState() WorkerState { + writeBuf := w.bufferPool.Take() + + return &writeWorkerState{ + bufferPool: w.bufferPool, + writeBuf: writeBuf, + buf: bytes.NewBuffer(writeBuf[0:0:len(writeBuf)]), + } +} + +// Cleanup returns the writeBuf to the underlying buffer pool, and removes the +// goroutine's reference to the readBuf and encapsulating buf. +func (w *writeWorkerState) Cleanup() { + w.bufferPool.Return(w.writeBuf) + w.writeBuf = nil + w.buf = nil +} + +// Reset resets the bytes.Buffer so that it is zero-length and has the capacity +// of the underlying buffer.Write.k +func (w *writeWorkerState) Reset() { + w.buf.Reset() +} diff --git a/server.go b/server.go index d2fa352c..b75ea503 100644 --- a/server.go +++ b/server.go @@ -171,7 +171,9 @@ type server struct { sigPool *lnwallet.SigPool - writeBufferPool *pool.WriteBuffer + writePool *pool.Write + + readPool *pool.Read // globalFeatures feature vector which affects HTLCs and thus are also // advertised to other nodes. @@ -263,16 +265,31 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl, sharedSecretPath := filepath.Join(graphDir, "sphinxreplay.db") replayLog := htlcswitch.NewDecayedLog(sharedSecretPath, cc.chainNotifier) sphinxRouter := sphinx.NewRouter(privKey, activeNetParams.Params, replayLog) + writeBufferPool := pool.NewWriteBuffer( pool.DefaultWriteBufferGCInterval, pool.DefaultWriteBufferExpiryInterval, ) + writePool := pool.NewWrite( + writeBufferPool, runtime.NumCPU(), pool.DefaultWorkerTimeout, + ) + + readBufferPool := pool.NewReadBuffer( + pool.DefaultReadBufferGCInterval, + pool.DefaultReadBufferExpiryInterval, + ) + + readPool := pool.NewRead( + readBufferPool, runtime.NumCPU(), pool.DefaultWorkerTimeout, + ) + s := &server{ - chanDB: chanDB, - cc: cc, - sigPool: lnwallet.NewSigPool(runtime.NumCPU()*2, cc.signer), - writeBufferPool: writeBufferPool, + chanDB: chanDB, + cc: cc, + sigPool: lnwallet.NewSigPool(runtime.NumCPU()*2, cc.signer), + writePool: writePool, + readPool: readPool, invoices: invoices.NewRegistry(chanDB, activeNetParams.Params), @@ -1010,6 +1027,12 @@ func (s *server) Start() error { if err := s.sigPool.Start(); err != nil { return err } + if err := s.writePool.Start(); err != nil { + return err + } + if err := s.readPool.Start(); err != nil { + return err + } if err := s.cc.chainNotifier.Start(); err != nil { return err } @@ -1102,7 +1125,6 @@ func (s *server) Stop() error { // Shutdown the wallet, funding manager, and the rpc server. s.chanStatusMgr.Stop() - s.sigPool.Stop() s.cc.chainNotifier.Stop() s.chanRouter.Stop() s.htlcSwitch.Stop() @@ -1129,6 +1151,10 @@ func (s *server) Stop() error { // Wait for all lingering goroutines to quit. s.wg.Wait() + s.sigPool.Stop() + s.writePool.Stop() + s.readPool.Stop() + return nil }