discovery/gossiper_test: prevent race conditions within mockGraphSource

This commit is contained in:
Wilmer Paulino 2019-02-05 17:18:41 -08:00
parent 73b4bc4b68
commit 6e556aa897
No known key found for this signature in database
GPG Key ID: 6DF57B9F9514972F

@ -108,28 +108,36 @@ func (n *mockSigner) SignMessage(pubKey *btcec.PublicKey,
} }
type mockGraphSource struct { type mockGraphSource struct {
nodes []*channeldb.LightningNode
infos map[uint64]*channeldb.ChannelEdgeInfo
edges map[uint64][]*channeldb.ChannelEdgePolicy
bestHeight uint32 bestHeight uint32
mu sync.Mutex
nodes []channeldb.LightningNode
infos map[uint64]channeldb.ChannelEdgeInfo
edges map[uint64][]channeldb.ChannelEdgePolicy
} }
func newMockRouter(height uint32) *mockGraphSource { func newMockRouter(height uint32) *mockGraphSource {
return &mockGraphSource{ return &mockGraphSource{
bestHeight: height, bestHeight: height,
infos: make(map[uint64]*channeldb.ChannelEdgeInfo), infos: make(map[uint64]channeldb.ChannelEdgeInfo),
edges: make(map[uint64][]*channeldb.ChannelEdgePolicy), edges: make(map[uint64][]channeldb.ChannelEdgePolicy),
} }
} }
var _ routing.ChannelGraphSource = (*mockGraphSource)(nil) var _ routing.ChannelGraphSource = (*mockGraphSource)(nil)
func (r *mockGraphSource) AddNode(node *channeldb.LightningNode) error { func (r *mockGraphSource) AddNode(node *channeldb.LightningNode) error {
r.nodes = append(r.nodes, node) r.mu.Lock()
defer r.mu.Unlock()
r.nodes = append(r.nodes, *node)
return nil return nil
} }
func (r *mockGraphSource) AddEdge(info *channeldb.ChannelEdgeInfo) error { func (r *mockGraphSource) AddEdge(info *channeldb.ChannelEdgeInfo) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, ok := r.infos[info.ChannelID]; ok { if _, ok := r.infos[info.ChannelID]; ok {
return errors.New("info already exist") return errors.New("info already exist")
} }
@ -137,15 +145,15 @@ func (r *mockGraphSource) AddEdge(info *channeldb.ChannelEdgeInfo) error {
// Usually, the capacity is fetched in the router from the funding txout. // Usually, the capacity is fetched in the router from the funding txout.
// Since the mockGraphSource can't access the txout, assign a default value. // Since the mockGraphSource can't access the txout, assign a default value.
info.Capacity = maxBtcFundingAmount info.Capacity = maxBtcFundingAmount
r.infos[info.ChannelID] = info r.infos[info.ChannelID] = *info
return nil return nil
} }
func (r *mockGraphSource) UpdateEdge(edge *channeldb.ChannelEdgePolicy) error { func (r *mockGraphSource) UpdateEdge(edge *channeldb.ChannelEdgePolicy) error {
r.edges[edge.ChannelID] = append( r.mu.Lock()
r.edges[edge.ChannelID], defer r.mu.Unlock()
edge,
) r.edges[edge.ChannelID] = append(r.edges[edge.ChannelID], *edge)
return nil return nil
} }
@ -159,11 +167,19 @@ func (r *mockGraphSource) CurrentBlockHeight() (uint32, error) {
func (r *mockGraphSource) AddProof(chanID lnwire.ShortChannelID, func (r *mockGraphSource) AddProof(chanID lnwire.ShortChannelID,
proof *channeldb.ChannelAuthProof) error { proof *channeldb.ChannelAuthProof) error {
info, ok := r.infos[chanID.ToUint64()]
r.mu.Lock()
defer r.mu.Unlock()
chanIDInt := chanID.ToUint64()
info, ok := r.infos[chanIDInt]
if !ok { if !ok {
return errors.New("channel does not exist") return errors.New("channel does not exist")
} }
info.AuthProof = proof info.AuthProof = proof
r.infos[chanIDInt] = info
return nil return nil
} }
@ -186,6 +202,9 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) (
*channeldb.ChannelEdgePolicy, *channeldb.ChannelEdgePolicy,
*channeldb.ChannelEdgePolicy, error) { *channeldb.ChannelEdgePolicy, error) {
r.mu.Lock()
defer r.mu.Unlock()
chanInfo, ok := r.infos[chanID.ToUint64()] chanInfo, ok := r.infos[chanID.ToUint64()]
if !ok { if !ok {
return nil, nil, nil, channeldb.ErrEdgeNotFound return nil, nil, nil, channeldb.ErrEdgeNotFound
@ -193,14 +212,16 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) (
edges := r.edges[chanID.ToUint64()] edges := r.edges[chanID.ToUint64()]
if len(edges) == 0 { if len(edges) == 0 {
return chanInfo, nil, nil, nil return &chanInfo, nil, nil, nil
} }
if len(edges) == 1 { if len(edges) == 1 {
return chanInfo, edges[0], nil, nil edge1 := edges[0]
return &chanInfo, &edge1, nil, nil
} }
return chanInfo, edges[0], edges[1], nil edge1, edge2 := edges[0], edges[1]
return &chanInfo, &edge1, &edge2, nil
} }
func (r *mockGraphSource) FetchLightningNode( func (r *mockGraphSource) FetchLightningNode(
@ -208,7 +229,7 @@ func (r *mockGraphSource) FetchLightningNode(
for _, node := range r.nodes { for _, node := range r.nodes {
if bytes.Equal(nodePub[:], node.PubKeyBytes[:]) { if bytes.Equal(nodePub[:], node.PubKeyBytes[:]) {
return node, nil return &node, nil
} }
} }
@ -218,6 +239,9 @@ func (r *mockGraphSource) FetchLightningNode(
// IsStaleNode returns true if the graph source has a node announcement for the // IsStaleNode returns true if the graph source has a node announcement for the
// target node with a more recent timestamp. // target node with a more recent timestamp.
func (r *mockGraphSource) IsStaleNode(nodePub routing.Vertex, timestamp time.Time) bool { func (r *mockGraphSource) IsStaleNode(nodePub routing.Vertex, timestamp time.Time) bool {
r.mu.Lock()
defer r.mu.Unlock()
for _, node := range r.nodes { for _, node := range r.nodes {
if node.PubKeyBytes == nodePub { if node.PubKeyBytes == nodePub {
return node.LastUpdate.After(timestamp) || return node.LastUpdate.After(timestamp) ||
@ -258,6 +282,9 @@ func (r *mockGraphSource) IsPublicNode(node routing.Vertex) (bool, error) {
// IsKnownEdge returns true if the graph source already knows of the passed // IsKnownEdge returns true if the graph source already knows of the passed
// channel ID. // channel ID.
func (r *mockGraphSource) IsKnownEdge(chanID lnwire.ShortChannelID) bool { func (r *mockGraphSource) IsKnownEdge(chanID lnwire.ShortChannelID) bool {
r.mu.Lock()
defer r.mu.Unlock()
_, ok := r.infos[chanID.ToUint64()] _, ok := r.infos[chanID.ToUint64()]
return ok return ok
} }
@ -267,6 +294,9 @@ func (r *mockGraphSource) IsKnownEdge(chanID lnwire.ShortChannelID) bool {
func (r *mockGraphSource) IsStaleEdgePolicy(chanID lnwire.ShortChannelID, func (r *mockGraphSource) IsStaleEdgePolicy(chanID lnwire.ShortChannelID,
timestamp time.Time, flags lnwire.ChanUpdateChanFlags) bool { timestamp time.Time, flags lnwire.ChanUpdateChanFlags) bool {
r.mu.Lock()
defer r.mu.Unlock()
edges, ok := r.edges[chanID.ToUint64()] edges, ok := r.edges[chanID.ToUint64()]
if !ok { if !ok {
return false return false