discovery: add a mutex in order to make deDupedAnnouncements thread-safe
This commit is contained in:
parent
f4f476fe9f
commit
2dcd2b8a6d
@ -397,11 +397,11 @@ type channelUpdateID struct {
|
||||
flags uint16
|
||||
}
|
||||
|
||||
// deDupedAnnouncements de-duplicates announcements that have been
|
||||
// added to the batch. Internally, announcements are stored in three maps
|
||||
// deDupedAnnouncements de-duplicates announcements that have been added to the
|
||||
// batch. Internally, announcements are stored in three maps
|
||||
// (one each for channel announcements, channel updates, and node
|
||||
// announcements). These maps keep track of unique announcements and
|
||||
// ensure no announcements are duplicated.
|
||||
// announcements). These maps keep track of unique announcements and ensure no
|
||||
// announcements are duplicated.
|
||||
type deDupedAnnouncements struct {
|
||||
// channelAnnouncements are identified by the short channel id field.
|
||||
channelAnnouncements map[lnwire.ShortChannelID]lnwire.Message
|
||||
@ -411,59 +411,80 @@ type deDupedAnnouncements struct {
|
||||
|
||||
// nodeAnnouncements are identified by the Vertex field.
|
||||
nodeAnnouncements map[routing.Vertex]lnwire.Message
|
||||
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
// Reset operates on deDupedAnnouncements to reset storage of announcements
|
||||
// Reset operates on deDupedAnnouncements to reset the storage of
|
||||
// announcements.
|
||||
func (d *deDupedAnnouncements) Reset() {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
|
||||
d.reset()
|
||||
}
|
||||
|
||||
// reset is the private version of the Reset method. We have this so we can
|
||||
// call this method within method that are already holding the lock.
|
||||
func (d *deDupedAnnouncements) reset() {
|
||||
// Storage of each type of announcement (channel anouncements, channel
|
||||
// updates, node announcements) is set to an empty map where the
|
||||
// approprate key points to the corresponding lnwire.Message.
|
||||
// appropriate key points to the corresponding lnwire.Message.
|
||||
d.channelAnnouncements = make(map[lnwire.ShortChannelID]lnwire.Message)
|
||||
d.channelUpdates = make(map[channelUpdateID]lnwire.Message)
|
||||
d.nodeAnnouncements = make(map[routing.Vertex]lnwire.Message)
|
||||
}
|
||||
|
||||
// AddMsg adds a new message to the current batch.
|
||||
func (d *deDupedAnnouncements) AddMsg(message lnwire.Message) {
|
||||
// Depending on the message type (channel announcement, channel
|
||||
// update, or node announcement), the message is added to the
|
||||
// corresponding map in deDupedAnnouncements. Because each
|
||||
// identifying key can have at most one value, the announcements
|
||||
// are de-duplicated, with newer ones replacing older ones.
|
||||
// addMsg adds a new message to the current batch.
|
||||
func (d *deDupedAnnouncements) addMsg(message lnwire.Message) {
|
||||
// Depending on the message type (channel announcement, channel update,
|
||||
// or node announcement), the message is added to the corresponding map
|
||||
// in deDupedAnnouncements. Because each identifying key can have at
|
||||
// most one value, the announcements are de-duplicated, with newer ones
|
||||
// replacing older ones.
|
||||
switch msg := message.(type) {
|
||||
|
||||
// Channel announcements are identified by the short channel id field.
|
||||
case *lnwire.ChannelAnnouncement:
|
||||
// Channel announcements are identified by the short channel
|
||||
// id field.
|
||||
d.channelAnnouncements[msg.ShortChannelID] = msg
|
||||
|
||||
// Channel updates are identified by the (short channel id, flags)
|
||||
// tuple.
|
||||
case *lnwire.ChannelUpdate:
|
||||
// Channel updates are identified by the (short channel id,
|
||||
// flags) tuple.
|
||||
channelUpdateID := channelUpdateID{
|
||||
msg.ShortChannelID,
|
||||
msg.Flags,
|
||||
}
|
||||
|
||||
d.channelUpdates[channelUpdateID] = msg
|
||||
|
||||
// Node announcements are identified by the Vertex field. Use the
|
||||
// NodeID to create the corresponding Vertex.
|
||||
case *lnwire.NodeAnnouncement:
|
||||
// Node announcements are identified by the Vertex field.
|
||||
// Use the NodeID to create the corresponding Vertex.
|
||||
vertex := routing.NewVertex(msg.NodeID)
|
||||
d.nodeAnnouncements[vertex] = msg
|
||||
}
|
||||
}
|
||||
|
||||
// AddMsgs is a helper method to add multiple messages to the
|
||||
// announcement batch.
|
||||
func (d *deDupedAnnouncements) AddMsgs(msgs []lnwire.Message) {
|
||||
// AddMsgs is a helper method to add multiple messages to the announcement
|
||||
// batch.
|
||||
func (d *deDupedAnnouncements) AddMsgs(msgs ...lnwire.Message) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
|
||||
for _, msg := range msgs {
|
||||
d.AddMsg(msg)
|
||||
d.addMsg(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Batch returns the set of de-duplicated announcements to be sent out
|
||||
// during the next announcement epoch, in the order of channel announcements,
|
||||
// channel updates, and node announcements.
|
||||
func (d *deDupedAnnouncements) Batch() []lnwire.Message {
|
||||
// Emit returns the set of de-duplicated announcements to be sent out during
|
||||
// the next announcement epoch, in the order of channel announcements, channel
|
||||
// updates, and node announcements. Additionally, the set of stored messages
|
||||
// are reset.
|
||||
func (d *deDupedAnnouncements) Emit() []lnwire.Message {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
|
||||
// Get the total number of announcements.
|
||||
numAnnouncements := len(d.channelAnnouncements) + len(d.channelUpdates) +
|
||||
len(d.nodeAnnouncements)
|
||||
@ -487,6 +508,8 @@ func (d *deDupedAnnouncements) Batch() []lnwire.Message {
|
||||
announcements = append(announcements, message)
|
||||
}
|
||||
|
||||
d.reset()
|
||||
|
||||
// Return the array of lnwire.messages.
|
||||
return announcements
|
||||
}
|
||||
@ -500,11 +523,6 @@ func (d *deDupedAnnouncements) Batch() []lnwire.Message {
|
||||
func (d *AuthenticatedGossiper) networkHandler() {
|
||||
defer d.wg.Done()
|
||||
|
||||
// TODO(roasbeef): changes for spec compliance
|
||||
// * buffer recv'd node ann until after chan ann that includes is
|
||||
// created
|
||||
// * can use mostly empty struct in db as place holder
|
||||
|
||||
// Initialize empty deDupedAnnouncements to store announcement batch.
|
||||
announcements := deDupedAnnouncements{}
|
||||
announcements.Reset()
|
||||
|
@ -866,7 +866,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("can't create remote channel announcement: %v", err)
|
||||
}
|
||||
announcements.AddMsg(ca)
|
||||
announcements.AddMsgs(ca)
|
||||
if len(announcements.channelAnnouncements) != 1 {
|
||||
t.Fatal("new channel announcement not stored in batch")
|
||||
}
|
||||
@ -879,7 +879,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("can't create remote channel announcement: %v", err)
|
||||
}
|
||||
announcements.AddMsg(ca2)
|
||||
announcements.AddMsgs(ca2)
|
||||
if len(announcements.channelAnnouncements) != 1 {
|
||||
t.Fatal("channel announcement not replaced in batch")
|
||||
}
|
||||
@ -891,7 +891,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("can't create update announcement: %v", err)
|
||||
}
|
||||
announcements.AddMsg(ua)
|
||||
announcements.AddMsgs(ua)
|
||||
if len(announcements.channelUpdates) != 1 {
|
||||
t.Fatal("new channel update not stored in batch")
|
||||
}
|
||||
@ -902,7 +902,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("can't create update announcement: %v", err)
|
||||
}
|
||||
announcements.AddMsg(ua2)
|
||||
announcements.AddMsgs(ua2)
|
||||
if len(announcements.channelUpdates) != 1 {
|
||||
t.Fatal("channel update not replaced in batch")
|
||||
}
|
||||
@ -913,7 +913,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("can't create node announcement: %v", err)
|
||||
}
|
||||
announcements.AddMsg(na)
|
||||
announcements.AddMsgs(na)
|
||||
if len(announcements.nodeAnnouncements) != 1 {
|
||||
t.Fatal("new node announcement not stored in batch")
|
||||
}
|
||||
@ -923,7 +923,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("can't create node announcement: %v", err)
|
||||
}
|
||||
announcements.AddMsg(na2)
|
||||
announcements.AddMsgs(na2)
|
||||
if len(announcements.nodeAnnouncements) != 2 {
|
||||
t.Fatal("second node announcement not stored in batch")
|
||||
}
|
||||
@ -934,7 +934,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("can't create node announcement: %v", err)
|
||||
}
|
||||
announcements.AddMsg(na3)
|
||||
announcements.AddMsgs(na3)
|
||||
if len(announcements.nodeAnnouncements) != 2 {
|
||||
t.Fatal("second node announcement not replaced in batch")
|
||||
}
|
||||
@ -946,14 +946,14 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("can't create node announcement: %v", err)
|
||||
}
|
||||
announcements.AddMsg(na4)
|
||||
announcements.AddMsgs(na4)
|
||||
if len(announcements.nodeAnnouncements) != 2 {
|
||||
t.Fatal("second node announcement not replaced again in batch")
|
||||
}
|
||||
|
||||
// Ensure that announcement batch delivers channel announcements,
|
||||
// channel updates, and node announcements in proper order.
|
||||
batch := announcements.Batch()
|
||||
batch := announcements.Emit()
|
||||
if len(batch) != 4 {
|
||||
t.Fatal("announcement batch incorrect length")
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user