From 639c9875b2a573ab74c80d4d331b90886b0770de Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Tue, 19 Jun 2018 13:36:12 +0100 Subject: [PATCH 1/2] channeldb/channel_test: test packager source updated --- channeldb/channel_test.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index f7f331af..efd36abb 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -898,6 +898,16 @@ func TestRefreshShortChanID(t *testing.T) { "updated before refreshing short_chan_id") } + // Now that the receiver's short channel id has been updated, check to + // ensure that the channel packager's source has been updated as well. + // This ensures that the packager will read and write to buckets + // corresponding to the new short chan id, instead of the prior. + if state.Packager.(*ChannelPackager).source != chanOpenLoc { + t.Fatalf("channel packager source was not updated: want %v, "+ + "got %v", chanOpenLoc, + state.Packager.(*ChannelPackager).source) + } + // Now, refresh the short channel ID of the pending channel. err = pendingChannel.RefreshShortChanID() if err != nil { @@ -911,4 +921,14 @@ func TestRefreshShortChanID(t *testing.T) { "refreshed: want %v, got %v", state.ShortChanID(), pendingChannel.ShortChanID()) } + + // Check to ensure that the _other_ OpenChannel channel packager's + // source has also been updated after the refresh. This ensures that the + // other packagers will read and write to buckets corresponding to the + // updated short chan id. + if pendingChannel.Packager.(*ChannelPackager).source != chanOpenLoc { + t.Fatalf("channel packager source was not updated: want %v, "+ + "got %v", chanOpenLoc, + pendingChannel.Packager.(*ChannelPackager).source) + } } From 56e5eed0372a7505d1dcf7bb5031a6a4a46651b6 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Tue, 19 Jun 2018 13:12:59 +0100 Subject: [PATCH 2/2] channeldb/channel: update short chan id for fwd packager --- channeldb/channel.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/channeldb/channel.go b/channeldb/channel.go index f91c914d..9d7d01ba 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -497,6 +497,7 @@ func (c *OpenChannel) RefreshShortChanID() error { } c.ShortChannelID = sid + c.Packager = NewChannelPackager(sid) return nil } @@ -665,6 +666,7 @@ func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error { c.IsPending = false c.ShortChannelID = openLoc + c.Packager = NewChannelPackager(openLoc) return nil } @@ -1474,6 +1476,9 @@ func (c *OpenChannel) NextLocalHtlcIndex() (uint64, error) { // processed, and returns their deserialized log updates in map indexed by the // remote commitment height at which the updates were locked in. func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) { + c.RLock() + defer c.RUnlock() + var fwdPkgs []*FwdPkg if err := c.Db.View(func(tx *bolt.Tx) error { var err error @@ -1489,6 +1494,9 @@ func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) { // SetFwdFilter atomically sets the forwarding filter for the forwarding package // identified by `height`. func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error { + c.Lock() + defer c.Unlock() + return c.Db.Update(func(tx *bolt.Tx) error { return c.Packager.SetFwdFilter(tx, height, fwdFilter) }) @@ -1499,6 +1507,9 @@ func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error { // // NOTE: This method should only be called on packages marked FwdStateCompleted. func (c *OpenChannel) RemoveFwdPkg(height uint64) error { + c.Lock() + defer c.Unlock() + return c.Db.Update(func(tx *bolt.Tx) error { return c.Packager.RemovePkg(tx, height) })