Merge pull request #5133 from wpaulino/routing-validation-cancel-deps

discovery+routing: cancel dependent jobs if parent validation fails
This commit is contained in:
Olaoluwa Osuntokun 2021-04-01 18:32:58 -07:00 committed by GitHub
commit a329c80612
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 154 additions and 119 deletions

View File

@ -990,7 +990,7 @@ func (d *AuthenticatedGossiper) networkHandler() {
// Channel announcement signatures are amongst the only // Channel announcement signatures are amongst the only
// messages that we'll process serially. // messages that we'll process serially.
case *lnwire.AnnounceSignatures: case *lnwire.AnnounceSignatures:
emittedAnnouncements := d.processNetworkAnnouncement( emittedAnnouncements, _ := d.processNetworkAnnouncement(
announcement, announcement,
) )
if emittedAnnouncements != nil { if emittedAnnouncements != nil {
@ -1040,14 +1040,14 @@ func (d *AuthenticatedGossiper) networkHandler() {
// determine if this is either a new // determine if this is either a new
// announcement from our PoV or an edges to a // announcement from our PoV or an edges to a
// prior vertex/edge we previously proceeded. // prior vertex/edge we previously proceeded.
emittedAnnouncements := d.processNetworkAnnouncement( emittedAnnouncements, allowDependents := d.processNetworkAnnouncement(
announcement, announcement,
) )
// If this message had any dependencies, then // If this message had any dependencies, then
// we can now signal them to continue. // we can now signal them to continue.
validationBarrier.SignalDependants( validationBarrier.SignalDependants(
announcement.msg, announcement.msg, allowDependents,
) )
// If the announcement was accepted, then add // If the announcement was accepted, then add
@ -1514,9 +1514,11 @@ func (d *AuthenticatedGossiper) addNode(msg *lnwire.NodeAnnouncement,
// channel or node announcement or announcements proofs. If the announcement // channel or node announcement or announcements proofs. If the announcement
// didn't affect the internal state due to either being out of date, invalid, // didn't affect the internal state due to either being out of date, invalid,
// or redundant, then nil is returned. Otherwise, the set of announcements will // or redundant, then nil is returned. Otherwise, the set of announcements will
// be returned which should be broadcasted to the rest of the network. // be returned which should be broadcasted to the rest of the network. The
// boolean returned indicates whether any dependents of the announcement should
// attempt to be processed as well.
func (d *AuthenticatedGossiper) processNetworkAnnouncement( func (d *AuthenticatedGossiper) processNetworkAnnouncement(
nMsg *networkMsg) []networkMsg { nMsg *networkMsg) ([]networkMsg, bool) {
isPremature := func(chanID lnwire.ShortChannelID, delta uint32) bool { isPremature := func(chanID lnwire.ShortChannelID, delta uint32) bool {
// TODO(roasbeef) make height delta 6 // TODO(roasbeef) make height delta 6
@ -1546,7 +1548,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
// signatures if not required. // signatures if not required.
if d.cfg.Router.IsStaleNode(msg.NodeID, timestamp) { if d.cfg.Router.IsStaleNode(msg.NodeID, timestamp) {
nMsg.err <- nil nMsg.err <- nil
return nil return nil, true
} }
if err := d.addNode(msg, schedulerOp...); err != nil { if err := d.addNode(msg, schedulerOp...); err != nil {
@ -1559,7 +1561,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
} }
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
// In order to ensure we don't leak unadvertised nodes, we'll // In order to ensure we don't leak unadvertised nodes, we'll
@ -1570,7 +1572,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
log.Errorf("Unable to determine if node %x is "+ log.Errorf("Unable to determine if node %x is "+
"advertised: %v", msg.NodeID, err) "advertised: %v", msg.NodeID, err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
// If it does, we'll add their announcement to our batch so that // If it does, we'll add their announcement to our batch so that
@ -1588,7 +1590,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
nMsg.err <- nil nMsg.err <- nil
// TODO(roasbeef): get rid of the above // TODO(roasbeef): get rid of the above
return announcements return announcements, true
// A new channel announcement has arrived, this indicates the // A new channel announcement has arrived, this indicates the
// *creation* of a new channel within the network. This only advertises // *creation* of a new channel within the network. This only advertises
@ -1608,12 +1610,11 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
d.rejectMtx.Unlock() d.rejectMtx.Unlock()
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
// If the advertised inclusionary block is beyond our knowledge // If the advertised inclusionary block is beyond our knowledge
// of the chain tip, then we'll put the announcement in limbo // of the chain tip, then we'll ignore for it now.
// to be fully verified once we advance forward in the chain.
d.Lock() d.Lock()
if nMsg.isRemote && isPremature(msg.ShortChannelID, 0) { if nMsg.isRemote && isPremature(msg.ShortChannelID, 0) {
log.Infof("Announcement for chan_id=(%v), is "+ log.Infof("Announcement for chan_id=(%v), is "+
@ -1623,7 +1624,8 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
msg.ShortChannelID.BlockHeight, msg.ShortChannelID.BlockHeight,
d.bestHeight) d.bestHeight)
d.Unlock() d.Unlock()
return nil nMsg.err <- nil
return nil, false
} }
d.Unlock() d.Unlock()
@ -1632,7 +1634,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
// below. // below.
if d.cfg.Router.IsKnownEdge(msg.ShortChannelID) { if d.cfg.Router.IsKnownEdge(msg.ShortChannelID) {
nMsg.err <- nil nMsg.err <- nil
return nil return nil, true
} }
// If this is a remote channel announcement, then we'll validate // If this is a remote channel announcement, then we'll validate
@ -1649,7 +1651,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
log.Error(err) log.Error(err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
// If the proof checks out, then we'll save the proof // If the proof checks out, then we'll save the proof
@ -1669,7 +1671,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
if err := msg.Features.Encode(&featureBuf); err != nil { if err := msg.Features.Encode(&featureBuf); err != nil {
log.Errorf("unable to encode features: %v", err) log.Errorf("unable to encode features: %v", err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
edge := &channeldb.ChannelEdgeInfo{ edge := &channeldb.ChannelEdgeInfo{
@ -1720,7 +1722,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
d.recentRejects[msg.ShortChannelID.ToUint64()] = struct{}{} d.recentRejects[msg.ShortChannelID.ToUint64()] = struct{}{}
d.rejectMtx.Unlock() d.rejectMtx.Unlock()
nMsg.err <- rErr nMsg.err <- rErr
return nil return nil, false
} }
// If while processing this rejected edge, we // If while processing this rejected edge, we
@ -1729,7 +1731,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
// directly. // directly.
if len(anns) != 0 { if len(anns) != 0 {
nMsg.err <- nil nMsg.err <- nil
return anns return anns, true
} }
// Otherwise, this is just a regular rejected // Otherwise, this is just a regular rejected
@ -1742,7 +1744,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
} }
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
// If we earlier received any ChannelUpdates for this channel, // If we earlier received any ChannelUpdates for this channel,
@ -1806,7 +1808,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
} }
nMsg.err <- nil nMsg.err <- nil
return announcements return announcements, true
// A new authenticated channel edge update has arrived. This indicates // A new authenticated channel edge update has arrived. This indicates
// that the directional information for an already known channel has // that the directional information for an already known channel has
@ -1825,7 +1827,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
d.rejectMtx.Unlock() d.rejectMtx.Unlock()
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
blockHeight := msg.ShortChannelID.BlockHeight blockHeight := msg.ShortChannelID.BlockHeight
@ -1842,7 +1844,8 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
shortChanID, blockHeight, shortChanID, blockHeight,
d.bestHeight) d.bestHeight)
d.Unlock() d.Unlock()
return nil nMsg.err <- nil
return nil, false
} }
d.Unlock() d.Unlock()
@ -1854,7 +1857,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
msg.ShortChannelID, timestamp, msg.ChannelFlags, msg.ShortChannelID, timestamp, msg.ChannelFlags,
) { ) {
nMsg.err <- nil nMsg.err <- nil
return nil return nil, true
} }
// Get the node pub key as far as we don't have it in channel // Get the node pub key as far as we don't have it in channel
@ -1893,7 +1896,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
"update signature: %v", err) "update signature: %v", err)
log.Error(err) log.Error(err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
// With the signature valid, we'll proceed to mark the // With the signature valid, we'll proceed to mark the
@ -1906,7 +1909,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
msg.ShortChannelID, err) msg.ShortChannelID, err)
log.Error(err) log.Error(err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
log.Debugf("Removed edge with chan_id=%v from zombie "+ log.Debugf("Removed edge with chan_id=%v from zombie "+
@ -1949,7 +1952,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
// NOTE: We don't return anything on the error channel // NOTE: We don't return anything on the error channel
// for this message, as we expect that will be done when // for this message, as we expect that will be done when
// this ChannelUpdate is later reprocessed. // this ChannelUpdate is later reprocessed.
return nil return nil, false
default: default:
err := fmt.Errorf("unable to validate channel update "+ err := fmt.Errorf("unable to validate channel update "+
@ -1960,7 +1963,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
d.rejectMtx.Lock() d.rejectMtx.Lock()
d.recentRejects[msg.ShortChannelID.ToUint64()] = struct{}{} d.recentRejects[msg.ShortChannelID.ToUint64()] = struct{}{}
d.rejectMtx.Unlock() d.rejectMtx.Unlock()
return nil return nil, false
} }
// The least-significant bit in the flag on the channel update // The least-significant bit in the flag on the channel update
@ -1997,7 +2000,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
d.cfg.RebroadcastInterval, d.cfg.RebroadcastInterval,
shortChanID) shortChanID)
nMsg.err <- nil nMsg.err <- nil
return nil return nil, false
} }
} else { } else {
// If it's not, we'll allow an update per minute // If it's not, we'll allow an update per minute
@ -2024,7 +2027,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
shortChanID, shortChanID,
pubKey.SerializeCompressed()) pubKey.SerializeCompressed())
nMsg.err <- nil nMsg.err <- nil
return nil return nil, false
} }
} }
} }
@ -2040,7 +2043,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
log.Error(rErr) log.Error(rErr)
nMsg.err <- rErr nMsg.err <- rErr
return nil return nil, false
} }
update := &channeldb.ChannelEdgePolicy{ update := &channeldb.ChannelEdgePolicy{
@ -2069,7 +2072,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
} }
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
// If this is a local ChannelUpdate without an AuthProof, it // If this is a local ChannelUpdate without an AuthProof, it
@ -2094,7 +2097,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
msg.MsgType(), msg.ShortChannelID, msg.MsgType(), msg.ShortChannelID,
remotePubKey, err) remotePubKey, err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
} }
@ -2111,7 +2114,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
} }
nMsg.err <- nil nMsg.err <- nil
return announcements return announcements, true
// A new signature announcement has been received. This indicates // A new signature announcement has been received. This indicates
// willingness of nodes involved in the funding of a channel to // willingness of nodes involved in the funding of a channel to
@ -2132,17 +2135,15 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
// By the specification, channel announcement proofs should be // By the specification, channel announcement proofs should be
// sent after some number of confirmations after channel was // sent after some number of confirmations after channel was
// registered in bitcoin blockchain. Therefore, we check if the // registered in bitcoin blockchain. Therefore, we check if the
// proof is premature. If so we'll halt processing until the // proof is premature.
// expected announcement height. This allows us to be tolerant
// to other clients if this constraint was changed.
d.Lock() d.Lock()
if isPremature(msg.ShortChannelID, d.cfg.ProofMatureDelta) { if isPremature(msg.ShortChannelID, d.cfg.ProofMatureDelta) {
log.Infof("Premature proof announcement, "+ log.Infof("Premature proof announcement, current "+
"current block height lower than needed: %v <"+ "block height lower than needed: %v < %v",
" %v, add announcement to reprocessing batch",
d.bestHeight, needBlockHeight) d.bestHeight, needBlockHeight)
d.Unlock() d.Unlock()
return nil nMsg.err <- nil
return nil, false
} }
d.Unlock() d.Unlock()
@ -2168,14 +2169,14 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
shortChanID, err) shortChanID, err)
log.Error(err) log.Error(err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
log.Infof("Orphan %v proof announcement with "+ log.Infof("Orphan %v proof announcement with "+
"short_chan_id=%v, adding "+ "short_chan_id=%v, adding "+
"to waiting batch", prefix, shortChanID) "to waiting batch", prefix, shortChanID)
nMsg.err <- nil nMsg.err <- nil
return nil return nil, false
} }
nodeID := nMsg.source.SerializeCompressed() nodeID := nMsg.source.SerializeCompressed()
@ -2190,7 +2191,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
"short_chan_id=%v", shortChanID) "short_chan_id=%v", shortChanID)
log.Error(err) log.Error(err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
// If proof was sent by a local sub-system, then we'll // If proof was sent by a local sub-system, then we'll
@ -2214,7 +2215,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
msg.MsgType(), msg.ShortChannelID, msg.MsgType(), msg.ShortChannelID,
remotePubKey, err) remotePubKey, err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
} }
@ -2267,7 +2268,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
log.Debugf("Already have proof for channel "+ log.Debugf("Already have proof for channel "+
"with chanID=%v", msg.ChannelID) "with chanID=%v", msg.ChannelID)
nMsg.err <- nil nMsg.err <- nil
return nil return nil, true
} }
// Check that we received the opposite proof. If so, then we're // Check that we received the opposite proof. If so, then we're
@ -2285,7 +2286,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
shortChanID, err) shortChanID, err)
log.Error(err) log.Error(err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
if err == channeldb.ErrWaitingProofNotFound { if err == channeldb.ErrWaitingProofNotFound {
@ -2296,7 +2297,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
shortChanID, err) shortChanID, err)
log.Error(err) log.Error(err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
log.Infof("1/2 of channel ann proof received for "+ log.Infof("1/2 of channel ann proof received for "+
@ -2304,7 +2305,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
shortChanID) shortChanID)
nMsg.err <- nil nMsg.err <- nil
return nil return nil, false
} }
// We now have both halves of the channel announcement proof, // We now have both halves of the channel announcement proof,
@ -2328,7 +2329,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
if err != nil { if err != nil {
log.Error(err) log.Error(err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
// With all the necessary components assembled validate the // With all the necessary components assembled validate the
@ -2340,7 +2341,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
log.Error(err) log.Error(err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
// If the channel was returned by the router it means that // If the channel was returned by the router it means that
@ -2356,7 +2357,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
"channel chanID=%v: %v", msg.ChannelID, err) "channel chanID=%v: %v", msg.ChannelID, err)
log.Error(err) log.Error(err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
err = d.cfg.WaitingProofStore.Remove(proof.OppositeKey()) err = d.cfg.WaitingProofStore.Remove(proof.OppositeKey())
@ -2366,7 +2367,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
msg.ChannelID, err) msg.ChannelID, err)
log.Error(err) log.Error(err)
nMsg.err <- err nMsg.err <- err
return nil return nil, false
} }
// Proof was successfully created and now can announce the // Proof was successfully created and now can announce the
@ -2433,11 +2434,12 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
} }
nMsg.err <- nil nMsg.err <- nil
return announcements return announcements, true
default: default:
nMsg.err <- errors.New("wrong type of the announcement") err := errors.New("wrong type of the announcement")
return nil nMsg.err <- err
return nil, false
} }
} }

