From f13c1d278741d3b4c8b4c494f255848371efb86b Mon Sep 17 00:00:00 2001 From: Wilmer Paulino Date: Fri, 21 Sep 2018 17:03:56 -0700 Subject: [PATCH] channeldb: ensure channel buckets are only created once In this commit, we ensure that we only create the sub-bucket for channels once: at the time of creation. We do this as otherwise it's possible that a method that mutates a channel's state is called after it has already been closed on-chain, leading to the channel bucket being recreated. --- channeldb/channel.go | 126 +++++++++++++++---------------------------- channeldb/error.go | 4 ++ 2 files changed, 47 insertions(+), 83 deletions(-) diff --git a/channeldb/channel.go b/channeldb/channel.go index daaf8443..a284509a 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -495,7 +495,7 @@ func (c *OpenChannel) RefreshShortChanID() error { var sid lnwire.ShortChannelID err := c.Db.View(func(tx *bolt.Tx) error { - chanBucket, err := readChanBucket( + chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) if err != nil { @@ -521,59 +521,10 @@ func (c *OpenChannel) RefreshShortChanID() error { return nil } -// updateChanBucket is a helper function that returns a writable bucket that a +// fetchChanBucket is a helper function that returns the bucket where a // channel's data resides in given: the public key for the node, the outpoint, // and the chainhash that the channel resides on. -// -// NOTE: This function assumes that all the relevant descendent buckets already -// exist. -func updateChanBucket(tx *bolt.Tx, nodeKey *btcec.PublicKey, - outPoint *wire.OutPoint, chainHash chainhash.Hash) (*bolt.Bucket, error) { - - // First fetch the top level bucket which stores all data related to - // current, active channels. - openChanBucket := tx.Bucket(openChannelBucket) - if openChanBucket == nil { - return nil, ErrNoChanDBExists - } - - // Within this top level bucket, fetch the bucket dedicated to storing - // open channel data specific to the remote node. - nodePub := nodeKey.SerializeCompressed() - nodeChanBucket := openChanBucket.Bucket(nodePub) - if nodeChanBucket == nil { - return nil, ErrNoActiveChannels - } - - // We'll then recurse down an additional layer in order to fetch the - // bucket for this particular chain. - chainBucket, err := nodeChanBucket.CreateBucketIfNotExists(chainHash[:]) - if err != nil { - return nil, ErrNodeNotFound - } - - // With the bucket for the node fetched, we can now go down another - // level, creating the bucket (if it doesn't exist), for this channel - // itself. - var chanPointBuf bytes.Buffer - if err := writeOutpoint(&chanPointBuf, outPoint); err != nil { - return nil, fmt.Errorf("unable to write outpoint: %v", err) - } - chanBucket, err := chainBucket.CreateBucketIfNotExists( - chanPointBuf.Bytes(), - ) - if chanBucket == nil { - return nil, fmt.Errorf("unable to find bucket for "+ - "chan_point=%v", outPoint) - } - - return chanBucket, nil -} - -// readChanBucket is a helper function that returns a readable bucket that a -// channel's data resides in given: the public key for the node, the outpoint, -// and the chainhash that the channel resides on. -func readChanBucket(tx *bolt.Tx, nodeKey *btcec.PublicKey, +func fetchChanBucket(tx *bolt.Tx, nodeKey *btcec.PublicKey, outPoint *wire.OutPoint, chainHash chainhash.Hash) (*bolt.Bucket, error) { // First fetch the top level bucket which stores all data related to @@ -598,16 +549,15 @@ func readChanBucket(tx *bolt.Tx, nodeKey *btcec.PublicKey, return nil, ErrNoActiveChannels } - // With the bucket for the node fetched, we can now go down another - // level, for this channel itself. + // With the bucket for the node and chain fetched, we can now go down + // another level, for this channel itself. var chanPointBuf bytes.Buffer if err := writeOutpoint(&chanPointBuf, outPoint); err != nil { return nil, err } - chanBucket := chainBucket.Bucket(chanPointBuf.Bytes()) if chanBucket == nil { - return nil, ErrNoActiveChannels + return nil, ErrChannelNotFound } return chanBucket, nil @@ -663,7 +613,7 @@ func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error { defer c.Unlock() if err := c.Db.Update(func(tx *bolt.Tx) error { - chanBucket, err := updateChanBucket( + chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) if err != nil { @@ -699,7 +649,7 @@ func (c *OpenChannel) MarkDataLoss(commitPoint *btcec.PublicKey) error { var status ChannelStatus if err := c.Db.Update(func(tx *bolt.Tx) error { - chanBucket, err := updateChanBucket( + chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) if err != nil { @@ -743,11 +693,14 @@ func (c *OpenChannel) DataLossCommitPoint() (*btcec.PublicKey, error) { var commitPoint *btcec.PublicKey err := c.Db.View(func(tx *bolt.Tx) error { - chanBucket, err := readChanBucket(tx, c.IdentityPub, - &c.FundingOutpoint, c.ChainHash) - if err == ErrNoActiveChannels || err == ErrNoChanDBExists { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + switch err { + case nil: + case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound: return ErrNoCommitPoint - } else if err != nil { + default: return err } @@ -791,7 +744,7 @@ func (c *OpenChannel) MarkCommitmentBroadcasted() error { func (c *OpenChannel) putChanStatus(status ChannelStatus) error { if err := c.Db.Update(func(tx *bolt.Tx) error { - chanBucket, err := updateChanBucket( + chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) if err != nil { @@ -931,7 +884,7 @@ func (c *OpenChannel) UpdateCommitment(newCommitment *ChannelCommitment) error { defer c.Unlock() err := c.Db.Update(func(tx *bolt.Tx) error { - chanBucket, err := updateChanBucket( + chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) if err != nil { @@ -1352,7 +1305,7 @@ func (c *OpenChannel) AppendRemoteCommitChain(diff *CommitDiff) error { return c.Db.Update(func(tx *bolt.Tx) error { // First, we'll grab the writable bucket where this channel's // data resides. - chanBucket, err := updateChanBucket( + chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) if err != nil { @@ -1400,11 +1353,14 @@ func (c *OpenChannel) AppendRemoteCommitChain(diff *CommitDiff) error { func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) { var cd *CommitDiff err := c.Db.View(func(tx *bolt.Tx) error { - chanBucket, err := readChanBucket(tx, c.IdentityPub, - &c.FundingOutpoint, c.ChainHash) - if err == ErrNoActiveChannels || err == ErrNoChanDBExists { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + switch err { + case nil: + case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound: return ErrNoPendingCommit - } else if err != nil { + default: return err } @@ -1443,7 +1399,7 @@ func (c *OpenChannel) InsertNextRevocation(revKey *btcec.PublicKey) error { c.RemoteNextRevocation = revKey err := c.Db.Update(func(tx *bolt.Tx) error { - chanBucket, err := updateChanBucket( + chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) if err != nil { @@ -1473,7 +1429,7 @@ func (c *OpenChannel) AdvanceCommitChainTail(fwdPkg *FwdPkg) error { var newRemoteCommit *ChannelCommitment err := c.Db.Update(func(tx *bolt.Tx) error { - chanBucket, err := updateChanBucket( + chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) if err != nil { @@ -1642,8 +1598,7 @@ func (c *OpenChannel) RemoveFwdPkg(height uint64) error { // RevocationLogTail returns the "tail", or the end of the current revocation // log. This entry represents the last previous state for the remote node's -// commitment chain. The ChannelDelta returned by this method will always lag -// one state behind the most current (unrevoked) state of the remote node's +// commitment chain. The ChannelDelta returned by this method will always lag one state behind the most current (unrevoked) state of the remote node's // commitment chain. func (c *OpenChannel) RevocationLogTail() (*ChannelCommitment, error) { c.RLock() @@ -1657,8 +1612,9 @@ func (c *OpenChannel) RevocationLogTail() (*ChannelCommitment, error) { var commit ChannelCommitment if err := c.Db.View(func(tx *bolt.Tx) error { - chanBucket, err := readChanBucket(tx, c.IdentityPub, - &c.FundingOutpoint, c.ChainHash) + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) if err != nil { return err } @@ -1705,8 +1661,9 @@ func (c *OpenChannel) CommitmentHeight() (uint64, error) { err := c.Db.View(func(tx *bolt.Tx) error { // Get the bucket dedicated to storing the metadata for open // channels. - chanBucket, err := readChanBucket(tx, c.IdentityPub, - &c.FundingOutpoint, c.ChainHash) + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) if err != nil { return err } @@ -1737,8 +1694,9 @@ func (c *OpenChannel) FindPreviousState(updateNum uint64) (*ChannelCommitment, e var commit ChannelCommitment err := c.Db.View(func(tx *bolt.Tx) error { - chanBucket, err := readChanBucket(tx, c.IdentityPub, - &c.FundingOutpoint, c.ChainHash) + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) if err != nil { return err } @@ -2030,8 +1988,9 @@ func (c *OpenChannel) Snapshot() *ChannelSnapshot { // the local commitment, and the second returned is the remote commitment. func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitment, error) { err := c.Db.View(func(tx *bolt.Tx) error { - chanBucket, err := readChanBucket(tx, c.IdentityPub, - &c.FundingOutpoint, c.ChainHash) + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) if err != nil { return err } @@ -2051,8 +2010,9 @@ func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitmen // up to date information required to deliver justice. func (c *OpenChannel) RemoteRevocationStore() (shachain.Store, error) { err := c.Db.View(func(tx *bolt.Tx) error { - chanBucket, err := readChanBucket(tx, c.IdentityPub, - &c.FundingOutpoint, c.ChainHash) + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) if err != nil { return err } diff --git a/channeldb/error.go b/channeldb/error.go index e4df0a56..15cfa840 100644 --- a/channeldb/error.go +++ b/channeldb/error.go @@ -43,6 +43,10 @@ var ( // specific identity can't be found. ErrNodeNotFound = fmt.Errorf("link node with target identity not found") + // ErrChannelNotFound is returned when we attempt to locate a channel + // for a specific chain, but it is not found. + ErrChannelNotFound = fmt.Errorf("channel not found") + // ErrMetaNotFound is returned when meta bucket hasn't been // created. ErrMetaNotFound = fmt.Errorf("unable to locate meta information")