Merge pull request #1419 from cfromknecht/update-channel-packager-sid

Update channel packager sid
This commit is contained in:
Olaoluwa Osuntokun 2018-06-21 13:49:00 +01:00 committed by GitHub
commit 18f17ad49b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 0 deletions

@ -497,6 +497,7 @@ func (c *OpenChannel) RefreshShortChanID() error {
} }
c.ShortChannelID = sid c.ShortChannelID = sid
c.Packager = NewChannelPackager(sid)
return nil return nil
} }
@ -665,6 +666,7 @@ func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error {
c.IsPending = false c.IsPending = false
c.ShortChannelID = openLoc c.ShortChannelID = openLoc
c.Packager = NewChannelPackager(openLoc)
return nil return nil
} }
@ -1474,6 +1476,9 @@ func (c *OpenChannel) NextLocalHtlcIndex() (uint64, error) {
// processed, and returns their deserialized log updates in map indexed by the // processed, and returns their deserialized log updates in map indexed by the
// remote commitment height at which the updates were locked in. // remote commitment height at which the updates were locked in.
func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) { func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) {
c.RLock()
defer c.RUnlock()
var fwdPkgs []*FwdPkg var fwdPkgs []*FwdPkg
if err := c.Db.View(func(tx *bolt.Tx) error { if err := c.Db.View(func(tx *bolt.Tx) error {
var err error var err error
@ -1489,6 +1494,9 @@ func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) {
// SetFwdFilter atomically sets the forwarding filter for the forwarding package // SetFwdFilter atomically sets the forwarding filter for the forwarding package
// identified by `height`. // identified by `height`.
func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error { func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error {
c.Lock()
defer c.Unlock()
return c.Db.Update(func(tx *bolt.Tx) error { return c.Db.Update(func(tx *bolt.Tx) error {
return c.Packager.SetFwdFilter(tx, height, fwdFilter) return c.Packager.SetFwdFilter(tx, height, fwdFilter)
}) })
@ -1499,6 +1507,9 @@ func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error {
// //
// NOTE: This method should only be called on packages marked FwdStateCompleted. // NOTE: This method should only be called on packages marked FwdStateCompleted.
func (c *OpenChannel) RemoveFwdPkg(height uint64) error { func (c *OpenChannel) RemoveFwdPkg(height uint64) error {
c.Lock()
defer c.Unlock()
return c.Db.Update(func(tx *bolt.Tx) error { return c.Db.Update(func(tx *bolt.Tx) error {
return c.Packager.RemovePkg(tx, height) return c.Packager.RemovePkg(tx, height)
}) })

@ -898,6 +898,16 @@ func TestRefreshShortChanID(t *testing.T) {
"updated before refreshing short_chan_id") "updated before refreshing short_chan_id")
} }
// Now that the receiver's short channel id has been updated, check to
// ensure that the channel packager's source has been updated as well.
// This ensures that the packager will read and write to buckets
// corresponding to the new short chan id, instead of the prior.
if state.Packager.(*ChannelPackager).source != chanOpenLoc {
t.Fatalf("channel packager source was not updated: want %v, "+
"got %v", chanOpenLoc,
state.Packager.(*ChannelPackager).source)
}
// Now, refresh the short channel ID of the pending channel. // Now, refresh the short channel ID of the pending channel.
err = pendingChannel.RefreshShortChanID() err = pendingChannel.RefreshShortChanID()
if err != nil { if err != nil {
@ -911,4 +921,14 @@ func TestRefreshShortChanID(t *testing.T) {
"refreshed: want %v, got %v", state.ShortChanID(), "refreshed: want %v, got %v", state.ShortChanID(),
pendingChannel.ShortChanID()) pendingChannel.ShortChanID())
} }
// Check to ensure that the _other_ OpenChannel channel packager's
// source has also been updated after the refresh. This ensures that the
// other packagers will read and write to buckets corresponding to the
// updated short chan id.
if pendingChannel.Packager.(*ChannelPackager).source != chanOpenLoc {
t.Fatalf("channel packager source was not updated: want %v, "+
"got %v", chanOpenLoc,
pendingChannel.Packager.(*ChannelPackager).source)
}
} }