Merge pull request #1551 from cfromknecht/switch-revert-replace-link

[htlcswitch]: revert replace link, ensure removed links are stopped
This commit is contained in:
Olaoluwa Osuntokun 2018-08-13 21:44:42 -07:00 committed by GitHub
commit 7a113d469b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 287 additions and 69 deletions

@ -73,7 +73,7 @@ type chanCloseCfg struct {
// unregisterChannel is a function closure that allows the // unregisterChannel is a function closure that allows the
// channelCloser to re-register a channel. Once this has been done, no // channelCloser to re-register a channel. Once this has been done, no
// further HTLC's should be routed through the channel. // further HTLC's should be routed through the channel.
unregisterChannel func(lnwire.ChannelID) error unregisterChannel func(lnwire.ChannelID)
// broadcastTx broadcasts the passed transaction to the network. // broadcastTx broadcasts the passed transaction to the network.
broadcastTx func(*wire.MsgTx) error broadcastTx func(*wire.MsgTx) error

@ -117,9 +117,10 @@ type ChannelLinkConfig struct {
Switch *Switch Switch *Switch
// ForwardPackets attempts to forward the batch of htlcs through the // ForwardPackets attempts to forward the batch of htlcs through the
// switch. Any failed packets will be returned to the provided // switch, any failed packets will be returned to the provided
// ChannelLink. // ChannelLink. The link's quit signal should be provided to allow
ForwardPackets func(...*htlcPacket) chan error // cancellation of forwarding during link shutdown.
ForwardPackets func(chan struct{}, ...*htlcPacket) chan error
// DecodeHopIterators facilitates batched decoding of HTLC Sphinx onion // DecodeHopIterators facilitates batched decoding of HTLC Sphinx onion
// blobs, which are then used to inform how to forward an HTLC. // blobs, which are then used to inform how to forward an HTLC.
@ -359,21 +360,6 @@ func (l *channelLink) Start() error {
log.Infof("ChannelLink(%v) is starting", l) log.Infof("ChannelLink(%v) is starting", l)
// Before we start the link, we'll update the ChainArbitrator with the
// set of new channel signals for this channel.
//
// TODO(roasbeef): split goroutines within channel arb to avoid
go func() {
err := l.cfg.UpdateContractSignals(&contractcourt.ContractSignals{
HtlcUpdates: l.htlcUpdates,
ShortChanID: l.channel.ShortChanID(),
})
if err != nil {
log.Errorf("Unable to update signals for "+
"ChannelLink(%v)", l)
}
}()
l.mailBox.ResetMessages() l.mailBox.ResetMessages()
l.overflowQueue.Start() l.overflowQueue.Start()
@ -401,6 +387,24 @@ func (l *channelLink) Start() error {
return fmt.Errorf("unable to trim circuits above "+ return fmt.Errorf("unable to trim circuits above "+
"local htlc index %d: %v", localHtlcIndex, err) "local htlc index %d: %v", localHtlcIndex, err)
} }
// Since the link is live, before we start the link we'll update
// the ChainArbitrator with the set of new channel signals for
// this channel.
//
// TODO(roasbeef): split goroutines within channel arb to avoid
go func() {
signals := &contractcourt.ContractSignals{
HtlcUpdates: l.htlcUpdates,
ShortChanID: l.channel.ShortChanID(),
}
err := l.cfg.UpdateContractSignals(signals)
if err != nil {
log.Errorf("Unable to update signals for "+
"ChannelLink(%v)", l)
}
}()
} }
l.updateFeeTimer = time.NewTimer(l.randomFeeUpdateTimeout()) l.updateFeeTimer = time.NewTimer(l.randomFeeUpdateTimeout())
@ -2539,7 +2543,7 @@ func (l *channelLink) forwardBatch(packets ...*htlcPacket) {
filteredPkts = append(filteredPkts, pkt) filteredPkts = append(filteredPkts, pkt)
} }
errChan := l.cfg.ForwardPackets(filteredPkts...) errChan := l.cfg.ForwardPackets(l.quit, filteredPkts...)
go l.handleBatchFwdErrs(errChan) go l.handleBatchFwdErrs(errChan)
} }

