From f907fbcadc1339486ebc50176cc9a9e96a1e30d0 Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Mon, 6 Apr 2020 14:47:49 +0200 Subject: [PATCH] queue: detect close of incoming channel --- queue/queue.go | 27 +++++++++++++++++++++++++-- queue/queue_test.go | 22 ++++++++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/queue/queue.go b/queue/queue.go index e3b01b26..3c070205 100644 --- a/queue/queue.go +++ b/queue/queue.go @@ -58,6 +58,7 @@ func (cq *ConcurrentQueue) start() { go func() { defer cq.wg.Done() + readLoop: for { nextElement := cq.overflow.Front() if nextElement == nil { @@ -65,7 +66,10 @@ func (cq *ConcurrentQueue) start() { // directly to the output channel. If output channel is full // though, push to overflow. select { - case item := <-cq.chanIn: + case item, ok := <-cq.chanIn: + if !ok { + break readLoop + } select { case cq.chanOut <- item: // Optimistically push directly to chanOut @@ -79,7 +83,10 @@ func (cq *ConcurrentQueue) start() { // Overflow queue is not empty, so any new items get pushed to // the back to preserve order. select { - case item := <-cq.chanIn: + case item, ok := <-cq.chanIn: + if !ok { + break readLoop + } cq.overflow.PushBack(item) case cq.chanOut <- nextElement.Value: cq.overflow.Remove(nextElement) @@ -88,6 +95,22 @@ func (cq *ConcurrentQueue) start() { } } } + + // Incoming channel has been closed. Empty overflow queue into + // the outgoing channel. + nextElement := cq.overflow.Front() + for nextElement != nil { + select { + case cq.chanOut <- nextElement.Value: + cq.overflow.Remove(nextElement) + case <-cq.quit: + return + } + nextElement = cq.overflow.Front() + } + + // Close outgoing channel. + close(cq.chanOut) }() } diff --git a/queue/queue_test.go b/queue/queue_test.go index 9aee0cfb..bd74dcc0 100644 --- a/queue/queue_test.go +++ b/queue/queue_test.go @@ -63,3 +63,25 @@ func TestConcurrentQueueIdempotentStop(t *testing.T) { testQueueAddDrain(t, 100, 1, 10, 1000, 1000) } + +// TestQueueCloseIncoming tests that the queue properly handles an incoming +// channel that is closed. +func TestQueueCloseIncoming(t *testing.T) { + t.Parallel() + + queue := queue.NewConcurrentQueue(10) + queue.Start() + + queue.ChanIn() <- 1 + close(queue.ChanIn()) + + item := <-queue.ChanOut() + if item.(int) != 1 { + t.Fatalf("unexpected item") + } + + _, ok := <-queue.ChanOut() + if ok { + t.Fatalf("expected outgoing channel being closed") + } +}