discovery+routing: cancel dependent jobs if parent validation fails

Previously, we would always allow dependent jobs to be processed,
regardless of the result of its parent job's validation. This isn't
correct, as a parent job contains actions necessary to successfully
process a dependent job. A prime example of this can be found within the
AuthenticatedGossiper, where an incoming channel announcement and update
are both processed, but if the channel announcement job fails to
complete, then the gossiper is unable to properly validate the update.
This commit aims to address this by preventing the dependent jobs to
run.
This commit is contained in:
Wilmer Paulino 2021-03-22 15:32:24 -07:00
parent e713205eea
commit 393111cea9
No known key found for this signature in database
GPG Key ID: 6DF57B9F9514972F
4 changed files with 142 additions and 86 deletions

View File

@ -990,7 +990,7 @@ func (d *AuthenticatedGossiper) networkHandler() {
// Channel announcement signatures are amongst the only // Channel announcement signatures are amongst the only
// messages that we'll process serially. // messages that we'll process serially.
case *lnwire.AnnounceSignatures: case *lnwire.AnnounceSignatures:
emittedAnnouncements := d.processNetworkAnnouncement( emittedAnnouncements, _ := d.processNetworkAnnouncement(
announcement, announcement,
) )
if emittedAnnouncements != nil { if emittedAnnouncements != nil {
@ -1040,14 +1040,14 @@ func (d *AuthenticatedGossiper) networkHandler() {
// determine if this is either a new // determine if this is either a new
// announcement from our PoV or an edges to a // announcement from our PoV or an edges to a
// prior vertex/edge we previously proceeded. // prior vertex/edge we previously proceeded.
emittedAnnouncements := d.processNetworkAnnouncement( emittedAnnouncements, allowDependents := d.processNetworkAnnouncement(
announcement, announcement,
) )
// If this message had any dependencies, then // If this message had any dependencies, then
// we can now signal them to continue. // we can now signal them to continue.
validationBarrier.SignalDependants( validationBarrier.SignalDependants(
announcement.msg, announcement.msg, allowDependents,
) )
// If the announcement was accepted, then add // If the announcement was accepted, then add
@ -1514,9 +1514,11 @@ func (d *AuthenticatedGossiper) addNode(msg *lnwire.NodeAnnouncement,
// channel or node announcement or announcements proofs. If the announcement // channel or node announcement or announcements proofs. If the announcement
// didn't affect the internal state due to either being out of date, invalid, // didn't affect the internal state due to either being out of date, invalid,
// or redundant, then nil is returned. Otherwise, the set of announcements will // or redundant, then nil is returned. Otherwise, the set of announcements will
// be returned which should be broadcasted to the rest of the network. // be returned which should be broadcasted to the rest of the network. The
// boolean returned indicates whether any dependents of the announcement should
// attempt to be processed as well.
func (d *AuthenticatedGossiper) processNetworkAnnouncement( func (d *AuthenticatedGossiper) processNetworkAnnouncement(
nMsg *networkMsg) []networkMsg { nMsg *networkMsg) ([]networkMsg, bool) {
isPremature := func(chanID lnwire.ShortChannelID, delta uint32) bool { isPremature := func(chanID lnwire.ShortChannelID, delta uint32) bool {
// TODO(roasbeef) make height delta 6 // TODO(roasbeef) make height delta 6
@ -1546,7 +1548,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
// signatures if not required. // signatures if not required.
if d.cfg.Router.IsStaleNode(msg.NodeID, timestamp) { if d.cfg.Router.IsStaleNode(msg.NodeID, timestamp) {
nMsg.err <- nil nMsg.err <- nil
return nil return nil, true
} }
if err := d.addNode(msg, schedulerOp...); err != nil { if err := d.addNode(msg, schedulerOp...); err != nil {
@ -1559,7 +1561,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
} }
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
// In order to ensure we don't leak unadvertised nodes, we'll // In order to ensure we don't leak unadvertised nodes, we'll
@ -1570,7 +1572,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
log.Errorf("Unable to determine if node %x is "+ log.Errorf("Unable to determine if node %x is "+
"advertised: %v", msg.NodeID, err) "advertised: %v", msg.NodeID, err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
// If it does, we'll add their announcement to our batch so that // If it does, we'll add their announcement to our batch so that
@ -1588,7 +1590,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
nMsg.err <- nil nMsg.err <- nil
// TODO(roasbeef): get rid of the above // TODO(roasbeef): get rid of the above
return announcements return announcements, true
// A new channel announcement has arrived, this indicates the // A new channel announcement has arrived, this indicates the
// *creation* of a new channel within the network. This only advertises // *creation* of a new channel within the network. This only advertises
@ -1608,7 +1610,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
d.rejectMtx.Unlock() d.rejectMtx.Unlock()
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
// If the advertised inclusionary block is beyond our knowledge // If the advertised inclusionary block is beyond our knowledge
@ -1623,7 +1625,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
d.bestHeight) d.bestHeight)
d.Unlock() d.Unlock()
nMsg.err <- nil nMsg.err <- nil
return nil return nil, false
} }
d.Unlock() d.Unlock()
@ -1632,7 +1634,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
// below. // below.
if d.cfg.Router.IsKnownEdge(msg.ShortChannelID) { if d.cfg.Router.IsKnownEdge(msg.ShortChannelID) {
nMsg.err <- nil nMsg.err <- nil
return nil return nil, true
} }
// If this is a remote channel announcement, then we'll validate // If this is a remote channel announcement, then we'll validate
@ -1649,7 +1651,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
log.Error(err) log.Error(err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
// If the proof checks out, then we'll save the proof // If the proof checks out, then we'll save the proof
@ -1669,7 +1671,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
if err := msg.Features.Encode(&featureBuf); err != nil { if err := msg.Features.Encode(&featureBuf); err != nil {
log.Errorf("unable to encode features: %v", err) log.Errorf("unable to encode features: %v", err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
edge := &channeldb.ChannelEdgeInfo{ edge := &channeldb.ChannelEdgeInfo{
@ -1720,7 +1722,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
d.recentRejects[msg.ShortChannelID.ToUint64()] = struct{}{} d.recentRejects[msg.ShortChannelID.ToUint64()] = struct{}{}
d.rejectMtx.Unlock() d.rejectMtx.Unlock()
nMsg.err <- rErr nMsg.err <- rErr
return nil return nil, false
} }
// If while processing this rejected edge, we // If while processing this rejected edge, we
@ -1729,7 +1731,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
// directly. // directly.
if len(anns) != 0 { if len(anns) != 0 {
nMsg.err <- nil nMsg.err <- nil
return anns return anns, true
} }
// Otherwise, this is just a regular rejected // Otherwise, this is just a regular rejected
@ -1742,7 +1744,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
} }
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
// If we earlier received any ChannelUpdates for this channel, // If we earlier received any ChannelUpdates for this channel,
@ -1806,7 +1808,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
} }
nMsg.err <- nil nMsg.err <- nil
return announcements return announcements, true
// A new authenticated channel edge update has arrived. This indicates // A new authenticated channel edge update has arrived. This indicates
// that the directional information for an already known channel has // that the directional information for an already known channel has
@ -1825,7 +1827,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
d.rejectMtx.Unlock() d.rejectMtx.Unlock()
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
blockHeight := msg.ShortChannelID.BlockHeight blockHeight := msg.ShortChannelID.BlockHeight
@ -1842,7 +1844,8 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
shortChanID, blockHeight, shortChanID, blockHeight,
d.bestHeight) d.bestHeight)
d.Unlock() d.Unlock()
return nil nMsg.err <- nil
return nil, false
} }
d.Unlock() d.Unlock()
@ -1854,7 +1857,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
msg.ShortChannelID, timestamp, msg.ChannelFlags, msg.ShortChannelID, timestamp, msg.ChannelFlags,
) { ) {
nMsg.err <- nil nMsg.err <- nil
return nil return nil, true
} }
// Get the node pub key as far as we don't have it in channel // Get the node pub key as far as we don't have it in channel
@ -1893,7 +1896,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
"update signature: %v", err) "update signature: %v", err)
log.Error(err) log.Error(err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
// With the signature valid, we'll proceed to mark the // With the signature valid, we'll proceed to mark the
@ -1906,7 +1909,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
msg.ShortChannelID, err) msg.ShortChannelID, err)
log.Error(err) log.Error(err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
log.Debugf("Removed edge with chan_id=%v from zombie "+ log.Debugf("Removed edge with chan_id=%v from zombie "+
@ -1949,7 +1952,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
// NOTE: We don't return anything on the error channel // NOTE: We don't return anything on the error channel
// for this message, as we expect that will be done when // for this message, as we expect that will be done when
// this ChannelUpdate is later reprocessed. // this ChannelUpdate is later reprocessed.
return nil return nil, false
default: default:
err := fmt.Errorf("unable to validate channel update "+ err := fmt.Errorf("unable to validate channel update "+
@ -1960,7 +1963,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
d.rejectMtx.Lock() d.rejectMtx.Lock()
d.recentRejects[msg.ShortChannelID.ToUint64()] = struct{}{} d.recentRejects[msg.ShortChannelID.ToUint64()] = struct{}{}
d.rejectMtx.Unlock() d.rejectMtx.Unlock()
return nil return nil, false
} }
// The least-significant bit in the flag on the channel update // The least-significant bit in the flag on the channel update
@ -1997,7 +2000,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
d.cfg.RebroadcastInterval, d.cfg.RebroadcastInterval,
shortChanID) shortChanID)
nMsg.err <- nil nMsg.err <- nil
return nil return nil, false
} }
} else { } else {
// If it's not, we'll allow an update per minute // If it's not, we'll allow an update per minute
@ -2024,7 +2027,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
shortChanID, shortChanID,
pubKey.SerializeCompressed()) pubKey.SerializeCompressed())
nMsg.err <- nil nMsg.err <- nil
return nil return nil, false
} }
} }
} }
@ -2040,7 +2043,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
log.Error(rErr) log.Error(rErr)
nMsg.err <- rErr nMsg.err <- rErr
return nil return nil, false
} }
update := &channeldb.ChannelEdgePolicy{ update := &channeldb.ChannelEdgePolicy{
@ -2069,7 +2072,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
} }
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
// If this is a local ChannelUpdate without an AuthProof, it // If this is a local ChannelUpdate without an AuthProof, it
@ -2094,7 +2097,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
msg.MsgType(), msg.ShortChannelID, msg.MsgType(), msg.ShortChannelID,
remotePubKey, err) remotePubKey, err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
} }
@ -2111,7 +2114,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
} }
nMsg.err <- nil nMsg.err <- nil
return announcements return announcements, true
// A new signature announcement has been received. This indicates // A new signature announcement has been received. This indicates
// willingness of nodes involved in the funding of a channel to // willingness of nodes involved in the funding of a channel to
@ -2140,7 +2143,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
d.bestHeight, needBlockHeight) d.bestHeight, needBlockHeight)
d.Unlock() d.Unlock()
nMsg.err <- nil nMsg.err <- nil
return nil return nil, false
} }
d.Unlock() d.Unlock()
@ -2166,14 +2169,14 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
shortChanID, err) shortChanID, err)
log.Error(err) log.Error(err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
log.Infof("Orphan %v proof announcement with "+ log.Infof("Orphan %v proof announcement with "+
"short_chan_id=%v, adding "+ "short_chan_id=%v, adding "+
"to waiting batch", prefix, shortChanID) "to waiting batch", prefix, shortChanID)
nMsg.err <- nil nMsg.err <- nil
return nil return nil, false
} }
nodeID := nMsg.source.SerializeCompressed() nodeID := nMsg.source.SerializeCompressed()
@ -2188,7 +2191,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
"short_chan_id=%v", shortChanID) "short_chan_id=%v", shortChanID)
log.Error(err) log.Error(err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
// If proof was sent by a local sub-system, then we'll // If proof was sent by a local sub-system, then we'll
@ -2212,7 +2215,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
msg.MsgType(), msg.ShortChannelID, msg.MsgType(), msg.ShortChannelID,
remotePubKey, err) remotePubKey, err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
} }
@ -2265,7 +2268,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
log.Debugf("Already have proof for channel "+ log.Debugf("Already have proof for channel "+
"with chanID=%v", msg.ChannelID) "with chanID=%v", msg.ChannelID)
nMsg.err <- nil nMsg.err <- nil
return nil return nil, true
} }
// Check that we received the opposite proof. If so, then we're // Check that we received the opposite proof. If so, then we're
@ -2283,7 +2286,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
shortChanID, err) shortChanID, err)
log.Error(err) log.Error(err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
if err == channeldb.ErrWaitingProofNotFound { if err == channeldb.ErrWaitingProofNotFound {
@ -2294,7 +2297,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
shortChanID, err) shortChanID, err)
log.Error(err) log.Error(err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
log.Infof("1/2 of channel ann proof received for "+ log.Infof("1/2 of channel ann proof received for "+
@ -2302,7 +2305,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
shortChanID) shortChanID)
nMsg.err <- nil nMsg.err <- nil
return nil return nil, false
} }
// We now have both halves of the channel announcement proof, // We now have both halves of the channel announcement proof,
@ -2326,7 +2329,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
if err != nil { if err != nil {
log.Error(err) log.Error(err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
// With all the necessary components assembled validate the // With all the necessary components assembled validate the
@ -2338,7 +2341,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
log.Error(err) log.Error(err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
// If the channel was returned by the router it means that // If the channel was returned by the router it means that
@ -2354,7 +2357,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
"channel chanID=%v: %v", msg.ChannelID, err) "channel chanID=%v: %v", msg.ChannelID, err)
log.Error(err) log.Error(err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
err = d.cfg.WaitingProofStore.Remove(proof.OppositeKey()) err = d.cfg.WaitingProofStore.Remove(proof.OppositeKey())
@ -2364,7 +2367,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
msg.ChannelID, err) msg.ChannelID, err)
log.Error(err) log.Error(err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
// Proof was successfully created and now can announce the // Proof was successfully created and now can announce the
@ -2431,11 +2434,12 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
} }
nMsg.err <- nil nMsg.err <- nil
return announcements return announcements, true
default: default:
nMsg.err <- errors.New("wrong type of the announcement") err := errors.New("wrong type of the announcement")
return nil nMsg.err <- err
return nil, false
} }
} }

