discovery/gossiper_test: prevent race conditions within mockGraphSource
This commit is contained in:
parent
73b4bc4b68
commit
6e556aa897
@ -108,28 +108,36 @@ func (n *mockSigner) SignMessage(pubKey *btcec.PublicKey,
|
|||||||
}
|
}
|
||||||
|
|
||||||
type mockGraphSource struct {
|
type mockGraphSource struct {
|
||||||
nodes []*channeldb.LightningNode
|
|
||||||
infos map[uint64]*channeldb.ChannelEdgeInfo
|
|
||||||
edges map[uint64][]*channeldb.ChannelEdgePolicy
|
|
||||||
bestHeight uint32
|
bestHeight uint32
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
nodes []channeldb.LightningNode
|
||||||
|
infos map[uint64]channeldb.ChannelEdgeInfo
|
||||||
|
edges map[uint64][]channeldb.ChannelEdgePolicy
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMockRouter(height uint32) *mockGraphSource {
|
func newMockRouter(height uint32) *mockGraphSource {
|
||||||
return &mockGraphSource{
|
return &mockGraphSource{
|
||||||
bestHeight: height,
|
bestHeight: height,
|
||||||
infos: make(map[uint64]*channeldb.ChannelEdgeInfo),
|
infos: make(map[uint64]channeldb.ChannelEdgeInfo),
|
||||||
edges: make(map[uint64][]*channeldb.ChannelEdgePolicy),
|
edges: make(map[uint64][]channeldb.ChannelEdgePolicy),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ routing.ChannelGraphSource = (*mockGraphSource)(nil)
|
var _ routing.ChannelGraphSource = (*mockGraphSource)(nil)
|
||||||
|
|
||||||
func (r *mockGraphSource) AddNode(node *channeldb.LightningNode) error {
|
func (r *mockGraphSource) AddNode(node *channeldb.LightningNode) error {
|
||||||
r.nodes = append(r.nodes, node)
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
r.nodes = append(r.nodes, *node)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *mockGraphSource) AddEdge(info *channeldb.ChannelEdgeInfo) error {
|
func (r *mockGraphSource) AddEdge(info *channeldb.ChannelEdgeInfo) error {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
if _, ok := r.infos[info.ChannelID]; ok {
|
if _, ok := r.infos[info.ChannelID]; ok {
|
||||||
return errors.New("info already exist")
|
return errors.New("info already exist")
|
||||||
}
|
}
|
||||||
@ -137,15 +145,15 @@ func (r *mockGraphSource) AddEdge(info *channeldb.ChannelEdgeInfo) error {
|
|||||||
// Usually, the capacity is fetched in the router from the funding txout.
|
// Usually, the capacity is fetched in the router from the funding txout.
|
||||||
// Since the mockGraphSource can't access the txout, assign a default value.
|
// Since the mockGraphSource can't access the txout, assign a default value.
|
||||||
info.Capacity = maxBtcFundingAmount
|
info.Capacity = maxBtcFundingAmount
|
||||||
r.infos[info.ChannelID] = info
|
r.infos[info.ChannelID] = *info
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *mockGraphSource) UpdateEdge(edge *channeldb.ChannelEdgePolicy) error {
|
func (r *mockGraphSource) UpdateEdge(edge *channeldb.ChannelEdgePolicy) error {
|
||||||
r.edges[edge.ChannelID] = append(
|
r.mu.Lock()
|
||||||
r.edges[edge.ChannelID],
|
defer r.mu.Unlock()
|
||||||
edge,
|
|
||||||
)
|
r.edges[edge.ChannelID] = append(r.edges[edge.ChannelID], *edge)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -159,11 +167,19 @@ func (r *mockGraphSource) CurrentBlockHeight() (uint32, error) {
|
|||||||
|
|
||||||
func (r *mockGraphSource) AddProof(chanID lnwire.ShortChannelID,
|
func (r *mockGraphSource) AddProof(chanID lnwire.ShortChannelID,
|
||||||
proof *channeldb.ChannelAuthProof) error {
|
proof *channeldb.ChannelAuthProof) error {
|
||||||
info, ok := r.infos[chanID.ToUint64()]
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
chanIDInt := chanID.ToUint64()
|
||||||
|
info, ok := r.infos[chanIDInt]
|
||||||
if !ok {
|
if !ok {
|
||||||
return errors.New("channel does not exist")
|
return errors.New("channel does not exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
info.AuthProof = proof
|
info.AuthProof = proof
|
||||||
|
r.infos[chanIDInt] = info
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -186,6 +202,9 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) (
|
|||||||
*channeldb.ChannelEdgePolicy,
|
*channeldb.ChannelEdgePolicy,
|
||||||
*channeldb.ChannelEdgePolicy, error) {
|
*channeldb.ChannelEdgePolicy, error) {
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
chanInfo, ok := r.infos[chanID.ToUint64()]
|
chanInfo, ok := r.infos[chanID.ToUint64()]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, nil, nil, channeldb.ErrEdgeNotFound
|
return nil, nil, nil, channeldb.ErrEdgeNotFound
|
||||||
@ -193,14 +212,16 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) (
|
|||||||
|
|
||||||
edges := r.edges[chanID.ToUint64()]
|
edges := r.edges[chanID.ToUint64()]
|
||||||
if len(edges) == 0 {
|
if len(edges) == 0 {
|
||||||
return chanInfo, nil, nil, nil
|
return &chanInfo, nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(edges) == 1 {
|
if len(edges) == 1 {
|
||||||
return chanInfo, edges[0], nil, nil
|
edge1 := edges[0]
|
||||||
|
return &chanInfo, &edge1, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return chanInfo, edges[0], edges[1], nil
|
edge1, edge2 := edges[0], edges[1]
|
||||||
|
return &chanInfo, &edge1, &edge2, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *mockGraphSource) FetchLightningNode(
|
func (r *mockGraphSource) FetchLightningNode(
|
||||||
@ -208,7 +229,7 @@ func (r *mockGraphSource) FetchLightningNode(
|
|||||||
|
|
||||||
for _, node := range r.nodes {
|
for _, node := range r.nodes {
|
||||||
if bytes.Equal(nodePub[:], node.PubKeyBytes[:]) {
|
if bytes.Equal(nodePub[:], node.PubKeyBytes[:]) {
|
||||||
return node, nil
|
return &node, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -218,6 +239,9 @@ func (r *mockGraphSource) FetchLightningNode(
|
|||||||
// IsStaleNode returns true if the graph source has a node announcement for the
|
// IsStaleNode returns true if the graph source has a node announcement for the
|
||||||
// target node with a more recent timestamp.
|
// target node with a more recent timestamp.
|
||||||
func (r *mockGraphSource) IsStaleNode(nodePub routing.Vertex, timestamp time.Time) bool {
|
func (r *mockGraphSource) IsStaleNode(nodePub routing.Vertex, timestamp time.Time) bool {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
for _, node := range r.nodes {
|
for _, node := range r.nodes {
|
||||||
if node.PubKeyBytes == nodePub {
|
if node.PubKeyBytes == nodePub {
|
||||||
return node.LastUpdate.After(timestamp) ||
|
return node.LastUpdate.After(timestamp) ||
|
||||||
@ -258,6 +282,9 @@ func (r *mockGraphSource) IsPublicNode(node routing.Vertex) (bool, error) {
|
|||||||
// IsKnownEdge returns true if the graph source already knows of the passed
|
// IsKnownEdge returns true if the graph source already knows of the passed
|
||||||
// channel ID.
|
// channel ID.
|
||||||
func (r *mockGraphSource) IsKnownEdge(chanID lnwire.ShortChannelID) bool {
|
func (r *mockGraphSource) IsKnownEdge(chanID lnwire.ShortChannelID) bool {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
_, ok := r.infos[chanID.ToUint64()]
|
_, ok := r.infos[chanID.ToUint64()]
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
@ -267,6 +294,9 @@ func (r *mockGraphSource) IsKnownEdge(chanID lnwire.ShortChannelID) bool {
|
|||||||
func (r *mockGraphSource) IsStaleEdgePolicy(chanID lnwire.ShortChannelID,
|
func (r *mockGraphSource) IsStaleEdgePolicy(chanID lnwire.ShortChannelID,
|
||||||
timestamp time.Time, flags lnwire.ChanUpdateChanFlags) bool {
|
timestamp time.Time, flags lnwire.ChanUpdateChanFlags) bool {
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
edges, ok := r.edges[chanID.ToUint64()]
|
edges, ok := r.edges[chanID.ToUint64()]
|
||||||
if !ok {
|
if !ok {
|
||||||
return false
|
return false
|
||||||
|
Loading…
Reference in New Issue
Block a user