diff --git a/brontide/conn.go b/brontide/conn.go index e5bfecc1..05f17be3 100644 --- a/brontide/conn.go +++ b/brontide/conn.go @@ -61,7 +61,11 @@ func Dial(localPriv *btcec.PrivateKey, netAddr *lnwire.NetAddress, // We'll ensure that we get ActTwo from the remote peer in a timely // manner. If they don't respond within 1s, then we'll kill the // connection. - conn.SetReadDeadline(time.Now().Add(handshakeReadTimeout)) + err = conn.SetReadDeadline(time.Now().Add(handshakeReadTimeout)) + if err != nil { + b.conn.Close() + return nil, err + } // If the first act was successful (we know that address is actually // remotePub), then read the second act after which we'll be able to @@ -91,7 +95,11 @@ func Dial(localPriv *btcec.PrivateKey, netAddr *lnwire.NetAddress, // We'll reset the deadline as it's no longer critical beyond the // initial handshake. - conn.SetReadDeadline(time.Time{}) + err = conn.SetReadDeadline(time.Time{}) + if err != nil { + b.conn.Close() + return nil, err + } return b, nil } diff --git a/brontide/listener.go b/brontide/listener.go index 8fd09a1b..95505ecf 100644 --- a/brontide/listener.go +++ b/brontide/listener.go @@ -116,7 +116,12 @@ func (l *Listener) doHandshake(conn net.Conn) { // We'll ensure that we get ActOne from the remote peer in a timely // manner. If they don't respond within 1s, then we'll kill the // connection. - conn.SetReadDeadline(time.Now().Add(handshakeReadTimeout)) + err := conn.SetReadDeadline(time.Now().Add(handshakeReadTimeout)) + if err != nil { + brontideConn.conn.Close() + l.rejectConn(rejectedConnErr(err, remoteAddr)) + return + } // Attempt to carry out the first act of the handshake protocol. If the // connecting node doesn't know our long-term static public key, then @@ -156,7 +161,12 @@ func (l *Listener) doHandshake(conn net.Conn) { // We'll ensure that we get ActTwo from the remote peer in a timely // manner. If they don't respond within 1 second, then we'll kill the // connection. - conn.SetReadDeadline(time.Now().Add(handshakeReadTimeout)) + err = conn.SetReadDeadline(time.Now().Add(handshakeReadTimeout)) + if err != nil { + brontideConn.conn.Close() + l.rejectConn(rejectedConnErr(err, remoteAddr)) + return + } // Finally, finish the handshake processes by reading and decrypting // the connection peer's static public key. If this succeeds then both @@ -175,7 +185,12 @@ func (l *Listener) doHandshake(conn net.Conn) { // We'll reset the deadline as it's no longer critical beyond the // initial handshake. - conn.SetReadDeadline(time.Time{}) + err = conn.SetReadDeadline(time.Time{}) + if err != nil { + brontideConn.conn.Close() + l.rejectConn(rejectedConnErr(err, remoteAddr)) + return + } l.acceptConn(brontideConn) } diff --git a/brontide/noise.go b/brontide/noise.go index 9a87f424..a01d518b 100644 --- a/brontide/noise.go +++ b/brontide/noise.go @@ -8,12 +8,15 @@ import ( "fmt" "io" "math" + "runtime" "time" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/hkdf" "github.com/btcsuite/btcd/btcec" + "github.com/lightningnetwork/lnd/buffer" + "github.com/lightningnetwork/lnd/pool" ) const ( @@ -47,6 +50,24 @@ var ( // the cipher session exceeds the maximum allowed message payload. ErrMaxMessageLengthExceeded = errors.New("the generated payload exceeds " + "the max allowed message length of (2^16)-1") + + // lightningPrologue is the noise prologue that is used to initialize + // the brontide noise handshake. + lightningPrologue = []byte("lightning") + + // ephemeralGen is the default ephemeral key generator, used to derive a + // unique ephemeral key for each brontide handshake. + ephemeralGen = func() (*btcec.PrivateKey, error) { + return btcec.NewPrivateKey(btcec.S256()) + } + + // readBufferPool is a singleton instance of a buffer pool, used to + // conserve memory allocations due to read buffers across the entire + // brontide package. + readBufferPool = pool.NewReadBuffer( + pool.DefaultReadBufferGCInterval, + pool.DefaultReadBufferExpiryInterval, + ) ) // TODO(roasbeef): free buffer pool? @@ -365,7 +386,7 @@ type Machine struct { // read the next one. Having a fixed buffer that's re-used also means // that we save on allocations as we don't need to create a new one // each time. - nextCipherText [math.MaxUint16 + macSize]byte + nextCipherText *buffer.Read } // NewBrontideMachine creates a new instance of the brontide state-machine. If @@ -377,15 +398,13 @@ type Machine struct { func NewBrontideMachine(initiator bool, localPub *btcec.PrivateKey, remotePub *btcec.PublicKey, options ...func(*Machine)) *Machine { - handshake := newHandshakeState(initiator, []byte("lightning"), localPub, - remotePub) + handshake := newHandshakeState( + initiator, lightningPrologue, localPub, remotePub, + ) - m := &Machine{handshakeState: handshake} - - // With the initial base machine created, we'll assign our default - // version of the ephemeral key generator. - m.ephemeralGen = func() (*btcec.PrivateKey, error) { - return btcec.NewPrivateKey(btcec.S256()) + m := &Machine{ + handshakeState: handshake, + ephemeralGen: ephemeralGen, } // With the default options established, we'll now process all the @@ -731,6 +750,15 @@ func (b *Machine) ReadMessage(r io.Reader) ([]byte, error) { return nil, err } + // If this is the first message being read, take a read buffer from the + // buffer pool. This is delayed until this point to avoid allocating + // read buffers until after the peer has successfully completed the + // handshake, and is ready to begin sending lnwire messages. + if b.nextCipherText == nil { + b.nextCipherText = readBufferPool.Take() + runtime.SetFinalizer(b, freeReadBuffer) + } + // Next, using the length read from the packet header, read the // encrypted packet itself. pktLen := uint32(binary.BigEndian.Uint16(pktLenBytes)) + macSize @@ -741,3 +769,12 @@ func (b *Machine) ReadMessage(r io.Reader) ([]byte, error) { // TODO(roasbeef): modify to let pass in slice return b.recvCipher.Decrypt(nil, nil, b.nextCipherText[:pktLen]) } + +// freeReadBuffer returns the Machine's read buffer back to the package wide +// read buffer pool. +// +// NOTE: This method should only be called by a Machine's finalizer. +func freeReadBuffer(b *Machine) { + readBufferPool.Return(b.nextCipherText) + b.nextCipherText = nil +} diff --git a/buffer/buffer_test.go b/buffer/buffer_test.go new file mode 100644 index 00000000..754da0e6 --- /dev/null +++ b/buffer/buffer_test.go @@ -0,0 +1,44 @@ +package buffer_test + +import ( + "bytes" + "testing" + + "github.com/lightningnetwork/lnd/buffer" +) + +// TestRecycleSlice asserts that RecycleSlice always zeros a byte slice. +func TestRecycleSlice(t *testing.T) { + tests := []struct { + name string + slice []byte + }{ + { + name: "length zero", + }, + { + name: "length one", + slice: []byte("a"), + }, + { + name: "length power of two length", + slice: bytes.Repeat([]byte("b"), 16), + }, + { + name: "length non power of two", + slice: bytes.Repeat([]byte("c"), 27), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + buffer.RecycleSlice(test.slice) + + expSlice := make([]byte, len(test.slice)) + if !bytes.Equal(expSlice, test.slice) { + t.Fatalf("slice not recycled, want: %v, got: %v", + expSlice, test.slice) + } + }) + } +} diff --git a/buffer/read.go b/buffer/read.go new file mode 100644 index 00000000..57050d58 --- /dev/null +++ b/buffer/read.go @@ -0,0 +1,19 @@ +package buffer + +import ( + "github.com/lightningnetwork/lnd/lnwire" +) + +// ReadSize represents the size of the maximum message that can be read off the +// wire by brontide. The buffer is used to hold the ciphertext while the +// brontide state machine decrypts the message. +const ReadSize = lnwire.MaxMessagePayload + 16 + +// Read is a static byte array sized to the maximum-allowed Lightning message +// size, plus 16 bytes for the MAC. +type Read [ReadSize]byte + +// Recycle zeroes the Read, making it fresh for another use. +func (b *Read) Recycle() { + RecycleSlice(b[:]) +} diff --git a/buffer/utils.go b/buffer/utils.go new file mode 100644 index 00000000..40a386a9 --- /dev/null +++ b/buffer/utils.go @@ -0,0 +1,17 @@ +package buffer + +// RecycleSlice zeroes byte slice, making it fresh for another use. +// Zeroing the buffer using a logarithmic number of calls to the optimized copy +// method. Benchmarking shows this to be ~30 times faster than a for loop that +// sets each index to 0 for ~65KB buffers use for wire messages. Inspired by: +// https://stackoverflow.com/questions/30614165/is-there-analog-of-memset-in-go +func RecycleSlice(b []byte) { + if len(b) == 0 { + return + } + + b[0] = 0 + for i := 1; i < len(b); i *= 2 { + copy(b[i:], b[:i]) + } +} diff --git a/buffer/write.go b/buffer/write.go new file mode 100644 index 00000000..47010c29 --- /dev/null +++ b/buffer/write.go @@ -0,0 +1,19 @@ +package buffer + +import ( + "github.com/lightningnetwork/lnd/lnwire" +) + +// WriteSize represents the size of the maximum plaintext message than can be +// sent using brontide. The buffer does not include extra space for the MAC, as +// that is applied by the Noise protocol after encrypting the plaintext. +const WriteSize = lnwire.MaxMessagePayload + +// Write is static byte array occupying to maximum-allowed plaintext-message +// size. +type Write [WriteSize]byte + +// Recycle zeroes the Write, making it fresh for another use. +func (b *Write) Recycle() { + RecycleSlice(b[:]) +} diff --git a/lnpeer/write_buffer_pool.go b/lnpeer/write_buffer_pool.go deleted file mode 100644 index ec266ed9..00000000 --- a/lnpeer/write_buffer_pool.go +++ /dev/null @@ -1,79 +0,0 @@ -package lnpeer - -import ( - "time" - - "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/queue" -) - -const ( - // DefaultGCInterval is the default interval that the WriteBufferPool - // will perform a sweep to see which expired buffers can be released to - // the runtime. - DefaultGCInterval = 15 * time.Second - - // DefaultExpiryInterval is the default, minimum interval that must - // elapse before a WriteBuffer will be released. The maximum time before - // the buffer can be released is equal to the expiry interval plus the - // gc interval. - DefaultExpiryInterval = 30 * time.Second -) - -// WriteBuffer is static byte array occupying to maximum-allowed -// plaintext-message size. -type WriteBuffer [lnwire.MaxMessagePayload]byte - -// Recycle zeroes the WriteBuffer, making it fresh for another use. -// Zeroing the buffer using a logarithmic number of calls to the optimized copy -// method. Benchmarking shows this to be ~30 times faster than a for loop that -// sets each index to 0 for this buffer size. Inspired by: -// https://stackoverflow.com/questions/30614165/is-there-analog-of-memset-in-go -// -// This is part of the queue.Recycler interface. -func (b *WriteBuffer) Recycle() { - b[0] = 0 - for i := 1; i < lnwire.MaxMessagePayload; i *= 2 { - copy(b[i:], b[:i]) - } -} - -// newRecyclableWriteBuffer is a constructor that returns a WriteBuffer typed as -// a queue.Recycler. -func newRecyclableWriteBuffer() queue.Recycler { - return new(WriteBuffer) -} - -// A compile-time constraint to ensure that *WriteBuffer implements the -// queue.Recycler interface. -var _ queue.Recycler = (*WriteBuffer)(nil) - -// WriteBufferPool acts a global pool of WriteBuffers, that dynamically -// allocates and reclaims buffers in response to load. -type WriteBufferPool struct { - pool *queue.GCQueue -} - -// NewWriteBufferPool returns a freshly instantiated WriteBufferPool, using the -// given gcInterval and expiryIntervals. -func NewWriteBufferPool( - gcInterval, expiryInterval time.Duration) *WriteBufferPool { - - return &WriteBufferPool{ - pool: queue.NewGCQueue( - newRecyclableWriteBuffer, 100, - gcInterval, expiryInterval, - ), - } -} - -// Take returns a fresh WriteBuffer to the caller. -func (p *WriteBufferPool) Take() *WriteBuffer { - return p.pool.Take().(*WriteBuffer) -} - -// Return returns the WriteBuffer to the pool, so that it can be recycled or -// released. -func (p *WriteBufferPool) Return(buf *WriteBuffer) { - p.pool.Return(buf) -} diff --git a/lnpeer/write_buffer_pool_test.go b/lnpeer/write_buffer_pool_test.go deleted file mode 100644 index 7580de6e..00000000 --- a/lnpeer/write_buffer_pool_test.go +++ /dev/null @@ -1,67 +0,0 @@ -package lnpeer_test - -import ( - "testing" - "time" - - "github.com/lightningnetwork/lnd/lnpeer" -) - -// TestWriteBufferPool verifies that buffer pool properly resets used write -// buffers. -func TestWriteBufferPool(t *testing.T) { - const ( - gcInterval = time.Second - expiryInterval = 250 * time.Millisecond - ) - - bp := lnpeer.NewWriteBufferPool(gcInterval, expiryInterval) - - // Take a fresh write buffer from the pool. - writeBuf := bp.Take() - - // Dirty the write buffer. - for i := range writeBuf[:] { - writeBuf[i] = 0xff - } - - // Return the buffer to the pool. - bp.Return(writeBuf) - - // Take buffers from the pool until we find the original. We expect at - // most two, in the even that a fresh buffer is populated after the - // first is taken. - for i := 0; i < 2; i++ { - // Wait a small duration to ensure the tests behave reliable, - // and don't activate the non-blocking case unintentionally. - <-time.After(time.Millisecond) - - // Take a buffer, skipping those whose pointer does not match - // the one we dirtied. - writeBuf2 := bp.Take() - if writeBuf2 != writeBuf { - continue - } - - // Finally, verify that the buffer has been properly cleaned. - for i := range writeBuf2[:] { - if writeBuf2[i] != 0 { - t.Fatalf("buffer was not recycled") - } - } - - return - } - - t.Fatalf("original buffer not found") -} - -// BenchmarkWriteBufferRecycle tests how quickly a WriteBuffer can be zeroed. -func BenchmarkWriteBufferRecycle(b *testing.B) { - b.ReportAllocs() - - buffer := new(lnpeer.WriteBuffer) - for i := 0; i < b.N; i++ { - buffer.Recycle() - } -} diff --git a/peer.go b/peer.go index 976804e9..41a55892 100644 --- a/peer.go +++ b/peer.go @@ -18,6 +18,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/brontide" + "github.com/lightningnetwork/lnd/buffer" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" @@ -212,7 +213,7 @@ type peer struct { // messages to write out directly on the socket. By re-using this // buffer, we avoid needing to allocate more memory each time a new // message is to be sent to a peer. - writeBuf *lnpeer.WriteBuffer + writeBuf *buffer.Write queueQuit chan struct{} quit chan struct{} diff --git a/pool/read_buffer.go b/pool/read_buffer.go new file mode 100644 index 00000000..58f6014e --- /dev/null +++ b/pool/read_buffer.go @@ -0,0 +1,48 @@ +package pool + +import ( + "time" + + "github.com/lightningnetwork/lnd/buffer" +) + +const ( + // DefaultReadBufferGCInterval is the default interval that a Read will + // perform a sweep to see which expired buffer.Reads can be released to + // the runtime. + DefaultReadBufferGCInterval = 15 * time.Second + + // DefaultReadBufferExpiryInterval is the default, minimum interval that + // must elapse before a Read will release a buffer.Read. The maximum + // time before the buffer can be released is equal to the expiry + // interval plus the gc interval. + DefaultReadBufferExpiryInterval = 30 * time.Second +) + +// ReadBuffer is a pool of buffer.Read items, that dynamically allocates and +// reclaims buffers in response to load. +type ReadBuffer struct { + pool *Recycle +} + +// NewReadBuffer returns a freshly instantiated ReadBuffer, using the given +// gcInterval and expieryInterval. +func NewReadBuffer(gcInterval, expiryInterval time.Duration) *ReadBuffer { + return &ReadBuffer{ + pool: NewRecycle( + func() interface{} { return new(buffer.Read) }, + 100, gcInterval, expiryInterval, + ), + } +} + +// Take returns a fresh buffer.Read to the caller. +func (p *ReadBuffer) Take() *buffer.Read { + return p.pool.Take().(*buffer.Read) +} + +// Return returns the buffer.Read to the pool, so that it can be cycled or +// released. +func (p *ReadBuffer) Return(buf *buffer.Read) { + p.pool.Return(buf) +} diff --git a/pool/recycle.go b/pool/recycle.go new file mode 100644 index 00000000..4d13674b --- /dev/null +++ b/pool/recycle.go @@ -0,0 +1,52 @@ +package pool + +import ( + "time" + + "github.com/lightningnetwork/lnd/queue" +) + +// Recycler is an interface that allows an object to be reclaimed without +// needing to be returned to the runtime. +type Recycler interface { + // Recycle resets the object to its default state. + Recycle() +} + +// Recycle is a generic queue for recycling objects implementing the Recycler +// interface. It is backed by an underlying queue.GCQueue, and invokes the +// Recycle method on returned objects before returning them to the queue. +type Recycle struct { + queue *queue.GCQueue +} + +// NewRecycle initializes a fresh Recycle instance. +func NewRecycle(newItem func() interface{}, returnQueueSize int, + gcInterval, expiryInterval time.Duration) *Recycle { + + return &Recycle{ + queue: queue.NewGCQueue( + newItem, returnQueueSize, + gcInterval, expiryInterval, + ), + } +} + +// Take returns an element from the pool. +func (r *Recycle) Take() interface{} { + return r.queue.Take() +} + +// Return returns an item implementing the Recycler interface to the pool. The +// Recycle method is invoked before returning the item to improve performance +// and utilization under load. +func (r *Recycle) Return(item Recycler) { + // Recycle the item to ensure that a dirty instance is never offered + // from Take. The call is done here so that the CPU cycles spent + // clearing the buffer are owned by the caller, and not by the queue + // itself. This makes the queue more likely to be available to deliver + // items in the free list. + item.Recycle() + + r.queue.Return(item) +} diff --git a/pool/recycle_test.go b/pool/recycle_test.go new file mode 100644 index 00000000..2750da00 --- /dev/null +++ b/pool/recycle_test.go @@ -0,0 +1,217 @@ +package pool_test + +import ( + "bytes" + "testing" + "time" + + "github.com/lightningnetwork/lnd/buffer" + "github.com/lightningnetwork/lnd/pool" +) + +type mockRecycler bool + +func (m *mockRecycler) Recycle() { + *m = false +} + +// TestRecyclers verifies that known recyclable types properly return to their +// zero-value after invoking Recycle. +func TestRecyclers(t *testing.T) { + tests := []struct { + name string + newItem func() interface{} + }{ + { + "mock recycler", + func() interface{} { return new(mockRecycler) }, + }, + { + "write_buffer", + func() interface{} { return new(buffer.Write) }, + }, + { + "read_buffer", + func() interface{} { return new(buffer.Read) }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Initialize the Recycler to test. + r := test.newItem().(pool.Recycler) + + // Dirty the item. + dirtyGeneric(t, r) + + // Invoke Recycle to clear the item. + r.Recycle() + + // Assert the item is now clean. + isCleanGeneric(t, r) + }) + } +} + +type recyclePoolTest struct { + name string + newPool func() interface{} +} + +// TestGenericRecyclePoolTests generically tests that pools derived from the +// base Recycle pool properly are properly configured. +func TestConcreteRecyclePoolTests(t *testing.T) { + const ( + gcInterval = time.Second + expiryInterval = 250 * time.Millisecond + ) + + tests := []recyclePoolTest{ + { + name: "write buffer pool", + newPool: func() interface{} { + return pool.NewWriteBuffer( + gcInterval, expiryInterval, + ) + }, + }, + { + name: "read buffer pool", + newPool: func() interface{} { + return pool.NewReadBuffer( + gcInterval, expiryInterval, + ) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testRecyclePool(t, test) + }) + } +} + +func testRecyclePool(t *testing.T, test recyclePoolTest) { + p := test.newPool() + + // Take an item from the pool. + r1 := takeGeneric(t, p) + + // Dirty the item. + dirtyGeneric(t, r1) + + // Return the item to the pool. + returnGeneric(t, p, r1) + + // Take items from the pool until we find the original. We expect at + // most two, in the event that a fresh item is populated after the + // first is taken. + for i := 0; i < 2; i++ { + // Wait a small duration to ensure the tests are reliable, and + // don't to active the non-blocking case unintentionally. + <-time.After(time.Millisecond) + + r2 := takeGeneric(t, p) + + // Take an item, skipping those whose pointer does not match the + // one we dirtied. + if r1 != r2 { + continue + } + + // Finally, verify that the item has been properly cleaned. + isCleanGeneric(t, r2) + + return + } + + t.Fatalf("original item not found") +} + +func takeGeneric(t *testing.T, p interface{}) pool.Recycler { + t.Helper() + + switch pp := p.(type) { + case *pool.WriteBuffer: + return pp.Take() + + case *pool.ReadBuffer: + return pp.Take() + + default: + t.Fatalf("unknown pool type: %T", p) + } + + return nil +} + +func returnGeneric(t *testing.T, p, item interface{}) { + t.Helper() + + switch pp := p.(type) { + case *pool.WriteBuffer: + pp.Return(item.(*buffer.Write)) + + case *pool.ReadBuffer: + pp.Return(item.(*buffer.Read)) + + default: + t.Fatalf("unknown pool type: %T", p) + } +} + +func dirtyGeneric(t *testing.T, i interface{}) { + t.Helper() + + switch item := i.(type) { + case *mockRecycler: + *item = true + + case *buffer.Write: + dirtySlice(item[:]) + + case *buffer.Read: + dirtySlice(item[:]) + + default: + t.Fatalf("unknown item type: %T", i) + } + +} + +func dirtySlice(slice []byte) { + for i := range slice { + slice[i] = 0xff + } +} + +func isCleanGeneric(t *testing.T, i interface{}) { + t.Helper() + + switch item := i.(type) { + case *mockRecycler: + if isDirty := *item; isDirty { + t.Fatalf("mock recycler still diry") + } + + case *buffer.Write: + isCleanSlice(t, item[:]) + + case *buffer.Read: + isCleanSlice(t, item[:]) + + default: + t.Fatalf("unknown item type: %T", i) + } +} + +func isCleanSlice(t *testing.T, slice []byte) { + t.Helper() + + expSlice := make([]byte, len(slice)) + if !bytes.Equal(expSlice, slice) { + t.Fatalf("slice not recycled, want: %v, got: %v", + expSlice, slice) + } +} diff --git a/pool/write_buffer.go b/pool/write_buffer.go new file mode 100644 index 00000000..dfcfd9f0 --- /dev/null +++ b/pool/write_buffer.go @@ -0,0 +1,48 @@ +package pool + +import ( + "time" + + "github.com/lightningnetwork/lnd/buffer" +) + +const ( + // DefaultWriteBufferGCInterval is the default interval that a Write + // will perform a sweep to see which expired buffer.Writes can be + // released to the runtime. + DefaultWriteBufferGCInterval = 15 * time.Second + + // DefaultWriteBufferExpiryInterval is the default, minimum interval + // that must elapse before a Write will release a buffer.Write. The + // maximum time before the buffer can be released is equal to the expiry + // interval plus the gc interval. + DefaultWriteBufferExpiryInterval = 30 * time.Second +) + +// WriteBuffer is a pool of recycled buffer.Write items, that dynamically +// allocates and reclaims buffers in response to load. +type WriteBuffer struct { + pool *Recycle +} + +// NewWriteBuffer returns a freshly instantiated WriteBuffer, using the given +// gcInterval and expiryIntervals. +func NewWriteBuffer(gcInterval, expiryInterval time.Duration) *WriteBuffer { + return &WriteBuffer{ + pool: NewRecycle( + func() interface{} { return new(buffer.Write) }, + 100, gcInterval, expiryInterval, + ), + } +} + +// Take returns a fresh buffer.Write to the caller. +func (p *WriteBuffer) Take() *buffer.Write { + return p.pool.Take().(*buffer.Write) +} + +// Return returns the buffer.Write to the pool, so that it can be recycled or +// released. +func (p *WriteBuffer) Return(buf *buffer.Write) { + p.pool.Return(buf) +} diff --git a/queue/gc_queue.go b/queue/gc_queue.go index 41366c33..7698f324 100644 --- a/queue/gc_queue.go +++ b/queue/gc_queue.go @@ -8,21 +8,6 @@ import ( "github.com/lightningnetwork/lnd/ticker" ) -// Recycler is an interface that allows an object to be reclaimed without -// needing to be returned to the runtime. -type Recycler interface { - // Recycle resets the object to its default state. - Recycle() -} - -// gcQueueEntry is a tuple containing a Recycler and the time at which the item -// was added to the queue. The recorded time is used to determine when the entry -// becomes stale, and can be released if it has not already been taken. -type gcQueueEntry struct { - item Recycler - time time.Time -} - // GCQueue is garbage collecting queue, which dynamically grows and contracts // based on load. If the queue has items which have been returned, the queue // will check every gcInterval amount of time to see if any elements are @@ -36,15 +21,15 @@ type gcQueueEntry struct { type GCQueue struct { // takeBuffer coordinates the delivery of items taken from the queue // such that they are delivered to requesters. - takeBuffer chan Recycler + takeBuffer chan interface{} // returnBuffer coordinates the return of items back into the queue, // where they will be kept until retaken or released. - returnBuffer chan Recycler + returnBuffer chan interface{} // newItem is a constructor, used to generate new elements if none are // otherwise available for reuse. - newItem func() Recycler + newItem func() interface{} // expiryInterval is the minimum amount of time an element will remain // in the queue before being released. @@ -75,12 +60,12 @@ type GCQueue struct { // the steady state. The returnQueueSize parameter is used to size the maximal // number of items that can be returned without being dropped during large // bursts in attempts to return items to the GCQUeue. -func NewGCQueue(newItem func() Recycler, returnQueueSize int, +func NewGCQueue(newItem func() interface{}, returnQueueSize int, gcInterval, expiryInterval time.Duration) *GCQueue { q := &GCQueue{ - takeBuffer: make(chan Recycler), - returnBuffer: make(chan Recycler, returnQueueSize), + takeBuffer: make(chan interface{}), + returnBuffer: make(chan interface{}, returnQueueSize), expiryInterval: expiryInterval, freeList: list.New(), recycleTicker: ticker.New(gcInterval), @@ -95,7 +80,7 @@ func NewGCQueue(newItem func() Recycler, returnQueueSize int, // Take returns either a recycled element from the queue, or creates a new item // if none are available. -func (q *GCQueue) Take() Recycler { +func (q *GCQueue) Take() interface{} { select { case item := <-q.takeBuffer: return item @@ -107,20 +92,21 @@ func (q *GCQueue) Take() Recycler { // Return adds the returned item to freelist if the queue's returnBuffer has // available capacity. Under load, items may be dropped to ensure this method // does not block. -func (q *GCQueue) Return(item Recycler) { - // Recycle the item to ensure that a dirty instance is never offered - // from Take. The call is done here so that the CPU cycles spent - // clearing the buffer are owned by the caller, and not by the queue - // itself. This makes the queue more likely to be available to deliver - // items in the free list. - item.Recycle() - +func (q *GCQueue) Return(item interface{}) { select { case q.returnBuffer <- item: default: } } +// gcQueueEntry is a tuple containing an interface{} and the time at which the +// item was added to the queue. The recorded time is used to determine when the +// entry becomes stale, and can be released if it has not already been taken. +type gcQueueEntry struct { + item interface{} + time time.Time +} + // queueManager maintains the free list of elements by popping the head of the // queue when items are needed, and appending them to the end of the queue when // items are returned. The queueManager will periodically attempt to release any @@ -190,20 +176,20 @@ func (q *GCQueue) queueManager() { next = e.Next() entry := e.Value.(gcQueueEntry) - // Use now - insertTime > expiryInterval to - // determine if this entry has expired. - if time.Since(entry.time) > q.expiryInterval { - // Remove the expired entry from the - // linked-list. - q.freeList.Remove(e) - entry.item = nil - e.Value = nil - } else { + // Use now - insertTime <= expiryInterval to + // determine if this entry has not expired. + if time.Since(entry.time) <= q.expiryInterval { // If this entry hasn't expired, then // all entries that follow will still be // valid. break } + + // Otherwise, remove the expired entry from the + // linked-list. + q.freeList.Remove(e) + entry.item = nil + e.Value = nil } } } diff --git a/queue/gc_queue_test.go b/queue/gc_queue_test.go index cda0f998..3fc5d3b2 100644 --- a/queue/gc_queue_test.go +++ b/queue/gc_queue_test.go @@ -7,10 +7,8 @@ import ( "github.com/lightningnetwork/lnd/queue" ) -// mockRecycler implements the queue.Recycler interface using a NOP. -type mockRecycler bool - -func (*mockRecycler) Recycle() {} +// testItem is an item type we'll be using to test the GCQueue. +type testItem uint32 // TestGCQueueGCCycle asserts that items that are kept in the GCQueue past their // expiration will be released by a subsequent gc cycle. @@ -23,7 +21,7 @@ func TestGCQueueGCCycle(t *testing.T) { numItems = 6 ) - newItem := func() queue.Recycler { return new(mockRecycler) } + newItem := func() interface{} { return new(testItem) } bp := queue.NewGCQueue(newItem, 100, gcInterval, expiryInterval) @@ -61,7 +59,7 @@ func TestGCQueuePartialGCCycle(t *testing.T) { numItems = 6 ) - newItem := func() queue.Recycler { return new(mockRecycler) } + newItem := func() interface{} { return new(testItem) } bp := queue.NewGCQueue(newItem, 100, gcInterval, expiryInterval) @@ -104,10 +102,10 @@ func TestGCQueuePartialGCCycle(t *testing.T) { // takeN draws n items from the provided GCQueue. This method also asserts that // n unique items are drawn, and then returns the resulting set. -func takeN(t *testing.T, q *queue.GCQueue, n int) map[queue.Recycler]struct{} { +func takeN(t *testing.T, q *queue.GCQueue, n int) map[interface{}]struct{} { t.Helper() - items := make(map[queue.Recycler]struct{}) + items := make(map[interface{}]struct{}) for i := 0; i < n; i++ { // Wait a small duration to ensure the tests behave reliable, // and don't activate the non-blocking case unintentionally. @@ -125,7 +123,7 @@ func takeN(t *testing.T, q *queue.GCQueue, n int) map[queue.Recycler]struct{} { } // returnAll returns the items of the given set back to the GCQueue. -func returnAll(q *queue.GCQueue, items map[queue.Recycler]struct{}) { +func returnAll(q *queue.GCQueue, items map[interface{}]struct{}) { for item := range items { q.Return(item) @@ -138,11 +136,11 @@ func returnAll(q *queue.GCQueue, items map[queue.Recycler]struct{}) { // returnN returns n items at random from the set of items back to the GCQueue. // This method fails if the set's cardinality is smaller than n. func returnN(t *testing.T, q *queue.GCQueue, - items map[queue.Recycler]struct{}, n int) map[queue.Recycler]struct{} { + items map[interface{}]struct{}, n int) map[interface{}]struct{} { t.Helper() - var remainingItems = make(map[queue.Recycler]struct{}) + var remainingItems = make(map[interface{}]struct{}) var numReturned int for item := range items { if numReturned < n { diff --git a/server.go b/server.go index 4d2b526d..b32b1034 100644 --- a/server.go +++ b/server.go @@ -41,6 +41,7 @@ import ( "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/nat" "github.com/lightningnetwork/lnd/netann" + "github.com/lightningnetwork/lnd/pool" "github.com/lightningnetwork/lnd/routing" "github.com/lightningnetwork/lnd/sweep" "github.com/lightningnetwork/lnd/ticker" @@ -170,7 +171,7 @@ type server struct { sigPool *lnwallet.SigPool - writeBufferPool *lnpeer.WriteBufferPool + writeBufferPool *pool.WriteBuffer // globalFeatures feature vector which affects HTLCs and thus are also // advertised to other nodes. @@ -262,8 +263,9 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl, sharedSecretPath := filepath.Join(graphDir, "sphinxreplay.db") replayLog := htlcswitch.NewDecayedLog(sharedSecretPath, cc.chainNotifier) sphinxRouter := sphinx.NewRouter(privKey, activeNetParams.Params, replayLog) - writeBufferPool := lnpeer.NewWriteBufferPool( - lnpeer.DefaultGCInterval, lnpeer.DefaultExpiryInterval, + writeBufferPool := pool.NewWriteBuffer( + pool.DefaultWriteBufferGCInterval, + pool.DefaultWriteBufferExpiryInterval, ) s := &server{