diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index f11d1fd3..dcb6714d 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -1723,6 +1723,13 @@ func (s *Switch) AddLink(link ChannelLink) error { chanID := link.ChanID() + // First, ensure that this link is not already active in the switch. + _, err := s.getLink(chanID) + if err == nil { + return fmt.Errorf("unable to add ChannelLink(%v), already "+ + "active", chanID) + } + // Get and attach the mailbox for this link, which buffers packets in // case there packets that we tried to deliver while this link was // offline. @@ -1772,24 +1779,18 @@ func (s *Switch) addLiveLink(link ChannelLink) { s.interfaceIndex[peerPub][link] = struct{}{} } -// removeLiveLink removes a link from all associated forwarding indexes, this -// prevents it from being a candidate in forwarding. -func (s *Switch) removeLiveLink(link ChannelLink) { - // Remove the channel from live link indexes. - delete(s.linkIndex, link.ChanID()) - delete(s.forwardingIndex, link.ShortChanID()) - - // Remove the channel from channel index. - peerPub := link.Peer().PubKey() - delete(s.interfaceIndex, peerPub) -} - // GetLink is used to initiate the handling of the get link command. The // request will be propagated/handled to/in the main goroutine. func (s *Switch) GetLink(chanID lnwire.ChannelID) (ChannelLink, error) { s.indexMtx.RLock() defer s.indexMtx.RUnlock() + return s.getLink(chanID) +} + +// getLink returns the link stored in either the pending index or the live +// lindex. +func (s *Switch) getLink(chanID lnwire.ChannelID) (ChannelLink, error) { link, ok := s.linkIndex[chanID] if !ok { link, ok = s.pendingLinkIndex[chanID] @@ -1829,23 +1830,23 @@ func (s *Switch) RemoveLink(chanID lnwire.ChannelID) error { func (s *Switch) removeLink(chanID lnwire.ChannelID) error { log.Infof("Removing channel link with ChannelID(%v)", chanID) - link, ok := s.linkIndex[chanID] - if ok { - s.removeLiveLink(link) - link.Stop() - - return nil + link, err := s.getLink(chanID) + if err != nil { + return err } - link, ok = s.pendingLinkIndex[chanID] - if ok { - delete(s.pendingLinkIndex, chanID) - link.Stop() + // Remove the channel from live link indexes. + delete(s.pendingLinkIndex, link.ChanID()) + delete(s.linkIndex, link.ChanID()) + delete(s.forwardingIndex, link.ShortChanID()) - return nil - } + // Remove the channel from channel index. + peerPub := link.Peer().PubKey() + delete(s.interfaceIndex, peerPub) - return ErrChannelLinkNotFound + link.Stop() + + return nil } // UpdateShortChanID updates the short chan ID for an existing channel. This is