htlcswitch: remove linkControl in favor of a mutex guarding all channel indexes

In this commit, we simplify the switch's code a bit. Rather than having
a set of channels we use to mutate or query for the set of current
links, we'll instead now just use a mutex to guard a set of link
indexes. This serves to simplify the ode, and also make it such that we
don't need to block forwarding in order to add/remove a link.
This commit is contained in:
Olaoluwa Osuntokun 2018-04-03 20:06:57 -07:00
parent 7037d55f65
commit 0a47b2c4ad
No known key found for this signature in database
GPG Key ID: 964EA263DD637C21

@ -168,10 +168,6 @@ type Switch struct {
// forward the settle/fail htlc updates back to the add htlc initiator. // forward the settle/fail htlc updates back to the add htlc initiator.
circuits CircuitMap circuits CircuitMap
// links is a map of channel id and channel link which manages
// this channel.
linkIndex map[lnwire.ChannelID]ChannelLink
// mailMtx is a read/write mutex that protects the mailboxes map. // mailMtx is a read/write mutex that protects the mailboxes map.
mailMtx sync.RWMutex mailMtx sync.RWMutex
@ -179,6 +175,14 @@ type Switch struct {
// switch to buffer messages for peers that have not come back online. // switch to buffer messages for peers that have not come back online.
mailboxes map[lnwire.ShortChannelID]MailBox mailboxes map[lnwire.ShortChannelID]MailBox
// indexMtx is a read/write mutex that protects the set of indexes
// below.
indexMtx sync.RWMutex
// links is a map of channel id and channel link which manages
// this channel.
linkIndex map[lnwire.ChannelID]ChannelLink
// forwardingIndex is an index which is consulted by the switch when it // forwardingIndex is an index which is consulted by the switch when it
// needs to locate the next hop to forward an incoming/outgoing HTLC // needs to locate the next hop to forward an incoming/outgoing HTLC
// update to/from. // update to/from.
@ -244,7 +248,6 @@ func New(cfg Config) (*Switch, error) {
htlcPlex: make(chan *plexPacket), htlcPlex: make(chan *plexPacket),
chanCloseRequests: make(chan *ChanClose), chanCloseRequests: make(chan *ChanClose),
resolutionMsgs: make(chan *resolutionMsg), resolutionMsgs: make(chan *resolutionMsg),
linkControl: make(chan interface{}),
quit: make(chan struct{}), quit: make(chan struct{}),
}, nil }, nil
} }
@ -386,63 +389,47 @@ func (s *Switch) SendHTLC(nextNode [33]byte, htlc *lnwire.UpdateAddHTLC,
func (s *Switch) UpdateForwardingPolicies(newPolicy ForwardingPolicy, func (s *Switch) UpdateForwardingPolicies(newPolicy ForwardingPolicy,
targetChans ...wire.OutPoint) error { targetChans ...wire.OutPoint) error {
errChan := make(chan error, 1) log.Debugf("Updating link policies: %v", newLogClosure(func() string {
select { return spew.Sdump(newPolicy)
case s.linkControl <- &updatePoliciesCmd{ }))
newPolicy: newPolicy,
targetChans: targetChans,
err: errChan,
}:
case <-s.quit:
return fmt.Errorf("switch is shutting down")
}
select { s.indexMtx.RLock()
case err := <-errChan:
return err
case <-s.quit:
return fmt.Errorf("switch is shutting down")
}
}
// updatePoliciesCmd is a message sent to the switch to update the forwarding var linksToUpdate []ChannelLink
// policies of a set of target links.
type updatePoliciesCmd struct {
newPolicy ForwardingPolicy
targetChans []wire.OutPoint
err chan error // If no channels have been targeted, then we'll collect all inks to
} // update their policies.
if len(targetChans) == 0 {
// updateLinkPolicies attempts to update the forwarding policies for the set of
// passed links identified by their channel points. If a nil set of channel
// points is passed, then the forwarding policies for all active links will be
// updated.
func (s *Switch) updateLinkPolicies(c *updatePoliciesCmd) error {
log.Debugf("Updating link policies: %v", spew.Sdump(c))
// If no channels have been targeted, then we'll update the link policies
// for all active channels
if len(c.targetChans) == 0 {
for _, link := range s.linkIndex { for _, link := range s.linkIndex {
link.UpdateForwardingPolicy(c.newPolicy) linksToUpdate = append(linksToUpdate, link)
}
} else {
// Otherwise, we'll only attempt to update the forwarding
// policies for the set of targeted links.
for _, targetLink := range targetChans {
cid := lnwire.NewChanIDFromOutPoint(&targetLink)
// If we can't locate a link by its converted channel
// ID, then we'll return an error back to the caller.
link, ok := s.linkIndex[cid]
if !ok {
s.indexMtx.RUnlock()
return fmt.Errorf("unable to find "+
"ChannelPoint(%v) to update link "+
"policy", targetLink)
}
linksToUpdate = append(linksToUpdate, link)
} }
} }
// Otherwise, we'll only attempt to update the forwarding policies for the s.indexMtx.RUnlock()
// set of targeted links.
for _, targetLink := range c.targetChans {
cid := lnwire.NewChanIDFromOutPoint(&targetLink)
// If we can't locate a link by its converted channel ID, then we'll // With all the links we need to update collected, we can release the
// return an error back to the caller. // mutex then update each link directly.
link, ok := s.linkIndex[cid] for _, link := range linksToUpdate {
if !ok { link.UpdateForwardingPolicy(newPolicy)
return fmt.Errorf("unable to find ChannelPoint(%v) to "+
"update link policy", targetLink)
}
link.UpdateForwardingPolicy(c.newPolicy)
} }
return nil return nil
@ -715,14 +702,18 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error {
// appropriate channel link and send the payment over this link. // appropriate channel link and send the payment over this link.
case *lnwire.UpdateAddHTLC: case *lnwire.UpdateAddHTLC:
// Try to find links by node destination. // Try to find links by node destination.
s.indexMtx.RLock()
links, err := s.getLinks(pkt.destNode) links, err := s.getLinks(pkt.destNode)
if err != nil { if err != nil {
s.indexMtx.RUnlock()
log.Errorf("unable to find links by destination %v", err) log.Errorf("unable to find links by destination %v", err)
return &ForwardingError{ return &ForwardingError{
ErrorSource: s.cfg.SelfKey, ErrorSource: s.cfg.SelfKey,
FailureMessage: &lnwire.FailUnknownNextPeer{}, FailureMessage: &lnwire.FailUnknownNextPeer{},
} }
} }
s.indexMtx.RUnlock()
// Try to find destination channel link with appropriate // Try to find destination channel link with appropriate
// bandwidth. // bandwidth.
@ -880,8 +871,11 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
return s.handleLocalDispatch(packet) return s.handleLocalDispatch(packet)
} }
s.indexMtx.RLock()
targetLink, err := s.getLinkByShortID(packet.outgoingChanID) targetLink, err := s.getLinkByShortID(packet.outgoingChanID)
if err != nil { if err != nil {
s.indexMtx.RUnlock()
// If packet was forwarded from another channel link // If packet was forwarded from another channel link
// than we should notify this link that some error // than we should notify this link that some error
// occurred. // occurred.
@ -892,6 +886,7 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
return s.failAddPacket(packet, failure, addErr) return s.failAddPacket(packet, failure, addErr)
} }
interfaceLinks, _ := s.getLinks(targetLink.Peer().PubKey()) interfaceLinks, _ := s.getLinks(targetLink.Peer().PubKey())
s.indexMtx.RUnlock()
// We'll keep track of any HTLC failures during the link // We'll keep track of any HTLC failures during the link
// selection process. This way we can return the error for // selection process. This way we can return the error for
@ -1300,12 +1295,14 @@ func (s *Switch) htlcForwarder() {
// Remove all links once we've been signalled for shutdown. // Remove all links once we've been signalled for shutdown.
defer func() { defer func() {
s.indexMtx.Lock()
for _, link := range s.linkIndex { for _, link := range s.linkIndex {
if err := s.removeLink(link.ChanID()); err != nil { if err := s.removeLink(link.ChanID()); err != nil {
log.Errorf("unable to remove "+ log.Errorf("unable to remove "+
"channel link on stop: %v", err) "channel link on stop: %v", err)
} }
} }
s.indexMtx.Unlock()
// Before we exit fully, we'll attempt to flush out any // Before we exit fully, we'll attempt to flush out any
// forwarding events that may still be lingering since the last // forwarding events that may still be lingering since the last
@ -1336,12 +1333,17 @@ func (s *Switch) htlcForwarder() {
// cooperatively closed (if possible). // cooperatively closed (if possible).
case req := <-s.chanCloseRequests: case req := <-s.chanCloseRequests:
chanID := lnwire.NewChanIDFromOutPoint(req.ChanPoint) chanID := lnwire.NewChanIDFromOutPoint(req.ChanPoint)
s.indexMtx.RLock()
link, ok := s.linkIndex[chanID] link, ok := s.linkIndex[chanID]
if !ok { if !ok {
s.indexMtx.RUnlock()
req.Err <- errors.Errorf("no peer for channel with "+ req.Err <- errors.Errorf("no peer for channel with "+
"chan_id=%x", chanID[:]) "chan_id=%x", chanID[:])
continue continue
} }
s.indexMtx.RUnlock()
peerPub := link.Peer().PubKey() peerPub := link.Peer().PubKey()
log.Debugf("Requesting local channel close: peer=%v, "+ log.Debugf("Requesting local channel close: peer=%v, "+
@ -1421,6 +1423,7 @@ func (s *Switch) htlcForwarder() {
// Next, we'll run through all the registered links and // Next, we'll run through all the registered links and
// compute their up-to-date forwarding stats. // compute their up-to-date forwarding stats.
s.indexMtx.RLock()
for _, link := range s.linkIndex { for _, link := range s.linkIndex {
// TODO(roasbeef): when links first registered // TODO(roasbeef): when links first registered
// stats printed. // stats printed.
@ -1429,6 +1432,7 @@ func (s *Switch) htlcForwarder() {
newSatSent += sent.ToSatoshis() newSatSent += sent.ToSatoshis()
newSatRecv += recv.ToSatoshis() newSatRecv += recv.ToSatoshis()
} }
s.indexMtx.RUnlock()
var ( var (
diffNumUpdates uint64 diffNumUpdates uint64
@ -1478,28 +1482,6 @@ func (s *Switch) htlcForwarder() {
totalSatSent += diffSatSent totalSatSent += diffSatSent
totalSatRecv += diffSatRecv totalSatRecv += diffSatRecv
case req := <-s.linkControl:
switch cmd := req.(type) {
case *updatePoliciesCmd:
cmd.err <- s.updateLinkPolicies(cmd)
case *addLinkCmd:
cmd.err <- s.addLink(cmd.link)
case *removeLinkCmd:
cmd.err <- s.removeLink(cmd.chanID)
case *getLinkCmd:
link, err := s.getLink(cmd.chanID)
cmd.done <- link
cmd.err <- err
case *getLinksCmd:
links, err := s.getLinks(cmd.peer)
cmd.done <- links
cmd.err <- err
case *updateForwardingIndexCmd:
cmd.err <- s.updateShortChanID(
cmd.chanID, cmd.shortChanID,
)
}
case <-s.quit: case <-s.quit:
return return
} }
@ -1555,8 +1537,7 @@ func (s *Switch) reforwardResponses() error {
// loadChannelFwdPkgs loads all forwarding packages owned by the `source` short // loadChannelFwdPkgs loads all forwarding packages owned by the `source` short
// channel identifier. // channel identifier.
func (s *Switch) loadChannelFwdPkgs( func (s *Switch) loadChannelFwdPkgs(source lnwire.ShortChannelID) ([]*channeldb.FwdPkg, error) {
source lnwire.ShortChannelID) ([]*channeldb.FwdPkg, error) {
var fwdPkgs []*channeldb.FwdPkg var fwdPkgs []*channeldb.FwdPkg
if err := s.cfg.DB.Update(func(tx *bolt.Tx) error { if err := s.cfg.DB.Update(func(tx *bolt.Tx) error {
@ -1688,38 +1669,11 @@ func (s *Switch) Stop() error {
return nil return nil
} }
// addLinkCmd is a add link command wrapper, it is used to propagate handler
// parameters and return handler error.
type addLinkCmd struct {
link ChannelLink
err chan error
}
// AddLink is used to initiate the handling of the add link command. The // AddLink is used to initiate the handling of the add link command. The
// request will be propagated and handled in the main goroutine. // request will be propagated and handled in the main goroutine.
func (s *Switch) AddLink(link ChannelLink) error { func (s *Switch) AddLink(link ChannelLink) error {
command := &addLinkCmd{ s.indexMtx.Lock()
link: link, defer s.indexMtx.Unlock()
err: make(chan error, 1),
}
select {
case s.linkControl <- command:
select {
case err := <-command.err:
return err
case <-s.quit:
}
case <-s.quit:
}
return errors.New("unable to add link htlc switch was stopped")
}
// addLink is used to add the newly created channel link and start use it to
// handle the channel updates.
func (s *Switch) addLink(link ChannelLink) error {
// TODO(roasbeef): reject if link already tehre?
// First we'll add the link to the linkIndex which lets us quickly look // First we'll add the link to the linkIndex which lets us quickly look
// up a channel when we need to close or register it, and the // up a channel when we need to close or register it, and the
@ -1781,47 +1735,12 @@ func (s *Switch) getOrCreateMailBox(chanID lnwire.ShortChannelID) MailBox {
return mailbox return mailbox
} }
// getLinkCmd is a get link command wrapper, it is used to propagate handler
// parameters and return handler error.
type getLinkCmd struct {
chanID lnwire.ChannelID
err chan error
done chan ChannelLink
}
// GetLink is used to initiate the handling of the get link command. The // GetLink is used to initiate the handling of the get link command. The
// request will be propagated/handled to/in the main goroutine. // request will be propagated/handled to/in the main goroutine.
func (s *Switch) GetLink(chanID lnwire.ChannelID) (ChannelLink, error) { func (s *Switch) GetLink(chanID lnwire.ChannelID) (ChannelLink, error) {
command := &getLinkCmd{ s.indexMtx.RLock()
chanID: chanID, defer s.indexMtx.RUnlock()
err: make(chan error, 1),
done: make(chan ChannelLink, 1),
}
query:
select {
case s.linkControl <- command:
var link ChannelLink
select {
case link = <-command.done:
case <-s.quit:
break query
}
select {
case err := <-command.err:
return link, err
case <-s.quit:
}
case <-s.quit:
}
return nil, errors.New("unable to get link htlc switch was stopped")
}
// getLink attempts to return the link that has the specified channel ID.
func (s *Switch) getLink(chanID lnwire.ChannelID) (ChannelLink, error) {
link, ok := s.linkIndex[chanID] link, ok := s.linkIndex[chanID]
if !ok { if !ok {
return nil, ErrChannelLinkNotFound return nil, ErrChannelLinkNotFound
@ -1832,6 +1751,8 @@ func (s *Switch) getLink(chanID lnwire.ChannelID) (ChannelLink, error) {
// getLinkByShortID attempts to return the link which possesses the target // getLinkByShortID attempts to return the link which possesses the target
// short channel ID. // short channel ID.
//
// NOTE: This MUST be called with the indexMtx held.
func (s *Switch) getLinkByShortID(chanID lnwire.ShortChannelID) (ChannelLink, error) { func (s *Switch) getLinkByShortID(chanID lnwire.ShortChannelID) (ChannelLink, error) {
link, ok := s.forwardingIndex[chanID] link, ok := s.forwardingIndex[chanID]
if !ok { if !ok {
@ -1841,35 +1762,18 @@ func (s *Switch) getLinkByShortID(chanID lnwire.ShortChannelID) (ChannelLink, er
return link, nil return link, nil
} }
// removeLinkCmd is a get link command wrapper, it is used to propagate handler
// parameters and return handler error.
type removeLinkCmd struct {
chanID lnwire.ChannelID
err chan error
}
// RemoveLink is used to initiate the handling of the remove link command. The // RemoveLink is used to initiate the handling of the remove link command. The
// request will be propagated/handled to/in the main goroutine. // request will be propagated/handled to/in the main goroutine.
func (s *Switch) RemoveLink(chanID lnwire.ChannelID) error { func (s *Switch) RemoveLink(chanID lnwire.ChannelID) error {
command := &removeLinkCmd{ s.indexMtx.Lock()
chanID: chanID, defer s.indexMtx.Unlock()
err: make(chan error, 1),
}
select { return s.removeLink(chanID)
case s.linkControl <- command:
select {
case err := <-command.err:
return err
case <-s.quit:
}
case <-s.quit:
}
return errors.New("unable to remove link htlc switch was stopped")
} }
// removeLink is used to remove and stop the channel link. // removeLink is used to remove and stop the channel link.
//
// NOTE: This MUST be called with the indexMtx held.
func (s *Switch) removeLink(chanID lnwire.ChannelID) error { func (s *Switch) removeLink(chanID lnwire.ChannelID) error {
log.Infof("Removing channel link with ChannelID(%v)", chanID) log.Infof("Removing channel link with ChannelID(%v)", chanID)
@ -1891,50 +1795,21 @@ func (s *Switch) removeLink(chanID lnwire.ChannelID) error {
return nil return nil
} }
// updateForwardingIndexCmd is a command sent by outside sub-systems to update
// the forwarding index of the switch in the event that the short channel ID of
// a particular link changes.
type updateForwardingIndexCmd struct {
chanID lnwire.ChannelID
shortChanID lnwire.ShortChannelID
err chan error
}
// UpdateShortChanID updates the short chan ID for an existing channel. This is // UpdateShortChanID updates the short chan ID for an existing channel. This is
// required in the case of a re-org and re-confirmation or a channel, or in the // required in the case of a re-org and re-confirmation or a channel, or in the
// case that a link was added to the switch before its short chan ID was known. // case that a link was added to the switch before its short chan ID was known.
func (s *Switch) UpdateShortChanID(chanID lnwire.ChannelID, func (s *Switch) UpdateShortChanID(chanID lnwire.ChannelID,
shortChanID lnwire.ShortChannelID) error { shortChanID lnwire.ShortChannelID) error {
command := &updateForwardingIndexCmd{ s.indexMtx.Lock()
chanID: chanID,
shortChanID: shortChanID,
err: make(chan error, 1),
}
select {
case s.linkControl <- command:
select {
case err := <-command.err:
return err
case <-s.quit:
}
case <-s.quit:
}
return errors.New("unable to update short chan id htlc switch was stopped")
}
// updateShortChanID updates the short chan ID of an existing link.
func (s *Switch) updateShortChanID(chanID lnwire.ChannelID,
shortChanID lnwire.ShortChannelID) error {
// First, we'll extract the current link as is from the link link // First, we'll extract the current link as is from the link link
// index. If the link isn't even in the index, then we'll return an // index. If the link isn't even in the index, then we'll return an
// error. // error.
link, ok := s.linkIndex[chanID] link, ok := s.linkIndex[chanID]
if !ok { if !ok {
s.indexMtx.Unlock()
return fmt.Errorf("link %v not found", chanID) return fmt.Errorf("link %v not found", chanID)
} }
@ -1945,53 +1820,27 @@ func (s *Switch) updateShortChanID(chanID lnwire.ChannelID,
// forwarding index with the next short channel ID. // forwarding index with the next short channel ID.
s.forwardingIndex[shortChanID] = link s.forwardingIndex[shortChanID] = link
s.indexMtx.Unlock()
// Finally, we'll notify the link of its new short channel ID. // Finally, we'll notify the link of its new short channel ID.
link.UpdateShortChanID(shortChanID) link.UpdateShortChanID(shortChanID)
return nil return nil
} }
// getLinksCmd is a get links command wrapper, it is used to propagate handler
// parameters and return handler error.
type getLinksCmd struct {
peer [33]byte
err chan error
done chan []ChannelLink
}
// GetLinksByInterface fetches all the links connected to a particular node // GetLinksByInterface fetches all the links connected to a particular node
// identified by the serialized compressed form of its public key. // identified by the serialized compressed form of its public key.
func (s *Switch) GetLinksByInterface(hop [33]byte) ([]ChannelLink, error) { func (s *Switch) GetLinksByInterface(hop [33]byte) ([]ChannelLink, error) {
command := &getLinksCmd{ s.indexMtx.RLock()
peer: hop, defer s.indexMtx.RUnlock()
err: make(chan error, 1),
done: make(chan []ChannelLink, 1),
}
query: return s.getLinks(hop)
select {
case s.linkControl <- command:
var links []ChannelLink
select {
case links = <-command.done:
case <-s.quit:
break query
}
select {
case err := <-command.err:
return links, err
case <-s.quit:
}
case <-s.quit:
}
return nil, errors.New("unable to get links htlc switch was stopped")
} }
// getLinks is function which returns the channel links of the peer by hop // getLinks is function which returns the channel links of the peer by hop
// destination id. // destination id.
//
// NOTE: This MUST be called with the indexMtx held.
func (s *Switch) getLinks(destination [33]byte) ([]ChannelLink, error) { func (s *Switch) getLinks(destination [33]byte) ([]ChannelLink, error) {
links, ok := s.interfaceIndex[destination] links, ok := s.interfaceIndex[destination]
if !ok { if !ok {