diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index f6120556..e1235f6b 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -168,10 +168,6 @@ type Switch struct { // forward the settle/fail htlc updates back to the add htlc initiator. circuits CircuitMap - // links is a map of channel id and channel link which manages - // this channel. - linkIndex map[lnwire.ChannelID]ChannelLink - // mailMtx is a read/write mutex that protects the mailboxes map. mailMtx sync.RWMutex @@ -179,6 +175,14 @@ type Switch struct { // switch to buffer messages for peers that have not come back online. mailboxes map[lnwire.ShortChannelID]MailBox + // indexMtx is a read/write mutex that protects the set of indexes + // below. + indexMtx sync.RWMutex + + // links is a map of channel id and channel link which manages + // this channel. + linkIndex map[lnwire.ChannelID]ChannelLink + // forwardingIndex is an index which is consulted by the switch when it // needs to locate the next hop to forward an incoming/outgoing HTLC // update to/from. @@ -244,7 +248,6 @@ func New(cfg Config) (*Switch, error) { htlcPlex: make(chan *plexPacket), chanCloseRequests: make(chan *ChanClose), resolutionMsgs: make(chan *resolutionMsg), - linkControl: make(chan interface{}), quit: make(chan struct{}), }, nil } @@ -386,63 +389,47 @@ func (s *Switch) SendHTLC(nextNode [33]byte, htlc *lnwire.UpdateAddHTLC, func (s *Switch) UpdateForwardingPolicies(newPolicy ForwardingPolicy, targetChans ...wire.OutPoint) error { - errChan := make(chan error, 1) - select { - case s.linkControl <- &updatePoliciesCmd{ - newPolicy: newPolicy, - targetChans: targetChans, - err: errChan, - }: - case <-s.quit: - return fmt.Errorf("switch is shutting down") - } + log.Debugf("Updating link policies: %v", newLogClosure(func() string { + return spew.Sdump(newPolicy) + })) - select { - case err := <-errChan: - return err - case <-s.quit: - return fmt.Errorf("switch is shutting down") - } -} + s.indexMtx.RLock() -// updatePoliciesCmd is a message sent to the switch to update the forwarding -// policies of a set of target links. -type updatePoliciesCmd struct { - newPolicy ForwardingPolicy - targetChans []wire.OutPoint + var linksToUpdate []ChannelLink - err chan error -} - -// updateLinkPolicies attempts to update the forwarding policies for the set of -// passed links identified by their channel points. If a nil set of channel -// points is passed, then the forwarding policies for all active links will be -// updated. -func (s *Switch) updateLinkPolicies(c *updatePoliciesCmd) error { - log.Debugf("Updating link policies: %v", spew.Sdump(c)) - - // If no channels have been targeted, then we'll update the link policies - // for all active channels - if len(c.targetChans) == 0 { + // If no channels have been targeted, then we'll collect all inks to + // update their policies. + if len(targetChans) == 0 { for _, link := range s.linkIndex { - link.UpdateForwardingPolicy(c.newPolicy) + linksToUpdate = append(linksToUpdate, link) + } + } else { + // Otherwise, we'll only attempt to update the forwarding + // policies for the set of targeted links. + for _, targetLink := range targetChans { + cid := lnwire.NewChanIDFromOutPoint(&targetLink) + + // If we can't locate a link by its converted channel + // ID, then we'll return an error back to the caller. + link, ok := s.linkIndex[cid] + if !ok { + s.indexMtx.RUnlock() + + return fmt.Errorf("unable to find "+ + "ChannelPoint(%v) to update link "+ + "policy", targetLink) + } + + linksToUpdate = append(linksToUpdate, link) } } - // Otherwise, we'll only attempt to update the forwarding policies for the - // set of targeted links. - for _, targetLink := range c.targetChans { - cid := lnwire.NewChanIDFromOutPoint(&targetLink) + s.indexMtx.RUnlock() - // If we can't locate a link by its converted channel ID, then we'll - // return an error back to the caller. - link, ok := s.linkIndex[cid] - if !ok { - return fmt.Errorf("unable to find ChannelPoint(%v) to "+ - "update link policy", targetLink) - } - - link.UpdateForwardingPolicy(c.newPolicy) + // With all the links we need to update collected, we can release the + // mutex then update each link directly. + for _, link := range linksToUpdate { + link.UpdateForwardingPolicy(newPolicy) } return nil @@ -715,14 +702,18 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error { // appropriate channel link and send the payment over this link. case *lnwire.UpdateAddHTLC: // Try to find links by node destination. + s.indexMtx.RLock() links, err := s.getLinks(pkt.destNode) if err != nil { + s.indexMtx.RUnlock() + log.Errorf("unable to find links by destination %v", err) return &ForwardingError{ ErrorSource: s.cfg.SelfKey, FailureMessage: &lnwire.FailUnknownNextPeer{}, } } + s.indexMtx.RUnlock() // Try to find destination channel link with appropriate // bandwidth. @@ -880,8 +871,11 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { return s.handleLocalDispatch(packet) } + s.indexMtx.RLock() targetLink, err := s.getLinkByShortID(packet.outgoingChanID) if err != nil { + s.indexMtx.RUnlock() + // If packet was forwarded from another channel link // than we should notify this link that some error // occurred. @@ -892,6 +886,7 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { return s.failAddPacket(packet, failure, addErr) } interfaceLinks, _ := s.getLinks(targetLink.Peer().PubKey()) + s.indexMtx.RUnlock() // We'll keep track of any HTLC failures during the link // selection process. This way we can return the error for @@ -1300,12 +1295,14 @@ func (s *Switch) htlcForwarder() { // Remove all links once we've been signalled for shutdown. defer func() { + s.indexMtx.Lock() for _, link := range s.linkIndex { if err := s.removeLink(link.ChanID()); err != nil { log.Errorf("unable to remove "+ "channel link on stop: %v", err) } } + s.indexMtx.Unlock() // Before we exit fully, we'll attempt to flush out any // forwarding events that may still be lingering since the last @@ -1336,12 +1333,17 @@ func (s *Switch) htlcForwarder() { // cooperatively closed (if possible). case req := <-s.chanCloseRequests: chanID := lnwire.NewChanIDFromOutPoint(req.ChanPoint) + + s.indexMtx.RLock() link, ok := s.linkIndex[chanID] if !ok { + s.indexMtx.RUnlock() + req.Err <- errors.Errorf("no peer for channel with "+ "chan_id=%x", chanID[:]) continue } + s.indexMtx.RUnlock() peerPub := link.Peer().PubKey() log.Debugf("Requesting local channel close: peer=%v, "+ @@ -1421,6 +1423,7 @@ func (s *Switch) htlcForwarder() { // Next, we'll run through all the registered links and // compute their up-to-date forwarding stats. + s.indexMtx.RLock() for _, link := range s.linkIndex { // TODO(roasbeef): when links first registered // stats printed. @@ -1429,6 +1432,7 @@ func (s *Switch) htlcForwarder() { newSatSent += sent.ToSatoshis() newSatRecv += recv.ToSatoshis() } + s.indexMtx.RUnlock() var ( diffNumUpdates uint64 @@ -1478,28 +1482,6 @@ func (s *Switch) htlcForwarder() { totalSatSent += diffSatSent totalSatRecv += diffSatRecv - case req := <-s.linkControl: - switch cmd := req.(type) { - case *updatePoliciesCmd: - cmd.err <- s.updateLinkPolicies(cmd) - case *addLinkCmd: - cmd.err <- s.addLink(cmd.link) - case *removeLinkCmd: - cmd.err <- s.removeLink(cmd.chanID) - case *getLinkCmd: - link, err := s.getLink(cmd.chanID) - cmd.done <- link - cmd.err <- err - case *getLinksCmd: - links, err := s.getLinks(cmd.peer) - cmd.done <- links - cmd.err <- err - case *updateForwardingIndexCmd: - cmd.err <- s.updateShortChanID( - cmd.chanID, cmd.shortChanID, - ) - } - case <-s.quit: return } @@ -1555,8 +1537,7 @@ func (s *Switch) reforwardResponses() error { // loadChannelFwdPkgs loads all forwarding packages owned by the `source` short // channel identifier. -func (s *Switch) loadChannelFwdPkgs( - source lnwire.ShortChannelID) ([]*channeldb.FwdPkg, error) { +func (s *Switch) loadChannelFwdPkgs(source lnwire.ShortChannelID) ([]*channeldb.FwdPkg, error) { var fwdPkgs []*channeldb.FwdPkg if err := s.cfg.DB.Update(func(tx *bolt.Tx) error { @@ -1688,38 +1669,11 @@ func (s *Switch) Stop() error { return nil } -// addLinkCmd is a add link command wrapper, it is used to propagate handler -// parameters and return handler error. -type addLinkCmd struct { - link ChannelLink - err chan error -} - // AddLink is used to initiate the handling of the add link command. The // request will be propagated and handled in the main goroutine. func (s *Switch) AddLink(link ChannelLink) error { - command := &addLinkCmd{ - link: link, - err: make(chan error, 1), - } - - select { - case s.linkControl <- command: - select { - case err := <-command.err: - return err - case <-s.quit: - } - case <-s.quit: - } - - return errors.New("unable to add link htlc switch was stopped") -} - -// addLink is used to add the newly created channel link and start use it to -// handle the channel updates. -func (s *Switch) addLink(link ChannelLink) error { - // TODO(roasbeef): reject if link already tehre? + s.indexMtx.Lock() + defer s.indexMtx.Unlock() // First we'll add the link to the linkIndex which lets us quickly look // up a channel when we need to close or register it, and the @@ -1781,47 +1735,12 @@ func (s *Switch) getOrCreateMailBox(chanID lnwire.ShortChannelID) MailBox { return mailbox } -// getLinkCmd is a get link command wrapper, it is used to propagate handler -// parameters and return handler error. -type getLinkCmd struct { - chanID lnwire.ChannelID - err chan error - done chan ChannelLink -} - // GetLink is used to initiate the handling of the get link command. The // request will be propagated/handled to/in the main goroutine. func (s *Switch) GetLink(chanID lnwire.ChannelID) (ChannelLink, error) { - command := &getLinkCmd{ - chanID: chanID, - err: make(chan error, 1), - done: make(chan ChannelLink, 1), - } + s.indexMtx.RLock() + defer s.indexMtx.RUnlock() -query: - select { - case s.linkControl <- command: - - var link ChannelLink - select { - case link = <-command.done: - case <-s.quit: - break query - } - - select { - case err := <-command.err: - return link, err - case <-s.quit: - } - case <-s.quit: - } - - return nil, errors.New("unable to get link htlc switch was stopped") -} - -// getLink attempts to return the link that has the specified channel ID. -func (s *Switch) getLink(chanID lnwire.ChannelID) (ChannelLink, error) { link, ok := s.linkIndex[chanID] if !ok { return nil, ErrChannelLinkNotFound @@ -1832,6 +1751,8 @@ func (s *Switch) getLink(chanID lnwire.ChannelID) (ChannelLink, error) { // getLinkByShortID attempts to return the link which possesses the target // short channel ID. +// +// NOTE: This MUST be called with the indexMtx held. func (s *Switch) getLinkByShortID(chanID lnwire.ShortChannelID) (ChannelLink, error) { link, ok := s.forwardingIndex[chanID] if !ok { @@ -1841,35 +1762,18 @@ func (s *Switch) getLinkByShortID(chanID lnwire.ShortChannelID) (ChannelLink, er return link, nil } -// removeLinkCmd is a get link command wrapper, it is used to propagate handler -// parameters and return handler error. -type removeLinkCmd struct { - chanID lnwire.ChannelID - err chan error -} - // RemoveLink is used to initiate the handling of the remove link command. The // request will be propagated/handled to/in the main goroutine. func (s *Switch) RemoveLink(chanID lnwire.ChannelID) error { - command := &removeLinkCmd{ - chanID: chanID, - err: make(chan error, 1), - } + s.indexMtx.Lock() + defer s.indexMtx.Unlock() - select { - case s.linkControl <- command: - select { - case err := <-command.err: - return err - case <-s.quit: - } - case <-s.quit: - } - - return errors.New("unable to remove link htlc switch was stopped") + return s.removeLink(chanID) } // removeLink is used to remove and stop the channel link. +// +// NOTE: This MUST be called with the indexMtx held. func (s *Switch) removeLink(chanID lnwire.ChannelID) error { log.Infof("Removing channel link with ChannelID(%v)", chanID) @@ -1891,50 +1795,21 @@ func (s *Switch) removeLink(chanID lnwire.ChannelID) error { return nil } -// updateForwardingIndexCmd is a command sent by outside sub-systems to update -// the forwarding index of the switch in the event that the short channel ID of -// a particular link changes. -type updateForwardingIndexCmd struct { - chanID lnwire.ChannelID - shortChanID lnwire.ShortChannelID - - err chan error -} - // UpdateShortChanID updates the short chan ID for an existing channel. This is // required in the case of a re-org and re-confirmation or a channel, or in the // case that a link was added to the switch before its short chan ID was known. func (s *Switch) UpdateShortChanID(chanID lnwire.ChannelID, shortChanID lnwire.ShortChannelID) error { - command := &updateForwardingIndexCmd{ - chanID: chanID, - shortChanID: shortChanID, - err: make(chan error, 1), - } - - select { - case s.linkControl <- command: - select { - case err := <-command.err: - return err - case <-s.quit: - } - case <-s.quit: - } - - return errors.New("unable to update short chan id htlc switch was stopped") -} - -// updateShortChanID updates the short chan ID of an existing link. -func (s *Switch) updateShortChanID(chanID lnwire.ChannelID, - shortChanID lnwire.ShortChannelID) error { + s.indexMtx.Lock() // First, we'll extract the current link as is from the link link // index. If the link isn't even in the index, then we'll return an // error. link, ok := s.linkIndex[chanID] if !ok { + s.indexMtx.Unlock() + return fmt.Errorf("link %v not found", chanID) } @@ -1945,53 +1820,27 @@ func (s *Switch) updateShortChanID(chanID lnwire.ChannelID, // forwarding index with the next short channel ID. s.forwardingIndex[shortChanID] = link + s.indexMtx.Unlock() + // Finally, we'll notify the link of its new short channel ID. link.UpdateShortChanID(shortChanID) return nil } -// getLinksCmd is a get links command wrapper, it is used to propagate handler -// parameters and return handler error. -type getLinksCmd struct { - peer [33]byte - err chan error - done chan []ChannelLink -} - // GetLinksByInterface fetches all the links connected to a particular node // identified by the serialized compressed form of its public key. func (s *Switch) GetLinksByInterface(hop [33]byte) ([]ChannelLink, error) { - command := &getLinksCmd{ - peer: hop, - err: make(chan error, 1), - done: make(chan []ChannelLink, 1), - } + s.indexMtx.RLock() + defer s.indexMtx.RUnlock() -query: - select { - case s.linkControl <- command: - - var links []ChannelLink - select { - case links = <-command.done: - case <-s.quit: - break query - } - - select { - case err := <-command.err: - return links, err - case <-s.quit: - } - case <-s.quit: - } - - return nil, errors.New("unable to get links htlc switch was stopped") + return s.getLinks(hop) } // getLinks is function which returns the channel links of the peer by hop // destination id. +// +// NOTE: This MUST be called with the indexMtx held. func (s *Switch) getLinks(destination [33]byte) ([]ChannelLink, error) { links, ok := s.interfaceIndex[destination] if !ok {