@ -3396,6 +3396,136 @@ func TestShouldAdjustCommitFee(t *testing.T) {
} }
} }
// TestChannelLinkShutdownDuringForward asserts that a link can be fully
// stopped when it is trying to send synchronously through the switch. The
// specific case this can occur is when a link forwards incoming Adds. We test
// this by forcing the switch into a state where it will not accept new packets,
// and then killing the link, which can only succeed if forwarding can be
// canceled by a call to Stop.
func TestChannelLinkShutdownDuringForward(t *testing.T) {
t.Parallel()
// First, we'll create our traditional three hop network. We're
// interested in testing the ability to stop the link when it is
// synchronously forwarding to the switch, which happens when an
// incoming link forwards Adds. Thus, the test will be performed
// against Bob's first link.
channels, cleanUp, _, err := createClusterChannels(
btcutil.SatoshiPerBitcoin*3,
btcutil.SatoshiPerBitcoin*5)
if err != nil {
t.Fatalf("unable to create channel: %v", err)
}
defer cleanUp()
n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice,
channels.bobToCarol, channels.carolToBob, testStartingHeight)
if err := n.start(); err != nil {
t.Fatal(err)
}
defer n.stop()
defer n.feeEstimator.Stop()
// Define a helper method that strobes the switch's log ticker, and
// unblocks after nothing has been pulled for two seconds.
waitForBobsSwitchToBlock := func() {
bobSwitch := n.firstBobChannelLink.cfg.Switch
ticker := bobSwitch.cfg.LogEventTicker.(*ticker.Mock)
timeout := time.After(15 * time.Second)
for {
time.Sleep(50 * time.Millisecond)
select {
case ticker.Force <- time.Now():
case <-time.After(2 * time.Second):
return
case <-timeout:
t.Fatalf("switch did not block")
}
}
}
// Define a helper method that strobes the link's batch ticker, and
// unblocks after nothing has been pulled for two seconds.
waitForBobsIncomingLinkToBlock := func() {
ticker := n.firstBobChannelLink.cfg.BatchTicker.(*ticker.Mock)
timeout := time.After(15 * time.Second)
for {
time.Sleep(50 * time.Millisecond)
select {
case ticker.Force <- time.Now():
case <-time.After(2 * time.Second):
// We'll give a little extra time here, to
// ensure that the packet is being pressed
// against the htlcPlex.
time.Sleep(50 * time.Millisecond)
return
case <-timeout:
t.Fatalf("link did not block")
}
}
}
// To test that the cancellation is happening properly, we will set the
// switch's htlcPlex to nil, so that calls to routeAsync block, and can
// only exit if the link (or switch) is exiting. We will only be testing
// the link here.
//
// In order to avoid data races, we need to ensure the switch isn't
// selecting on that channel in the meantime. We'll prevent this by
// first acquiring the index mutex and forcing a log event so that the
// htlcForwarder is blocked inside the logTicker case, which also needs
// the indexMtx.
n.firstBobChannelLink.cfg.Switch.indexMtx.Lock()
// Strobe the log ticker, and wait for switch to stop accepting any more
// log ticks.
waitForBobsSwitchToBlock()
// While the htlcForwarder is blocked, swap out the htlcPlex with a nil
// channel, and unlock the indexMtx to allow return to the
// htlcForwarder's main select. After this, any attempt to forward
// through the switch will block.
n.firstBobChannelLink.cfg.Switch.htlcPlex = nil
n.firstBobChannelLink.cfg.Switch.indexMtx.Unlock()
// Now, make a payment from Alice to Carol, which should cause Bob's
// incoming link to block when it tries to submit the packet to the nil
// htlcPlex.
amount := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin)
htlcAmt, totalTimelock, hops := generateHops(
amount, testStartingHeight,
n.firstBobChannelLink, n.carolChannelLink,
)
n.makePayment(
n.aliceServer, n.carolServer, n.bobServer.PubKey(),
hops, amount, htlcAmt, totalTimelock,
)
// Strobe the batch ticker of Bob's incoming link, waiting for it to
// become fully blocked.
waitForBobsIncomingLinkToBlock()
// Finally, stop the link to test that it can exit while synchronously
// forwarding Adds to the switch.
done := make(chan struct{})
go func() {
n.firstBobChannelLink.Stop()
close(done)
}()
select {
case <-time.After(3 * time.Second):
t.Fatalf("unable to shutdown link while fwding incoming Adds")
case <-done:
}
}
// TestChannelLinkUpdateCommitFee tests that when a new block comes in, the // TestChannelLinkUpdateCommitFee tests that when a new block comes in, the
// channel link properly checks to see if it should update the commitment fee. // channel link properly checks to see if it should update the commitment fee.
func TestChannelLinkUpdateCommitFee(t *testing.T) { func TestChannelLinkUpdateCommitFee(t *testing.T) {
@ -3709,7 +3839,6 @@ func (h *persistentLinkHarness) restart(restartSwitch bool,
// First, remove the link from the switch. // First, remove the link from the switch.
h.coreLink.cfg.Switch.RemoveLink(h.link.ChanID()) h.coreLink.cfg.Switch.RemoveLink(h.link.ChanID())
h.coreLink.WaitForShutdown()
var htlcSwitch *Switch var htlcSwitch *Switch
if restartSwitch { if restartSwitch {

@ -527,12 +527,15 @@ func (s *Switch) forward(packet *htlcPacket) error {
// ForwardPackets adds a list of packets to the switch for processing. Fails // ForwardPackets adds a list of packets to the switch for processing. Fails
// and settles are added on a first past, simultaneously constructing circuits // and settles are added on a first past, simultaneously constructing circuits
// for any adds. After persisting the circuits, another pass of the adds is // for any adds. After persisting the circuits, another pass of the adds is
// given to forward them through the router. // given to forward them through the router. The sending link's quit channel is
// used to prevent deadlocks when the switch stops a link in the midst of
// forwarding.
// //
// NOTE: This method guarantees that the returned err chan will eventually be // NOTE: This method guarantees that the returned err chan will eventually be
// closed. The receiver should read on the channel until receiving such a // closed. The receiver should read on the channel until receiving such a
// signal. // signal.
func (s *Switch) ForwardPackets(packets ...*htlcPacket) chan error { func (s *Switch) ForwardPackets(linkQuit chan struct{},
packets ...*htlcPacket) chan error {
var ( var (
// fwdChan is a buffered channel used to receive err msgs from // fwdChan is a buffered channel used to receive err msgs from
@ -568,6 +571,9 @@ func (s *Switch) ForwardPackets(packets ...*htlcPacket) chan error {
// so, we exit early to avoid incrementing the switch's waitgroup while // so, we exit early to avoid incrementing the switch's waitgroup while
// it is already in the process of shutting down. // it is already in the process of shutting down.
select { select {
case <-linkQuit:
close(errChan)
return errChan
case <-s.quit: case <-s.quit:
close(errChan) close(errChan)
return errChan return errChan
@ -593,7 +599,10 @@ func (s *Switch) ForwardPackets(packets ...*htlcPacket) chan error {
circuits = append(circuits, circuit) circuits = append(circuits, circuit)
addBatch = append(addBatch, packet) addBatch = append(addBatch, packet)
default: default:
s.routeAsync(packet, fwdChan) err := s.routeAsync(packet, fwdChan, linkQuit)
if err != nil {
return errChan
}
numSent++ numSent++
} }
} }
@ -635,7 +644,10 @@ func (s *Switch) ForwardPackets(packets ...*htlcPacket) chan error {
// Now, forward any packets for circuits that were successfully added to // Now, forward any packets for circuits that were successfully added to
// the switch's circuit map. // the switch's circuit map.
for _, packet := range addedPackets { for _, packet := range addedPackets {
s.routeAsync(packet, fwdChan) err := s.routeAsync(packet, fwdChan, linkQuit)
if err != nil {
return errChan
}
numSent++ numSent++
} }
@ -722,9 +734,13 @@ func (s *Switch) route(packet *htlcPacket) error {
} }
// routeAsync sends a packet through the htlc switch, using the provided err // routeAsync sends a packet through the htlc switch, using the provided err
// chan to propagate errors back to the caller. This method does not wait for // chan to propagate errors back to the caller. The link's quit channel is
// a response before returning. // provided so that the send can be canceled if either the link or the switch
func (s *Switch) routeAsync(packet *htlcPacket, errChan chan error) error { // receive a shutdown requuest. This method does not wait for a response from
// the htlcForwarder before returning.
func (s *Switch) routeAsync(packet *htlcPacket, errChan chan error,
linkQuit chan struct{}) error {
command := &plexPacket{ command := &plexPacket{
pkt: packet, pkt: packet,
err: errChan, err: errChan,
@ -733,6 +749,8 @@ func (s *Switch) routeAsync(packet *htlcPacket, errChan chan error) error {
select { select {
case s.htlcPlex <- command: case s.htlcPlex <- command:
return nil return nil
case <-linkQuit:
return ErrLinkShuttingDown
case <-s.quit: case <-s.quit:
return errors.New("Htlc Switch was stopped") return errors.New("Htlc Switch was stopped")
} }
@ -1380,21 +1398,34 @@ func (s *Switch) htlcForwarder() {
s.blockEpochStream.Cancel() s.blockEpochStream.Cancel()
// Remove all links once we've been signalled for shutdown. // Remove all links once we've been signalled for shutdown.
var linksToStop []ChannelLink
s.indexMtx.Lock() s.indexMtx.Lock()
for _, link := range s.linkIndex { for _, link := range s.linkIndex {
if err := s.removeLink(link.ChanID()); err != nil { activeLink := s.removeLink(link.ChanID())
log.Errorf("unable to remove "+ if activeLink == nil {
"channel link on stop: %v", err) log.Errorf("unable to remove ChannelLink(%v) "+
"on stop", link.ChanID())
continue
} }
linksToStop = append(linksToStop, activeLink)
} }
for _, link := range s.pendingLinkIndex { for _, link := range s.pendingLinkIndex {
if err := s.removeLink(link.ChanID()); err != nil { pendingLink := s.removeLink(link.ChanID())
log.Errorf("unable to remove pending "+ if pendingLink == nil {
"channel link on stop: %v", err) log.Errorf("unable to remove ChannelLink(%v) "+
"on stop", link.ChanID())
continue
} }
linksToStop = append(linksToStop, pendingLink)
} }
s.indexMtx.Unlock() s.indexMtx.Unlock()
// Now that all pending and live links have been removed from
// the forwarding indexes, stop each one before shutting down.
for _, link := range linksToStop {
link.Stop()
}
// Before we exit fully, we'll attempt to flush out any // Before we exit fully, we'll attempt to flush out any
// forwarding events that may still be lingering since the last // forwarding events that may still be lingering since the last
// batch flush. // batch flush.
@ -1721,7 +1752,10 @@ func (s *Switch) reforwardSettleFails(fwdPkgs []*channeldb.FwdPkg) {
} }
} }
errChan := s.ForwardPackets(switchPackets...) // Since this send isn't tied to a specific link, we pass a nil
// link quit channel, meaning the send will fail only if the
// switch receives a shutdown request.
errChan := s.ForwardPackets(nil, switchPackets...)
go handleBatchFwdErrs(errChan) go handleBatchFwdErrs(errChan)
} }
} }
@ -1776,11 +1810,11 @@ func (s *Switch) AddLink(link ChannelLink) error {
chanID := link.ChanID() chanID := link.ChanID()
// If a link already exists, then remove the prior one so we can // First, ensure that this link is not already active in the switch.
// replace it with this fresh instance.
_, err := s.getLink(chanID) _, err := s.getLink(chanID)
if err == nil { if err == nil {
s.removeLink(chanID) return fmt.Errorf("unable to add ChannelLink(%v), already "+
"active", chanID)
} }
// Get and attach the mailbox for this link, which buffers packets in // Get and attach the mailbox for this link, which buffers packets in
@ -1868,24 +1902,28 @@ func (s *Switch) getLinkByShortID(chanID lnwire.ShortChannelID) (ChannelLink, er
return link, nil return link, nil
} }
// RemoveLink is used to initiate the handling of the remove link command. The // RemoveLink purges the switch of any link associated with chanID. If a pending
// request will be propagated/handled to/in the main goroutine. // or active link is not found, this method does nothing. Otherwise, the method
func (s *Switch) RemoveLink(chanID lnwire.ChannelID) error { // returns after the link has been completely shutdown.
func (s *Switch) RemoveLink(chanID lnwire.ChannelID) {
s.indexMtx.Lock() s.indexMtx.Lock()
defer s.indexMtx.Unlock() link := s.removeLink(chanID)
s.indexMtx.Unlock()
return s.removeLink(chanID) if link != nil {
link.Stop()
}
} }
// removeLink is used to remove and stop the channel link. // removeLink is used to remove and stop the channel link.
// //
// NOTE: This MUST be called with the indexMtx held. // NOTE: This MUST be called with the indexMtx held.
func (s *Switch) removeLink(chanID lnwire.ChannelID) error { func (s *Switch) removeLink(chanID lnwire.ChannelID) ChannelLink {
log.Infof("Removing channel link with ChannelID(%v)", chanID) log.Infof("Removing channel link with ChannelID(%v)", chanID)
link, err := s.getLink(chanID) link, err := s.getLink(chanID)
if err != nil { if err != nil {
return err return nil
} }
// Remove the channel from live link indexes. // Remove the channel from live link indexes.
@ -1906,9 +1944,7 @@ func (s *Switch) removeLink(chanID lnwire.ChannelID) error {
} }
} }
go link.Stop() return link
return nil
} }
// UpdateShortChanID updates the short chan ID for an existing channel. This is // UpdateShortChanID updates the short chan ID for an existing channel. This is

@ -25,6 +25,63 @@ func genPreimage() ([32]byte, error) {
return preimage, nil return preimage, nil
} }
// TestSwitchAddDuplicateLink tests that the switch will reject duplicate links
// for both pending and live links. It also tests that we can successfully
// add a link after having removed it.
func TestSwitchAddDuplicateLink(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
s, err := initSwitchWithDB(testStartingHeight, nil)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, _, aliceChanID, _ := genIDs()
pendingChanID := lnwire.ShortChannelID{}
aliceChannelLink := newMockChannelLink(
s, chanID1, pendingChanID, alicePeer, false,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
// Alice should have a pending link, adding again should fail.
if err := s.AddLink(aliceChannelLink); err == nil {
t.Fatalf("adding duplicate link should have failed")
}
// Update the short chan id of the channel, so that the link goes live.
aliceChannelLink.setLiveShortChanID(aliceChanID)
err = s.UpdateShortChanID(chanID1)
if err != nil {
t.Fatalf("unable to update alice short_chan_id: %v", err)
}
// Alice should have a live link, adding again should fail.
if err := s.AddLink(aliceChannelLink); err == nil {
t.Fatalf("adding duplicate link should have failed")
}
// Remove the live link to ensure the indexes are cleared.
s.RemoveLink(chanID1)
// Alice has no links, adding should succeed.
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
}
// TestSwitchSendPending checks the inability of htlc switch to forward adds // TestSwitchSendPending checks the inability of htlc switch to forward adds
// over pending links, and the UpdateShortChanID makes a pending link live. // over pending links, and the UpdateShortChanID makes a pending link live.
func TestSwitchSendPending(t *testing.T) { func TestSwitchSendPending(t *testing.T) {

25
peer.go

@ -469,11 +469,7 @@ func (p *peer) addLink(chanPoint *wire.OutPoint,
// mailboxes such that we can safely force close // mailboxes such that we can safely force close
// without the link being added again and updates being // without the link being added again and updates being
// applied. // applied.
err := p.server.htlcSwitch.RemoveLink(chanID) p.server.htlcSwitch.RemoveLink(chanID)
if err != nil {
peerLog.Errorf("unable to stop link(%v): %v",
shortChanID, err)
}
// If the error encountered was severe enough, we'll // If the error encountered was severe enough, we'll
// now force close the channel. // now force close the channel.
@ -557,6 +553,12 @@ func (p *peer) addLink(chanPoint *wire.OutPoint,
link := htlcswitch.NewChannelLink(linkCfg, lnChan) link := htlcswitch.NewChannelLink(linkCfg, lnChan)
// Before adding our new link, purge the switch of any pending or live
// links going by the same channel id. If one is found, we'll shut it
// down to ensure that the mailboxes are only ever under the control of
// one link.
p.server.htlcSwitch.RemoveLink(link.ChanID())
// With the channel link created, we'll now notify the htlc switch so // With the channel link created, we'll now notify the htlc switch so
// this channel can be used to dispatch local payments and also // this channel can be used to dispatch local payments and also
// passively forward payments. // passively forward payments.
@ -1526,8 +1528,8 @@ out:
) )
if err != nil { if err != nil {
peerLog.Errorf("can't register new channel "+ peerLog.Errorf("can't register new channel "+
"link(%v) with NodeKey(%x): %v", chanPoint, "link(%v) with NodeKey(%x)", chanPoint,
p.PubKey(), err) p.PubKey())
} }
close(newChanReq.done) close(newChanReq.done)
@ -1922,14 +1924,7 @@ func (p *peer) WipeChannel(chanPoint *wire.OutPoint) error {
// Instruct the HtlcSwitch to close this link as the channel is no // Instruct the HtlcSwitch to close this link as the channel is no
// longer active. // longer active.
if err := p.server.htlcSwitch.RemoveLink(chanID); err != nil { p.server.htlcSwitch.RemoveLink(chanID)
if err == htlcswitch.ErrChannelLinkNotFound {
peerLog.Warnf("unable remove channel link with "+
"ChannelPoint(%v): %v", chanID, err)
return nil
}
return err
}
return nil return nil
} }

@ -646,7 +646,8 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl,
ChainIO: cc.chainIO, ChainIO: cc.chainIO,
MarkLinkInactive: func(chanPoint wire.OutPoint) error { MarkLinkInactive: func(chanPoint wire.OutPoint) error {
chanID := lnwire.NewChanIDFromOutPoint(&chanPoint) chanID := lnwire.NewChanIDFromOutPoint(&chanPoint)
return s.htlcSwitch.RemoveLink(chanID) s.htlcSwitch.RemoveLink(chanID)
return nil
}, },
IsOurAddress: func(addr btcutil.Address) bool { IsOurAddress: func(addr btcutil.Address) bool {
_, err := cc.wallet.GetPrivKey(addr) _, err := cc.wallet.GetPrivKey(addr)
@ -1960,11 +1961,7 @@ func (s *server) peerTerminationWatcher(p *peer) {
} }
for _, link := range links { for _, link := range links {
err := p.server.htlcSwitch.RemoveLink(link.ChanID()) p.server.htlcSwitch.RemoveLink(link.ChanID())
if err != nil {
srvrLog.Errorf("unable to remove channel link: %v",
err)
}
} }
s.mu.Lock() s.mu.Lock()