Merge pull request #1551 from cfromknecht/switch-revert-replace-link
[htlcswitch]: revert replace link, ensure removed links are stopped
This commit is contained in:
commit
7a113d469b
@ -73,7 +73,7 @@ type chanCloseCfg struct {
|
|||||||
// unregisterChannel is a function closure that allows the
|
// unregisterChannel is a function closure that allows the
|
||||||
// channelCloser to re-register a channel. Once this has been done, no
|
// channelCloser to re-register a channel. Once this has been done, no
|
||||||
// further HTLC's should be routed through the channel.
|
// further HTLC's should be routed through the channel.
|
||||||
unregisterChannel func(lnwire.ChannelID) error
|
unregisterChannel func(lnwire.ChannelID)
|
||||||
|
|
||||||
// broadcastTx broadcasts the passed transaction to the network.
|
// broadcastTx broadcasts the passed transaction to the network.
|
||||||
broadcastTx func(*wire.MsgTx) error
|
broadcastTx func(*wire.MsgTx) error
|
||||||
|
@ -117,9 +117,10 @@ type ChannelLinkConfig struct {
|
|||||||
Switch *Switch
|
Switch *Switch
|
||||||
|
|
||||||
// ForwardPackets attempts to forward the batch of htlcs through the
|
// ForwardPackets attempts to forward the batch of htlcs through the
|
||||||
// switch. Any failed packets will be returned to the provided
|
// switch, any failed packets will be returned to the provided
|
||||||
// ChannelLink.
|
// ChannelLink. The link's quit signal should be provided to allow
|
||||||
ForwardPackets func(...*htlcPacket) chan error
|
// cancellation of forwarding during link shutdown.
|
||||||
|
ForwardPackets func(chan struct{}, ...*htlcPacket) chan error
|
||||||
|
|
||||||
// DecodeHopIterators facilitates batched decoding of HTLC Sphinx onion
|
// DecodeHopIterators facilitates batched decoding of HTLC Sphinx onion
|
||||||
// blobs, which are then used to inform how to forward an HTLC.
|
// blobs, which are then used to inform how to forward an HTLC.
|
||||||
@ -359,21 +360,6 @@ func (l *channelLink) Start() error {
|
|||||||
|
|
||||||
log.Infof("ChannelLink(%v) is starting", l)
|
log.Infof("ChannelLink(%v) is starting", l)
|
||||||
|
|
||||||
// Before we start the link, we'll update the ChainArbitrator with the
|
|
||||||
// set of new channel signals for this channel.
|
|
||||||
//
|
|
||||||
// TODO(roasbeef): split goroutines within channel arb to avoid
|
|
||||||
go func() {
|
|
||||||
err := l.cfg.UpdateContractSignals(&contractcourt.ContractSignals{
|
|
||||||
HtlcUpdates: l.htlcUpdates,
|
|
||||||
ShortChanID: l.channel.ShortChanID(),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Unable to update signals for "+
|
|
||||||
"ChannelLink(%v)", l)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
l.mailBox.ResetMessages()
|
l.mailBox.ResetMessages()
|
||||||
l.overflowQueue.Start()
|
l.overflowQueue.Start()
|
||||||
|
|
||||||
@ -401,6 +387,24 @@ func (l *channelLink) Start() error {
|
|||||||
return fmt.Errorf("unable to trim circuits above "+
|
return fmt.Errorf("unable to trim circuits above "+
|
||||||
"local htlc index %d: %v", localHtlcIndex, err)
|
"local htlc index %d: %v", localHtlcIndex, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Since the link is live, before we start the link we'll update
|
||||||
|
// the ChainArbitrator with the set of new channel signals for
|
||||||
|
// this channel.
|
||||||
|
//
|
||||||
|
// TODO(roasbeef): split goroutines within channel arb to avoid
|
||||||
|
go func() {
|
||||||
|
signals := &contractcourt.ContractSignals{
|
||||||
|
HtlcUpdates: l.htlcUpdates,
|
||||||
|
ShortChanID: l.channel.ShortChanID(),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := l.cfg.UpdateContractSignals(signals)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Unable to update signals for "+
|
||||||
|
"ChannelLink(%v)", l)
|
||||||
|
}
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
l.updateFeeTimer = time.NewTimer(l.randomFeeUpdateTimeout())
|
l.updateFeeTimer = time.NewTimer(l.randomFeeUpdateTimeout())
|
||||||
@ -2539,7 +2543,7 @@ func (l *channelLink) forwardBatch(packets ...*htlcPacket) {
|
|||||||
filteredPkts = append(filteredPkts, pkt)
|
filteredPkts = append(filteredPkts, pkt)
|
||||||
}
|
}
|
||||||
|
|
||||||
errChan := l.cfg.ForwardPackets(filteredPkts...)
|
errChan := l.cfg.ForwardPackets(l.quit, filteredPkts...)
|
||||||
go l.handleBatchFwdErrs(errChan)
|
go l.handleBatchFwdErrs(errChan)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3396,6 +3396,136 @@ func TestShouldAdjustCommitFee(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestChannelLinkShutdownDuringForward asserts that a link can be fully
|
||||||
|
// stopped when it is trying to send synchronously through the switch. The
|
||||||
|
// specific case this can occur is when a link forwards incoming Adds. We test
|
||||||
|
// this by forcing the switch into a state where it will not accept new packets,
|
||||||
|
// and then killing the link, which can only succeed if forwarding can be
|
||||||
|
// canceled by a call to Stop.
|
||||||
|
func TestChannelLinkShutdownDuringForward(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// First, we'll create our traditional three hop network. We're
|
||||||
|
// interested in testing the ability to stop the link when it is
|
||||||
|
// synchronously forwarding to the switch, which happens when an
|
||||||
|
// incoming link forwards Adds. Thus, the test will be performed
|
||||||
|
// against Bob's first link.
|
||||||
|
channels, cleanUp, _, err := createClusterChannels(
|
||||||
|
btcutil.SatoshiPerBitcoin*3,
|
||||||
|
btcutil.SatoshiPerBitcoin*5)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to create channel: %v", err)
|
||||||
|
}
|
||||||
|
defer cleanUp()
|
||||||
|
|
||||||
|
n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice,
|
||||||
|
channels.bobToCarol, channels.carolToBob, testStartingHeight)
|
||||||
|
|
||||||
|
if err := n.start(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer n.stop()
|
||||||
|
defer n.feeEstimator.Stop()
|
||||||
|
|
||||||
|
// Define a helper method that strobes the switch's log ticker, and
|
||||||
|
// unblocks after nothing has been pulled for two seconds.
|
||||||
|
waitForBobsSwitchToBlock := func() {
|
||||||
|
bobSwitch := n.firstBobChannelLink.cfg.Switch
|
||||||
|
ticker := bobSwitch.cfg.LogEventTicker.(*ticker.Mock)
|
||||||
|
timeout := time.After(15 * time.Second)
|
||||||
|
for {
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
select {
|
||||||
|
case ticker.Force <- time.Now():
|
||||||
|
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
return
|
||||||
|
|
||||||
|
case <-timeout:
|
||||||
|
t.Fatalf("switch did not block")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define a helper method that strobes the link's batch ticker, and
|
||||||
|
// unblocks after nothing has been pulled for two seconds.
|
||||||
|
waitForBobsIncomingLinkToBlock := func() {
|
||||||
|
ticker := n.firstBobChannelLink.cfg.BatchTicker.(*ticker.Mock)
|
||||||
|
timeout := time.After(15 * time.Second)
|
||||||
|
for {
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
select {
|
||||||
|
case ticker.Force <- time.Now():
|
||||||
|
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
// We'll give a little extra time here, to
|
||||||
|
// ensure that the packet is being pressed
|
||||||
|
// against the htlcPlex.
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
return
|
||||||
|
|
||||||
|
case <-timeout:
|
||||||
|
t.Fatalf("link did not block")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// To test that the cancellation is happening properly, we will set the
|
||||||
|
// switch's htlcPlex to nil, so that calls to routeAsync block, and can
|
||||||
|
// only exit if the link (or switch) is exiting. We will only be testing
|
||||||
|
// the link here.
|
||||||
|
//
|
||||||
|
// In order to avoid data races, we need to ensure the switch isn't
|
||||||
|
// selecting on that channel in the meantime. We'll prevent this by
|
||||||
|
// first acquiring the index mutex and forcing a log event so that the
|
||||||
|
// htlcForwarder is blocked inside the logTicker case, which also needs
|
||||||
|
// the indexMtx.
|
||||||
|
n.firstBobChannelLink.cfg.Switch.indexMtx.Lock()
|
||||||
|
|
||||||
|
// Strobe the log ticker, and wait for switch to stop accepting any more
|
||||||
|
// log ticks.
|
||||||
|
waitForBobsSwitchToBlock()
|
||||||
|
|
||||||
|
// While the htlcForwarder is blocked, swap out the htlcPlex with a nil
|
||||||
|
// channel, and unlock the indexMtx to allow return to the
|
||||||
|
// htlcForwarder's main select. After this, any attempt to forward
|
||||||
|
// through the switch will block.
|
||||||
|
n.firstBobChannelLink.cfg.Switch.htlcPlex = nil
|
||||||
|
n.firstBobChannelLink.cfg.Switch.indexMtx.Unlock()
|
||||||
|
|
||||||
|
// Now, make a payment from Alice to Carol, which should cause Bob's
|
||||||
|
// incoming link to block when it tries to submit the packet to the nil
|
||||||
|
// htlcPlex.
|
||||||
|
amount := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin)
|
||||||
|
htlcAmt, totalTimelock, hops := generateHops(
|
||||||
|
amount, testStartingHeight,
|
||||||
|
n.firstBobChannelLink, n.carolChannelLink,
|
||||||
|
)
|
||||||
|
|
||||||
|
n.makePayment(
|
||||||
|
n.aliceServer, n.carolServer, n.bobServer.PubKey(),
|
||||||
|
hops, amount, htlcAmt, totalTimelock,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Strobe the batch ticker of Bob's incoming link, waiting for it to
|
||||||
|
// become fully blocked.
|
||||||
|
waitForBobsIncomingLinkToBlock()
|
||||||
|
|
||||||
|
// Finally, stop the link to test that it can exit while synchronously
|
||||||
|
// forwarding Adds to the switch.
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
n.firstBobChannelLink.Stop()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Fatalf("unable to shutdown link while fwding incoming Adds")
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestChannelLinkUpdateCommitFee tests that when a new block comes in, the
|
// TestChannelLinkUpdateCommitFee tests that when a new block comes in, the
|
||||||
// channel link properly checks to see if it should update the commitment fee.
|
// channel link properly checks to see if it should update the commitment fee.
|
||||||
func TestChannelLinkUpdateCommitFee(t *testing.T) {
|
func TestChannelLinkUpdateCommitFee(t *testing.T) {
|
||||||
@ -3709,7 +3839,6 @@ func (h *persistentLinkHarness) restart(restartSwitch bool,
|
|||||||
|
|
||||||
// First, remove the link from the switch.
|
// First, remove the link from the switch.
|
||||||
h.coreLink.cfg.Switch.RemoveLink(h.link.ChanID())
|
h.coreLink.cfg.Switch.RemoveLink(h.link.ChanID())
|
||||||
h.coreLink.WaitForShutdown()
|
|
||||||
|
|
||||||
var htlcSwitch *Switch
|
var htlcSwitch *Switch
|
||||||
if restartSwitch {
|
if restartSwitch {
|
||||||
|
@ -527,12 +527,15 @@ func (s *Switch) forward(packet *htlcPacket) error {
|
|||||||
// ForwardPackets adds a list of packets to the switch for processing. Fails
|
// ForwardPackets adds a list of packets to the switch for processing. Fails
|
||||||
// and settles are added on a first past, simultaneously constructing circuits
|
// and settles are added on a first past, simultaneously constructing circuits
|
||||||
// for any adds. After persisting the circuits, another pass of the adds is
|
// for any adds. After persisting the circuits, another pass of the adds is
|
||||||
// given to forward them through the router.
|
// given to forward them through the router. The sending link's quit channel is
|
||||||
|
// used to prevent deadlocks when the switch stops a link in the midst of
|
||||||
|
// forwarding.
|
||||||
//
|
//
|
||||||
// NOTE: This method guarantees that the returned err chan will eventually be
|
// NOTE: This method guarantees that the returned err chan will eventually be
|
||||||
// closed. The receiver should read on the channel until receiving such a
|
// closed. The receiver should read on the channel until receiving such a
|
||||||
// signal.
|
// signal.
|
||||||
func (s *Switch) ForwardPackets(packets ...*htlcPacket) chan error {
|
func (s *Switch) ForwardPackets(linkQuit chan struct{},
|
||||||
|
packets ...*htlcPacket) chan error {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// fwdChan is a buffered channel used to receive err msgs from
|
// fwdChan is a buffered channel used to receive err msgs from
|
||||||
@ -568,6 +571,9 @@ func (s *Switch) ForwardPackets(packets ...*htlcPacket) chan error {
|
|||||||
// so, we exit early to avoid incrementing the switch's waitgroup while
|
// so, we exit early to avoid incrementing the switch's waitgroup while
|
||||||
// it is already in the process of shutting down.
|
// it is already in the process of shutting down.
|
||||||
select {
|
select {
|
||||||
|
case <-linkQuit:
|
||||||
|
close(errChan)
|
||||||
|
return errChan
|
||||||
case <-s.quit:
|
case <-s.quit:
|
||||||
close(errChan)
|
close(errChan)
|
||||||
return errChan
|
return errChan
|
||||||
@ -593,7 +599,10 @@ func (s *Switch) ForwardPackets(packets ...*htlcPacket) chan error {
|
|||||||
circuits = append(circuits, circuit)
|
circuits = append(circuits, circuit)
|
||||||
addBatch = append(addBatch, packet)
|
addBatch = append(addBatch, packet)
|
||||||
default:
|
default:
|
||||||
s.routeAsync(packet, fwdChan)
|
err := s.routeAsync(packet, fwdChan, linkQuit)
|
||||||
|
if err != nil {
|
||||||
|
return errChan
|
||||||
|
}
|
||||||
numSent++
|
numSent++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -635,7 +644,10 @@ func (s *Switch) ForwardPackets(packets ...*htlcPacket) chan error {
|
|||||||
// Now, forward any packets for circuits that were successfully added to
|
// Now, forward any packets for circuits that were successfully added to
|
||||||
// the switch's circuit map.
|
// the switch's circuit map.
|
||||||
for _, packet := range addedPackets {
|
for _, packet := range addedPackets {
|
||||||
s.routeAsync(packet, fwdChan)
|
err := s.routeAsync(packet, fwdChan, linkQuit)
|
||||||
|
if err != nil {
|
||||||
|
return errChan
|
||||||
|
}
|
||||||
numSent++
|
numSent++
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -722,9 +734,13 @@ func (s *Switch) route(packet *htlcPacket) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// routeAsync sends a packet through the htlc switch, using the provided err
|
// routeAsync sends a packet through the htlc switch, using the provided err
|
||||||
// chan to propagate errors back to the caller. This method does not wait for
|
// chan to propagate errors back to the caller. The link's quit channel is
|
||||||
// a response before returning.
|
// provided so that the send can be canceled if either the link or the switch
|
||||||
func (s *Switch) routeAsync(packet *htlcPacket, errChan chan error) error {
|
// receive a shutdown requuest. This method does not wait for a response from
|
||||||
|
// the htlcForwarder before returning.
|
||||||
|
func (s *Switch) routeAsync(packet *htlcPacket, errChan chan error,
|
||||||
|
linkQuit chan struct{}) error {
|
||||||
|
|
||||||
command := &plexPacket{
|
command := &plexPacket{
|
||||||
pkt: packet,
|
pkt: packet,
|
||||||
err: errChan,
|
err: errChan,
|
||||||
@ -733,6 +749,8 @@ func (s *Switch) routeAsync(packet *htlcPacket, errChan chan error) error {
|
|||||||
select {
|
select {
|
||||||
case s.htlcPlex <- command:
|
case s.htlcPlex <- command:
|
||||||
return nil
|
return nil
|
||||||
|
case <-linkQuit:
|
||||||
|
return ErrLinkShuttingDown
|
||||||
case <-s.quit:
|
case <-s.quit:
|
||||||
return errors.New("Htlc Switch was stopped")
|
return errors.New("Htlc Switch was stopped")
|
||||||
}
|
}
|
||||||
@ -1380,21 +1398,34 @@ func (s *Switch) htlcForwarder() {
|
|||||||
s.blockEpochStream.Cancel()
|
s.blockEpochStream.Cancel()
|
||||||
|
|
||||||
// Remove all links once we've been signalled for shutdown.
|
// Remove all links once we've been signalled for shutdown.
|
||||||
|
var linksToStop []ChannelLink
|
||||||
s.indexMtx.Lock()
|
s.indexMtx.Lock()
|
||||||
for _, link := range s.linkIndex {
|
for _, link := range s.linkIndex {
|
||||||
if err := s.removeLink(link.ChanID()); err != nil {
|
activeLink := s.removeLink(link.ChanID())
|
||||||
log.Errorf("unable to remove "+
|
if activeLink == nil {
|
||||||
"channel link on stop: %v", err)
|
log.Errorf("unable to remove ChannelLink(%v) "+
|
||||||
|
"on stop", link.ChanID())
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
linksToStop = append(linksToStop, activeLink)
|
||||||
}
|
}
|
||||||
for _, link := range s.pendingLinkIndex {
|
for _, link := range s.pendingLinkIndex {
|
||||||
if err := s.removeLink(link.ChanID()); err != nil {
|
pendingLink := s.removeLink(link.ChanID())
|
||||||
log.Errorf("unable to remove pending "+
|
if pendingLink == nil {
|
||||||
"channel link on stop: %v", err)
|
log.Errorf("unable to remove ChannelLink(%v) "+
|
||||||
|
"on stop", link.ChanID())
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
linksToStop = append(linksToStop, pendingLink)
|
||||||
}
|
}
|
||||||
s.indexMtx.Unlock()
|
s.indexMtx.Unlock()
|
||||||
|
|
||||||
|
// Now that all pending and live links have been removed from
|
||||||
|
// the forwarding indexes, stop each one before shutting down.
|
||||||
|
for _, link := range linksToStop {
|
||||||
|
link.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
// Before we exit fully, we'll attempt to flush out any
|
// Before we exit fully, we'll attempt to flush out any
|
||||||
// forwarding events that may still be lingering since the last
|
// forwarding events that may still be lingering since the last
|
||||||
// batch flush.
|
// batch flush.
|
||||||
@ -1721,7 +1752,10 @@ func (s *Switch) reforwardSettleFails(fwdPkgs []*channeldb.FwdPkg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
errChan := s.ForwardPackets(switchPackets...)
|
// Since this send isn't tied to a specific link, we pass a nil
|
||||||
|
// link quit channel, meaning the send will fail only if the
|
||||||
|
// switch receives a shutdown request.
|
||||||
|
errChan := s.ForwardPackets(nil, switchPackets...)
|
||||||
go handleBatchFwdErrs(errChan)
|
go handleBatchFwdErrs(errChan)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1776,11 +1810,11 @@ func (s *Switch) AddLink(link ChannelLink) error {
|
|||||||
|
|
||||||
chanID := link.ChanID()
|
chanID := link.ChanID()
|
||||||
|
|
||||||
// If a link already exists, then remove the prior one so we can
|
// First, ensure that this link is not already active in the switch.
|
||||||
// replace it with this fresh instance.
|
|
||||||
_, err := s.getLink(chanID)
|
_, err := s.getLink(chanID)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
s.removeLink(chanID)
|
return fmt.Errorf("unable to add ChannelLink(%v), already "+
|
||||||
|
"active", chanID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get and attach the mailbox for this link, which buffers packets in
|
// Get and attach the mailbox for this link, which buffers packets in
|
||||||
@ -1868,24 +1902,28 @@ func (s *Switch) getLinkByShortID(chanID lnwire.ShortChannelID) (ChannelLink, er
|
|||||||
return link, nil
|
return link, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveLink is used to initiate the handling of the remove link command. The
|
// RemoveLink purges the switch of any link associated with chanID. If a pending
|
||||||
// request will be propagated/handled to/in the main goroutine.
|
// or active link is not found, this method does nothing. Otherwise, the method
|
||||||
func (s *Switch) RemoveLink(chanID lnwire.ChannelID) error {
|
// returns after the link has been completely shutdown.
|
||||||
|
func (s *Switch) RemoveLink(chanID lnwire.ChannelID) {
|
||||||
s.indexMtx.Lock()
|
s.indexMtx.Lock()
|
||||||
defer s.indexMtx.Unlock()
|
link := s.removeLink(chanID)
|
||||||
|
s.indexMtx.Unlock()
|
||||||
|
|
||||||
return s.removeLink(chanID)
|
if link != nil {
|
||||||
|
link.Stop()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// removeLink is used to remove and stop the channel link.
|
// removeLink is used to remove and stop the channel link.
|
||||||
//
|
//
|
||||||
// NOTE: This MUST be called with the indexMtx held.
|
// NOTE: This MUST be called with the indexMtx held.
|
||||||
func (s *Switch) removeLink(chanID lnwire.ChannelID) error {
|
func (s *Switch) removeLink(chanID lnwire.ChannelID) ChannelLink {
|
||||||
log.Infof("Removing channel link with ChannelID(%v)", chanID)
|
log.Infof("Removing channel link with ChannelID(%v)", chanID)
|
||||||
|
|
||||||
link, err := s.getLink(chanID)
|
link, err := s.getLink(chanID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove the channel from live link indexes.
|
// Remove the channel from live link indexes.
|
||||||
@ -1906,9 +1944,7 @@ func (s *Switch) removeLink(chanID lnwire.ChannelID) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
go link.Stop()
|
return link
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateShortChanID updates the short chan ID for an existing channel. This is
|
// UpdateShortChanID updates the short chan ID for an existing channel. This is
|
||||||
|
@ -25,6 +25,63 @@ func genPreimage() ([32]byte, error) {
|
|||||||
return preimage, nil
|
return preimage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestSwitchAddDuplicateLink tests that the switch will reject duplicate links
|
||||||
|
// for both pending and live links. It also tests that we can successfully
|
||||||
|
// add a link after having removed it.
|
||||||
|
func TestSwitchAddDuplicateLink(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to create alice server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := initSwitchWithDB(testStartingHeight, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to init switch: %v", err)
|
||||||
|
}
|
||||||
|
if err := s.Start(); err != nil {
|
||||||
|
t.Fatalf("unable to start switch: %v", err)
|
||||||
|
}
|
||||||
|
defer s.Stop()
|
||||||
|
|
||||||
|
chanID1, _, aliceChanID, _ := genIDs()
|
||||||
|
|
||||||
|
pendingChanID := lnwire.ShortChannelID{}
|
||||||
|
|
||||||
|
aliceChannelLink := newMockChannelLink(
|
||||||
|
s, chanID1, pendingChanID, alicePeer, false,
|
||||||
|
)
|
||||||
|
if err := s.AddLink(aliceChannelLink); err != nil {
|
||||||
|
t.Fatalf("unable to add alice link: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Alice should have a pending link, adding again should fail.
|
||||||
|
if err := s.AddLink(aliceChannelLink); err == nil {
|
||||||
|
t.Fatalf("adding duplicate link should have failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the short chan id of the channel, so that the link goes live.
|
||||||
|
aliceChannelLink.setLiveShortChanID(aliceChanID)
|
||||||
|
err = s.UpdateShortChanID(chanID1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to update alice short_chan_id: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Alice should have a live link, adding again should fail.
|
||||||
|
if err := s.AddLink(aliceChannelLink); err == nil {
|
||||||
|
t.Fatalf("adding duplicate link should have failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the live link to ensure the indexes are cleared.
|
||||||
|
s.RemoveLink(chanID1)
|
||||||
|
|
||||||
|
// Alice has no links, adding should succeed.
|
||||||
|
if err := s.AddLink(aliceChannelLink); err != nil {
|
||||||
|
t.Fatalf("unable to add alice link: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestSwitchSendPending checks the inability of htlc switch to forward adds
|
// TestSwitchSendPending checks the inability of htlc switch to forward adds
|
||||||
// over pending links, and the UpdateShortChanID makes a pending link live.
|
// over pending links, and the UpdateShortChanID makes a pending link live.
|
||||||
func TestSwitchSendPending(t *testing.T) {
|
func TestSwitchSendPending(t *testing.T) {
|
||||||
|
25
peer.go
25
peer.go
@ -469,11 +469,7 @@ func (p *peer) addLink(chanPoint *wire.OutPoint,
|
|||||||
// mailboxes such that we can safely force close
|
// mailboxes such that we can safely force close
|
||||||
// without the link being added again and updates being
|
// without the link being added again and updates being
|
||||||
// applied.
|
// applied.
|
||||||
err := p.server.htlcSwitch.RemoveLink(chanID)
|
p.server.htlcSwitch.RemoveLink(chanID)
|
||||||
if err != nil {
|
|
||||||
peerLog.Errorf("unable to stop link(%v): %v",
|
|
||||||
shortChanID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the error encountered was severe enough, we'll
|
// If the error encountered was severe enough, we'll
|
||||||
// now force close the channel.
|
// now force close the channel.
|
||||||
@ -557,6 +553,12 @@ func (p *peer) addLink(chanPoint *wire.OutPoint,
|
|||||||
|
|
||||||
link := htlcswitch.NewChannelLink(linkCfg, lnChan)
|
link := htlcswitch.NewChannelLink(linkCfg, lnChan)
|
||||||
|
|
||||||
|
// Before adding our new link, purge the switch of any pending or live
|
||||||
|
// links going by the same channel id. If one is found, we'll shut it
|
||||||
|
// down to ensure that the mailboxes are only ever under the control of
|
||||||
|
// one link.
|
||||||
|
p.server.htlcSwitch.RemoveLink(link.ChanID())
|
||||||
|
|
||||||
// With the channel link created, we'll now notify the htlc switch so
|
// With the channel link created, we'll now notify the htlc switch so
|
||||||
// this channel can be used to dispatch local payments and also
|
// this channel can be used to dispatch local payments and also
|
||||||
// passively forward payments.
|
// passively forward payments.
|
||||||
@ -1526,8 +1528,8 @@ out:
|
|||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
peerLog.Errorf("can't register new channel "+
|
peerLog.Errorf("can't register new channel "+
|
||||||
"link(%v) with NodeKey(%x): %v", chanPoint,
|
"link(%v) with NodeKey(%x)", chanPoint,
|
||||||
p.PubKey(), err)
|
p.PubKey())
|
||||||
}
|
}
|
||||||
|
|
||||||
close(newChanReq.done)
|
close(newChanReq.done)
|
||||||
@ -1922,14 +1924,7 @@ func (p *peer) WipeChannel(chanPoint *wire.OutPoint) error {
|
|||||||
|
|
||||||
// Instruct the HtlcSwitch to close this link as the channel is no
|
// Instruct the HtlcSwitch to close this link as the channel is no
|
||||||
// longer active.
|
// longer active.
|
||||||
if err := p.server.htlcSwitch.RemoveLink(chanID); err != nil {
|
p.server.htlcSwitch.RemoveLink(chanID)
|
||||||
if err == htlcswitch.ErrChannelLinkNotFound {
|
|
||||||
peerLog.Warnf("unable remove channel link with "+
|
|
||||||
"ChannelPoint(%v): %v", chanID, err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -646,7 +646,8 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl,
|
|||||||
ChainIO: cc.chainIO,
|
ChainIO: cc.chainIO,
|
||||||
MarkLinkInactive: func(chanPoint wire.OutPoint) error {
|
MarkLinkInactive: func(chanPoint wire.OutPoint) error {
|
||||||
chanID := lnwire.NewChanIDFromOutPoint(&chanPoint)
|
chanID := lnwire.NewChanIDFromOutPoint(&chanPoint)
|
||||||
return s.htlcSwitch.RemoveLink(chanID)
|
s.htlcSwitch.RemoveLink(chanID)
|
||||||
|
return nil
|
||||||
},
|
},
|
||||||
IsOurAddress: func(addr btcutil.Address) bool {
|
IsOurAddress: func(addr btcutil.Address) bool {
|
||||||
_, err := cc.wallet.GetPrivKey(addr)
|
_, err := cc.wallet.GetPrivKey(addr)
|
||||||
@ -1960,11 +1961,7 @@ func (s *server) peerTerminationWatcher(p *peer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, link := range links {
|
for _, link := range links {
|
||||||
err := p.server.htlcSwitch.RemoveLink(link.ChanID())
|
p.server.htlcSwitch.RemoveLink(link.ChanID())
|
||||||
if err != nil {
|
|
||||||
srvrLog.Errorf("unable to remove channel link: %v",
|
|
||||||
err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
|
Loading…
Reference in New Issue
Block a user