diff --git a/chancloser.go b/chancloser.go index bce0e8a5..903d1fa9 100644 --- a/chancloser.go +++ b/chancloser.go @@ -73,7 +73,7 @@ type chanCloseCfg struct { // unregisterChannel is a function closure that allows the // channelCloser to re-register a channel. Once this has been done, no // further HTLC's should be routed through the channel. - unregisterChannel func(lnwire.ChannelID) error + unregisterChannel func(lnwire.ChannelID) // broadcastTx broadcasts the passed transaction to the network. broadcastTx func(*wire.MsgTx) error diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 8697624c..3112b8ed 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -117,9 +117,10 @@ type ChannelLinkConfig struct { Switch *Switch // ForwardPackets attempts to forward the batch of htlcs through the - // switch. Any failed packets will be returned to the provided - // ChannelLink. - ForwardPackets func(...*htlcPacket) chan error + // switch, any failed packets will be returned to the provided + // ChannelLink. The link's quit signal should be provided to allow + // cancellation of forwarding during link shutdown. + ForwardPackets func(chan struct{}, ...*htlcPacket) chan error // DecodeHopIterators facilitates batched decoding of HTLC Sphinx onion // blobs, which are then used to inform how to forward an HTLC. @@ -359,21 +360,6 @@ func (l *channelLink) Start() error { log.Infof("ChannelLink(%v) is starting", l) - // Before we start the link, we'll update the ChainArbitrator with the - // set of new channel signals for this channel. - // - // TODO(roasbeef): split goroutines within channel arb to avoid - go func() { - err := l.cfg.UpdateContractSignals(&contractcourt.ContractSignals{ - HtlcUpdates: l.htlcUpdates, - ShortChanID: l.channel.ShortChanID(), - }) - if err != nil { - log.Errorf("Unable to update signals for "+ - "ChannelLink(%v)", l) - } - }() - l.mailBox.ResetMessages() l.overflowQueue.Start() @@ -401,6 +387,24 @@ func (l *channelLink) Start() error { return fmt.Errorf("unable to trim circuits above "+ "local htlc index %d: %v", localHtlcIndex, err) } + + // Since the link is live, before we start the link we'll update + // the ChainArbitrator with the set of new channel signals for + // this channel. + // + // TODO(roasbeef): split goroutines within channel arb to avoid + go func() { + signals := &contractcourt.ContractSignals{ + HtlcUpdates: l.htlcUpdates, + ShortChanID: l.channel.ShortChanID(), + } + + err := l.cfg.UpdateContractSignals(signals) + if err != nil { + log.Errorf("Unable to update signals for "+ + "ChannelLink(%v)", l) + } + }() } l.updateFeeTimer = time.NewTimer(l.randomFeeUpdateTimeout()) @@ -2539,7 +2543,7 @@ func (l *channelLink) forwardBatch(packets ...*htlcPacket) { filteredPkts = append(filteredPkts, pkt) } - errChan := l.cfg.ForwardPackets(filteredPkts...) + errChan := l.cfg.ForwardPackets(l.quit, filteredPkts...) go l.handleBatchFwdErrs(errChan) } diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 540dd937..d261d8f6 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -3396,6 +3396,136 @@ func TestShouldAdjustCommitFee(t *testing.T) { } } +// TestChannelLinkShutdownDuringForward asserts that a link can be fully +// stopped when it is trying to send synchronously through the switch. The +// specific case this can occur is when a link forwards incoming Adds. We test +// this by forcing the switch into a state where it will not accept new packets, +// and then killing the link, which can only succeed if forwarding can be +// canceled by a call to Stop. +func TestChannelLinkShutdownDuringForward(t *testing.T) { + t.Parallel() + + // First, we'll create our traditional three hop network. We're + // interested in testing the ability to stop the link when it is + // synchronously forwarding to the switch, which happens when an + // incoming link forwards Adds. Thus, the test will be performed + // against Bob's first link. + channels, cleanUp, _, err := createClusterChannels( + btcutil.SatoshiPerBitcoin*3, + btcutil.SatoshiPerBitcoin*5) + if err != nil { + t.Fatalf("unable to create channel: %v", err) + } + defer cleanUp() + + n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, + channels.bobToCarol, channels.carolToBob, testStartingHeight) + + if err := n.start(); err != nil { + t.Fatal(err) + } + defer n.stop() + defer n.feeEstimator.Stop() + + // Define a helper method that strobes the switch's log ticker, and + // unblocks after nothing has been pulled for two seconds. + waitForBobsSwitchToBlock := func() { + bobSwitch := n.firstBobChannelLink.cfg.Switch + ticker := bobSwitch.cfg.LogEventTicker.(*ticker.Mock) + timeout := time.After(15 * time.Second) + for { + time.Sleep(50 * time.Millisecond) + select { + case ticker.Force <- time.Now(): + + case <-time.After(2 * time.Second): + return + + case <-timeout: + t.Fatalf("switch did not block") + } + } + } + + // Define a helper method that strobes the link's batch ticker, and + // unblocks after nothing has been pulled for two seconds. + waitForBobsIncomingLinkToBlock := func() { + ticker := n.firstBobChannelLink.cfg.BatchTicker.(*ticker.Mock) + timeout := time.After(15 * time.Second) + for { + time.Sleep(50 * time.Millisecond) + select { + case ticker.Force <- time.Now(): + + case <-time.After(2 * time.Second): + // We'll give a little extra time here, to + // ensure that the packet is being pressed + // against the htlcPlex. + time.Sleep(50 * time.Millisecond) + return + + case <-timeout: + t.Fatalf("link did not block") + } + } + } + + // To test that the cancellation is happening properly, we will set the + // switch's htlcPlex to nil, so that calls to routeAsync block, and can + // only exit if the link (or switch) is exiting. We will only be testing + // the link here. + // + // In order to avoid data races, we need to ensure the switch isn't + // selecting on that channel in the meantime. We'll prevent this by + // first acquiring the index mutex and forcing a log event so that the + // htlcForwarder is blocked inside the logTicker case, which also needs + // the indexMtx. + n.firstBobChannelLink.cfg.Switch.indexMtx.Lock() + + // Strobe the log ticker, and wait for switch to stop accepting any more + // log ticks. + waitForBobsSwitchToBlock() + + // While the htlcForwarder is blocked, swap out the htlcPlex with a nil + // channel, and unlock the indexMtx to allow return to the + // htlcForwarder's main select. After this, any attempt to forward + // through the switch will block. + n.firstBobChannelLink.cfg.Switch.htlcPlex = nil + n.firstBobChannelLink.cfg.Switch.indexMtx.Unlock() + + // Now, make a payment from Alice to Carol, which should cause Bob's + // incoming link to block when it tries to submit the packet to the nil + // htlcPlex. + amount := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin) + htlcAmt, totalTimelock, hops := generateHops( + amount, testStartingHeight, + n.firstBobChannelLink, n.carolChannelLink, + ) + + n.makePayment( + n.aliceServer, n.carolServer, n.bobServer.PubKey(), + hops, amount, htlcAmt, totalTimelock, + ) + + // Strobe the batch ticker of Bob's incoming link, waiting for it to + // become fully blocked. + waitForBobsIncomingLinkToBlock() + + // Finally, stop the link to test that it can exit while synchronously + // forwarding Adds to the switch. + done := make(chan struct{}) + go func() { + n.firstBobChannelLink.Stop() + close(done) + }() + + select { + case <-time.After(3 * time.Second): + t.Fatalf("unable to shutdown link while fwding incoming Adds") + case <-done: + } +} + // TestChannelLinkUpdateCommitFee tests that when a new block comes in, the // channel link properly checks to see if it should update the commitment fee. func TestChannelLinkUpdateCommitFee(t *testing.T) { @@ -3709,7 +3839,6 @@ func (h *persistentLinkHarness) restart(restartSwitch bool, // First, remove the link from the switch. h.coreLink.cfg.Switch.RemoveLink(h.link.ChanID()) - h.coreLink.WaitForShutdown() var htlcSwitch *Switch if restartSwitch { diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 8f19cfb3..ecb56a5f 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -527,12 +527,15 @@ func (s *Switch) forward(packet *htlcPacket) error { // ForwardPackets adds a list of packets to the switch for processing. Fails // and settles are added on a first past, simultaneously constructing circuits // for any adds. After persisting the circuits, another pass of the adds is -// given to forward them through the router. +// given to forward them through the router. The sending link's quit channel is +// used to prevent deadlocks when the switch stops a link in the midst of +// forwarding. // // NOTE: This method guarantees that the returned err chan will eventually be // closed. The receiver should read on the channel until receiving such a // signal. -func (s *Switch) ForwardPackets(packets ...*htlcPacket) chan error { +func (s *Switch) ForwardPackets(linkQuit chan struct{}, + packets ...*htlcPacket) chan error { var ( // fwdChan is a buffered channel used to receive err msgs from @@ -568,6 +571,9 @@ func (s *Switch) ForwardPackets(packets ...*htlcPacket) chan error { // so, we exit early to avoid incrementing the switch's waitgroup while // it is already in the process of shutting down. select { + case <-linkQuit: + close(errChan) + return errChan case <-s.quit: close(errChan) return errChan @@ -593,7 +599,10 @@ func (s *Switch) ForwardPackets(packets ...*htlcPacket) chan error { circuits = append(circuits, circuit) addBatch = append(addBatch, packet) default: - s.routeAsync(packet, fwdChan) + err := s.routeAsync(packet, fwdChan, linkQuit) + if err != nil { + return errChan + } numSent++ } } @@ -635,7 +644,10 @@ func (s *Switch) ForwardPackets(packets ...*htlcPacket) chan error { // Now, forward any packets for circuits that were successfully added to // the switch's circuit map. for _, packet := range addedPackets { - s.routeAsync(packet, fwdChan) + err := s.routeAsync(packet, fwdChan, linkQuit) + if err != nil { + return errChan + } numSent++ } @@ -722,9 +734,13 @@ func (s *Switch) route(packet *htlcPacket) error { } // routeAsync sends a packet through the htlc switch, using the provided err -// chan to propagate errors back to the caller. This method does not wait for -// a response before returning. -func (s *Switch) routeAsync(packet *htlcPacket, errChan chan error) error { +// chan to propagate errors back to the caller. The link's quit channel is +// provided so that the send can be canceled if either the link or the switch +// receive a shutdown requuest. This method does not wait for a response from +// the htlcForwarder before returning. +func (s *Switch) routeAsync(packet *htlcPacket, errChan chan error, + linkQuit chan struct{}) error { + command := &plexPacket{ pkt: packet, err: errChan, @@ -733,6 +749,8 @@ func (s *Switch) routeAsync(packet *htlcPacket, errChan chan error) error { select { case s.htlcPlex <- command: return nil + case <-linkQuit: + return ErrLinkShuttingDown case <-s.quit: return errors.New("Htlc Switch was stopped") } @@ -1380,21 +1398,34 @@ func (s *Switch) htlcForwarder() { s.blockEpochStream.Cancel() // Remove all links once we've been signalled for shutdown. + var linksToStop []ChannelLink s.indexMtx.Lock() for _, link := range s.linkIndex { - if err := s.removeLink(link.ChanID()); err != nil { - log.Errorf("unable to remove "+ - "channel link on stop: %v", err) + activeLink := s.removeLink(link.ChanID()) + if activeLink == nil { + log.Errorf("unable to remove ChannelLink(%v) "+ + "on stop", link.ChanID()) + continue } + linksToStop = append(linksToStop, activeLink) } for _, link := range s.pendingLinkIndex { - if err := s.removeLink(link.ChanID()); err != nil { - log.Errorf("unable to remove pending "+ - "channel link on stop: %v", err) + pendingLink := s.removeLink(link.ChanID()) + if pendingLink == nil { + log.Errorf("unable to remove ChannelLink(%v) "+ + "on stop", link.ChanID()) + continue } + linksToStop = append(linksToStop, pendingLink) } s.indexMtx.Unlock() + // Now that all pending and live links have been removed from + // the forwarding indexes, stop each one before shutting down. + for _, link := range linksToStop { + link.Stop() + } + // Before we exit fully, we'll attempt to flush out any // forwarding events that may still be lingering since the last // batch flush. @@ -1721,7 +1752,10 @@ func (s *Switch) reforwardSettleFails(fwdPkgs []*channeldb.FwdPkg) { } } - errChan := s.ForwardPackets(switchPackets...) + // Since this send isn't tied to a specific link, we pass a nil + // link quit channel, meaning the send will fail only if the + // switch receives a shutdown request. + errChan := s.ForwardPackets(nil, switchPackets...) go handleBatchFwdErrs(errChan) } } @@ -1776,11 +1810,11 @@ func (s *Switch) AddLink(link ChannelLink) error { chanID := link.ChanID() - // If a link already exists, then remove the prior one so we can - // replace it with this fresh instance. + // First, ensure that this link is not already active in the switch. _, err := s.getLink(chanID) if err == nil { - s.removeLink(chanID) + return fmt.Errorf("unable to add ChannelLink(%v), already "+ + "active", chanID) } // Get and attach the mailbox for this link, which buffers packets in @@ -1868,24 +1902,28 @@ func (s *Switch) getLinkByShortID(chanID lnwire.ShortChannelID) (ChannelLink, er return link, nil } -// RemoveLink is used to initiate the handling of the remove link command. The -// request will be propagated/handled to/in the main goroutine. -func (s *Switch) RemoveLink(chanID lnwire.ChannelID) error { +// RemoveLink purges the switch of any link associated with chanID. If a pending +// or active link is not found, this method does nothing. Otherwise, the method +// returns after the link has been completely shutdown. +func (s *Switch) RemoveLink(chanID lnwire.ChannelID) { s.indexMtx.Lock() - defer s.indexMtx.Unlock() + link := s.removeLink(chanID) + s.indexMtx.Unlock() - return s.removeLink(chanID) + if link != nil { + link.Stop() + } } // removeLink is used to remove and stop the channel link. // // NOTE: This MUST be called with the indexMtx held. -func (s *Switch) removeLink(chanID lnwire.ChannelID) error { +func (s *Switch) removeLink(chanID lnwire.ChannelID) ChannelLink { log.Infof("Removing channel link with ChannelID(%v)", chanID) link, err := s.getLink(chanID) if err != nil { - return err + return nil } // Remove the channel from live link indexes. @@ -1906,9 +1944,7 @@ func (s *Switch) removeLink(chanID lnwire.ChannelID) error { } } - go link.Stop() - - return nil + return link } // UpdateShortChanID updates the short chan ID for an existing channel. This is diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 904113e4..ca607e82 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -25,6 +25,63 @@ func genPreimage() ([32]byte, error) { return preimage, nil } +// TestSwitchAddDuplicateLink tests that the switch will reject duplicate links +// for both pending and live links. It also tests that we can successfully +// add a link after having removed it. +func TestSwitchAddDuplicateLink(t *testing.T) { + t.Parallel() + + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) + if err != nil { + t.Fatalf("unable to create alice server: %v", err) + } + + s, err := initSwitchWithDB(testStartingHeight, nil) + if err != nil { + t.Fatalf("unable to init switch: %v", err) + } + if err := s.Start(); err != nil { + t.Fatalf("unable to start switch: %v", err) + } + defer s.Stop() + + chanID1, _, aliceChanID, _ := genIDs() + + pendingChanID := lnwire.ShortChannelID{} + + aliceChannelLink := newMockChannelLink( + s, chanID1, pendingChanID, alicePeer, false, + ) + if err := s.AddLink(aliceChannelLink); err != nil { + t.Fatalf("unable to add alice link: %v", err) + } + + // Alice should have a pending link, adding again should fail. + if err := s.AddLink(aliceChannelLink); err == nil { + t.Fatalf("adding duplicate link should have failed") + } + + // Update the short chan id of the channel, so that the link goes live. + aliceChannelLink.setLiveShortChanID(aliceChanID) + err = s.UpdateShortChanID(chanID1) + if err != nil { + t.Fatalf("unable to update alice short_chan_id: %v", err) + } + + // Alice should have a live link, adding again should fail. + if err := s.AddLink(aliceChannelLink); err == nil { + t.Fatalf("adding duplicate link should have failed") + } + + // Remove the live link to ensure the indexes are cleared. + s.RemoveLink(chanID1) + + // Alice has no links, adding should succeed. + if err := s.AddLink(aliceChannelLink); err != nil { + t.Fatalf("unable to add alice link: %v", err) + } +} + // TestSwitchSendPending checks the inability of htlc switch to forward adds // over pending links, and the UpdateShortChanID makes a pending link live. func TestSwitchSendPending(t *testing.T) { diff --git a/peer.go b/peer.go index 48374c5d..74ff4c74 100644 --- a/peer.go +++ b/peer.go @@ -469,11 +469,7 @@ func (p *peer) addLink(chanPoint *wire.OutPoint, // mailboxes such that we can safely force close // without the link being added again and updates being // applied. - err := p.server.htlcSwitch.RemoveLink(chanID) - if err != nil { - peerLog.Errorf("unable to stop link(%v): %v", - shortChanID, err) - } + p.server.htlcSwitch.RemoveLink(chanID) // If the error encountered was severe enough, we'll // now force close the channel. @@ -557,6 +553,12 @@ func (p *peer) addLink(chanPoint *wire.OutPoint, link := htlcswitch.NewChannelLink(linkCfg, lnChan) + // Before adding our new link, purge the switch of any pending or live + // links going by the same channel id. If one is found, we'll shut it + // down to ensure that the mailboxes are only ever under the control of + // one link. + p.server.htlcSwitch.RemoveLink(link.ChanID()) + // With the channel link created, we'll now notify the htlc switch so // this channel can be used to dispatch local payments and also // passively forward payments. @@ -1526,8 +1528,8 @@ out: ) if err != nil { peerLog.Errorf("can't register new channel "+ - "link(%v) with NodeKey(%x): %v", chanPoint, - p.PubKey(), err) + "link(%v) with NodeKey(%x)", chanPoint, + p.PubKey()) } close(newChanReq.done) @@ -1922,14 +1924,7 @@ func (p *peer) WipeChannel(chanPoint *wire.OutPoint) error { // Instruct the HtlcSwitch to close this link as the channel is no // longer active. - if err := p.server.htlcSwitch.RemoveLink(chanID); err != nil { - if err == htlcswitch.ErrChannelLinkNotFound { - peerLog.Warnf("unable remove channel link with "+ - "ChannelPoint(%v): %v", chanID, err) - return nil - } - return err - } + p.server.htlcSwitch.RemoveLink(chanID) return nil } diff --git a/server.go b/server.go index 565f66d6..2729fa76 100644 --- a/server.go +++ b/server.go @@ -646,7 +646,8 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl, ChainIO: cc.chainIO, MarkLinkInactive: func(chanPoint wire.OutPoint) error { chanID := lnwire.NewChanIDFromOutPoint(&chanPoint) - return s.htlcSwitch.RemoveLink(chanID) + s.htlcSwitch.RemoveLink(chanID) + return nil }, IsOurAddress: func(addr btcutil.Address) bool { _, err := cc.wallet.GetPrivKey(addr) @@ -1960,11 +1961,7 @@ func (s *server) peerTerminationWatcher(p *peer) { } for _, link := range links { - err := p.server.htlcSwitch.RemoveLink(link.ChanID()) - if err != nil { - srvrLog.Errorf("unable to remove channel link: %v", - err) - } + p.server.htlcSwitch.RemoveLink(link.ChanID()) } s.mu.Lock()