View File

@ -897,9 +897,8 @@ func TestProcessAnnouncement(t *testing.T) {
} }
} }
// TestPrematureAnnouncement checks that premature announcements are // TestPrematureAnnouncement checks that premature announcements are not
// not propagated to the router subsystem until block with according // propagated to the router subsystem.
// block height received.
func TestPrematureAnnouncement(t *testing.T) { func TestPrematureAnnouncement(t *testing.T) {
t.Parallel() t.Parallel()
@ -920,8 +919,8 @@ func TestPrematureAnnouncement(t *testing.T) {
// Pretending that we receive the valid channel announcement from // Pretending that we receive the valid channel announcement from
// remote side, but block height of this announcement is greater than // remote side, but block height of this announcement is greater than
// highest know to us, for that reason it should be added to the // highest know to us, for that reason it should be ignored and not
// repeat/premature batch. // added to the router.
ca, err := createRemoteChannelAnnouncement(1) ca, err := createRemoteChannelAnnouncement(1)
if err != nil { if err != nil {
t.Fatalf("can't create channel announcement: %v", err) t.Fatalf("can't create channel announcement: %v", err)
@ -929,31 +928,13 @@ func TestPrematureAnnouncement(t *testing.T) {
select { select {
case <-ctx.gossiper.ProcessRemoteAnnouncement(ca, nodePeer): case <-ctx.gossiper.ProcessRemoteAnnouncement(ca, nodePeer):
t.Fatal("announcement was proceeded") case <-time.After(time.Second):
case <-time.After(100 * time.Millisecond): t.Fatal("announcement was not processed")
} }
if len(ctx.router.infos) != 0 { if len(ctx.router.infos) != 0 {
t.Fatal("edge was added to router") t.Fatal("edge was added to router")
} }
// Pretending that we receive the valid channel update announcement from
// remote side, but block height of this announcement is greater than
// highest known to us, so it should be rejected.
ua, err := createUpdateAnnouncement(1, 0, remoteKeyPriv1, timestamp)
if err != nil {
t.Fatalf("can't create update announcement: %v", err)
}
select {
case <-ctx.gossiper.ProcessRemoteAnnouncement(ua, nodePeer):
t.Fatal("announcement was proceeded")
case <-time.After(100 * time.Millisecond):
}
if len(ctx.router.edges) != 0 {
t.Fatal("edge update was added to router")
}
} }
// TestSignatureAnnouncementLocalFirst ensures that the AuthenticatedGossiper // TestSignatureAnnouncementLocalFirst ensures that the AuthenticatedGossiper

View File

@ -992,7 +992,8 @@ func (r *ChannelRouter) networkHandler() {
update.msg, update.msg,
) )
if err != nil { if err != nil {
if err != ErrVBarrierShuttingDown { if err != ErrVBarrierShuttingDown &&
err != ErrParentValidationFailed {
log.Warnf("unexpected error "+ log.Warnf("unexpected error "+
"during validation "+ "during validation "+
"barrier shutdown: %v", "barrier shutdown: %v",
@ -1010,7 +1011,11 @@ func (r *ChannelRouter) networkHandler() {
// If this message had any dependencies, then // If this message had any dependencies, then
// we can now signal them to continue. // we can now signal them to continue.
validationBarrier.SignalDependants(update.msg) allowDependents := err == nil ||
IsError(err, ErrIgnored, ErrOutdated)
validationBarrier.SignalDependants(
update.msg, allowDependents,
)
if err != nil { if err != nil {
return return
} }

View File

@ -9,10 +9,28 @@ import (
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
) )
// ErrVBarrierShuttingDown signals that the barrier has been requested to var (
// shutdown, and that the caller should not treat the wait condition as // ErrVBarrierShuttingDown signals that the barrier has been requested
// fulfilled. // to shutdown, and that the caller should not treat the wait condition
var ErrVBarrierShuttingDown = errors.New("validation barrier shutting down") // as fulfilled.
ErrVBarrierShuttingDown = errors.New("validation barrier shutting down")
// ErrParentValidationFailed signals that the validation of a
// dependent's parent failed, so the dependent must not be processed.
ErrParentValidationFailed = errors.New("parent validation failed")
)
// validationSignals contains two signals which allows the ValidationBarrier to
// communicate back to the caller whether a dependent should be processed or not
// based on whether its parent was successfully validated. Only one of these
// signals is to be used at a time.
type validationSignals struct {
// allow is the signal used to allow a dependent to be processed.
allow chan struct{}
// deny is the signal used to prevent a dependent from being processed.
deny chan struct{}
}
// ValidationBarrier is a barrier used to ensure proper validation order while // ValidationBarrier is a barrier used to ensure proper validation order while
// concurrently validating new announcements for channel edges, and the // concurrently validating new announcements for channel edges, and the
@ -31,19 +49,19 @@ type ValidationBarrier struct {
// ChannelAnnouncement like validation job going on. Once the job has // ChannelAnnouncement like validation job going on. Once the job has
// been completed, the channel will be closed unblocking any // been completed, the channel will be closed unblocking any
// dependants. // dependants.
chanAnnFinSignal map[lnwire.ShortChannelID]chan struct{} chanAnnFinSignal map[lnwire.ShortChannelID]*validationSignals
// chanEdgeDependencies tracks any channel edge updates which should // chanEdgeDependencies tracks any channel edge updates which should
// wait until the completion of the ChannelAnnouncement before // wait until the completion of the ChannelAnnouncement before
// proceeding. This is a dependency, as we can't validate the update // proceeding. This is a dependency, as we can't validate the update
// before we validate the announcement which creates the channel // before we validate the announcement which creates the channel
// itself. // itself.
chanEdgeDependencies map[lnwire.ShortChannelID]chan struct{} chanEdgeDependencies map[lnwire.ShortChannelID]*validationSignals
// nodeAnnDependencies tracks any pending NodeAnnouncement validation // nodeAnnDependencies tracks any pending NodeAnnouncement validation
// jobs which should wait until the completion of the // jobs which should wait until the completion of the
// ChannelAnnouncement before proceeding. // ChannelAnnouncement before proceeding.
nodeAnnDependencies map[route.Vertex]chan struct{} nodeAnnDependencies map[route.Vertex]*validationSignals
quit chan struct{} quit chan struct{}
sync.Mutex sync.Mutex
@ -56,9 +74,9 @@ func NewValidationBarrier(numActiveReqs int,
quitChan chan struct{}) *ValidationBarrier { quitChan chan struct{}) *ValidationBarrier {
v := &ValidationBarrier{ v := &ValidationBarrier{
chanAnnFinSignal: make(map[lnwire.ShortChannelID]chan struct{}), chanAnnFinSignal: make(map[lnwire.ShortChannelID]*validationSignals),
chanEdgeDependencies: make(map[lnwire.ShortChannelID]chan struct{}), chanEdgeDependencies: make(map[lnwire.ShortChannelID]*validationSignals),
nodeAnnDependencies: make(map[route.Vertex]chan struct{}), nodeAnnDependencies: make(map[route.Vertex]*validationSignals),
quit: quitChan, quit: quitChan,
} }
@ -107,24 +125,31 @@ func (v *ValidationBarrier) InitJobDependencies(job interface{}) {
// validate this announcement. All dependants will // validate this announcement. All dependants will
// point to this same channel, so they'll be unblocked // point to this same channel, so they'll be unblocked
// at the same time. // at the same time.
annFinCond := make(chan struct{}) signals := &validationSignals{
v.chanAnnFinSignal[msg.ShortChannelID] = annFinCond allow: make(chan struct{}),
v.chanEdgeDependencies[msg.ShortChannelID] = annFinCond deny: make(chan struct{}),
}
v.nodeAnnDependencies[route.Vertex(msg.NodeID1)] = annFinCond v.chanAnnFinSignal[msg.ShortChannelID] = signals
v.nodeAnnDependencies[route.Vertex(msg.NodeID2)] = annFinCond v.chanEdgeDependencies[msg.ShortChannelID] = signals
v.nodeAnnDependencies[route.Vertex(msg.NodeID1)] = signals
v.nodeAnnDependencies[route.Vertex(msg.NodeID2)] = signals
} }
case *channeldb.ChannelEdgeInfo: case *channeldb.ChannelEdgeInfo:
shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID) shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID)
if _, ok := v.chanAnnFinSignal[shortID]; !ok { if _, ok := v.chanAnnFinSignal[shortID]; !ok {
annFinCond := make(chan struct{}) signals := &validationSignals{
allow: make(chan struct{}),
deny: make(chan struct{}),
}
v.chanAnnFinSignal[shortID] = annFinCond v.chanAnnFinSignal[shortID] = signals
v.chanEdgeDependencies[shortID] = annFinCond v.chanEdgeDependencies[shortID] = signals
v.nodeAnnDependencies[route.Vertex(msg.NodeKey1Bytes)] = annFinCond v.nodeAnnDependencies[route.Vertex(msg.NodeKey1Bytes)] = signals
v.nodeAnnDependencies[route.Vertex(msg.NodeKey2Bytes)] = annFinCond v.nodeAnnDependencies[route.Vertex(msg.NodeKey2Bytes)] = signals
} }
// These other types don't have any dependants, so no further // These other types don't have any dependants, so no further
@ -162,8 +187,8 @@ func (v *ValidationBarrier) CompleteJob() {
func (v *ValidationBarrier) WaitForDependants(job interface{}) error { func (v *ValidationBarrier) WaitForDependants(job interface{}) error {
var ( var (
signal chan struct{} signals *validationSignals
ok bool ok bool
) )
v.Lock() v.Lock()
@ -173,15 +198,15 @@ func (v *ValidationBarrier) WaitForDependants(job interface{}) error {
// completion of any active ChannelAnnouncement jobs related to them. // completion of any active ChannelAnnouncement jobs related to them.
case *channeldb.ChannelEdgePolicy: case *channeldb.ChannelEdgePolicy:
shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID) shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID)
signal, ok = v.chanEdgeDependencies[shortID] signals, ok = v.chanEdgeDependencies[shortID]
case *channeldb.LightningNode: case *channeldb.LightningNode:
vertex := route.Vertex(msg.PubKeyBytes) vertex := route.Vertex(msg.PubKeyBytes)
signal, ok = v.nodeAnnDependencies[vertex] signals, ok = v.nodeAnnDependencies[vertex]
case *lnwire.ChannelUpdate: case *lnwire.ChannelUpdate:
signal, ok = v.chanEdgeDependencies[msg.ShortChannelID] signals, ok = v.chanEdgeDependencies[msg.ShortChannelID]
case *lnwire.NodeAnnouncement: case *lnwire.NodeAnnouncement:
vertex := route.Vertex(msg.NodeID) vertex := route.Vertex(msg.NodeID)
signal, ok = v.nodeAnnDependencies[vertex] signals, ok = v.nodeAnnDependencies[vertex]
// Other types of jobs can be executed immediately, so we'll just // Other types of jobs can be executed immediately, so we'll just
// return directly. // return directly.
@ -204,7 +229,9 @@ func (v *ValidationBarrier) WaitForDependants(job interface{}) error {
select { select {
case <-v.quit: case <-v.quit:
return ErrVBarrierShuttingDown return ErrVBarrierShuttingDown
case <-signal: case <-signals.deny:
return ErrParentValidationFailed
case <-signals.allow:
return nil return nil
} }
} }
@ -212,10 +239,10 @@ func (v *ValidationBarrier) WaitForDependants(job interface{}) error {
return nil return nil
} }
// SignalDependants will signal any jobs that are dependent on this job that // SignalDependants will allow/deny any jobs that are dependent on this job that
// they can continue execution. If the job doesn't have any dependants, then // they can continue execution. If the job doesn't have any dependants, then
// this function sill exit immediately. // this function sill exit immediately.
func (v *ValidationBarrier) SignalDependants(job interface{}) { func (v *ValidationBarrier) SignalDependants(job interface{}, allow bool) {
v.Lock() v.Lock()
defer v.Unlock() defer v.Unlock()
@ -223,18 +250,26 @@ func (v *ValidationBarrier) SignalDependants(job interface{}) {
// If we've just finished executing a ChannelAnnouncement, then we'll // If we've just finished executing a ChannelAnnouncement, then we'll
// close out the signal, and remove the signal from the map of active // close out the signal, and remove the signal from the map of active
// ones. This will allow any dependent jobs to continue execution. // ones. This will allow/deny any dependent jobs to continue execution.
case *channeldb.ChannelEdgeInfo: case *channeldb.ChannelEdgeInfo:
shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID) shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID)
finSignal, ok := v.chanAnnFinSignal[shortID] finSignals, ok := v.chanAnnFinSignal[shortID]
if ok { if ok {
close(finSignal) if allow {
close(finSignals.allow)
} else {
close(finSignals.deny)
}
delete(v.chanAnnFinSignal, shortID) delete(v.chanAnnFinSignal, shortID)
} }
case *lnwire.ChannelAnnouncement: case *lnwire.ChannelAnnouncement:
finSignal, ok := v.chanAnnFinSignal[msg.ShortChannelID] finSignals, ok := v.chanAnnFinSignal[msg.ShortChannelID]
if ok { if ok {
close(finSignal) if allow {
close(finSignals.allow)
} else {
close(finSignals.deny)
}
delete(v.chanAnnFinSignal, msg.ShortChannelID) delete(v.chanAnnFinSignal, msg.ShortChannelID)
} }

View File

@ -12,6 +12,8 @@ import (
// TestValidationBarrierSemaphore checks basic properties of the validation // TestValidationBarrierSemaphore checks basic properties of the validation
// barrier's semaphore wrt. enqueuing/dequeuing. // barrier's semaphore wrt. enqueuing/dequeuing.
func TestValidationBarrierSemaphore(t *testing.T) { func TestValidationBarrierSemaphore(t *testing.T) {
t.Parallel()
const ( const (
numTasks = 8 numTasks = 8
numPendingTasks = 8 numPendingTasks = 8
@ -59,6 +61,8 @@ func TestValidationBarrierSemaphore(t *testing.T) {
// TestValidationBarrierQuit checks that pending validation tasks will return an // TestValidationBarrierQuit checks that pending validation tasks will return an
// error from WaitForDependants if the barrier's quit signal is canceled. // error from WaitForDependants if the barrier's quit signal is canceled.
func TestValidationBarrierQuit(t *testing.T) { func TestValidationBarrierQuit(t *testing.T) {
t.Parallel()
const ( const (
numTasks = 8 numTasks = 8
timeout = 50 * time.Millisecond timeout = 50 * time.Millisecond
@ -113,9 +117,14 @@ func TestValidationBarrierQuit(t *testing.T) {
// with the correct error. // with the correct error.
for i := 0; i < numTasks; i++ { for i := 0; i < numTasks; i++ {
switch { switch {
// First half, signal completion and task semaphore // Signal completion for the first half of tasks, but only allow
// dependents to be processed as well for the second quarter.
case i < numTasks/4:
barrier.SignalDependants(anns[i], false)
barrier.CompleteJob()
case i < numTasks/2: case i < numTasks/2:
barrier.SignalDependants(anns[i]) barrier.SignalDependants(anns[i], true)
barrier.CompleteJob() barrier.CompleteJob()
// At midpoint, quit the validation barrier. // At midpoint, quit the validation barrier.
@ -132,7 +141,10 @@ func TestValidationBarrierQuit(t *testing.T) {
switch { switch {
// First half should return without failure. // First half should return without failure.
case i < numTasks/2 && err != nil: case i < numTasks/4 && err != routing.ErrParentValidationFailed:
t.Fatalf("unexpected failure while waiting: %v", err)
case i >= numTasks/4 && i < numTasks/2 && err != nil:
t.Fatalf("unexpected failure while waiting: %v", err) t.Fatalf("unexpected failure while waiting: %v", err)
// Last half should return the shutdown error. // Last half should return the shutdown error.