discovery: add a mutex in order to make deDupedAnnouncements thread-safe
This commit is contained in:
parent
f4f476fe9f
commit
2dcd2b8a6d
@ -397,11 +397,11 @@ type channelUpdateID struct {
|
|||||||
flags uint16
|
flags uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
// deDupedAnnouncements de-duplicates announcements that have been
|
// deDupedAnnouncements de-duplicates announcements that have been added to the
|
||||||
// added to the batch. Internally, announcements are stored in three maps
|
// batch. Internally, announcements are stored in three maps
|
||||||
// (one each for channel announcements, channel updates, and node
|
// (one each for channel announcements, channel updates, and node
|
||||||
// announcements). These maps keep track of unique announcements and
|
// announcements). These maps keep track of unique announcements and ensure no
|
||||||
// ensure no announcements are duplicated.
|
// announcements are duplicated.
|
||||||
type deDupedAnnouncements struct {
|
type deDupedAnnouncements struct {
|
||||||
// channelAnnouncements are identified by the short channel id field.
|
// channelAnnouncements are identified by the short channel id field.
|
||||||
channelAnnouncements map[lnwire.ShortChannelID]lnwire.Message
|
channelAnnouncements map[lnwire.ShortChannelID]lnwire.Message
|
||||||
@ -411,59 +411,80 @@ type deDupedAnnouncements struct {
|
|||||||
|
|
||||||
// nodeAnnouncements are identified by the Vertex field.
|
// nodeAnnouncements are identified by the Vertex field.
|
||||||
nodeAnnouncements map[routing.Vertex]lnwire.Message
|
nodeAnnouncements map[routing.Vertex]lnwire.Message
|
||||||
|
|
||||||
|
sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset operates on deDupedAnnouncements to reset storage of announcements
|
// Reset operates on deDupedAnnouncements to reset the storage of
|
||||||
|
// announcements.
|
||||||
func (d *deDupedAnnouncements) Reset() {
|
func (d *deDupedAnnouncements) Reset() {
|
||||||
|
d.Lock()
|
||||||
|
defer d.Unlock()
|
||||||
|
|
||||||
|
d.reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
// reset is the private version of the Reset method. We have this so we can
|
||||||
|
// call this method within method that are already holding the lock.
|
||||||
|
func (d *deDupedAnnouncements) reset() {
|
||||||
// Storage of each type of announcement (channel anouncements, channel
|
// Storage of each type of announcement (channel anouncements, channel
|
||||||
// updates, node announcements) is set to an empty map where the
|
// updates, node announcements) is set to an empty map where the
|
||||||
// approprate key points to the corresponding lnwire.Message.
|
// appropriate key points to the corresponding lnwire.Message.
|
||||||
d.channelAnnouncements = make(map[lnwire.ShortChannelID]lnwire.Message)
|
d.channelAnnouncements = make(map[lnwire.ShortChannelID]lnwire.Message)
|
||||||
d.channelUpdates = make(map[channelUpdateID]lnwire.Message)
|
d.channelUpdates = make(map[channelUpdateID]lnwire.Message)
|
||||||
d.nodeAnnouncements = make(map[routing.Vertex]lnwire.Message)
|
d.nodeAnnouncements = make(map[routing.Vertex]lnwire.Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddMsg adds a new message to the current batch.
|
// addMsg adds a new message to the current batch.
|
||||||
func (d *deDupedAnnouncements) AddMsg(message lnwire.Message) {
|
func (d *deDupedAnnouncements) addMsg(message lnwire.Message) {
|
||||||
// Depending on the message type (channel announcement, channel
|
// Depending on the message type (channel announcement, channel update,
|
||||||
// update, or node announcement), the message is added to the
|
// or node announcement), the message is added to the corresponding map
|
||||||
// corresponding map in deDupedAnnouncements. Because each
|
// in deDupedAnnouncements. Because each identifying key can have at
|
||||||
// identifying key can have at most one value, the announcements
|
// most one value, the announcements are de-duplicated, with newer ones
|
||||||
// are de-duplicated, with newer ones replacing older ones.
|
// replacing older ones.
|
||||||
switch msg := message.(type) {
|
switch msg := message.(type) {
|
||||||
|
|
||||||
|
// Channel announcements are identified by the short channel id field.
|
||||||
case *lnwire.ChannelAnnouncement:
|
case *lnwire.ChannelAnnouncement:
|
||||||
// Channel announcements are identified by the short channel
|
|
||||||
// id field.
|
|
||||||
d.channelAnnouncements[msg.ShortChannelID] = msg
|
d.channelAnnouncements[msg.ShortChannelID] = msg
|
||||||
|
|
||||||
|
// Channel updates are identified by the (short channel id, flags)
|
||||||
|
// tuple.
|
||||||
case *lnwire.ChannelUpdate:
|
case *lnwire.ChannelUpdate:
|
||||||
// Channel updates are identified by the (short channel id,
|
|
||||||
// flags) tuple.
|
|
||||||
channelUpdateID := channelUpdateID{
|
channelUpdateID := channelUpdateID{
|
||||||
msg.ShortChannelID,
|
msg.ShortChannelID,
|
||||||
msg.Flags,
|
msg.Flags,
|
||||||
}
|
}
|
||||||
|
|
||||||
d.channelUpdates[channelUpdateID] = msg
|
d.channelUpdates[channelUpdateID] = msg
|
||||||
|
|
||||||
|
// Node announcements are identified by the Vertex field. Use the
|
||||||
|
// NodeID to create the corresponding Vertex.
|
||||||
case *lnwire.NodeAnnouncement:
|
case *lnwire.NodeAnnouncement:
|
||||||
// Node announcements are identified by the Vertex field.
|
|
||||||
// Use the NodeID to create the corresponding Vertex.
|
|
||||||
vertex := routing.NewVertex(msg.NodeID)
|
vertex := routing.NewVertex(msg.NodeID)
|
||||||
d.nodeAnnouncements[vertex] = msg
|
d.nodeAnnouncements[vertex] = msg
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddMsgs is a helper method to add multiple messages to the
|
// AddMsgs is a helper method to add multiple messages to the announcement
|
||||||
// announcement batch.
|
// batch.
|
||||||
func (d *deDupedAnnouncements) AddMsgs(msgs []lnwire.Message) {
|
func (d *deDupedAnnouncements) AddMsgs(msgs ...lnwire.Message) {
|
||||||
|
d.Lock()
|
||||||
|
defer d.Unlock()
|
||||||
|
|
||||||
for _, msg := range msgs {
|
for _, msg := range msgs {
|
||||||
d.AddMsg(msg)
|
d.addMsg(msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Batch returns the set of de-duplicated announcements to be sent out
|
// Emit returns the set of de-duplicated announcements to be sent out during
|
||||||
// during the next announcement epoch, in the order of channel announcements,
|
// the next announcement epoch, in the order of channel announcements, channel
|
||||||
// channel updates, and node announcements.
|
// updates, and node announcements. Additionally, the set of stored messages
|
||||||
func (d *deDupedAnnouncements) Batch() []lnwire.Message {
|
// are reset.
|
||||||
|
func (d *deDupedAnnouncements) Emit() []lnwire.Message {
|
||||||
|
d.Lock()
|
||||||
|
defer d.Unlock()
|
||||||
|
|
||||||
// Get the total number of announcements.
|
// Get the total number of announcements.
|
||||||
numAnnouncements := len(d.channelAnnouncements) + len(d.channelUpdates) +
|
numAnnouncements := len(d.channelAnnouncements) + len(d.channelUpdates) +
|
||||||
len(d.nodeAnnouncements)
|
len(d.nodeAnnouncements)
|
||||||
@ -487,6 +508,8 @@ func (d *deDupedAnnouncements) Batch() []lnwire.Message {
|
|||||||
announcements = append(announcements, message)
|
announcements = append(announcements, message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
d.reset()
|
||||||
|
|
||||||
// Return the array of lnwire.messages.
|
// Return the array of lnwire.messages.
|
||||||
return announcements
|
return announcements
|
||||||
}
|
}
|
||||||
@ -500,11 +523,6 @@ func (d *deDupedAnnouncements) Batch() []lnwire.Message {
|
|||||||
func (d *AuthenticatedGossiper) networkHandler() {
|
func (d *AuthenticatedGossiper) networkHandler() {
|
||||||
defer d.wg.Done()
|
defer d.wg.Done()
|
||||||
|
|
||||||
// TODO(roasbeef): changes for spec compliance
|
|
||||||
// * buffer recv'd node ann until after chan ann that includes is
|
|
||||||
// created
|
|
||||||
// * can use mostly empty struct in db as place holder
|
|
||||||
|
|
||||||
// Initialize empty deDupedAnnouncements to store announcement batch.
|
// Initialize empty deDupedAnnouncements to store announcement batch.
|
||||||
announcements := deDupedAnnouncements{}
|
announcements := deDupedAnnouncements{}
|
||||||
announcements.Reset()
|
announcements.Reset()
|
||||||
|
@ -866,7 +866,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("can't create remote channel announcement: %v", err)
|
t.Fatalf("can't create remote channel announcement: %v", err)
|
||||||
}
|
}
|
||||||
announcements.AddMsg(ca)
|
announcements.AddMsgs(ca)
|
||||||
if len(announcements.channelAnnouncements) != 1 {
|
if len(announcements.channelAnnouncements) != 1 {
|
||||||
t.Fatal("new channel announcement not stored in batch")
|
t.Fatal("new channel announcement not stored in batch")
|
||||||
}
|
}
|
||||||
@ -879,7 +879,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("can't create remote channel announcement: %v", err)
|
t.Fatalf("can't create remote channel announcement: %v", err)
|
||||||
}
|
}
|
||||||
announcements.AddMsg(ca2)
|
announcements.AddMsgs(ca2)
|
||||||
if len(announcements.channelAnnouncements) != 1 {
|
if len(announcements.channelAnnouncements) != 1 {
|
||||||
t.Fatal("channel announcement not replaced in batch")
|
t.Fatal("channel announcement not replaced in batch")
|
||||||
}
|
}
|
||||||
@ -891,7 +891,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("can't create update announcement: %v", err)
|
t.Fatalf("can't create update announcement: %v", err)
|
||||||
}
|
}
|
||||||
announcements.AddMsg(ua)
|
announcements.AddMsgs(ua)
|
||||||
if len(announcements.channelUpdates) != 1 {
|
if len(announcements.channelUpdates) != 1 {
|
||||||
t.Fatal("new channel update not stored in batch")
|
t.Fatal("new channel update not stored in batch")
|
||||||
}
|
}
|
||||||
@ -902,7 +902,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("can't create update announcement: %v", err)
|
t.Fatalf("can't create update announcement: %v", err)
|
||||||
}
|
}
|
||||||
announcements.AddMsg(ua2)
|
announcements.AddMsgs(ua2)
|
||||||
if len(announcements.channelUpdates) != 1 {
|
if len(announcements.channelUpdates) != 1 {
|
||||||
t.Fatal("channel update not replaced in batch")
|
t.Fatal("channel update not replaced in batch")
|
||||||
}
|
}
|
||||||
@ -913,7 +913,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("can't create node announcement: %v", err)
|
t.Fatalf("can't create node announcement: %v", err)
|
||||||
}
|
}
|
||||||
announcements.AddMsg(na)
|
announcements.AddMsgs(na)
|
||||||
if len(announcements.nodeAnnouncements) != 1 {
|
if len(announcements.nodeAnnouncements) != 1 {
|
||||||
t.Fatal("new node announcement not stored in batch")
|
t.Fatal("new node announcement not stored in batch")
|
||||||
}
|
}
|
||||||
@ -923,7 +923,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("can't create node announcement: %v", err)
|
t.Fatalf("can't create node announcement: %v", err)
|
||||||
}
|
}
|
||||||
announcements.AddMsg(na2)
|
announcements.AddMsgs(na2)
|
||||||
if len(announcements.nodeAnnouncements) != 2 {
|
if len(announcements.nodeAnnouncements) != 2 {
|
||||||
t.Fatal("second node announcement not stored in batch")
|
t.Fatal("second node announcement not stored in batch")
|
||||||
}
|
}
|
||||||
@ -934,7 +934,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("can't create node announcement: %v", err)
|
t.Fatalf("can't create node announcement: %v", err)
|
||||||
}
|
}
|
||||||
announcements.AddMsg(na3)
|
announcements.AddMsgs(na3)
|
||||||
if len(announcements.nodeAnnouncements) != 2 {
|
if len(announcements.nodeAnnouncements) != 2 {
|
||||||
t.Fatal("second node announcement not replaced in batch")
|
t.Fatal("second node announcement not replaced in batch")
|
||||||
}
|
}
|
||||||
@ -946,14 +946,14 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("can't create node announcement: %v", err)
|
t.Fatalf("can't create node announcement: %v", err)
|
||||||
}
|
}
|
||||||
announcements.AddMsg(na4)
|
announcements.AddMsgs(na4)
|
||||||
if len(announcements.nodeAnnouncements) != 2 {
|
if len(announcements.nodeAnnouncements) != 2 {
|
||||||
t.Fatal("second node announcement not replaced again in batch")
|
t.Fatal("second node announcement not replaced again in batch")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure that announcement batch delivers channel announcements,
|
// Ensure that announcement batch delivers channel announcements,
|
||||||
// channel updates, and node announcements in proper order.
|
// channel updates, and node announcements in proper order.
|
||||||
batch := announcements.Batch()
|
batch := announcements.Emit()
|
||||||
if len(batch) != 4 {
|
if len(batch) != 4 {
|
||||||
t.Fatal("announcement batch incorrect length")
|
t.Fatal("announcement batch incorrect length")
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user