htlcswitch: remove linkControl in favor of a mutex guarding all channel indexes
In this commit, we simplify the switch's code a bit. Rather than having a set of channels we use to mutate or query for the set of current links, we'll instead now just use a mutex to guard a set of link indexes. This serves to simplify the ode, and also make it such that we don't need to block forwarding in order to add/remove a link.
This commit is contained in:
parent
7037d55f65
commit
0a47b2c4ad
@ -168,10 +168,6 @@ type Switch struct {
|
||||
// forward the settle/fail htlc updates back to the add htlc initiator.
|
||||
circuits CircuitMap
|
||||
|
||||
// links is a map of channel id and channel link which manages
|
||||
// this channel.
|
||||
linkIndex map[lnwire.ChannelID]ChannelLink
|
||||
|
||||
// mailMtx is a read/write mutex that protects the mailboxes map.
|
||||
mailMtx sync.RWMutex
|
||||
|
||||
@ -179,6 +175,14 @@ type Switch struct {
|
||||
// switch to buffer messages for peers that have not come back online.
|
||||
mailboxes map[lnwire.ShortChannelID]MailBox
|
||||
|
||||
// indexMtx is a read/write mutex that protects the set of indexes
|
||||
// below.
|
||||
indexMtx sync.RWMutex
|
||||
|
||||
// links is a map of channel id and channel link which manages
|
||||
// this channel.
|
||||
linkIndex map[lnwire.ChannelID]ChannelLink
|
||||
|
||||
// forwardingIndex is an index which is consulted by the switch when it
|
||||
// needs to locate the next hop to forward an incoming/outgoing HTLC
|
||||
// update to/from.
|
||||
@ -244,7 +248,6 @@ func New(cfg Config) (*Switch, error) {
|
||||
htlcPlex: make(chan *plexPacket),
|
||||
chanCloseRequests: make(chan *ChanClose),
|
||||
resolutionMsgs: make(chan *resolutionMsg),
|
||||
linkControl: make(chan interface{}),
|
||||
quit: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
@ -386,63 +389,47 @@ func (s *Switch) SendHTLC(nextNode [33]byte, htlc *lnwire.UpdateAddHTLC,
|
||||
func (s *Switch) UpdateForwardingPolicies(newPolicy ForwardingPolicy,
|
||||
targetChans ...wire.OutPoint) error {
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
select {
|
||||
case s.linkControl <- &updatePoliciesCmd{
|
||||
newPolicy: newPolicy,
|
||||
targetChans: targetChans,
|
||||
err: errChan,
|
||||
}:
|
||||
case <-s.quit:
|
||||
return fmt.Errorf("switch is shutting down")
|
||||
}
|
||||
log.Debugf("Updating link policies: %v", newLogClosure(func() string {
|
||||
return spew.Sdump(newPolicy)
|
||||
}))
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
return err
|
||||
case <-s.quit:
|
||||
return fmt.Errorf("switch is shutting down")
|
||||
}
|
||||
}
|
||||
s.indexMtx.RLock()
|
||||
|
||||
// updatePoliciesCmd is a message sent to the switch to update the forwarding
|
||||
// policies of a set of target links.
|
||||
type updatePoliciesCmd struct {
|
||||
newPolicy ForwardingPolicy
|
||||
targetChans []wire.OutPoint
|
||||
var linksToUpdate []ChannelLink
|
||||
|
||||
err chan error
|
||||
}
|
||||
|
||||
// updateLinkPolicies attempts to update the forwarding policies for the set of
|
||||
// passed links identified by their channel points. If a nil set of channel
|
||||
// points is passed, then the forwarding policies for all active links will be
|
||||
// updated.
|
||||
func (s *Switch) updateLinkPolicies(c *updatePoliciesCmd) error {
|
||||
log.Debugf("Updating link policies: %v", spew.Sdump(c))
|
||||
|
||||
// If no channels have been targeted, then we'll update the link policies
|
||||
// for all active channels
|
||||
if len(c.targetChans) == 0 {
|
||||
// If no channels have been targeted, then we'll collect all inks to
|
||||
// update their policies.
|
||||
if len(targetChans) == 0 {
|
||||
for _, link := range s.linkIndex {
|
||||
link.UpdateForwardingPolicy(c.newPolicy)
|
||||
linksToUpdate = append(linksToUpdate, link)
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise, we'll only attempt to update the forwarding policies for the
|
||||
// set of targeted links.
|
||||
for _, targetLink := range c.targetChans {
|
||||
} else {
|
||||
// Otherwise, we'll only attempt to update the forwarding
|
||||
// policies for the set of targeted links.
|
||||
for _, targetLink := range targetChans {
|
||||
cid := lnwire.NewChanIDFromOutPoint(&targetLink)
|
||||
|
||||
// If we can't locate a link by its converted channel ID, then we'll
|
||||
// return an error back to the caller.
|
||||
// If we can't locate a link by its converted channel
|
||||
// ID, then we'll return an error back to the caller.
|
||||
link, ok := s.linkIndex[cid]
|
||||
if !ok {
|
||||
return fmt.Errorf("unable to find ChannelPoint(%v) to "+
|
||||
"update link policy", targetLink)
|
||||
s.indexMtx.RUnlock()
|
||||
|
||||
return fmt.Errorf("unable to find "+
|
||||
"ChannelPoint(%v) to update link "+
|
||||
"policy", targetLink)
|
||||
}
|
||||
|
||||
link.UpdateForwardingPolicy(c.newPolicy)
|
||||
linksToUpdate = append(linksToUpdate, link)
|
||||
}
|
||||
}
|
||||
|
||||
s.indexMtx.RUnlock()
|
||||
|
||||
// With all the links we need to update collected, we can release the
|
||||
// mutex then update each link directly.
|
||||
for _, link := range linksToUpdate {
|
||||
link.UpdateForwardingPolicy(newPolicy)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -715,14 +702,18 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error {
|
||||
// appropriate channel link and send the payment over this link.
|
||||
case *lnwire.UpdateAddHTLC:
|
||||
// Try to find links by node destination.
|
||||
s.indexMtx.RLock()
|
||||
links, err := s.getLinks(pkt.destNode)
|
||||
if err != nil {
|
||||
s.indexMtx.RUnlock()
|
||||
|
||||
log.Errorf("unable to find links by destination %v", err)
|
||||
return &ForwardingError{
|
||||
ErrorSource: s.cfg.SelfKey,
|
||||
FailureMessage: &lnwire.FailUnknownNextPeer{},
|
||||
}
|
||||
}
|
||||
s.indexMtx.RUnlock()
|
||||
|
||||
// Try to find destination channel link with appropriate
|
||||
// bandwidth.
|
||||
@ -880,8 +871,11 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
|
||||
return s.handleLocalDispatch(packet)
|
||||
}
|
||||
|
||||
s.indexMtx.RLock()
|
||||
targetLink, err := s.getLinkByShortID(packet.outgoingChanID)
|
||||
if err != nil {
|
||||
s.indexMtx.RUnlock()
|
||||
|
||||
// If packet was forwarded from another channel link
|
||||
// than we should notify this link that some error
|
||||
// occurred.
|
||||
@ -892,6 +886,7 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
|
||||
return s.failAddPacket(packet, failure, addErr)
|
||||
}
|
||||
interfaceLinks, _ := s.getLinks(targetLink.Peer().PubKey())
|
||||
s.indexMtx.RUnlock()
|
||||
|
||||
// We'll keep track of any HTLC failures during the link
|
||||
// selection process. This way we can return the error for
|
||||
@ -1300,12 +1295,14 @@ func (s *Switch) htlcForwarder() {
|
||||
|
||||
// Remove all links once we've been signalled for shutdown.
|
||||
defer func() {
|
||||
s.indexMtx.Lock()
|
||||
for _, link := range s.linkIndex {
|
||||
if err := s.removeLink(link.ChanID()); err != nil {
|
||||
log.Errorf("unable to remove "+
|
||||
"channel link on stop: %v", err)
|
||||
}
|
||||
}
|
||||
s.indexMtx.Unlock()
|
||||
|
||||
// Before we exit fully, we'll attempt to flush out any
|
||||
// forwarding events that may still be lingering since the last
|
||||
@ -1336,12 +1333,17 @@ func (s *Switch) htlcForwarder() {
|
||||
// cooperatively closed (if possible).
|
||||
case req := <-s.chanCloseRequests:
|
||||
chanID := lnwire.NewChanIDFromOutPoint(req.ChanPoint)
|
||||
|
||||
s.indexMtx.RLock()
|
||||
link, ok := s.linkIndex[chanID]
|
||||
if !ok {
|
||||
s.indexMtx.RUnlock()
|
||||
|
||||
req.Err <- errors.Errorf("no peer for channel with "+
|
||||
"chan_id=%x", chanID[:])
|
||||
continue
|
||||
}
|
||||
s.indexMtx.RUnlock()
|
||||
|
||||
peerPub := link.Peer().PubKey()
|
||||
log.Debugf("Requesting local channel close: peer=%v, "+
|
||||
@ -1421,6 +1423,7 @@ func (s *Switch) htlcForwarder() {
|
||||
|
||||
// Next, we'll run through all the registered links and
|
||||
// compute their up-to-date forwarding stats.
|
||||
s.indexMtx.RLock()
|
||||
for _, link := range s.linkIndex {
|
||||
// TODO(roasbeef): when links first registered
|
||||
// stats printed.
|
||||
@ -1429,6 +1432,7 @@ func (s *Switch) htlcForwarder() {
|
||||
newSatSent += sent.ToSatoshis()
|
||||
newSatRecv += recv.ToSatoshis()
|
||||
}
|
||||
s.indexMtx.RUnlock()
|
||||
|
||||
var (
|
||||
diffNumUpdates uint64
|
||||
@ -1478,28 +1482,6 @@ func (s *Switch) htlcForwarder() {
|
||||
totalSatSent += diffSatSent
|
||||
totalSatRecv += diffSatRecv
|
||||
|
||||
case req := <-s.linkControl:
|
||||
switch cmd := req.(type) {
|
||||
case *updatePoliciesCmd:
|
||||
cmd.err <- s.updateLinkPolicies(cmd)
|
||||
case *addLinkCmd:
|
||||
cmd.err <- s.addLink(cmd.link)
|
||||
case *removeLinkCmd:
|
||||
cmd.err <- s.removeLink(cmd.chanID)
|
||||
case *getLinkCmd:
|
||||
link, err := s.getLink(cmd.chanID)
|
||||
cmd.done <- link
|
||||
cmd.err <- err
|
||||
case *getLinksCmd:
|
||||
links, err := s.getLinks(cmd.peer)
|
||||
cmd.done <- links
|
||||
cmd.err <- err
|
||||
case *updateForwardingIndexCmd:
|
||||
cmd.err <- s.updateShortChanID(
|
||||
cmd.chanID, cmd.shortChanID,
|
||||
)
|
||||
}
|
||||
|
||||
case <-s.quit:
|
||||
return
|
||||
}
|
||||
@ -1555,8 +1537,7 @@ func (s *Switch) reforwardResponses() error {
|
||||
|
||||
// loadChannelFwdPkgs loads all forwarding packages owned by the `source` short
|
||||
// channel identifier.
|
||||
func (s *Switch) loadChannelFwdPkgs(
|
||||
source lnwire.ShortChannelID) ([]*channeldb.FwdPkg, error) {
|
||||
func (s *Switch) loadChannelFwdPkgs(source lnwire.ShortChannelID) ([]*channeldb.FwdPkg, error) {
|
||||
|
||||
var fwdPkgs []*channeldb.FwdPkg
|
||||
if err := s.cfg.DB.Update(func(tx *bolt.Tx) error {
|
||||
@ -1688,38 +1669,11 @@ func (s *Switch) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// addLinkCmd is a add link command wrapper, it is used to propagate handler
|
||||
// parameters and return handler error.
|
||||
type addLinkCmd struct {
|
||||
link ChannelLink
|
||||
err chan error
|
||||
}
|
||||
|
||||
// AddLink is used to initiate the handling of the add link command. The
|
||||
// request will be propagated and handled in the main goroutine.
|
||||
func (s *Switch) AddLink(link ChannelLink) error {
|
||||
command := &addLinkCmd{
|
||||
link: link,
|
||||
err: make(chan error, 1),
|
||||
}
|
||||
|
||||
select {
|
||||
case s.linkControl <- command:
|
||||
select {
|
||||
case err := <-command.err:
|
||||
return err
|
||||
case <-s.quit:
|
||||
}
|
||||
case <-s.quit:
|
||||
}
|
||||
|
||||
return errors.New("unable to add link htlc switch was stopped")
|
||||
}
|
||||
|
||||
// addLink is used to add the newly created channel link and start use it to
|
||||
// handle the channel updates.
|
||||
func (s *Switch) addLink(link ChannelLink) error {
|
||||
// TODO(roasbeef): reject if link already tehre?
|
||||
s.indexMtx.Lock()
|
||||
defer s.indexMtx.Unlock()
|
||||
|
||||
// First we'll add the link to the linkIndex which lets us quickly look
|
||||
// up a channel when we need to close or register it, and the
|
||||
@ -1781,47 +1735,12 @@ func (s *Switch) getOrCreateMailBox(chanID lnwire.ShortChannelID) MailBox {
|
||||
return mailbox
|
||||
}
|
||||
|
||||
// getLinkCmd is a get link command wrapper, it is used to propagate handler
|
||||
// parameters and return handler error.
|
||||
type getLinkCmd struct {
|
||||
chanID lnwire.ChannelID
|
||||
err chan error
|
||||
done chan ChannelLink
|
||||
}
|
||||
|
||||
// GetLink is used to initiate the handling of the get link command. The
|
||||
// request will be propagated/handled to/in the main goroutine.
|
||||
func (s *Switch) GetLink(chanID lnwire.ChannelID) (ChannelLink, error) {
|
||||
command := &getLinkCmd{
|
||||
chanID: chanID,
|
||||
err: make(chan error, 1),
|
||||
done: make(chan ChannelLink, 1),
|
||||
}
|
||||
s.indexMtx.RLock()
|
||||
defer s.indexMtx.RUnlock()
|
||||
|
||||
query:
|
||||
select {
|
||||
case s.linkControl <- command:
|
||||
|
||||
var link ChannelLink
|
||||
select {
|
||||
case link = <-command.done:
|
||||
case <-s.quit:
|
||||
break query
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-command.err:
|
||||
return link, err
|
||||
case <-s.quit:
|
||||
}
|
||||
case <-s.quit:
|
||||
}
|
||||
|
||||
return nil, errors.New("unable to get link htlc switch was stopped")
|
||||
}
|
||||
|
||||
// getLink attempts to return the link that has the specified channel ID.
|
||||
func (s *Switch) getLink(chanID lnwire.ChannelID) (ChannelLink, error) {
|
||||
link, ok := s.linkIndex[chanID]
|
||||
if !ok {
|
||||
return nil, ErrChannelLinkNotFound
|
||||
@ -1832,6 +1751,8 @@ func (s *Switch) getLink(chanID lnwire.ChannelID) (ChannelLink, error) {
|
||||
|
||||
// getLinkByShortID attempts to return the link which possesses the target
|
||||
// short channel ID.
|
||||
//
|
||||
// NOTE: This MUST be called with the indexMtx held.
|
||||
func (s *Switch) getLinkByShortID(chanID lnwire.ShortChannelID) (ChannelLink, error) {
|
||||
link, ok := s.forwardingIndex[chanID]
|
||||
if !ok {
|
||||
@ -1841,35 +1762,18 @@ func (s *Switch) getLinkByShortID(chanID lnwire.ShortChannelID) (ChannelLink, er
|
||||
return link, nil
|
||||
}
|
||||
|
||||
// removeLinkCmd is a get link command wrapper, it is used to propagate handler
|
||||
// parameters and return handler error.
|
||||
type removeLinkCmd struct {
|
||||
chanID lnwire.ChannelID
|
||||
err chan error
|
||||
}
|
||||
|
||||
// RemoveLink is used to initiate the handling of the remove link command. The
|
||||
// request will be propagated/handled to/in the main goroutine.
|
||||
func (s *Switch) RemoveLink(chanID lnwire.ChannelID) error {
|
||||
command := &removeLinkCmd{
|
||||
chanID: chanID,
|
||||
err: make(chan error, 1),
|
||||
}
|
||||
s.indexMtx.Lock()
|
||||
defer s.indexMtx.Unlock()
|
||||
|
||||
select {
|
||||
case s.linkControl <- command:
|
||||
select {
|
||||
case err := <-command.err:
|
||||
return err
|
||||
case <-s.quit:
|
||||
}
|
||||
case <-s.quit:
|
||||
}
|
||||
|
||||
return errors.New("unable to remove link htlc switch was stopped")
|
||||
return s.removeLink(chanID)
|
||||
}
|
||||
|
||||
// removeLink is used to remove and stop the channel link.
|
||||
//
|
||||
// NOTE: This MUST be called with the indexMtx held.
|
||||
func (s *Switch) removeLink(chanID lnwire.ChannelID) error {
|
||||
log.Infof("Removing channel link with ChannelID(%v)", chanID)
|
||||
|
||||
@ -1891,50 +1795,21 @@ func (s *Switch) removeLink(chanID lnwire.ChannelID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateForwardingIndexCmd is a command sent by outside sub-systems to update
|
||||
// the forwarding index of the switch in the event that the short channel ID of
|
||||
// a particular link changes.
|
||||
type updateForwardingIndexCmd struct {
|
||||
chanID lnwire.ChannelID
|
||||
shortChanID lnwire.ShortChannelID
|
||||
|
||||
err chan error
|
||||
}
|
||||
|
||||
// UpdateShortChanID updates the short chan ID for an existing channel. This is
|
||||
// required in the case of a re-org and re-confirmation or a channel, or in the
|
||||
// case that a link was added to the switch before its short chan ID was known.
|
||||
func (s *Switch) UpdateShortChanID(chanID lnwire.ChannelID,
|
||||
shortChanID lnwire.ShortChannelID) error {
|
||||
|
||||
command := &updateForwardingIndexCmd{
|
||||
chanID: chanID,
|
||||
shortChanID: shortChanID,
|
||||
err: make(chan error, 1),
|
||||
}
|
||||
|
||||
select {
|
||||
case s.linkControl <- command:
|
||||
select {
|
||||
case err := <-command.err:
|
||||
return err
|
||||
case <-s.quit:
|
||||
}
|
||||
case <-s.quit:
|
||||
}
|
||||
|
||||
return errors.New("unable to update short chan id htlc switch was stopped")
|
||||
}
|
||||
|
||||
// updateShortChanID updates the short chan ID of an existing link.
|
||||
func (s *Switch) updateShortChanID(chanID lnwire.ChannelID,
|
||||
shortChanID lnwire.ShortChannelID) error {
|
||||
s.indexMtx.Lock()
|
||||
|
||||
// First, we'll extract the current link as is from the link link
|
||||
// index. If the link isn't even in the index, then we'll return an
|
||||
// error.
|
||||
link, ok := s.linkIndex[chanID]
|
||||
if !ok {
|
||||
s.indexMtx.Unlock()
|
||||
|
||||
return fmt.Errorf("link %v not found", chanID)
|
||||
}
|
||||
|
||||
@ -1945,53 +1820,27 @@ func (s *Switch) updateShortChanID(chanID lnwire.ChannelID,
|
||||
// forwarding index with the next short channel ID.
|
||||
s.forwardingIndex[shortChanID] = link
|
||||
|
||||
s.indexMtx.Unlock()
|
||||
|
||||
// Finally, we'll notify the link of its new short channel ID.
|
||||
link.UpdateShortChanID(shortChanID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getLinksCmd is a get links command wrapper, it is used to propagate handler
|
||||
// parameters and return handler error.
|
||||
type getLinksCmd struct {
|
||||
peer [33]byte
|
||||
err chan error
|
||||
done chan []ChannelLink
|
||||
}
|
||||
|
||||
// GetLinksByInterface fetches all the links connected to a particular node
|
||||
// identified by the serialized compressed form of its public key.
|
||||
func (s *Switch) GetLinksByInterface(hop [33]byte) ([]ChannelLink, error) {
|
||||
command := &getLinksCmd{
|
||||
peer: hop,
|
||||
err: make(chan error, 1),
|
||||
done: make(chan []ChannelLink, 1),
|
||||
}
|
||||
s.indexMtx.RLock()
|
||||
defer s.indexMtx.RUnlock()
|
||||
|
||||
query:
|
||||
select {
|
||||
case s.linkControl <- command:
|
||||
|
||||
var links []ChannelLink
|
||||
select {
|
||||
case links = <-command.done:
|
||||
case <-s.quit:
|
||||
break query
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-command.err:
|
||||
return links, err
|
||||
case <-s.quit:
|
||||
}
|
||||
case <-s.quit:
|
||||
}
|
||||
|
||||
return nil, errors.New("unable to get links htlc switch was stopped")
|
||||
return s.getLinks(hop)
|
||||
}
|
||||
|
||||
// getLinks is function which returns the channel links of the peer by hop
|
||||
// destination id.
|
||||
//
|
||||
// NOTE: This MUST be called with the indexMtx held.
|
||||
func (s *Switch) getLinks(destination [33]byte) ([]ChannelLink, error) {
|
||||
links, ok := s.interfaceIndex[destination]
|
||||
if !ok {
|
||||
|
Loading…
Reference in New Issue
Block a user