diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 71e38594..7db498e6 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -93,7 +93,7 @@ type chanPolicyUpdateRequest struct { targetChans []wire.OutPoint newSchema routing.ChannelPolicy - errResp chan error + chanPolicies chan updatedChanPolicies } // Config defines the configuration for the service. ALL elements within the @@ -338,27 +338,40 @@ func New(cfg Config, selfKey *btcec.PublicKey) *AuthenticatedGossiper { return gossiper } +// updatedChanPolicies is a set of channel policies that have been successfully +// updated and written to disk, or an error if the policy update failed. This +// struct's map field is intended to be used for updating channel policies on +// the link layer. +type updatedChanPolicies struct { + chanPolicies map[wire.OutPoint]*channeldb.ChannelEdgePolicy + err error +} + // PropagateChanPolicyUpdate signals the AuthenticatedGossiper to update the // channel forwarding policies for the specified channels. If no channels are // specified, then the update will be applied to all outgoing channels from the // source node. Policy updates are done in two stages: first, the // AuthenticatedGossiper ensures the update has been committed by dependent -// sub-systems, then it signs and broadcasts new updates to the network. +// sub-systems, then it signs and broadcasts new updates to the network. A +// mapping between outpoints and updated channel policies is returned, which is +// used to update the forwarding policies of the underlying links. func (d *AuthenticatedGossiper) PropagateChanPolicyUpdate( - newSchema routing.ChannelPolicy, chanPoints ...wire.OutPoint) error { + newSchema routing.ChannelPolicy, chanPoints ...wire.OutPoint) ( + map[wire.OutPoint]*channeldb.ChannelEdgePolicy, error) { - errChan := make(chan error, 1) + chanPolicyChan := make(chan updatedChanPolicies, 1) policyUpdate := &chanPolicyUpdateRequest{ - targetChans: chanPoints, - newSchema: newSchema, - errResp: errChan, + targetChans: chanPoints, + newSchema: newSchema, + chanPolicies: chanPolicyChan, } select { case d.chanPolicyUpdates <- policyUpdate: - return <-errChan + updatedPolicies := <-chanPolicyChan + return updatedPolicies.chanPolicies, updatedPolicies.err case <-d.quit: - return fmt.Errorf("AuthenticatedGossiper shutting down") + return nil, fmt.Errorf("AuthenticatedGossiper shutting down") } } @@ -895,13 +908,17 @@ func (d *AuthenticatedGossiper) networkHandler() { // First, we'll now create new fully signed updates for // the affected channels and also update the underlying // graph with the new state. - newChanUpdates, err := d.processChanPolicyUpdate( + chanPolicies, newChanUpdates, err := d.processChanPolicyUpdate( policyUpdate, ) + update := updatedChanPolicies{ + chanPolicies, + err, + } + policyUpdate.chanPolicies <- update if err != nil { log.Errorf("Unable to craft policy updates: %v", err) - policyUpdate.errResp <- err continue } @@ -910,8 +927,6 @@ func (d *AuthenticatedGossiper) networkHandler() { // start of the next epoch. announcements.AddMsgs(newChanUpdates...) - policyUpdate.errResp <- nil - case announcement := <-d.networkMsgs: // We should only broadcast this message forward if it // originated from us or it wasn't received as part of @@ -1244,7 +1259,9 @@ func (d *AuthenticatedGossiper) retransmitStaleChannels() error { // // TODO(roasbeef): generalize into generic for any channel update func (d *AuthenticatedGossiper) processChanPolicyUpdate( - policyUpdate *chanPolicyUpdateRequest) ([]networkMsg, error) { + policyUpdate *chanPolicyUpdateRequest) ( + map[wire.OutPoint]*channeldb.ChannelEdgePolicy, []networkMsg, error) { + // First, we'll construct a set of all the channels that need to be // updated. chansToUpdate := make(map[wire.OutPoint]struct{}) @@ -1252,6 +1269,10 @@ func (d *AuthenticatedGossiper) processChanPolicyUpdate( chansToUpdate[chanPoint] = struct{}{} } + // Next, we'll create a mapping from outpoint to edge policy that will + // be used by each edge's underlying link to update its policy. + chanPolicies := make(map[wire.OutPoint]*channeldb.ChannelEdgePolicy) + haveChanFilter := len(chansToUpdate) != 0 if haveChanFilter { log.Infof("Updating routing policies for chan_points=%v", @@ -1295,7 +1316,7 @@ func (d *AuthenticatedGossiper) processChanPolicyUpdate( return nil }) if err != nil { - return nil, err + return nil, nil, err } // With the set of edges we need to update retrieved, we'll now re-sign @@ -1309,9 +1330,13 @@ func (d *AuthenticatedGossiper) processChanPolicyUpdate( edgeInfo.info, edgeInfo.edge, ) if err != nil { - return nil, err + return nil, nil, err } + // Since the update succeeded, add the edge to our policy + // mapping. + chanPolicies[edgeInfo.info.ChannelPoint] = edgeInfo.edge + // We'll avoid broadcasting any updates for private channels to // avoid directly giving away their existence. Instead, we'll // send the update directly to the remote party. @@ -1340,7 +1365,7 @@ func (d *AuthenticatedGossiper) processChanPolicyUpdate( }) } - return chanUpdates, nil + return chanPolicies, chanUpdates, nil } // processRejectedEdge examines a rejected edge to see if we can extract any diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index be27e9e7..4c98366e 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -3391,8 +3391,14 @@ func TestPropagateChanPolicyUpdate(t *testing.T) { // the channel ann proof from the first channel in order to have it be // marked as private channel. firstChanID := channelsToAnnounce[0].localChanAnn.ShortChannelID - for _, batch := range channelsToAnnounce { - sendLocalMsg(t, ctx, batch.localChanAnn, localKey) + for i, batch := range channelsToAnnounce { + // channelPoint ensures that each channel policy in the map + // returned by PropagateChanPolicyUpdate has a unique key. Since + // the map is keyed by wire.OutPoint, we want to ensure that + // each channel has a unique channel point. + channelPoint := ChannelPoint(wire.OutPoint{Index: uint32(i)}) + + sendLocalMsg(t, ctx, batch.localChanAnn, localKey, channelPoint) sendLocalMsg(t, ctx, batch.chanUpdAnn1, localKey) sendLocalMsg(t, ctx, batch.nodeAnn1, localKey) @@ -3430,11 +3436,19 @@ out: newPolicy := routing.ChannelPolicy{ TimeLockDelta: newTimeLockDelta, } - err = ctx.gossiper.PropagateChanPolicyUpdate(newPolicy) + newChanPolicies, err := ctx.gossiper.PropagateChanPolicyUpdate(newPolicy) if err != nil { t.Fatalf("unable to chan policies: %v", err) } + // Ensure that the updated channel policies are as expected. + for _, dbPolicy := range newChanPolicies { + if dbPolicy.TimeLockDelta != uint16(newPolicy.TimeLockDelta) { + t.Fatalf("wrong delta: expected %v, got %v", + newPolicy.TimeLockDelta, dbPolicy.TimeLockDelta) + } + } + // Two channel updates should now be broadcast, with neither of them // being the channel our first private channel. for i := 0; i < numChannels-1; i++ { diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 6f1de368..3c396814 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -2156,31 +2156,16 @@ func (l *channelLink) AttachMailBox(mailbox MailBox) { // UpdateForwardingPolicy updates the forwarding policy for the target // ChannelLink. Once updated, the link will use the new forwarding policy to -// govern if it an incoming HTLC should be forwarded or not. Note that this -// processing of the new policy will ensure that uninitialized fields in the -// passed policy won't override already initialized fields in the current -// policy. +// govern if it an incoming HTLC should be forwarded or not. We assume that +// fields that are zero are intentionally set to zero, so we'll use newPolicy to +// update all of the link's FwrdingPolicy's values. // // NOTE: Part of the ChannelLink interface. func (l *channelLink) UpdateForwardingPolicy(newPolicy ForwardingPolicy) { l.Lock() defer l.Unlock() - // In order to avoid overriding a valid policy with a "null" field in - // the new policy, we'll only update to the set sub policy if the new - // value isn't uninitialized. - if newPolicy.BaseFee != 0 { - l.cfg.FwrdingPolicy.BaseFee = newPolicy.BaseFee - } - if newPolicy.FeeRate != 0 { - l.cfg.FwrdingPolicy.FeeRate = newPolicy.FeeRate - } - if newPolicy.TimeLockDelta != 0 { - l.cfg.FwrdingPolicy.TimeLockDelta = newPolicy.TimeLockDelta - } - if newPolicy.MinHTLC != 0 { - l.cfg.FwrdingPolicy.MinHTLC = newPolicy.MinHTLC - } + l.cfg.FwrdingPolicy = newPolicy } // HtlcSatifiesPolicy should return a nil error if the passed HTLC details diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index a3c7ed07..2fcfa2c6 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -5530,9 +5530,9 @@ func TestForwardingAsymmetricTimeLockPolicies(t *testing.T) { // Now that each of the links are up, we'll modify the link from Alice // -> Bob to have a greater time lock delta than that of the link of // Bob -> Carol. - n.firstBobChannelLink.UpdateForwardingPolicy(ForwardingPolicy{ - TimeLockDelta: 7, - }) + newPolicy := n.firstBobChannelLink.cfg.FwrdingPolicy + newPolicy.TimeLockDelta = 7 + n.firstBobChannelLink.UpdateForwardingPolicy(newPolicy) // Now that the Alice -> Bob link has been updated, we'll craft and // send a payment from Alice -> Carol. This should succeed as normal, diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index e22f3962..b241d76e 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -439,60 +439,51 @@ func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64, } // UpdateForwardingPolicies sends a message to the switch to update the -// forwarding policies for the set of target channels. If the set of targeted -// channels is nil, then the forwarding policies for all active channels with -// be updated. +// forwarding policies for the set of target channels, keyed in chanPolicies. // // NOTE: This function is synchronous and will block until either the // forwarding policies for all links have been updated, or the switch shuts // down. -func (s *Switch) UpdateForwardingPolicies(newPolicy ForwardingPolicy, - targetChans ...wire.OutPoint) error { +func (s *Switch) UpdateForwardingPolicies( + chanPolicies map[wire.OutPoint]*channeldb.ChannelEdgePolicy) { - log.Debugf("Updating link policies: %v", newLogClosure(func() string { - return spew.Sdump(newPolicy) + log.Tracef("Updating link policies: %v", newLogClosure(func() string { + return spew.Sdump(chanPolicies) })) - var linksToUpdate []ChannelLink - s.indexMtx.RLock() - // If no channels have been targeted, then we'll collect all inks to - // update their policies. - if len(targetChans) == 0 { - for _, link := range s.linkIndex { - linksToUpdate = append(linksToUpdate, link) + // Update each link in chanPolicies. + for targetLink := range chanPolicies { + cid := lnwire.NewChanIDFromOutPoint(&targetLink) + + link, ok := s.linkIndex[cid] + if !ok { + log.Debugf("Unable to find ChannelPoint(%v) to update "+ + "link policy", targetLink) + continue } - } else { - // Otherwise, we'll only attempt to update the forwarding - // policies for the set of targeted links. - for _, targetLink := range targetChans { - cid := lnwire.NewChanIDFromOutPoint(&targetLink) - // If we can't locate a link by its converted channel - // ID, then we'll return an error back to the caller. - link, ok := s.linkIndex[cid] - if !ok { - s.indexMtx.RUnlock() - - return fmt.Errorf("unable to find "+ - "ChannelPoint(%v) to update link "+ - "policy", targetLink) - } - - linksToUpdate = append(linksToUpdate, link) - } - } - - s.indexMtx.RUnlock() - - // With all the links we need to update collected, we can release the - // mutex then update each link directly. - for _, link := range linksToUpdate { + newPolicy := dbPolicyToFwdingPolicy( + chanPolicies[*link.ChannelPoint()], + ) link.UpdateForwardingPolicy(newPolicy) } - return nil + s.indexMtx.RUnlock() +} + +// dbPolicyToFwdingPolicy is a helper function that converts a channeldb +// ChannelEdgePolicy into a ForwardingPolicy struct for the purpose of updating +// the forwarding policy of a link. +func dbPolicyToFwdingPolicy(policy *channeldb.ChannelEdgePolicy) ForwardingPolicy { + return ForwardingPolicy{ + BaseFee: policy.FeeBaseMSat, + FeeRate: policy.FeeProportionalMillionths, + TimeLockDelta: uint32(policy.TimeLockDelta), + MinHTLC: policy.MinHTLC, + MaxHTLC: policy.MaxHTLC, + } } // forward is used in order to find next channel link and apply htlc update. diff --git a/rpcserver.go b/rpcserver.go index e11a13f4..e187169f 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -4550,6 +4550,12 @@ func (r *rpcServer) FeeReport(ctx context.Context, // 0.000001, or 0.0001%. const minFeeRate = 1e-6 +// policyUpdateLock ensures that the database and the link do not fall out of +// sync if there are concurrent fee update calls. Without it, there is a chance +// that policy A updates the database, then policy B updates the database, then +// policy B updates the link, then policy A updates the link. +var policyUpdateLock sync.Mutex + // UpdateChannelPolicy allows the caller to update the channel forwarding policy // for all channels globally, or a particular channel. func (r *rpcServer) UpdateChannelPolicy(ctx context.Context, @@ -4617,30 +4623,18 @@ func (r *rpcServer) UpdateChannelPolicy(ctx context.Context, // With the scope resolved, we'll now send this to the // AuthenticatedGossiper so it can propagate the new policy for our // target channel(s). - err := r.server.authGossiper.PropagateChanPolicyUpdate( + policyUpdateLock.Lock() + defer policyUpdateLock.Unlock() + chanPolicies, err := r.server.authGossiper.PropagateChanPolicyUpdate( chanPolicy, targetChans..., ) if err != nil { return nil, err } - // Finally, we'll apply the set of active links amongst the target - // channels. - // - // We create a partially policy as the logic won't overwrite a valid - // sub-policy with a "nil" one. - p := htlcswitch.ForwardingPolicy{ - BaseFee: baseFeeMsat, - FeeRate: lnwire.MilliSatoshi(feeRateFixed), - TimeLockDelta: req.TimeLockDelta, - } - err = r.server.htlcSwitch.UpdateForwardingPolicies(p, targetChans...) - if err != nil { - // If we're unable update the fees due to the links not being - // online, then we don't need to fail the call. We'll simply - // log the failure. - rpcsLog.Warnf("Unable to update link fees: %v", err) - } + // Finally, we'll apply the set of channel policies to the target + // channels' links. + r.server.htlcSwitch.UpdateForwardingPolicies(chanPolicies) return &lnrpc.PolicyUpdateResponse{}, nil }