diff --git a/htlcswitch/mailbox.go b/htlcswitch/mailbox.go index 04c5cd6d..e32d75a7 100644 --- a/htlcswitch/mailbox.go +++ b/htlcswitch/mailbox.go @@ -78,8 +78,9 @@ type memoryMailBox struct { pktOutbox chan *htlcPacket pktReset chan chan struct{} - wg sync.WaitGroup - quit chan struct{} + wireShutdown chan struct{} + pktShutdown chan struct{} + quit chan struct{} } // newMemoryMailBox creates a new instance of the memoryMailBox. @@ -92,6 +93,8 @@ func newMemoryMailBox() *memoryMailBox { msgReset: make(chan chan struct{}, 1), pktReset: make(chan chan struct{}, 1), pktIndex: make(map[CircuitKey]*list.Element), + wireShutdown: make(chan struct{}), + pktShutdown: make(chan struct{}), quit: make(chan struct{}), } box.wireCond = sync.NewCond(&box.wireMtx) @@ -122,7 +125,6 @@ const ( // NOTE: This method is part of the MailBox interface. func (m *memoryMailBox) Start() { m.started.Do(func() { - m.wg.Add(2) go m.mailCourier(wireCourier) go m.mailCourier(pktCourier) }) @@ -157,6 +159,7 @@ func (m *memoryMailBox) signalUntilReset(cType courierType, done chan struct{}) error { for { + switch cType { case wireCourier: m.wireCond.Signal() @@ -209,17 +212,49 @@ func (m *memoryMailBox) Stop() { m.stopped.Do(func() { close(m.quit) - m.wireCond.Signal() - m.pktCond.Signal() + m.signalUntilShutdown(wireCourier) + m.signalUntilShutdown(pktCourier) }) } +// signalUntilShutdown strobes the condition variable of the passed courier +// type, blocking until the worker has exited. +func (m *memoryMailBox) signalUntilShutdown(cType courierType) { + var ( + cond *sync.Cond + shutdown chan struct{} + ) + + switch cType { + case wireCourier: + cond = m.wireCond + shutdown = m.wireShutdown + case pktCourier: + cond = m.pktCond + shutdown = m.pktShutdown + } + + for { + select { + case <-time.After(time.Millisecond): + cond.Signal() + case <-shutdown: + return + } + } +} + // mailCourier is a dedicated goroutine whose job is to reliably deliver // messages of a particular type. There are two types of couriers: wire // couriers, and mail couriers. Depending on the passed courierType, this // goroutine will assume one of two roles. func (m *memoryMailBox) mailCourier(cType courierType) { - defer m.wg.Done() + switch cType { + case wireCourier: + defer close(m.wireShutdown) + case pktCourier: + defer close(m.pktShutdown) + } // TODO(roasbeef): refactor... diff --git a/htlcswitch/mailbox_test.go b/htlcswitch/mailbox_test.go index e8356c97..83dfc282 100644 --- a/htlcswitch/mailbox_test.go +++ b/htlcswitch/mailbox_test.go @@ -148,6 +148,28 @@ func TestMailBoxCouriers(t *testing.T) { } } +// TestMailBoxResetAfterShutdown tests that ResetMessages and ResetPackets +// return ErrMailBoxShuttingDown after the mailbox has been stopped. +func TestMailBoxResetAfterShutdown(t *testing.T) { + t.Parallel() + + m := newMemoryMailBox() + m.Start() + + // Stop the mailbox, then try to reset the message and packet couriers. + m.Stop() + + err := m.ResetMessages() + if err != ErrMailBoxShuttingDown { + t.Fatalf("expected ErrMailBoxShuttingDown, got: %v", err) + } + + err = m.ResetPackets() + if err != ErrMailBoxShuttingDown { + t.Fatalf("expected ErrMailBoxShuttingDown, got: %v", err) + } +} + // TestMailOrchestrator asserts that the orchestrator properly buffers packets // for channels that haven't been made live, such that they are delivered // immediately after BindLiveShortChanID. It also tests that packets are delivered