View File

@ -967,7 +967,8 @@ func (r *ChannelRouter) networkHandler() {
update.msg, update.msg,
) )
if err != nil { if err != nil {
if err != ErrVBarrierShuttingDown { if err != ErrVBarrierShuttingDown &&
err != ErrParentValidationFailed {
log.Warnf("unexpected error "+ log.Warnf("unexpected error "+
"during validation "+ "during validation "+
"barrier shutdown: %v", "barrier shutdown: %v",
@ -985,7 +986,11 @@ func (r *ChannelRouter) networkHandler() {
// If this message had any dependencies, then // If this message had any dependencies, then
// we can now signal them to continue. // we can now signal them to continue.
validationBarrier.SignalDependants(update.msg) allowDependents := err == nil ||
IsError(err, ErrIgnored, ErrOutdated)
validationBarrier.SignalDependants(
update.msg, allowDependents,
)
if err != nil { if err != nil {
return return
} }

View File

@ -9,10 +9,28 @@ import (
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
) )
// ErrVBarrierShuttingDown signals that the barrier has been requested to var (
// shutdown, and that the caller should not treat the wait condition as // ErrVBarrierShuttingDown signals that the barrier has been requested
// fulfilled. // to shutdown, and that the caller should not treat the wait condition
var ErrVBarrierShuttingDown = errors.New("validation barrier shutting down") // as fulfilled.
ErrVBarrierShuttingDown = errors.New("validation barrier shutting down")
// ErrParentValidationFailed signals that the validation of a
// dependent's parent failed, so the dependent must not be processed.
ErrParentValidationFailed = errors.New("parent validation failed")
)
// validationSignals contains two signals which allows the ValidationBarrier to
// communicate back to the caller whether a dependent should be processed or not
// based on whether its parent was successfully validated. Only one of these
// signals is to be used at a time.
type validationSignals struct {
// allow is the signal used to allow a dependent to be processed.
allow chan struct{}
// deny is the signal used to prevent a dependent from being processed.
deny chan struct{}
}
// ValidationBarrier is a barrier used to ensure proper validation order while // ValidationBarrier is a barrier used to ensure proper validation order while
// concurrently validating new announcements for channel edges, and the // concurrently validating new announcements for channel edges, and the
@ -31,19 +49,19 @@ type ValidationBarrier struct {
// ChannelAnnouncement like validation job going on. Once the job has // ChannelAnnouncement like validation job going on. Once the job has
// been completed, the channel will be closed unblocking any // been completed, the channel will be closed unblocking any
// dependants. // dependants.
chanAnnFinSignal map[lnwire.ShortChannelID]chan struct{} chanAnnFinSignal map[lnwire.ShortChannelID]*validationSignals
// chanEdgeDependencies tracks any channel edge updates which should // chanEdgeDependencies tracks any channel edge updates which should
// wait until the completion of the ChannelAnnouncement before // wait until the completion of the ChannelAnnouncement before
// proceeding. This is a dependency, as we can't validate the update // proceeding. This is a dependency, as we can't validate the update
// before we validate the announcement which creates the channel // before we validate the announcement which creates the channel
// itself. // itself.
chanEdgeDependencies map[lnwire.ShortChannelID]chan struct{} chanEdgeDependencies map[lnwire.ShortChannelID]*validationSignals
// nodeAnnDependencies tracks any pending NodeAnnouncement validation // nodeAnnDependencies tracks any pending NodeAnnouncement validation
// jobs which should wait until the completion of the // jobs which should wait until the completion of the
// ChannelAnnouncement before proceeding. // ChannelAnnouncement before proceeding.
nodeAnnDependencies map[route.Vertex]chan struct{} nodeAnnDependencies map[route.Vertex]*validationSignals
quit chan struct{} quit chan struct{}
sync.Mutex sync.Mutex
@ -56,9 +74,9 @@ func NewValidationBarrier(numActiveReqs int,
quitChan chan struct{}) *ValidationBarrier { quitChan chan struct{}) *ValidationBarrier {
v := &ValidationBarrier{ v := &ValidationBarrier{
chanAnnFinSignal: make(map[lnwire.ShortChannelID]chan struct{}), chanAnnFinSignal: make(map[lnwire.ShortChannelID]*validationSignals),
chanEdgeDependencies: make(map[lnwire.ShortChannelID]chan struct{}), chanEdgeDependencies: make(map[lnwire.ShortChannelID]*validationSignals),
nodeAnnDependencies: make(map[route.Vertex]chan struct{}), nodeAnnDependencies: make(map[route.Vertex]*validationSignals),
quit: quitChan, quit: quitChan,
} }
@ -107,24 +125,31 @@ func (v *ValidationBarrier) InitJobDependencies(job interface{}) {
// validate this announcement. All dependants will // validate this announcement. All dependants will
// point to this same channel, so they'll be unblocked // point to this same channel, so they'll be unblocked
// at the same time. // at the same time.
annFinCond := make(chan struct{}) signals := &validationSignals{
v.chanAnnFinSignal[msg.ShortChannelID] = annFinCond allow: make(chan struct{}),
v.chanEdgeDependencies[msg.ShortChannelID] = annFinCond deny: make(chan struct{}),
}
v.nodeAnnDependencies[route.Vertex(msg.NodeID1)] = annFinCond v.chanAnnFinSignal[msg.ShortChannelID] = signals
v.nodeAnnDependencies[route.Vertex(msg.NodeID2)] = annFinCond v.chanEdgeDependencies[msg.ShortChannelID] = signals
v.nodeAnnDependencies[route.Vertex(msg.NodeID1)] = signals
v.nodeAnnDependencies[route.Vertex(msg.NodeID2)] = signals
} }
case *channeldb.ChannelEdgeInfo: case *channeldb.ChannelEdgeInfo:
shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID) shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID)
if _, ok := v.chanAnnFinSignal[shortID]; !ok { if _, ok := v.chanAnnFinSignal[shortID]; !ok {
annFinCond := make(chan struct{}) signals := &validationSignals{
allow: make(chan struct{}),
deny: make(chan struct{}),
}
v.chanAnnFinSignal[shortID] = annFinCond v.chanAnnFinSignal[shortID] = signals
v.chanEdgeDependencies[shortID] = annFinCond v.chanEdgeDependencies[shortID] = signals
v.nodeAnnDependencies[route.Vertex(msg.NodeKey1Bytes)] = annFinCond v.nodeAnnDependencies[route.Vertex(msg.NodeKey1Bytes)] = signals
v.nodeAnnDependencies[route.Vertex(msg.NodeKey2Bytes)] = annFinCond v.nodeAnnDependencies[route.Vertex(msg.NodeKey2Bytes)] = signals
} }
// These other types don't have any dependants, so no further // These other types don't have any dependants, so no further
@ -162,8 +187,8 @@ func (v *ValidationBarrier) CompleteJob() {
func (v *ValidationBarrier) WaitForDependants(job interface{}) error { func (v *ValidationBarrier) WaitForDependants(job interface{}) error {
var ( var (
signal chan struct{} signals *validationSignals
ok bool ok bool
) )
v.Lock() v.Lock()
@ -173,15 +198,15 @@ func (v *ValidationBarrier) WaitForDependants(job interface{}) error {
// completion of any active ChannelAnnouncement jobs related to them. // completion of any active ChannelAnnouncement jobs related to them.
case *channeldb.ChannelEdgePolicy: case *channeldb.ChannelEdgePolicy:
shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID) shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID)
signal, ok = v.chanEdgeDependencies[shortID] signals, ok = v.chanEdgeDependencies[shortID]
case *channeldb.LightningNode: case *channeldb.LightningNode:
vertex := route.Vertex(msg.PubKeyBytes) vertex := route.Vertex(msg.PubKeyBytes)
signal, ok = v.nodeAnnDependencies[vertex] signals, ok = v.nodeAnnDependencies[vertex]
case *lnwire.ChannelUpdate: case *lnwire.ChannelUpdate:
signal, ok = v.chanEdgeDependencies[msg.ShortChannelID] signals, ok = v.chanEdgeDependencies[msg.ShortChannelID]
case *lnwire.NodeAnnouncement: case *lnwire.NodeAnnouncement:
vertex := route.Vertex(msg.NodeID) vertex := route.Vertex(msg.NodeID)
signal, ok = v.nodeAnnDependencies[vertex] signals, ok = v.nodeAnnDependencies[vertex]
// Other types of jobs can be executed immediately, so we'll just // Other types of jobs can be executed immediately, so we'll just
// return directly. // return directly.
@ -204,7 +229,9 @@ func (v *ValidationBarrier) WaitForDependants(job interface{}) error {
select { select {
case <-v.quit: case <-v.quit:
return ErrVBarrierShuttingDown return ErrVBarrierShuttingDown
case <-signal: case <-signals.deny:
return ErrParentValidationFailed
case <-signals.allow:
return nil return nil
} }
} }
@ -212,10 +239,10 @@ func (v *ValidationBarrier) WaitForDependants(job interface{}) error {
return nil return nil
} }
// SignalDependants will signal any jobs that are dependent on this job that // SignalDependants will allow/deny any jobs that are dependent on this job that
// they can continue execution. If the job doesn't have any dependants, then // they can continue execution. If the job doesn't have any dependants, then
// this function sill exit immediately. // this function sill exit immediately.
func (v *ValidationBarrier) SignalDependants(job interface{}) { func (v *ValidationBarrier) SignalDependants(job interface{}, allow bool) {
v.Lock() v.Lock()
defer v.Unlock() defer v.Unlock()
@ -223,18 +250,26 @@ func (v *ValidationBarrier) SignalDependants(job interface{}) {
// If we've just finished executing a ChannelAnnouncement, then we'll // If we've just finished executing a ChannelAnnouncement, then we'll
// close out the signal, and remove the signal from the map of active // close out the signal, and remove the signal from the map of active
// ones. This will allow any dependent jobs to continue execution. // ones. This will allow/deny any dependent jobs to continue execution.
case *channeldb.ChannelEdgeInfo: case *channeldb.ChannelEdgeInfo:
shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID) shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID)
finSignal, ok := v.chanAnnFinSignal[shortID] finSignals, ok := v.chanAnnFinSignal[shortID]
if ok { if ok {
close(finSignal) if allow {
close(finSignals.allow)
} else {
close(finSignals.deny)
}
delete(v.chanAnnFinSignal, shortID) delete(v.chanAnnFinSignal, shortID)
} }
case *lnwire.ChannelAnnouncement: case *lnwire.ChannelAnnouncement:
finSignal, ok := v.chanAnnFinSignal[msg.ShortChannelID] finSignals, ok := v.chanAnnFinSignal[msg.ShortChannelID]
if ok { if ok {
close(finSignal) if allow {
close(finSignals.allow)
} else {
close(finSignals.deny)
}
delete(v.chanAnnFinSignal, msg.ShortChannelID) delete(v.chanAnnFinSignal, msg.ShortChannelID)
} }

View File

@ -12,6 +12,8 @@ import (
// TestValidationBarrierSemaphore checks basic properties of the validation // TestValidationBarrierSemaphore checks basic properties of the validation
// barrier's semaphore wrt. enqueuing/dequeuing. // barrier's semaphore wrt. enqueuing/dequeuing.
func TestValidationBarrierSemaphore(t *testing.T) { func TestValidationBarrierSemaphore(t *testing.T) {
t.Parallel()
const ( const (
numTasks = 8 numTasks = 8
numPendingTasks = 8 numPendingTasks = 8
@ -59,6 +61,8 @@ func TestValidationBarrierSemaphore(t *testing.T) {
// TestValidationBarrierQuit checks that pending validation tasks will return an // TestValidationBarrierQuit checks that pending validation tasks will return an
// error from WaitForDependants if the barrier's quit signal is canceled. // error from WaitForDependants if the barrier's quit signal is canceled.
func TestValidationBarrierQuit(t *testing.T) { func TestValidationBarrierQuit(t *testing.T) {
t.Parallel()
const ( const (
numTasks = 8 numTasks = 8
timeout = 50 * time.Millisecond timeout = 50 * time.Millisecond
@ -113,9 +117,14 @@ func TestValidationBarrierQuit(t *testing.T) {
// with the correct error. // with the correct error.
for i := 0; i < numTasks; i++ { for i := 0; i < numTasks; i++ {
switch { switch {
// First half, signal completion and task semaphore // Signal completion for the first half of tasks, but only allow
// dependents to be processed as well for the second quarter.
case i < numTasks/4:
barrier.SignalDependants(anns[i], false)
barrier.CompleteJob()
case i < numTasks/2: case i < numTasks/2:
barrier.SignalDependants(anns[i]) barrier.SignalDependants(anns[i], true)
barrier.CompleteJob() barrier.CompleteJob()
// At midpoint, quit the validation barrier. // At midpoint, quit the validation barrier.
@ -132,7 +141,10 @@ func TestValidationBarrierQuit(t *testing.T) {
switch { switch {
// First half should return without failure. // First half should return without failure.
case i < numTasks/2 && err != nil: case i < numTasks/4 && err != routing.ErrParentValidationFailed:
t.Fatalf("unexpected failure while waiting: %v", err)
case i >= numTasks/4 && i < numTasks/2 && err != nil:
t.Fatalf("unexpected failure while waiting: %v", err) t.Fatalf("unexpected failure while waiting: %v", err)
// Last half should return the shutdown error. // Last half should return the shutdown error.