discovery: add a mutex in order to make deDupedAnnouncements thread-safe

This commit is contained in:
Olaoluwa Osuntokun 2017-11-29 16:21:08 -08:00
parent f4f476fe9f
commit 2dcd2b8a6d
No known key found for this signature in database
GPG Key ID: 964EA263DD637C21
2 changed files with 59 additions and 41 deletions

@ -397,11 +397,11 @@ type channelUpdateID struct {
flags uint16
}
// deDupedAnnouncements de-duplicates announcements that have been
// added to the batch. Internally, announcements are stored in three maps
// deDupedAnnouncements de-duplicates announcements that have been added to the
// batch. Internally, announcements are stored in three maps
// (one each for channel announcements, channel updates, and node
// announcements). These maps keep track of unique announcements and
// ensure no announcements are duplicated.
// announcements). These maps keep track of unique announcements and ensure no
// announcements are duplicated.
type deDupedAnnouncements struct {
// channelAnnouncements are identified by the short channel id field.
channelAnnouncements map[lnwire.ShortChannelID]lnwire.Message
@ -411,59 +411,80 @@ type deDupedAnnouncements struct {
// nodeAnnouncements are identified by the Vertex field.
nodeAnnouncements map[routing.Vertex]lnwire.Message
sync.Mutex
}
// Reset operates on deDupedAnnouncements to reset storage of announcements
// Reset operates on deDupedAnnouncements to reset the storage of
// announcements.
func (d *deDupedAnnouncements) Reset() {
d.Lock()
defer d.Unlock()
d.reset()
}
// reset is the private version of the Reset method. We have this so we can
// call this method within method that are already holding the lock.
func (d *deDupedAnnouncements) reset() {
// Storage of each type of announcement (channel anouncements, channel
// updates, node announcements) is set to an empty map where the
// approprate key points to the corresponding lnwire.Message.
// appropriate key points to the corresponding lnwire.Message.
d.channelAnnouncements = make(map[lnwire.ShortChannelID]lnwire.Message)
d.channelUpdates = make(map[channelUpdateID]lnwire.Message)
d.nodeAnnouncements = make(map[routing.Vertex]lnwire.Message)
}
// AddMsg adds a new message to the current batch.
func (d *deDupedAnnouncements) AddMsg(message lnwire.Message) {
// Depending on the message type (channel announcement, channel
// update, or node announcement), the message is added to the
// corresponding map in deDupedAnnouncements. Because each
// identifying key can have at most one value, the announcements
// are de-duplicated, with newer ones replacing older ones.
// addMsg adds a new message to the current batch.
func (d *deDupedAnnouncements) addMsg(message lnwire.Message) {
// Depending on the message type (channel announcement, channel update,
// or node announcement), the message is added to the corresponding map
// in deDupedAnnouncements. Because each identifying key can have at
// most one value, the announcements are de-duplicated, with newer ones
// replacing older ones.
switch msg := message.(type) {
// Channel announcements are identified by the short channel id field.
case *lnwire.ChannelAnnouncement:
// Channel announcements are identified by the short channel
// id field.
d.channelAnnouncements[msg.ShortChannelID] = msg
// Channel updates are identified by the (short channel id, flags)
// tuple.
case *lnwire.ChannelUpdate:
// Channel updates are identified by the (short channel id,
// flags) tuple.
channelUpdateID := channelUpdateID{
msg.ShortChannelID,
msg.Flags,
}
d.channelUpdates[channelUpdateID] = msg
// Node announcements are identified by the Vertex field. Use the
// NodeID to create the corresponding Vertex.
case *lnwire.NodeAnnouncement:
// Node announcements are identified by the Vertex field.
// Use the NodeID to create the corresponding Vertex.
vertex := routing.NewVertex(msg.NodeID)
d.nodeAnnouncements[vertex] = msg
}
}
// AddMsgs is a helper method to add multiple messages to the
// announcement batch.
func (d *deDupedAnnouncements) AddMsgs(msgs []lnwire.Message) {
// AddMsgs is a helper method to add multiple messages to the announcement
// batch.
func (d *deDupedAnnouncements) AddMsgs(msgs ...lnwire.Message) {
d.Lock()
defer d.Unlock()
for _, msg := range msgs {
d.AddMsg(msg)
d.addMsg(msg)
}
}
// Batch returns the set of de-duplicated announcements to be sent out
// during the next announcement epoch, in the order of channel announcements,
// channel updates, and node announcements.
func (d *deDupedAnnouncements) Batch() []lnwire.Message {
// Emit returns the set of de-duplicated announcements to be sent out during
// the next announcement epoch, in the order of channel announcements, channel
// updates, and node announcements. Additionally, the set of stored messages
// are reset.
func (d *deDupedAnnouncements) Emit() []lnwire.Message {
d.Lock()
defer d.Unlock()
// Get the total number of announcements.
numAnnouncements := len(d.channelAnnouncements) + len(d.channelUpdates) +
len(d.nodeAnnouncements)
@ -487,6 +508,8 @@ func (d *deDupedAnnouncements) Batch() []lnwire.Message {
announcements = append(announcements, message)
}
d.reset()
// Return the array of lnwire.messages.
return announcements
}
@ -500,11 +523,6 @@ func (d *deDupedAnnouncements) Batch() []lnwire.Message {
func (d *AuthenticatedGossiper) networkHandler() {
defer d.wg.Done()
// TODO(roasbeef): changes for spec compliance
// * buffer recv'd node ann until after chan ann that includes is
// created
// * can use mostly empty struct in db as place holder
// Initialize empty deDupedAnnouncements to store announcement batch.
announcements := deDupedAnnouncements{}
announcements.Reset()

@ -866,7 +866,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
if err != nil {
t.Fatalf("can't create remote channel announcement: %v", err)
}
announcements.AddMsg(ca)
announcements.AddMsgs(ca)
if len(announcements.channelAnnouncements) != 1 {
t.Fatal("new channel announcement not stored in batch")
}
@ -879,7 +879,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
if err != nil {
t.Fatalf("can't create remote channel announcement: %v", err)
}
announcements.AddMsg(ca2)
announcements.AddMsgs(ca2)
if len(announcements.channelAnnouncements) != 1 {
t.Fatal("channel announcement not replaced in batch")
}
@ -891,7 +891,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
if err != nil {
t.Fatalf("can't create update announcement: %v", err)
}
announcements.AddMsg(ua)
announcements.AddMsgs(ua)
if len(announcements.channelUpdates) != 1 {
t.Fatal("new channel update not stored in batch")
}
@ -902,7 +902,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
if err != nil {
t.Fatalf("can't create update announcement: %v", err)
}
announcements.AddMsg(ua2)
announcements.AddMsgs(ua2)
if len(announcements.channelUpdates) != 1 {
t.Fatal("channel update not replaced in batch")
}
@ -913,7 +913,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
if err != nil {
t.Fatalf("can't create node announcement: %v", err)
}
announcements.AddMsg(na)
announcements.AddMsgs(na)
if len(announcements.nodeAnnouncements) != 1 {
t.Fatal("new node announcement not stored in batch")
}
@ -923,7 +923,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
if err != nil {
t.Fatalf("can't create node announcement: %v", err)
}
announcements.AddMsg(na2)
announcements.AddMsgs(na2)
if len(announcements.nodeAnnouncements) != 2 {
t.Fatal("second node announcement not stored in batch")
}
@ -934,7 +934,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
if err != nil {
t.Fatalf("can't create node announcement: %v", err)
}
announcements.AddMsg(na3)
announcements.AddMsgs(na3)
if len(announcements.nodeAnnouncements) != 2 {
t.Fatal("second node announcement not replaced in batch")
}
@ -946,14 +946,14 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
if err != nil {
t.Fatalf("can't create node announcement: %v", err)
}
announcements.AddMsg(na4)
announcements.AddMsgs(na4)
if len(announcements.nodeAnnouncements) != 2 {
t.Fatal("second node announcement not replaced again in batch")
}
// Ensure that announcement batch delivers channel announcements,
// channel updates, and node announcements in proper order.
batch := announcements.Batch()
batch := announcements.Emit()
if len(batch) != 4 {
t.Fatal("announcement batch incorrect length")
}