From 37d866328bad079f6a118bde12894b5f51d08984 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 21 Feb 2019 20:10:06 -0800 Subject: [PATCH 1/9] pool/worker: add generic Worker pool --- pool/worker.go | 250 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 250 insertions(+) create mode 100644 pool/worker.go diff --git a/pool/worker.go b/pool/worker.go new file mode 100644 index 00000000..49325a8b --- /dev/null +++ b/pool/worker.go @@ -0,0 +1,250 @@ +package pool + +import ( + "errors" + "sync" + "time" +) + +// ErrWorkerPoolExiting signals that a shutdown of the Worker has been +// requested. +var ErrWorkerPoolExiting = errors.New("worker pool exiting") + +// DefaultWorkerTimeout is the default duration after which a worker goroutine +// will exit to free up resources after having received no newly submitted +// tasks. +const DefaultWorkerTimeout = 5 * time.Second + +type ( + // WorkerState is an interface used by the Worker to abstract the + // lifecycle of internal state used by a worker goroutine. + WorkerState interface { + // Reset clears any internal state that may have been dirtied in + // processing a prior task. + Reset() + + // Cleanup releases any shared state before a worker goroutine + // exits. + Cleanup() + } + + // WorkerConfig parameterizes the behavior of a Worker pool. + WorkerConfig struct { + // NewWorkerState allocates a new state for a worker goroutine. + // This method is called each time a new worker goroutine is + // spawned by the pool. + NewWorkerState func() WorkerState + + // NumWorkers is the maximum number of workers the Worker pool + // will permit to be allocated. Once the maximum number is + // reached, any newly submitted tasks are forced to be processed + // by existing worker goroutines. + NumWorkers int + + // WorkerTimeout is the duration after which a worker goroutine + // will exit after having received no newly submitted tasks. + WorkerTimeout time.Duration + } + + // Worker maintains a pool of goroutines that process submitted function + // closures, and enable more efficient reuse of expensive state. + Worker struct { + started sync.Once + stopped sync.Once + + cfg *WorkerConfig + + // requests is a channel where new tasks are submitted. Tasks + // submitted through this channel may cause a new worker + // goroutine to be allocated. + requests chan *request + + // work is a channel where new tasks are submitted, but is only + // read by active worker gorotuines. + work chan *request + + // workerSem is a channel-based sempahore that is used to limit + // the total number of worker goroutines to the number + // prescribed by the WorkerConfig. + workerSem chan struct{} + + wg sync.WaitGroup + quit chan struct{} + } + + // request is a tuple of task closure and error channel that is used to + // both submit a task to the pool and respond with any errors + // encountered during the task's execution. + request struct { + fn func(WorkerState) error + errChan chan error + } +) + +// NewWorker initializes a new Worker pool using the provided WorkerConfig. +func NewWorker(cfg *WorkerConfig) *Worker { + return &Worker{ + cfg: cfg, + requests: make(chan *request), + workerSem: make(chan struct{}, cfg.NumWorkers), + work: make(chan *request), + quit: make(chan struct{}), + } +} + +// Start safely spins up the Worker pool. +func (w *Worker) Start() error { + w.started.Do(func() { + w.wg.Add(1) + go w.requestHandler() + }) + return nil +} + +// Stop safely shuts down the Worker pool. +func (w *Worker) Stop() error { + w.stopped.Do(func() { + close(w.quit) + w.wg.Wait() + }) + return nil +} + +// Submit accepts a function closure to the worker pool. The returned error will +// be either the result of the closure's execution or an ErrWorkerPoolExiting if +// a shutdown is requested. +func (w *Worker) Submit(fn func(WorkerState) error) error { + req := &request{ + fn: fn, + errChan: make(chan error, 1), + } + + select { + + // Send request to requestHandler, where either a new worker is spawned + // or the task will be handed to an existing worker. + case w.requests <- req: + + // Fast path directly to existing worker. + case w.work <- req: + + case <-w.quit: + return ErrWorkerPoolExiting + } + + select { + + // Wait for task to be processed. + case err := <-req.errChan: + return err + + case <-w.quit: + return ErrWorkerPoolExiting + } +} + +// requestHandler processes incoming tasks by either allocating new worker +// goroutines to process the incoming tasks, or by feeding a submitted task to +// an already running worker goroutine. +func (w *Worker) requestHandler() { + defer w.wg.Done() + + for { + select { + case req := <-w.requests: + select { + + // If we have not reached our maximum number of workers, + // spawn one to process the submitted request. + case w.workerSem <- struct{}{}: + w.wg.Add(1) + go w.spawnWorker(req) + + // Otherwise, submit the task to any of the active + // workers. + case w.work <- req: + + case <-w.quit: + return + } + + case <-w.quit: + return + } + } +} + +// spawnWorker is used when the Worker pool wishes to create a new worker +// goroutine. The worker's state is initialized by calling the config's +// NewWorkerState method, and will continue to process incoming tasks until the +// pool is shut down or no new tasks are received before the worker's timeout +// elapses. +// +// NOTE: This method MUST be run as a goroutine. +func (w *Worker) spawnWorker(req *request) { + defer w.wg.Done() + defer func() { <-w.workerSem }() + + state := w.cfg.NewWorkerState() + defer state.Cleanup() + + req.errChan <- req.fn(state) + + // We'll use a timer to implement the worker timeouts, as this reduces + // the number of total allocations that would otherwise be necessary + // with time.After. + var t *time.Timer + for { + // Before processing another request, we'll reset the worker + // state to that each request is processed against a clean + // state. + state.Reset() + + select { + + // Process any new requests that get submitted. We use a + // non-blocking case first so that under high load we can spare + // allocating a timeout. + case req := <-w.work: + req.errChan <- req.fn(state) + continue + + case <-w.quit: + return + + default: + } + + // There were no new requests that could be taken immediately + // from the work channel. Initialize or reset the timeout, which + // will fire if the worker doesn't receive a new task before + // needing to exit. + if t != nil { + t.Reset(w.cfg.WorkerTimeout) + } else { + t = time.NewTimer(w.cfg.WorkerTimeout) + } + + select { + + // Process any new requests that get submitted. + case req := <-w.work: + req.errChan <- req.fn(state) + + // Stop the timer, draining the timer's channel if a + // notification was already delivered. + if !t.Stop() { + <-t.C + } + + // The timeout has elapsed, meaning the worker did not receive + // any new tasks. Exit to allow the worker to return and free + // its resources. + case <-t.C: + return + + case <-w.quit: + return + } + } +} From d2eeee7a12abbfaf1fba21dd664493581c6529b4 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 21 Feb 2019 20:10:17 -0800 Subject: [PATCH 2/9] pool/write: adds Write pool --- pool/write.go | 100 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 pool/write.go diff --git a/pool/write.go b/pool/write.go new file mode 100644 index 00000000..1322a289 --- /dev/null +++ b/pool/write.go @@ -0,0 +1,100 @@ +package pool + +import ( + "bytes" + "time" + + "github.com/lightningnetwork/lnd/buffer" +) + +// Write is a worker pool specifically designed for sharing access to +// buffer.Write objects amongst a set of worker goroutines. This enables an +// application to limit the total number of buffer.Write objects allocated at +// any given time. +type Write struct { + workerPool *Worker + bufferPool *WriteBuffer +} + +// NewWrite creates a Write pool, using an underlying Writebuffer pool to +// recycle buffer.Write objects accross the lifetime of the Write pool's +// workers. +func NewWrite(writeBufferPool *WriteBuffer, numWorkers int, + workerTimeout time.Duration) *Write { + + w := &Write{ + bufferPool: writeBufferPool, + } + w.workerPool = NewWorker(&WorkerConfig{ + NewWorkerState: w.newWorkerState, + NumWorkers: numWorkers, + WorkerTimeout: workerTimeout, + }) + + return w +} + +// Start safely spins up the Write pool. +func (w *Write) Start() error { + return w.workerPool.Start() +} + +// Stop safely shuts down the Write pool. +func (w *Write) Stop() error { + return w.workerPool.Stop() +} + +// Submit accepts a function closure that provides access to a fresh +// bytes.Buffer backed by a buffer.Write object. The function's execution will +// be allocated to one of the underlying Worker pool's goroutines. +func (w *Write) Submit(inner func(*bytes.Buffer) error) error { + return w.workerPool.Submit(func(s WorkerState) error { + state := s.(*writeWorkerState) + return inner(state.buf) + }) +} + +// writeWorkerState is the per-goroutine state maintained by a Write pool's +// goroutines. +type writeWorkerState struct { + // bufferPool is the pool to which the writeBuf will be returned when + // the goroutine exits. + bufferPool *WriteBuffer + + // writeBuf is the buffer taken from the bufferPool on initialization, + // which will be used to back the buf object provided to any tasks that + // the goroutine processes before exiting. + writeBuf *buffer.Write + + // buf is a buffer backed by writeBuf, that can be written to by tasks + // submitted to the Write pool. The buf will be reset between each task + // processed by a goroutine before exiting, and allows the task + // submitters to interact with the writeBuf as if it were an io.Writer. + buf *bytes.Buffer +} + +// newWorkerState initializes a new writeWorkerState, which will be called +// whenever a new goroutine is allocated to begin processing write tasks. +func (w *Write) newWorkerState() WorkerState { + writeBuf := w.bufferPool.Take() + + return &writeWorkerState{ + bufferPool: w.bufferPool, + writeBuf: writeBuf, + buf: bytes.NewBuffer(writeBuf[0:0:len(writeBuf)]), + } +} + +// Cleanup returns the writeBuf to the underlying buffer pool, and removes the +// goroutine's reference to the readBuf and encapsulating buf. +func (w *writeWorkerState) Cleanup() { + w.bufferPool.Return(w.writeBuf) + w.writeBuf = nil + w.buf = nil +} + +// Reset resets the bytes.Buffer so that it is zero-length and has the capacity +// of the underlying buffer.Write.k +func (w *writeWorkerState) Reset() { + w.buf.Reset() +} From 32339a92d3b0e5fe37af21c4dc162cafa39f8c6a Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 21 Feb 2019 20:10:28 -0800 Subject: [PATCH 3/9] pool/read: adds Read pool --- pool/read.go | 87 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 pool/read.go diff --git a/pool/read.go b/pool/read.go new file mode 100644 index 00000000..171a0d2a --- /dev/null +++ b/pool/read.go @@ -0,0 +1,87 @@ +package pool + +import ( + "time" + + "github.com/lightningnetwork/lnd/buffer" +) + +// Read is a worker pool specifically designed for sharing access to buffer.Read +// objects amongst a set of worker goroutines. This enables an application to +// limit the total number of buffer.Read objects allocated at any given time. +type Read struct { + workerPool *Worker + bufferPool *ReadBuffer +} + +// NewRead creates a new Read pool, using an underlying ReadBuffer pool to +// recycle buffer.Read objects across the lifetime of the Read pool's workers. +func NewRead(readBufferPool *ReadBuffer, numWorkers int, + workerTimeout time.Duration) *Read { + + r := &Read{ + bufferPool: readBufferPool, + } + r.workerPool = NewWorker(&WorkerConfig{ + NewWorkerState: r.newWorkerState, + NumWorkers: numWorkers, + WorkerTimeout: workerTimeout, + }) + + return r +} + +// Start safely spins up the Read pool. +func (r *Read) Start() error { + return r.workerPool.Start() +} + +// Stop safely shuts down the Read pool. +func (r *Read) Stop() error { + return r.workerPool.Stop() +} + +// Submit accepts a function closure that provides access to the fresh +// buffer.Read object. The function's execution will be allocated to one of the +// underlying Worker pool's goroutines. +func (r *Read) Submit(inner func(*buffer.Read) error) error { + return r.workerPool.Submit(func(s WorkerState) error { + state := s.(*readWorkerState) + return inner(state.readBuf) + }) +} + +// readWorkerState is the per-goroutine state maintained by a Read pool's +// goroutines. +type readWorkerState struct { + // bufferPool is the pool to which the readBuf will be returned when the + // goroutine exits. + bufferPool *ReadBuffer + + // readBuf is a buffer taken from the bufferPool on initialization, + // which will be cleaned and provided to any tasks that the goroutine + // processes before exiting. + readBuf *buffer.Read +} + +// newWorkerState initializes a new readWorkerState, which will be called +// whenever a new goroutine is allocated to begin processing read tasks. +func (r *Read) newWorkerState() WorkerState { + return &readWorkerState{ + bufferPool: r.bufferPool, + readBuf: r.bufferPool.Take(), + } +} + +// Cleanup returns the readBuf to the underlying buffer pool, and removes the +// goroutine's reference to the readBuf. +func (r *readWorkerState) Cleanup() { + r.bufferPool.Return(r.readBuf) + r.readBuf = nil +} + +// Reset recycles the readBuf to make it ready for any subsequent tasks the +// goroutine may process. +func (r *readWorkerState) Reset() { + r.readBuf.Recycle() +} From ce1bd4be2ce645a12c28f5ea94ca39e985f01cc9 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 21 Feb 2019 20:10:40 -0800 Subject: [PATCH 4/9] pool/worker_test: add tests for concrete Worker pools --- pool/worker_test.go | 353 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 353 insertions(+) create mode 100644 pool/worker_test.go diff --git a/pool/worker_test.go b/pool/worker_test.go new file mode 100644 index 00000000..ee23e7a5 --- /dev/null +++ b/pool/worker_test.go @@ -0,0 +1,353 @@ +package pool_test + +import ( + "bytes" + crand "crypto/rand" + "fmt" + "io" + "math/rand" + "testing" + "time" + + "github.com/lightningnetwork/lnd/buffer" + "github.com/lightningnetwork/lnd/pool" +) + +type workerPoolTest struct { + name string + newPool func() interface{} + numWorkers int +} + +// TestConcreteWorkerPools asserts the behavior of any concrete implementations +// of worker pools provided by the pool package. Currently this tests the +// pool.Read and pool.Write instances. +func TestConcreteWorkerPools(t *testing.T) { + const ( + gcInterval = time.Second + expiryInterval = 250 * time.Millisecond + numWorkers = 5 + workerTimeout = 500 * time.Millisecond + ) + + tests := []workerPoolTest{ + { + name: "write pool", + newPool: func() interface{} { + bp := pool.NewWriteBuffer( + gcInterval, expiryInterval, + ) + + return pool.NewWrite( + bp, numWorkers, workerTimeout, + ) + }, + numWorkers: numWorkers, + }, + { + name: "read pool", + newPool: func() interface{} { + bp := pool.NewReadBuffer( + gcInterval, expiryInterval, + ) + + return pool.NewRead( + bp, numWorkers, workerTimeout, + ) + }, + numWorkers: numWorkers, + }, + } + + for _, test := range tests { + testWorkerPool(t, test) + } +} + +func testWorkerPool(t *testing.T, test workerPoolTest) { + t.Run(test.name+" non blocking", func(t *testing.T) { + t.Parallel() + + p := test.newPool() + startGeneric(t, p) + defer stopGeneric(t, p) + + submitNonblockingGeneric(t, p, test.numWorkers) + }) + + t.Run(test.name+" blocking", func(t *testing.T) { + t.Parallel() + + p := test.newPool() + startGeneric(t, p) + defer stopGeneric(t, p) + + submitBlockingGeneric(t, p, test.numWorkers) + }) + + t.Run(test.name+" partial blocking", func(t *testing.T) { + t.Parallel() + + p := test.newPool() + startGeneric(t, p) + defer stopGeneric(t, p) + + submitPartialBlockingGeneric(t, p, test.numWorkers) + }) +} + +// submitNonblockingGeneric asserts that queueing tasks to the worker pool and +// allowing them all to unblock simultaneously results in all of the tasks being +// completed in a timely manner. +func submitNonblockingGeneric(t *testing.T, p interface{}, nWorkers int) { + // We'll submit 2*nWorkers tasks that will all be unblocked + // simultaneously. + nUnblocked := 2 * nWorkers + + // First we'll queue all of the tasks for the pool. + errChan := make(chan error) + semChan := make(chan struct{}) + for i := 0; i < nUnblocked; i++ { + go func() { errChan <- submitGeneric(p, semChan) }() + } + + // Since we haven't signaled the semaphore, none of the them should + // complete. + pullNothing(t, errChan) + + // Now, unblock them all simultaneously. All of the tasks should then be + // processed in parallel. Afterward, no more errors should come through. + close(semChan) + pullParllel(t, nUnblocked, errChan) + pullNothing(t, errChan) +} + +// submitBlockingGeneric asserts that submitting blocking tasks to the pool and +// unblocking each sequentially results in a single task being processed at a +// time. +func submitBlockingGeneric(t *testing.T, p interface{}, nWorkers int) { + // We'll submit 2*nWorkers tasks that will be unblocked sequentially. + nBlocked := 2 * nWorkers + + // First, queue all of the blocking tasks for the pool. + errChan := make(chan error) + semChan := make(chan struct{}) + for i := 0; i < nBlocked; i++ { + go func() { errChan <- submitGeneric(p, semChan) }() + } + + // Since we haven't signaled the semaphore, none of them should + // complete. + pullNothing(t, errChan) + + // Now, pull each blocking task sequentially from the pool. Afterwards, + // no more errors should come through. + pullSequntial(t, nBlocked, errChan, semChan) + pullNothing(t, errChan) + +} + +// submitPartialBlockingGeneric tests that so long as one worker is not blocked, +// any other non-blocking submitted tasks can still be processed. +func submitPartialBlockingGeneric(t *testing.T, p interface{}, nWorkers int) { + // We'll submit nWorkers-1 tasks that will be initially blocked, the + // remainder will all be unblocked simultaneously. After the unblocked + // tasks have finished, we will sequentially unblock the nWorkers-1 + // tasks that were first submitted. + nBlocked := nWorkers - 1 + nUnblocked := 2*nWorkers - nBlocked + + // First, submit all of the blocking tasks to the pool. + errChan := make(chan error) + semChan := make(chan struct{}) + for i := 0; i < nBlocked; i++ { + go func() { errChan <- submitGeneric(p, semChan) }() + } + + // Since these are all blocked, no errors should be returned yet. + pullNothing(t, errChan) + + // Now, add all of the non-blocking task to the pool. + semChanNB := make(chan struct{}) + for i := 0; i < nUnblocked; i++ { + go func() { errChan <- submitGeneric(p, semChanNB) }() + } + + // Since we haven't unblocked the second batch, we again expect no tasks + // to finish. + pullNothing(t, errChan) + + // Now, unblock the unblocked task and pull all of them. After they have + // been pulled, we should see no more tasks. + close(semChanNB) + pullParllel(t, nUnblocked, errChan) + pullNothing(t, errChan) + + // Finally, unblock each the blocked tasks we added initially, and + // assert that no further errors come through. + pullSequntial(t, nBlocked, errChan, semChan) + pullNothing(t, errChan) +} + +func pullNothing(t *testing.T, errChan chan error) { + t.Helper() + + select { + case err := <-errChan: + t.Fatalf("received unexpected error before semaphore "+ + "release: %v", err) + + case <-time.After(time.Second): + } +} + +func pullParllel(t *testing.T, n int, errChan chan error) { + t.Helper() + + for i := 0; i < n; i++ { + select { + case err := <-errChan: + if err != nil { + t.Fatal(err) + } + + case <-time.After(time.Second): + t.Fatalf("task %d was not processed in time", i) + } + } +} + +func pullSequntial(t *testing.T, n int, errChan chan error, semChan chan struct{}) { + t.Helper() + + for i := 0; i < n; i++ { + // Signal for another task to unblock. + select { + case semChan <- struct{}{}: + case <-time.After(time.Second): + t.Fatalf("task %d was not unblocked", i) + } + + // Wait for the error to arrive, we expect it to be non-nil. + select { + case err := <-errChan: + if err != nil { + t.Fatal(err) + } + + case <-time.After(time.Second): + t.Fatalf("task %d was not processed in time", i) + } + } +} + +func startGeneric(t *testing.T, p interface{}) { + t.Helper() + + var err error + switch pp := p.(type) { + case *pool.Write: + err = pp.Start() + + case *pool.Read: + err = pp.Start() + + default: + t.Fatalf("unknown worker pool type: %T", p) + } + + if err != nil { + t.Fatalf("unable to start worker pool: %v", err) + } +} + +func stopGeneric(t *testing.T, p interface{}) { + t.Helper() + + var err error + switch pp := p.(type) { + case *pool.Write: + err = pp.Stop() + + case *pool.Read: + err = pp.Stop() + + default: + t.Fatalf("unknown worker pool type: %T", p) + } + + if err != nil { + t.Fatalf("unable to stop worker pool: %v", err) + } +} + +func submitGeneric(p interface{}, sem <-chan struct{}) error { + var err error + switch pp := p.(type) { + case *pool.Write: + err = pp.Submit(func(buf *bytes.Buffer) error { + // Verify that the provided buffer has been reset to be + // zero length. + if buf.Len() != 0 { + return fmt.Errorf("buf should be length zero, "+ + "instead has length %d", buf.Len()) + } + + // Verify that the capacity of the buffer has the + // correct underlying size of a buffer.WriteSize. + if buf.Cap() != buffer.WriteSize { + return fmt.Errorf("buf should have capacity "+ + "%d, instead has capacity %d", + buffer.WriteSize, buf.Cap()) + } + + // Sample some random bytes that we'll use to dirty the + // buffer. + b := make([]byte, rand.Intn(buf.Cap())) + _, err := io.ReadFull(crand.Reader, b) + if err != nil { + return err + } + + // Write the random bytes the buffer. + _, err = buf.Write(b) + + // Wait until this task is signaled to exit. + <-sem + + return err + }) + + case *pool.Read: + err = pp.Submit(func(buf *buffer.Read) error { + // Assert that all of the bytes in the provided array + // are zero, indicating that the buffer was reset + // between uses. + for i := range buf[:] { + if buf[i] != 0x00 { + return fmt.Errorf("byte %d of "+ + "buffer.Read should be "+ + "0, instead is %d", i, buf[i]) + } + } + + // Sample some random bytes to read into the buffer. + _, err := io.ReadFull(crand.Reader, buf[:]) + + // Wait until this task is signaled to exit. + <-sem + + return err + + }) + + default: + return fmt.Errorf("unknown worker pool type: %T", p) + } + + if err != nil { + return fmt.Errorf("unable to submit task: %v", err) + } + + return nil +} From 9a3c0b8bca53076f3aa586cf4ac1f1dfd99c18b6 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 21 Feb 2019 20:10:51 -0800 Subject: [PATCH 5/9] peer+server: switch to pool.Write from pool.WriteBuffer --- peer.go | 67 +++++++++++++++++++++++++------------------------------ server.go | 20 ++++++++++++----- 2 files changed, 45 insertions(+), 42 deletions(-) diff --git a/peer.go b/peer.go index 41a55892..34b44020 100644 --- a/peer.go +++ b/peer.go @@ -18,7 +18,6 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/brontide" - "github.com/lightningnetwork/lnd/buffer" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" @@ -26,6 +25,7 @@ import ( "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/pool" "github.com/lightningnetwork/lnd/ticker" ) @@ -209,11 +209,11 @@ type peer struct { // TODO(halseth): remove when link failure is properly handled. failedChannels map[lnwire.ChannelID]struct{} - // writeBuf is a buffer that we'll re-use in order to encode wire - // messages to write out directly on the socket. By re-using this - // buffer, we avoid needing to allocate more memory each time a new - // message is to be sent to a peer. - writeBuf *buffer.Write + // writePool is the task pool to that manages reuse of write buffers. + // Write tasks are submitted to the pool in order to conserve the total + // number of write buffers allocated at any one time, and decouple write + // buffer allocation from the peer life cycle. + writePool *pool.Write queueQuit chan struct{} quit chan struct{} @@ -258,7 +258,7 @@ func newPeer(conn net.Conn, connReq *connmgr.ConnReq, server *server, chanActiveTimeout: chanActiveTimeout, - writeBuf: server.writeBufferPool.Take(), + writePool: server.writePool, queueQuit: make(chan struct{}), quit: make(chan struct{}), @@ -608,11 +608,6 @@ func (p *peer) WaitForDisconnect(ready chan struct{}) { } p.wg.Wait() - - // Now that we are certain all active goroutines which could have been - // modifying the write buffer have exited, return the buffer to the pool - // to be reused. - p.server.writeBufferPool.Return(p.writeBuf) } // Disconnect terminates the connection with the remote peer. Additionally, a @@ -1359,33 +1354,33 @@ func (p *peer) writeMessage(msg lnwire.Message) error { p.logWireMessage(msg, false) - // We'll re-slice of static write buffer to allow this new message to - // utilize all available space. We also ensure we cap the capacity of - // this new buffer to the static buffer which is sized for the largest - // possible protocol message. - b := bytes.NewBuffer(p.writeBuf[0:0:len(p.writeBuf)]) + var n int + err := p.writePool.Submit(func(buf *bytes.Buffer) error { + // Using a buffer allocated by the write pool, encode the + // message directly into the buffer. + _, writeErr := lnwire.WriteMessage(buf, msg, 0) + if writeErr != nil { + return writeErr + } - // With the temp buffer created and sliced properly (length zero, full - // capacity), we'll now encode the message directly into this buffer. - _, err := lnwire.WriteMessage(b, msg, 0) - if err != nil { - return err + // Ensure the write deadline is set before we attempt to send + // the message. + writeDeadline := time.Now().Add(writeMessageTimeout) + writeErr = p.conn.SetWriteDeadline(writeDeadline) + if writeErr != nil { + return writeErr + } + + // Finally, write the message itself in a single swoop. + n, writeErr = p.conn.Write(buf.Bytes()) + return writeErr + }) + + // Record the number of bytes written on the wire, if any. + if n > 0 { + atomic.AddUint64(&p.bytesSent, uint64(n)) } - // Compute and set the write deadline we will impose on the remote peer. - writeDeadline := time.Now().Add(writeMessageTimeout) - err = p.conn.SetWriteDeadline(writeDeadline) - if err != nil { - return err - } - - // Finally, write the message itself in a single swoop. - n, err := p.conn.Write(b.Bytes()) - - // Regardless of the error returned, record how many bytes were written - // to the wire. - atomic.AddUint64(&p.bytesSent, uint64(n)) - return err } diff --git a/server.go b/server.go index 6b4cb7a4..1a501bae 100644 --- a/server.go +++ b/server.go @@ -171,7 +171,7 @@ type server struct { sigPool *lnwallet.SigPool - writeBufferPool *pool.WriteBuffer + writePool *pool.Write // globalFeatures feature vector which affects HTLCs and thus are also // advertised to other nodes. @@ -267,12 +267,15 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl, pool.DefaultWriteBufferGCInterval, pool.DefaultWriteBufferExpiryInterval, ) + writePool := pool.NewWrite( + writeBufferPool, runtime.NumCPU(), pool.DefaultWorkerTimeout, + ) s := &server{ - chanDB: chanDB, - cc: cc, - sigPool: lnwallet.NewSigPool(runtime.NumCPU()*2, cc.signer), - writeBufferPool: writeBufferPool, + chanDB: chanDB, + cc: cc, + sigPool: lnwallet.NewSigPool(runtime.NumCPU()*2, cc.signer), + writePool: writePool, invoices: invoices.NewRegistry(chanDB, activeNetParams.Params), @@ -1010,6 +1013,9 @@ func (s *server) Start() error { if err := s.sigPool.Start(); err != nil { return err } + if err := s.writePool.Start(); err != nil { + return err + } if err := s.cc.chainNotifier.Start(); err != nil { return err } @@ -1102,7 +1108,6 @@ func (s *server) Stop() error { // Shutdown the wallet, funding manager, and the rpc server. s.chanStatusMgr.Stop() - s.sigPool.Stop() s.cc.chainNotifier.Stop() s.chanRouter.Stop() s.htlcSwitch.Stop() @@ -1129,6 +1134,9 @@ func (s *server) Stop() error { // Wait for all lingering goroutines to quit. s.wg.Wait() + s.sigPool.Stop() + s.writePool.Stop() + return nil } From 93ce4a7575400c56e606fe8e09b78b6976909431 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 21 Feb 2019 20:11:05 -0800 Subject: [PATCH 6/9] brontide/noise: compose ReadMessage from ReadHeader+ReadBody --- brontide/noise.go | 84 ++++++++++++++++++++++------------------------- 1 file changed, 40 insertions(+), 44 deletions(-) diff --git a/brontide/noise.go b/brontide/noise.go index a01d518b..7b6bdb43 100644 --- a/brontide/noise.go +++ b/brontide/noise.go @@ -8,15 +8,12 @@ import ( "fmt" "io" "math" - "runtime" "time" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/hkdf" "github.com/btcsuite/btcd/btcec" - "github.com/lightningnetwork/lnd/buffer" - "github.com/lightningnetwork/lnd/pool" ) const ( @@ -60,14 +57,6 @@ var ( ephemeralGen = func() (*btcec.PrivateKey, error) { return btcec.NewPrivateKey(btcec.S256()) } - - // readBufferPool is a singleton instance of a buffer pool, used to - // conserve memory allocations due to read buffers across the entire - // brontide package. - readBufferPool = pool.NewReadBuffer( - pool.DefaultReadBufferGCInterval, - pool.DefaultReadBufferExpiryInterval, - ) ) // TODO(roasbeef): free buffer pool? @@ -378,15 +367,6 @@ type Machine struct { // next ciphertext header from the wire. The header is a 2 byte length // (of the next ciphertext), followed by a 16 byte MAC. nextCipherHeader [lengthHeaderSize + macSize]byte - - // nextCipherText is a static buffer that we'll use to read in the - // bytes of the next cipher text message. As all messages in the - // protocol MUST be below 65KB plus our macSize, this will be - // sufficient to buffer all messages from the socket when we need to - // read the next one. Having a fixed buffer that's re-used also means - // that we save on allocations as we don't need to create a new one - // each time. - nextCipherText *buffer.Read } // NewBrontideMachine creates a new instance of the brontide state-machine. If @@ -738,43 +718,59 @@ func (b *Machine) WriteMessage(w io.Writer, p []byte) error { // ReadMessage attempts to read the next message from the passed io.Reader. In // the case of an authentication error, a non-nil error is returned. func (b *Machine) ReadMessage(r io.Reader) ([]byte, error) { - if _, err := io.ReadFull(r, b.nextCipherHeader[:]); err != nil { + pktLen, err := b.ReadHeader(r) + if err != nil { return nil, err } + buf := make([]byte, pktLen) + return b.ReadBody(r, buf) +} + +// ReadHeader attempts to read the next message header from the passed +// io.Reader. The header contains the length of the next body including +// additional overhead of the MAC. In the case of an authentication error, a +// non-nil error is returned. +// +// NOTE: This method SHOULD NOT be used in the case that the io.Reader may be +// adversarial and induce long delays. If the caller needs to set read deadlines +// appropriately, it is preferred that they use the split ReadHeader and +// ReadBody methods so that the deadlines can be set appropriately on each. +func (b *Machine) ReadHeader(r io.Reader) (uint32, error) { + _, err := io.ReadFull(r, b.nextCipherHeader[:]) + if err != nil { + return 0, err + } + // Attempt to decrypt+auth the packet length present in the stream. pktLenBytes, err := b.recvCipher.Decrypt( nil, nil, b.nextCipherHeader[:], ) if err != nil { - return nil, err + return 0, err } - // If this is the first message being read, take a read buffer from the - // buffer pool. This is delayed until this point to avoid allocating - // read buffers until after the peer has successfully completed the - // handshake, and is ready to begin sending lnwire messages. - if b.nextCipherText == nil { - b.nextCipherText = readBufferPool.Take() - runtime.SetFinalizer(b, freeReadBuffer) - } - - // Next, using the length read from the packet header, read the - // encrypted packet itself. + // Compute the packet length that we will need to read off the wire. pktLen := uint32(binary.BigEndian.Uint16(pktLenBytes)) + macSize - if _, err := io.ReadFull(r, b.nextCipherText[:pktLen]); err != nil { + + return pktLen, nil +} + +// ReadBody attempts to ready the next message body from the passed io.Reader. +// The provided buffer MUST be the length indicated by the packet length +// returned by the preceding call to ReadHeader. In the case of an +// authentication eerror, a non-nil error is returned. +func (b *Machine) ReadBody(r io.Reader, buf []byte) ([]byte, error) { + // Next, using the length read from the packet header, read the + // encrypted packet itself into the buffer allocated by the read + // pool. + _, err := io.ReadFull(r, buf) + if err != nil { return nil, err } + // Finally, decrypt the message held in the buffer, and return a + // new byte slice containing the plaintext. // TODO(roasbeef): modify to let pass in slice - return b.recvCipher.Decrypt(nil, nil, b.nextCipherText[:pktLen]) -} - -// freeReadBuffer returns the Machine's read buffer back to the package wide -// read buffer pool. -// -// NOTE: This method should only be called by a Machine's finalizer. -func freeReadBuffer(b *Machine) { - readBufferPool.Return(b.nextCipherText) - b.nextCipherText = nil + return b.recvCipher.Decrypt(nil, nil, buf) } From 8ac8d95b544d1d13608f632323576c8c0794046f Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 21 Feb 2019 20:11:19 -0800 Subject: [PATCH 7/9] brontide/conn: expose ReadNextHeader+ReadNextBody --- brontide/conn.go | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/brontide/conn.go b/brontide/conn.go index 05f17be3..643ff860 100644 --- a/brontide/conn.go +++ b/brontide/conn.go @@ -104,13 +104,34 @@ func Dial(localPriv *btcec.PrivateKey, netAddr *lnwire.NetAddress, return b, nil } -// ReadNextMessage uses the connection in a message-oriented instructing it to -// read the next _full_ message with the brontide stream. This function will -// block until the read succeeds. +// ReadNextMessage uses the connection in a message-oriented manner, instructing +// it to read the next _full_ message with the brontide stream. This function +// will block until the read of the header and body succeeds. +// +// NOTE: This method SHOULD NOT be used in the case that the connection may be +// adversarial and induce long delays. If the caller needs to set read deadlines +// appropriately, it is preferred that they use the split ReadNextHeader and +// ReadNextBody methods so that the deadlines can be set appropriately on each. func (c *Conn) ReadNextMessage() ([]byte, error) { return c.noise.ReadMessage(c.conn) } +// ReadNextHeader uses the connection to read the next header from the brontide +// stream. This function will block until the read of the header succeeds and +// return the packet length (including MAC overhead) that is expected from the +// subsequent call to ReadNextBody. +func (c *Conn) ReadNextHeader() (uint32, error) { + return c.noise.ReadHeader(c.conn) +} + +// ReadNextBody uses the connection to read the next message body from the +// brontide stream. This function will block until the read of the body succeeds +// and return the decrypted payload. The provided buffer MUST be the packet +// length returned by the preceding call to ReadNextHeader. +func (c *Conn) ReadNextBody(buf []byte) ([]byte, error) { + return c.noise.ReadBody(c.conn, buf) +} + // Read reads data from the connection. Read can be made to time out and // return an Error with Timeout() == true after a fixed time limit; see // SetDeadline and SetReadDeadline. From 603601a4c8ea8151524d0f3d6f3ca0abfa0aef92 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 21 Feb 2019 20:11:33 -0800 Subject: [PATCH 8/9] peer+server: use peer-level readPool --- peer.go | 36 +++++++++++++++++++++++++++++++++++- server.go | 18 ++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/peer.go b/peer.go index 34b44020..91510852 100644 --- a/peer.go +++ b/peer.go @@ -18,6 +18,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/brontide" + "github.com/lightningnetwork/lnd/buffer" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" @@ -46,6 +47,10 @@ const ( // writeMessageTimeout is the timeout used when writing a message to peer. writeMessageTimeout = 50 * time.Second + // readMessageTimeout is the timeout used when reading a message from a + // peer. + readMessageTimeout = 5 * time.Second + // handshakeTimeout is the timeout used when waiting for peer init message. handshakeTimeout = 15 * time.Second @@ -215,6 +220,8 @@ type peer struct { // buffer allocation from the peer life cycle. writePool *pool.Write + readPool *pool.Read + queueQuit chan struct{} quit chan struct{} wg sync.WaitGroup @@ -259,6 +266,7 @@ func newPeer(conn net.Conn, connReq *connmgr.ConnReq, server *server, chanActiveTimeout: chanActiveTimeout, writePool: server.writePool, + readPool: server.readPool, queueQuit: make(chan struct{}), quit: make(chan struct{}), @@ -639,11 +647,37 @@ func (p *peer) readNextMessage() (lnwire.Message, error) { return nil, fmt.Errorf("brontide.Conn required to read messages") } + err := noiseConn.SetReadDeadline(time.Time{}) + if err != nil { + return nil, err + } + + pktLen, err := noiseConn.ReadNextHeader() + if err != nil { + return nil, err + } + // First we'll read the next _full_ message. We do this rather than // reading incrementally from the stream as the Lightning wire protocol // is message oriented and allows nodes to pad on additional data to // the message stream. - rawMsg, err := noiseConn.ReadNextMessage() + var rawMsg []byte + err = p.readPool.Submit(func(buf *buffer.Read) error { + // Before reading the body of the message, set the read timeout + // accordingly to ensure we don't block other readers using the + // pool. We do so only after the task has been scheduled to + // ensure the deadline doesn't expire while the message is in + // the process of being scheduled. + readDeadline := time.Now().Add(readMessageTimeout) + readErr := noiseConn.SetReadDeadline(readDeadline) + if readErr != nil { + return readErr + } + + rawMsg, readErr = noiseConn.ReadNextBody(buf[:pktLen]) + return readErr + }) + atomic.AddUint64(&p.bytesReceived, uint64(len(rawMsg))) if err != nil { return nil, err diff --git a/server.go b/server.go index 1a501bae..3385c8ce 100644 --- a/server.go +++ b/server.go @@ -173,6 +173,8 @@ type server struct { writePool *pool.Write + readPool *pool.Read + // globalFeatures feature vector which affects HTLCs and thus are also // advertised to other nodes. globalFeatures *lnwire.FeatureVector @@ -263,19 +265,31 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl, sharedSecretPath := filepath.Join(graphDir, "sphinxreplay.db") replayLog := htlcswitch.NewDecayedLog(sharedSecretPath, cc.chainNotifier) sphinxRouter := sphinx.NewRouter(privKey, activeNetParams.Params, replayLog) + writeBufferPool := pool.NewWriteBuffer( pool.DefaultWriteBufferGCInterval, pool.DefaultWriteBufferExpiryInterval, ) + writePool := pool.NewWrite( writeBufferPool, runtime.NumCPU(), pool.DefaultWorkerTimeout, ) + readBufferPool := pool.NewReadBuffer( + pool.DefaultReadBufferGCInterval, + pool.DefaultReadBufferExpiryInterval, + ) + + readPool := pool.NewRead( + readBufferPool, runtime.NumCPU(), pool.DefaultWorkerTimeout, + ) + s := &server{ chanDB: chanDB, cc: cc, sigPool: lnwallet.NewSigPool(runtime.NumCPU()*2, cc.signer), writePool: writePool, + readPool: readPool, invoices: invoices.NewRegistry(chanDB, activeNetParams.Params), @@ -1016,6 +1030,9 @@ func (s *server) Start() error { if err := s.writePool.Start(); err != nil { return err } + if err := s.readPool.Start(); err != nil { + return err + } if err := s.cc.chainNotifier.Start(); err != nil { return err } @@ -1136,6 +1153,7 @@ func (s *server) Stop() error { s.sigPool.Stop() s.writePool.Stop() + s.readPool.Stop() return nil } From db2c104111c3c12e1651120fe5a51ce03dad8c58 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 21 Feb 2019 20:11:47 -0800 Subject: [PATCH 9/9] peer: reduce write timeout to 10s --- peer.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/peer.go b/peer.go index 91510852..4b8973dc 100644 --- a/peer.go +++ b/peer.go @@ -44,8 +44,8 @@ const ( // idleTimeout is the duration of inactivity before we time out a peer. idleTimeout = 5 * time.Minute - // writeMessageTimeout is the timeout used when writing a message to peer. - writeMessageTimeout = 50 * time.Second + // writeMessageTimeout is the timeout used when writing a message to a peer. + writeMessageTimeout = 10 * time.Second // readMessageTimeout is the timeout used when reading a message from a // peer.