htlcswitch: remove linkControl in favor of a mutex guarding all channel indexes

In this commit, we simplify the switch's code a bit. Rather than having
a set of channels we use to mutate or query for the set of current
links, we'll instead now just use a mutex to guard a set of link
indexes. This serves to simplify the ode, and also make it such that we
don't need to block forwarding in order to add/remove a link.
This commit is contained in:
Olaoluwa Osuntokun 2018-04-03 20:06:57 -07:00
parent 7037d55f65
commit 0a47b2c4ad
No known key found for this signature in database
GPG Key ID: 964EA263DD637C21

@ -168,10 +168,6 @@ type Switch struct {
// forward the settle/fail htlc updates back to the add htlc initiator.
circuits CircuitMap
// links is a map of channel id and channel link which manages
// this channel.
linkIndex map[lnwire.ChannelID]ChannelLink
// mailMtx is a read/write mutex that protects the mailboxes map.
mailMtx sync.RWMutex
@ -179,6 +175,14 @@ type Switch struct {
// switch to buffer messages for peers that have not come back online.
mailboxes map[lnwire.ShortChannelID]MailBox
// indexMtx is a read/write mutex that protects the set of indexes
// below.
indexMtx sync.RWMutex
// links is a map of channel id and channel link which manages
// this channel.
linkIndex map[lnwire.ChannelID]ChannelLink
// forwardingIndex is an index which is consulted by the switch when it
// needs to locate the next hop to forward an incoming/outgoing HTLC
// update to/from.
@ -244,7 +248,6 @@ func New(cfg Config) (*Switch, error) {
htlcPlex: make(chan *plexPacket),
chanCloseRequests: make(chan *ChanClose),
resolutionMsgs: make(chan *resolutionMsg),
linkControl: make(chan interface{}),
quit: make(chan struct{}),
}, nil
}
@ -386,63 +389,47 @@ func (s *Switch) SendHTLC(nextNode [33]byte, htlc *lnwire.UpdateAddHTLC,
func (s *Switch) UpdateForwardingPolicies(newPolicy ForwardingPolicy,
targetChans ...wire.OutPoint) error {
errChan := make(chan error, 1)
select {
case s.linkControl <- &updatePoliciesCmd{
newPolicy: newPolicy,
targetChans: targetChans,
err: errChan,
}:
case <-s.quit:
return fmt.Errorf("switch is shutting down")
}
log.Debugf("Updating link policies: %v", newLogClosure(func() string {
return spew.Sdump(newPolicy)
}))
select {
case err := <-errChan:
return err
case <-s.quit:
return fmt.Errorf("switch is shutting down")
}
}
s.indexMtx.RLock()
// updatePoliciesCmd is a message sent to the switch to update the forwarding
// policies of a set of target links.
type updatePoliciesCmd struct {
newPolicy ForwardingPolicy
targetChans []wire.OutPoint
var linksToUpdate []ChannelLink
err chan error
}
// updateLinkPolicies attempts to update the forwarding policies for the set of
// passed links identified by their channel points. If a nil set of channel
// points is passed, then the forwarding policies for all active links will be
// updated.
func (s *Switch) updateLinkPolicies(c *updatePoliciesCmd) error {
log.Debugf("Updating link policies: %v", spew.Sdump(c))
// If no channels have been targeted, then we'll update the link policies
// for all active channels
if len(c.targetChans) == 0 {
// If no channels have been targeted, then we'll collect all inks to
// update their policies.
if len(targetChans) == 0 {
for _, link := range s.linkIndex {
link.UpdateForwardingPolicy(c.newPolicy)
linksToUpdate = append(linksToUpdate, link)
}
} else {
// Otherwise, we'll only attempt to update the forwarding
// policies for the set of targeted links.
for _, targetLink := range targetChans {
cid := lnwire.NewChanIDFromOutPoint(&targetLink)
// If we can't locate a link by its converted channel
// ID, then we'll return an error back to the caller.
link, ok := s.linkIndex[cid]
if !ok {
s.indexMtx.RUnlock()
return fmt.Errorf("unable to find "+
"ChannelPoint(%v) to update link "+
"policy", targetLink)
}
linksToUpdate = append(linksToUpdate, link)
}
}
// Otherwise, we'll only attempt to update the forwarding policies for the
// set of targeted links.
for _, targetLink := range c.targetChans {
cid := lnwire.NewChanIDFromOutPoint(&targetLink)
s.indexMtx.RUnlock()
// If we can't locate a link by its converted channel ID, then we'll
// return an error back to the caller.
link, ok := s.linkIndex[cid]
if !ok {
return fmt.Errorf("unable to find ChannelPoint(%v) to "+
"update link policy", targetLink)
}
link.UpdateForwardingPolicy(c.newPolicy)
// With all the links we need to update collected, we can release the
// mutex then update each link directly.
for _, link := range linksToUpdate {
link.UpdateForwardingPolicy(newPolicy)
}
return nil
@ -715,14 +702,18 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error {
// appropriate channel link and send the payment over this link.
case *lnwire.UpdateAddHTLC:
// Try to find links by node destination.
s.indexMtx.RLock()
links, err := s.getLinks(pkt.destNode)
if err != nil {
s.indexMtx.RUnlock()
log.Errorf("unable to find links by destination %v", err)
return &ForwardingError{
ErrorSource: s.cfg.SelfKey,
FailureMessage: &lnwire.FailUnknownNextPeer{},
}
}
s.indexMtx.RUnlock()
// Try to find destination channel link with appropriate
// bandwidth.
@ -880,8 +871,11 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
return s.handleLocalDispatch(packet)
}
s.indexMtx.RLock()
targetLink, err := s.getLinkByShortID(packet.outgoingChanID)
if err != nil {
s.indexMtx.RUnlock()
// If packet was forwarded from another channel link
// than we should notify this link that some error
// occurred.
@ -892,6 +886,7 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
return s.failAddPacket(packet, failure, addErr)
}
interfaceLinks, _ := s.getLinks(targetLink.Peer().PubKey())
s.indexMtx.RUnlock()
// We'll keep track of any HTLC failures during the link
// selection process. This way we can return the error for
@ -1300,12 +1295,14 @@ func (s *Switch) htlcForwarder() {
// Remove all links once we've been signalled for shutdown.
defer func() {
s.indexMtx.Lock()
for _, link := range s.linkIndex {
if err := s.removeLink(link.ChanID()); err != nil {
log.Errorf("unable to remove "+
"channel link on stop: %v", err)
}
}
s.indexMtx.Unlock()
// Before we exit fully, we'll attempt to flush out any
// forwarding events that may still be lingering since the last
@ -1336,12 +1333,17 @@ func (s *Switch) htlcForwarder() {
// cooperatively closed (if possible).
case req := <-s.chanCloseRequests:
chanID := lnwire.NewChanIDFromOutPoint(req.ChanPoint)
s.indexMtx.RLock()
link, ok := s.linkIndex[chanID]
if !ok {
s.indexMtx.RUnlock()
req.Err <- errors.Errorf("no peer for channel with "+
"chan_id=%x", chanID[:])
continue
}
s.indexMtx.RUnlock()
peerPub := link.Peer().PubKey()
log.Debugf("Requesting local channel close: peer=%v, "+
@ -1421,6 +1423,7 @@ func (s *Switch) htlcForwarder() {
// Next, we'll run through all the registered links and
// compute their up-to-date forwarding stats.
s.indexMtx.RLock()
for _, link := range s.linkIndex {
// TODO(roasbeef): when links first registered
// stats printed.
@ -1429,6 +1432,7 @@ func (s *Switch) htlcForwarder() {
newSatSent += sent.ToSatoshis()
newSatRecv += recv.ToSatoshis()
}
s.indexMtx.RUnlock()
var (
diffNumUpdates uint64
@ -1478,28 +1482,6 @@ func (s *Switch) htlcForwarder() {
totalSatSent += diffSatSent
totalSatRecv += diffSatRecv
case req := <-s.linkControl:
switch cmd := req.(type) {
case *updatePoliciesCmd:
cmd.err <- s.updateLinkPolicies(cmd)
case *addLinkCmd:
cmd.err <- s.addLink(cmd.link)
case *removeLinkCmd:
cmd.err <- s.removeLink(cmd.chanID)
case *getLinkCmd:
link, err := s.getLink(cmd.chanID)
cmd.done <- link
cmd.err <- err
case *getLinksCmd:
links, err := s.getLinks(cmd.peer)
cmd.done <- links
cmd.err <- err
case *updateForwardingIndexCmd:
cmd.err <- s.updateShortChanID(
cmd.chanID, cmd.shortChanID,
)
}
case <-s.quit:
return
}
@ -1555,8 +1537,7 @@ func (s *Switch) reforwardResponses() error {
// loadChannelFwdPkgs loads all forwarding packages owned by the `source` short
// channel identifier.
func (s *Switch) loadChannelFwdPkgs(
source lnwire.ShortChannelID) ([]*channeldb.FwdPkg, error) {
func (s *Switch) loadChannelFwdPkgs(source lnwire.ShortChannelID) ([]*channeldb.FwdPkg, error) {
var fwdPkgs []*channeldb.FwdPkg
if err := s.cfg.DB.Update(func(tx *bolt.Tx) error {
@ -1688,38 +1669,11 @@ func (s *Switch) Stop() error {
return nil
}
// addLinkCmd is a add link command wrapper, it is used to propagate handler
// parameters and return handler error.
type addLinkCmd struct {
link ChannelLink
err chan error
}
// AddLink is used to initiate the handling of the add link command. The
// request will be propagated and handled in the main goroutine.
func (s *Switch) AddLink(link ChannelLink) error {
command := &addLinkCmd{
link: link,
err: make(chan error, 1),
}
select {
case s.linkControl <- command:
select {
case err := <-command.err:
return err
case <-s.quit:
}
case <-s.quit:
}
return errors.New("unable to add link htlc switch was stopped")
}
// addLink is used to add the newly created channel link and start use it to
// handle the channel updates.
func (s *Switch) addLink(link ChannelLink) error {
// TODO(roasbeef): reject if link already tehre?
s.indexMtx.Lock()
defer s.indexMtx.Unlock()
// First we'll add the link to the linkIndex which lets us quickly look
// up a channel when we need to close or register it, and the
@ -1781,47 +1735,12 @@ func (s *Switch) getOrCreateMailBox(chanID lnwire.ShortChannelID) MailBox {
return mailbox
}
// getLinkCmd is a get link command wrapper, it is used to propagate handler
// parameters and return handler error.
type getLinkCmd struct {
chanID lnwire.ChannelID
err chan error
done chan ChannelLink
}
// GetLink is used to initiate the handling of the get link command. The
// request will be propagated/handled to/in the main goroutine.
func (s *Switch) GetLink(chanID lnwire.ChannelID) (ChannelLink, error) {
command := &getLinkCmd{
chanID: chanID,
err: make(chan error, 1),
done: make(chan ChannelLink, 1),
}
s.indexMtx.RLock()
defer s.indexMtx.RUnlock()
query:
select {
case s.linkControl <- command:
var link ChannelLink
select {
case link = <-command.done:
case <-s.quit:
break query
}
select {
case err := <-command.err:
return link, err
case <-s.quit:
}
case <-s.quit:
}
return nil, errors.New("unable to get link htlc switch was stopped")
}
// getLink attempts to return the link that has the specified channel ID.
func (s *Switch) getLink(chanID lnwire.ChannelID) (ChannelLink, error) {
link, ok := s.linkIndex[chanID]
if !ok {
return nil, ErrChannelLinkNotFound
@ -1832,6 +1751,8 @@ func (s *Switch) getLink(chanID lnwire.ChannelID) (ChannelLink, error) {
// getLinkByShortID attempts to return the link which possesses the target
// short channel ID.
//
// NOTE: This MUST be called with the indexMtx held.
func (s *Switch) getLinkByShortID(chanID lnwire.ShortChannelID) (ChannelLink, error) {
link, ok := s.forwardingIndex[chanID]
if !ok {
@ -1841,35 +1762,18 @@ func (s *Switch) getLinkByShortID(chanID lnwire.ShortChannelID) (ChannelLink, er
return link, nil
}
// removeLinkCmd is a get link command wrapper, it is used to propagate handler
// parameters and return handler error.
type removeLinkCmd struct {
chanID lnwire.ChannelID
err chan error
}
// RemoveLink is used to initiate the handling of the remove link command. The
// request will be propagated/handled to/in the main goroutine.
func (s *Switch) RemoveLink(chanID lnwire.ChannelID) error {
command := &removeLinkCmd{
chanID: chanID,
err: make(chan error, 1),
}
s.indexMtx.Lock()
defer s.indexMtx.Unlock()
select {
case s.linkControl <- command:
select {
case err := <-command.err:
return err
case <-s.quit:
}
case <-s.quit:
}
return errors.New("unable to remove link htlc switch was stopped")
return s.removeLink(chanID)
}
// removeLink is used to remove and stop the channel link.
//
// NOTE: This MUST be called with the indexMtx held.
func (s *Switch) removeLink(chanID lnwire.ChannelID) error {
log.Infof("Removing channel link with ChannelID(%v)", chanID)
@ -1891,50 +1795,21 @@ func (s *Switch) removeLink(chanID lnwire.ChannelID) error {
return nil
}
// updateForwardingIndexCmd is a command sent by outside sub-systems to update
// the forwarding index of the switch in the event that the short channel ID of
// a particular link changes.
type updateForwardingIndexCmd struct {
chanID lnwire.ChannelID
shortChanID lnwire.ShortChannelID
err chan error
}
// UpdateShortChanID updates the short chan ID for an existing channel. This is
// required in the case of a re-org and re-confirmation or a channel, or in the
// case that a link was added to the switch before its short chan ID was known.
func (s *Switch) UpdateShortChanID(chanID lnwire.ChannelID,
shortChanID lnwire.ShortChannelID) error {
command := &updateForwardingIndexCmd{
chanID: chanID,
shortChanID: shortChanID,
err: make(chan error, 1),
}
select {
case s.linkControl <- command:
select {
case err := <-command.err:
return err
case <-s.quit:
}
case <-s.quit:
}
return errors.New("unable to update short chan id htlc switch was stopped")
}
// updateShortChanID updates the short chan ID of an existing link.
func (s *Switch) updateShortChanID(chanID lnwire.ChannelID,
shortChanID lnwire.ShortChannelID) error {
s.indexMtx.Lock()
// First, we'll extract the current link as is from the link link
// index. If the link isn't even in the index, then we'll return an
// error.
link, ok := s.linkIndex[chanID]
if !ok {
s.indexMtx.Unlock()
return fmt.Errorf("link %v not found", chanID)
}
@ -1945,53 +1820,27 @@ func (s *Switch) updateShortChanID(chanID lnwire.ChannelID,
// forwarding index with the next short channel ID.
s.forwardingIndex[shortChanID] = link
s.indexMtx.Unlock()
// Finally, we'll notify the link of its new short channel ID.
link.UpdateShortChanID(shortChanID)
return nil
}
// getLinksCmd is a get links command wrapper, it is used to propagate handler
// parameters and return handler error.
type getLinksCmd struct {
peer [33]byte
err chan error
done chan []ChannelLink
}
// GetLinksByInterface fetches all the links connected to a particular node
// identified by the serialized compressed form of its public key.
func (s *Switch) GetLinksByInterface(hop [33]byte) ([]ChannelLink, error) {
command := &getLinksCmd{
peer: hop,
err: make(chan error, 1),
done: make(chan []ChannelLink, 1),
}
s.indexMtx.RLock()
defer s.indexMtx.RUnlock()
query:
select {
case s.linkControl <- command:
var links []ChannelLink
select {
case links = <-command.done:
case <-s.quit:
break query
}
select {
case err := <-command.err:
return links, err
case <-s.quit:
}
case <-s.quit:
}
return nil, errors.New("unable to get links htlc switch was stopped")
return s.getLinks(hop)
}
// getLinks is function which returns the channel links of the peer by hop
// destination id.
//
// NOTE: This MUST be called with the indexMtx held.
func (s *Switch) getLinks(destination [33]byte) ([]ChannelLink, error) {
links, ok := s.interfaceIndex[destination]
if !ok {