diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 27125fab..66c225da 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -32,9 +32,14 @@ const ( expiryGraceDelta = 2 ) -// ErrInternalLinkFailure is a generic error returned to the remote party so as -// to obfuscate the true failure. -var ErrInternalLinkFailure = errors.New("internal link failure") +var ( + // ErrInternalLinkFailure is a generic error returned to the remote + // party so as to obfuscate the true failure. + ErrInternalLinkFailure = errors.New("internal link failure") + + // ErrLinkShuttingDown signals that the link is shutting down. + ErrLinkShuttingDown = errors.New("link shutting down") +) // ForwardingPolicy describes the set of constraints that a given ChannelLink // is to adhere to when forwarding HTLC's. For each incoming HTLC, this set of @@ -444,9 +449,11 @@ func (l *channelLink) Stop() { // EligibleToForward returns a bool indicating if the channel is able to // actively accept requests to forward HTLC's. We're able to forward HTLC's if // we know the remote party's next revocation point. Otherwise, we can't -// initiate new channel state. +// initiate new channel state. We also require that the short channel ID not be +// the all-zero source ID, meaning that the channel has had its ID finalized. func (l *channelLink) EligibleToForward() bool { - return l.channel.RemoteNextRevocation() != nil + return l.channel.RemoteNextRevocation() != nil && + l.ShortChanID() != sourceHop } // sampleNetworkFee samples the current fee rate on the network to get into the @@ -603,7 +610,7 @@ func (l *channelLink) syncChanStates() error { } case <-l.quit: - return fmt.Errorf("shutting down") + return ErrLinkShuttingDown case <-chanSyncDeadline: return fmt.Errorf("didn't receive ChannelReestablish before " + @@ -759,9 +766,12 @@ func (l *channelLink) htlcManager() { // re-synchronize state with the remote peer. settledHtlcs is a map of // HTLC's that we re-settled as part of the channel state sync. if l.cfg.SyncStates { - if err := l.syncChanStates(); err != nil { + err := l.syncChanStates() + if err != nil { l.errorf("unable to synchronize channel states: %v", err) - l.fail(err.Error()) + if err != ErrLinkShuttingDown { + l.fail(err.Error()) + } return } } @@ -1540,18 +1550,31 @@ func (l *channelLink) ShortChanID() lnwire.ShortChannelID { // within the chain. // // NOTE: Part of the ChannelLink interface. -func (l *channelLink) UpdateShortChanID(sid lnwire.ShortChannelID) { +func (l *channelLink) UpdateShortChanID() (lnwire.ShortChannelID, error) { + chanID := l.ChanID() + + // Refresh the channel state's short channel ID by loading it from disk. + // This ensures that the channel state accurately reflects the updated + // short channel ID. + err := l.channel.State().RefreshShortChanID() + if err != nil { + l.errorf("unable to refresh short_chan_id for chan_id=%v: %v", + chanID, err) + return sourceHop, err + } + + sid := l.channel.ShortChanID() + + l.infof("Updating to short_chan_id=%v for chan_id=%v", sid, chanID) + l.Lock() - defer l.Unlock() - - log.Infof("Updating short chan ID for ChannelPoint(%v)", l) - l.shortChanID = sid + l.Unlock() go func() { err := l.cfg.UpdateContractSignals(&contractcourt.ContractSignals{ HtlcUpdates: l.htlcUpdates, - ShortChanID: l.channel.ShortChanID(), + ShortChanID: sid, }) if err != nil { log.Errorf("Unable to update signals for "+ @@ -1559,7 +1582,7 @@ func (l *channelLink) UpdateShortChanID(sid lnwire.ShortChannelID) { } }() - return + return sid, nil } // ChanID returns the channel ID for the channel link. The channel ID is a more