diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 0e2e51f4..1decc1e3 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -108,28 +108,36 @@ func (n *mockSigner) SignMessage(pubKey *btcec.PublicKey, } type mockGraphSource struct { - nodes []*channeldb.LightningNode - infos map[uint64]*channeldb.ChannelEdgeInfo - edges map[uint64][]*channeldb.ChannelEdgePolicy bestHeight uint32 + + mu sync.Mutex + nodes []channeldb.LightningNode + infos map[uint64]channeldb.ChannelEdgeInfo + edges map[uint64][]channeldb.ChannelEdgePolicy } func newMockRouter(height uint32) *mockGraphSource { return &mockGraphSource{ bestHeight: height, - infos: make(map[uint64]*channeldb.ChannelEdgeInfo), - edges: make(map[uint64][]*channeldb.ChannelEdgePolicy), + infos: make(map[uint64]channeldb.ChannelEdgeInfo), + edges: make(map[uint64][]channeldb.ChannelEdgePolicy), } } var _ routing.ChannelGraphSource = (*mockGraphSource)(nil) func (r *mockGraphSource) AddNode(node *channeldb.LightningNode) error { - r.nodes = append(r.nodes, node) + r.mu.Lock() + defer r.mu.Unlock() + + r.nodes = append(r.nodes, *node) return nil } func (r *mockGraphSource) AddEdge(info *channeldb.ChannelEdgeInfo) error { + r.mu.Lock() + defer r.mu.Unlock() + if _, ok := r.infos[info.ChannelID]; ok { return errors.New("info already exist") } @@ -137,15 +145,15 @@ func (r *mockGraphSource) AddEdge(info *channeldb.ChannelEdgeInfo) error { // Usually, the capacity is fetched in the router from the funding txout. // Since the mockGraphSource can't access the txout, assign a default value. info.Capacity = maxBtcFundingAmount - r.infos[info.ChannelID] = info + r.infos[info.ChannelID] = *info return nil } func (r *mockGraphSource) UpdateEdge(edge *channeldb.ChannelEdgePolicy) error { - r.edges[edge.ChannelID] = append( - r.edges[edge.ChannelID], - edge, - ) + r.mu.Lock() + defer r.mu.Unlock() + + r.edges[edge.ChannelID] = append(r.edges[edge.ChannelID], *edge) return nil } @@ -159,11 +167,19 @@ func (r *mockGraphSource) CurrentBlockHeight() (uint32, error) { func (r *mockGraphSource) AddProof(chanID lnwire.ShortChannelID, proof *channeldb.ChannelAuthProof) error { - info, ok := r.infos[chanID.ToUint64()] + + r.mu.Lock() + defer r.mu.Unlock() + + chanIDInt := chanID.ToUint64() + info, ok := r.infos[chanIDInt] if !ok { return errors.New("channel does not exist") } + info.AuthProof = proof + r.infos[chanIDInt] = info + return nil } @@ -186,6 +202,9 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) ( *channeldb.ChannelEdgePolicy, *channeldb.ChannelEdgePolicy, error) { + r.mu.Lock() + defer r.mu.Unlock() + chanInfo, ok := r.infos[chanID.ToUint64()] if !ok { return nil, nil, nil, channeldb.ErrEdgeNotFound @@ -193,14 +212,16 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) ( edges := r.edges[chanID.ToUint64()] if len(edges) == 0 { - return chanInfo, nil, nil, nil + return &chanInfo, nil, nil, nil } if len(edges) == 1 { - return chanInfo, edges[0], nil, nil + edge1 := edges[0] + return &chanInfo, &edge1, nil, nil } - return chanInfo, edges[0], edges[1], nil + edge1, edge2 := edges[0], edges[1] + return &chanInfo, &edge1, &edge2, nil } func (r *mockGraphSource) FetchLightningNode( @@ -208,7 +229,7 @@ func (r *mockGraphSource) FetchLightningNode( for _, node := range r.nodes { if bytes.Equal(nodePub[:], node.PubKeyBytes[:]) { - return node, nil + return &node, nil } } @@ -218,6 +239,9 @@ func (r *mockGraphSource) FetchLightningNode( // IsStaleNode returns true if the graph source has a node announcement for the // target node with a more recent timestamp. func (r *mockGraphSource) IsStaleNode(nodePub routing.Vertex, timestamp time.Time) bool { + r.mu.Lock() + defer r.mu.Unlock() + for _, node := range r.nodes { if node.PubKeyBytes == nodePub { return node.LastUpdate.After(timestamp) || @@ -258,6 +282,9 @@ func (r *mockGraphSource) IsPublicNode(node routing.Vertex) (bool, error) { // IsKnownEdge returns true if the graph source already knows of the passed // channel ID. func (r *mockGraphSource) IsKnownEdge(chanID lnwire.ShortChannelID) bool { + r.mu.Lock() + defer r.mu.Unlock() + _, ok := r.infos[chanID.ToUint64()] return ok } @@ -267,6 +294,9 @@ func (r *mockGraphSource) IsKnownEdge(chanID lnwire.ShortChannelID) bool { func (r *mockGraphSource) IsStaleEdgePolicy(chanID lnwire.ShortChannelID, timestamp time.Time, flags lnwire.ChanUpdateChanFlags) bool { + r.mu.Lock() + defer r.mu.Unlock() + edges, ok := r.edges[chanID.ToUint64()] if !ok { return false