discovery: add a mutex in order to make deDupedAnnouncements thread-safe

This commit is contained in:
Olaoluwa Osuntokun 2017-11-29 16:21:08 -08:00
parent f4f476fe9f
commit 2dcd2b8a6d
No known key found for this signature in database
GPG Key ID: 964EA263DD637C21
2 changed files with 59 additions and 41 deletions

@ -397,11 +397,11 @@ type channelUpdateID struct {
flags uint16 flags uint16
} }
// deDupedAnnouncements de-duplicates announcements that have been // deDupedAnnouncements de-duplicates announcements that have been added to the
// added to the batch. Internally, announcements are stored in three maps // batch. Internally, announcements are stored in three maps
// (one each for channel announcements, channel updates, and node // (one each for channel announcements, channel updates, and node
// announcements). These maps keep track of unique announcements and // announcements). These maps keep track of unique announcements and ensure no
// ensure no announcements are duplicated. // announcements are duplicated.
type deDupedAnnouncements struct { type deDupedAnnouncements struct {
// channelAnnouncements are identified by the short channel id field. // channelAnnouncements are identified by the short channel id field.
channelAnnouncements map[lnwire.ShortChannelID]lnwire.Message channelAnnouncements map[lnwire.ShortChannelID]lnwire.Message
@ -411,59 +411,80 @@ type deDupedAnnouncements struct {
// nodeAnnouncements are identified by the Vertex field. // nodeAnnouncements are identified by the Vertex field.
nodeAnnouncements map[routing.Vertex]lnwire.Message nodeAnnouncements map[routing.Vertex]lnwire.Message
sync.Mutex
} }
// Reset operates on deDupedAnnouncements to reset storage of announcements // Reset operates on deDupedAnnouncements to reset the storage of
// announcements.
func (d *deDupedAnnouncements) Reset() { func (d *deDupedAnnouncements) Reset() {
d.Lock()
defer d.Unlock()
d.reset()
}
// reset is the private version of the Reset method. We have this so we can
// call this method within method that are already holding the lock.
func (d *deDupedAnnouncements) reset() {
// Storage of each type of announcement (channel anouncements, channel // Storage of each type of announcement (channel anouncements, channel
// updates, node announcements) is set to an empty map where the // updates, node announcements) is set to an empty map where the
// approprate key points to the corresponding lnwire.Message. // appropriate key points to the corresponding lnwire.Message.
d.channelAnnouncements = make(map[lnwire.ShortChannelID]lnwire.Message) d.channelAnnouncements = make(map[lnwire.ShortChannelID]lnwire.Message)
d.channelUpdates = make(map[channelUpdateID]lnwire.Message) d.channelUpdates = make(map[channelUpdateID]lnwire.Message)
d.nodeAnnouncements = make(map[routing.Vertex]lnwire.Message) d.nodeAnnouncements = make(map[routing.Vertex]lnwire.Message)
} }
// AddMsg adds a new message to the current batch. // addMsg adds a new message to the current batch.
func (d *deDupedAnnouncements) AddMsg(message lnwire.Message) { func (d *deDupedAnnouncements) addMsg(message lnwire.Message) {
// Depending on the message type (channel announcement, channel // Depending on the message type (channel announcement, channel update,
// update, or node announcement), the message is added to the // or node announcement), the message is added to the corresponding map
// corresponding map in deDupedAnnouncements. Because each // in deDupedAnnouncements. Because each identifying key can have at
// identifying key can have at most one value, the announcements // most one value, the announcements are de-duplicated, with newer ones
// are de-duplicated, with newer ones replacing older ones. // replacing older ones.
switch msg := message.(type) { switch msg := message.(type) {
// Channel announcements are identified by the short channel id field.
case *lnwire.ChannelAnnouncement: case *lnwire.ChannelAnnouncement:
// Channel announcements are identified by the short channel
// id field.
d.channelAnnouncements[msg.ShortChannelID] = msg d.channelAnnouncements[msg.ShortChannelID] = msg
// Channel updates are identified by the (short channel id, flags)
// tuple.
case *lnwire.ChannelUpdate: case *lnwire.ChannelUpdate:
// Channel updates are identified by the (short channel id,
// flags) tuple.
channelUpdateID := channelUpdateID{ channelUpdateID := channelUpdateID{
msg.ShortChannelID, msg.ShortChannelID,
msg.Flags, msg.Flags,
} }
d.channelUpdates[channelUpdateID] = msg d.channelUpdates[channelUpdateID] = msg
// Node announcements are identified by the Vertex field. Use the
// NodeID to create the corresponding Vertex.
case *lnwire.NodeAnnouncement: case *lnwire.NodeAnnouncement:
// Node announcements are identified by the Vertex field.
// Use the NodeID to create the corresponding Vertex.
vertex := routing.NewVertex(msg.NodeID) vertex := routing.NewVertex(msg.NodeID)
d.nodeAnnouncements[vertex] = msg d.nodeAnnouncements[vertex] = msg
} }
} }
// AddMsgs is a helper method to add multiple messages to the // AddMsgs is a helper method to add multiple messages to the announcement
// announcement batch. // batch.
func (d *deDupedAnnouncements) AddMsgs(msgs []lnwire.Message) { func (d *deDupedAnnouncements) AddMsgs(msgs ...lnwire.Message) {
d.Lock()
defer d.Unlock()
for _, msg := range msgs { for _, msg := range msgs {
d.AddMsg(msg) d.addMsg(msg)
} }
} }
// Batch returns the set of de-duplicated announcements to be sent out // Emit returns the set of de-duplicated announcements to be sent out during
// during the next announcement epoch, in the order of channel announcements, // the next announcement epoch, in the order of channel announcements, channel
// channel updates, and node announcements. // updates, and node announcements. Additionally, the set of stored messages
func (d *deDupedAnnouncements) Batch() []lnwire.Message { // are reset.
func (d *deDupedAnnouncements) Emit() []lnwire.Message {
d.Lock()
defer d.Unlock()
// Get the total number of announcements. // Get the total number of announcements.
numAnnouncements := len(d.channelAnnouncements) + len(d.channelUpdates) + numAnnouncements := len(d.channelAnnouncements) + len(d.channelUpdates) +
len(d.nodeAnnouncements) len(d.nodeAnnouncements)
@ -487,6 +508,8 @@ func (d *deDupedAnnouncements) Batch() []lnwire.Message {
announcements = append(announcements, message) announcements = append(announcements, message)
} }
d.reset()
// Return the array of lnwire.messages. // Return the array of lnwire.messages.
return announcements return announcements
} }
@ -500,11 +523,6 @@ func (d *deDupedAnnouncements) Batch() []lnwire.Message {
func (d *AuthenticatedGossiper) networkHandler() { func (d *AuthenticatedGossiper) networkHandler() {
defer d.wg.Done() defer d.wg.Done()
// TODO(roasbeef): changes for spec compliance
// * buffer recv'd node ann until after chan ann that includes is
// created
// * can use mostly empty struct in db as place holder
// Initialize empty deDupedAnnouncements to store announcement batch. // Initialize empty deDupedAnnouncements to store announcement batch.
announcements := deDupedAnnouncements{} announcements := deDupedAnnouncements{}
announcements.Reset() announcements.Reset()

@ -866,7 +866,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("can't create remote channel announcement: %v", err) t.Fatalf("can't create remote channel announcement: %v", err)
} }
announcements.AddMsg(ca) announcements.AddMsgs(ca)
if len(announcements.channelAnnouncements) != 1 { if len(announcements.channelAnnouncements) != 1 {
t.Fatal("new channel announcement not stored in batch") t.Fatal("new channel announcement not stored in batch")
} }
@ -879,7 +879,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("can't create remote channel announcement: %v", err) t.Fatalf("can't create remote channel announcement: %v", err)
} }
announcements.AddMsg(ca2) announcements.AddMsgs(ca2)
if len(announcements.channelAnnouncements) != 1 { if len(announcements.channelAnnouncements) != 1 {
t.Fatal("channel announcement not replaced in batch") t.Fatal("channel announcement not replaced in batch")
} }
@ -891,7 +891,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("can't create update announcement: %v", err) t.Fatalf("can't create update announcement: %v", err)
} }
announcements.AddMsg(ua) announcements.AddMsgs(ua)
if len(announcements.channelUpdates) != 1 { if len(announcements.channelUpdates) != 1 {
t.Fatal("new channel update not stored in batch") t.Fatal("new channel update not stored in batch")
} }
@ -902,7 +902,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("can't create update announcement: %v", err) t.Fatalf("can't create update announcement: %v", err)
} }
announcements.AddMsg(ua2) announcements.AddMsgs(ua2)
if len(announcements.channelUpdates) != 1 { if len(announcements.channelUpdates) != 1 {
t.Fatal("channel update not replaced in batch") t.Fatal("channel update not replaced in batch")
} }
@ -913,7 +913,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("can't create node announcement: %v", err) t.Fatalf("can't create node announcement: %v", err)
} }
announcements.AddMsg(na) announcements.AddMsgs(na)
if len(announcements.nodeAnnouncements) != 1 { if len(announcements.nodeAnnouncements) != 1 {
t.Fatal("new node announcement not stored in batch") t.Fatal("new node announcement not stored in batch")
} }
@ -923,7 +923,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("can't create node announcement: %v", err) t.Fatalf("can't create node announcement: %v", err)
} }
announcements.AddMsg(na2) announcements.AddMsgs(na2)
if len(announcements.nodeAnnouncements) != 2 { if len(announcements.nodeAnnouncements) != 2 {
t.Fatal("second node announcement not stored in batch") t.Fatal("second node announcement not stored in batch")
} }
@ -934,7 +934,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("can't create node announcement: %v", err) t.Fatalf("can't create node announcement: %v", err)
} }
announcements.AddMsg(na3) announcements.AddMsgs(na3)
if len(announcements.nodeAnnouncements) != 2 { if len(announcements.nodeAnnouncements) != 2 {
t.Fatal("second node announcement not replaced in batch") t.Fatal("second node announcement not replaced in batch")
} }
@ -946,14 +946,14 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("can't create node announcement: %v", err) t.Fatalf("can't create node announcement: %v", err)
} }
announcements.AddMsg(na4) announcements.AddMsgs(na4)
if len(announcements.nodeAnnouncements) != 2 { if len(announcements.nodeAnnouncements) != 2 {
t.Fatal("second node announcement not replaced again in batch") t.Fatal("second node announcement not replaced again in batch")
} }
// Ensure that announcement batch delivers channel announcements, // Ensure that announcement batch delivers channel announcements,
// channel updates, and node announcements in proper order. // channel updates, and node announcements in proper order.
batch := announcements.Batch() batch := announcements.Emit()
if len(batch) != 4 { if len(batch) != 4 {
t.Fatal("announcement batch incorrect length") t.Fatal("announcement batch incorrect length")
} }