diff --git a/chancloser.go b/chancloser.go index 3536f1ce..00dff1bf 100644 --- a/chancloser.go +++ b/chancloser.go @@ -155,6 +155,9 @@ type channelCloser struct { // remoteDeliveryScript is the script that we'll send the remote // party's settled channel funds to. remoteDeliveryScript []byte + + // locallyInitiated is true if we initiated the channel close. + locallyInitiated bool } // newChannelCloser creates a new instance of the channel closure given the @@ -162,7 +165,7 @@ type channelCloser struct { // only be populated iff, we're the initiator of this closing request. func newChannelCloser(cfg chanCloseCfg, deliveryScript []byte, idealFeePerKw chainfee.SatPerKWeight, negotiationHeight uint32, - closeReq *htlcswitch.ChanClose) *channelCloser { + closeReq *htlcswitch.ChanClose, locallyInitiated bool) *channelCloser { // Given the target fee-per-kw, we'll compute what our ideal _total_ // fee will be starting at for this fee negotiation. @@ -198,6 +201,7 @@ func newChannelCloser(cfg chanCloseCfg, deliveryScript []byte, idealFeeSat: idealFeeSat, localDeliveryScript: deliveryScript, priorFeeOffers: make(map[btcutil.Amount]*lnwire.ClosingSigned), + locallyInitiated: locallyInitiated, } } @@ -224,7 +228,7 @@ func (c *channelCloser) initChanShutdown() (*lnwire.Shutdown, error) { // guarantees that our listchannels rpc will be externally consistent, // and reflect that the channel is being shutdown by the time the // closing request returns. - err := c.cfg.channel.MarkCoopBroadcasted(nil) + err := c.cfg.channel.MarkCoopBroadcasted(nil, c.locallyInitiated) if err != nil { return nil, err } @@ -511,7 +515,9 @@ func (c *channelCloser) ProcessCloseMsg(msg lnwire.Message) ([]lnwire.Message, b // Before publishing the closing tx, we persist it to the // database, such that it can be republished if something goes // wrong. - err = c.cfg.channel.MarkCoopBroadcasted(closeTx) + err = c.cfg.channel.MarkCoopBroadcasted( + closeTx, c.locallyInitiated, + ) if err != nil { return nil, false, err } diff --git a/channeldb/channel.go b/channeldb/channel.go index 3e340390..bc205488 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -401,20 +401,35 @@ var ( // will have. ChanStatusRestored ChannelStatus = 1 << 3 - // ChanStatusCoopBroadcasted indicates that a cooperative close for this - // channel has been broadcasted. + // ChanStatusCoopBroadcasted indicates that a cooperative close for + // this channel has been broadcasted. Older cooperatively closed + // channels will only have this status set. Newer ones will also have + // close initiator information stored using the local/remote initiator + // status. This status is set in conjunction with the initiator status + // so that we do not need to check multiple channel statues for + // cooperative closes. ChanStatusCoopBroadcasted ChannelStatus = 1 << 4 + + // ChanStatusLocalCloseInitiator indicates that we initiated closing + // the channel. + ChanStatusLocalCloseInitiator ChannelStatus = 1 << 5 + + // ChanStatusRemoteCloseInitiator indicates that the remote node + // initiated closing the channel. + ChanStatusRemoteCloseInitiator ChannelStatus = 1 << 6 ) // chanStatusStrings maps a ChannelStatus to a human friendly string that // describes that status. var chanStatusStrings = map[ChannelStatus]string{ - ChanStatusDefault: "ChanStatusDefault", - ChanStatusBorked: "ChanStatusBorked", - ChanStatusCommitBroadcasted: "ChanStatusCommitBroadcasted", - ChanStatusLocalDataLoss: "ChanStatusLocalDataLoss", - ChanStatusRestored: "ChanStatusRestored", - ChanStatusCoopBroadcasted: "ChanStatusCoopBroadcasted", + ChanStatusDefault: "ChanStatusDefault", + ChanStatusBorked: "ChanStatusBorked", + ChanStatusCommitBroadcasted: "ChanStatusCommitBroadcasted", + ChanStatusLocalDataLoss: "ChanStatusLocalDataLoss", + ChanStatusRestored: "ChanStatusRestored", + ChanStatusCoopBroadcasted: "ChanStatusCoopBroadcasted", + ChanStatusLocalCloseInitiator: "ChanStatusLocalCloseInitiator", + ChanStatusRemoteCloseInitiator: "ChanStatusRemoteCloseInitiator", } // orderedChanStatusFlags is an in-order list of all that channel status flags. @@ -425,6 +440,8 @@ var orderedChanStatusFlags = []ChannelStatus{ ChanStatusLocalDataLoss, ChanStatusRestored, ChanStatusCoopBroadcasted, + ChanStatusLocalCloseInitiator, + ChanStatusRemoteCloseInitiator, } // String returns a human-readable representation of the ChannelStatus. @@ -974,30 +991,37 @@ func (c *OpenChannel) isBorked(chanBucket *bbolt.Bucket) (bool, error) { // closing tx _we believe_ will appear in the chain. This is only used to // republish this tx at startup to ensure propagation, and we should still // handle the case where a different tx actually hits the chain. -func (c *OpenChannel) MarkCommitmentBroadcasted(closeTx *wire.MsgTx) error { +func (c *OpenChannel) MarkCommitmentBroadcasted(closeTx *wire.MsgTx, + locallyInitiated bool) error { + return c.markBroadcasted( ChanStatusCommitBroadcasted, forceCloseTxKey, closeTx, + locallyInitiated, ) } // MarkCoopBroadcasted marks the channel to indicate that a cooperative close // transaction has been broadcast, either our own or the remote, and that we -// should wach the chain for it to confirm before taking further action. It +// should watch the chain for it to confirm before taking further action. It // takes as argument a cooperative close tx that could appear on chain, and -// should be rebroadcast upon startup. This is only used to republish and ensure -// propagation, and we should still handle the case where a different tx +// should be rebroadcast upon startup. This is only used to republish and +// ensure propagation, and we should still handle the case where a different tx // actually hits the chain. -func (c *OpenChannel) MarkCoopBroadcasted(closeTx *wire.MsgTx) error { +func (c *OpenChannel) MarkCoopBroadcasted(closeTx *wire.MsgTx, + locallyInitiated bool) error { + return c.markBroadcasted( ChanStatusCoopBroadcasted, coopCloseTxKey, closeTx, + locallyInitiated, ) } // markBroadcasted is a helper function which modifies the channel status of the // receiving channel and inserts a close transaction under the requested key, -// which should specify either a coop or force close. +// which should specify either a coop or force close. It adds a status which +// indicates the party that initiated the channel close. func (c *OpenChannel) markBroadcasted(status ChannelStatus, key []byte, - closeTx *wire.MsgTx) error { + closeTx *wire.MsgTx, locallyInitiated bool) error { c.Lock() defer c.Unlock() @@ -1016,6 +1040,15 @@ func (c *OpenChannel) markBroadcasted(status ChannelStatus, key []byte, } } + // Add the initiator status to the status provided. These statuses are + // set in addition to the broadcast status so that we do not need to + // migrate the original logic which does not store initiator. + if locallyInitiated { + status |= ChanStatusLocalCloseInitiator + } else { + status |= ChanStatusRemoteCloseInitiator + } + return c.putChanStatus(status, putClosingTx) } diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index cb29b521..ad7a6975 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -1089,13 +1089,13 @@ func TestFetchWaitingCloseChannels(t *testing.T) { }, ) - if err := channel.MarkCommitmentBroadcasted(closeTx); err != nil { + if err := channel.MarkCommitmentBroadcasted(closeTx, true); err != nil { t.Fatalf("unable to mark commitment broadcast: %v", err) } // Now try to marking a coop close with a nil tx. This should // succeed, but it shouldn't exit when queried. - if err = channel.MarkCoopBroadcasted(nil); err != nil { + if err = channel.MarkCoopBroadcasted(nil, true); err != nil { t.Fatalf("unable to mark nil coop broadcast: %v", err) } _, err := channel.BroadcastedCooperative() @@ -1107,7 +1107,7 @@ func TestFetchWaitingCloseChannels(t *testing.T) { // it as coop closed. Later we will test that distinct // transactions are returned for both coop and force closes. closeTx.TxIn[0].PreviousOutPoint.Index ^= 1 - if err := channel.MarkCoopBroadcasted(closeTx); err != nil { + if err := channel.MarkCoopBroadcasted(closeTx, true); err != nil { t.Fatalf("unable to mark coop broadcast: %v", err) } } @@ -1255,3 +1255,105 @@ func TestRefreshShortChanID(t *testing.T) { t.Fatalf("channel pending state wasn't updated: want false got true") } } + +// TestCloseInitiator tests the setting of close initiator statuses for +// cooperative closes and local force closes. +func TestCloseInitiator(t *testing.T) { + tests := []struct { + name string + // updateChannel is called to update the channel as broadcast, + // cooperatively or not, based on the test's requirements. + updateChannel func(c *OpenChannel) error + expectedStatuses []ChannelStatus + }{ + { + name: "local coop close", + // Mark the channel as cooperatively closed, initiated + // by the local party. + updateChannel: func(c *OpenChannel) error { + return c.MarkCoopBroadcasted( + &wire.MsgTx{}, true, + ) + }, + expectedStatuses: []ChannelStatus{ + ChanStatusLocalCloseInitiator, + ChanStatusCoopBroadcasted, + }, + }, + { + name: "remote coop close", + // Mark the channel as cooperatively closed, initiated + // by the remote party. + updateChannel: func(c *OpenChannel) error { + return c.MarkCoopBroadcasted( + &wire.MsgTx{}, false, + ) + }, + expectedStatuses: []ChannelStatus{ + ChanStatusRemoteCloseInitiator, + ChanStatusCoopBroadcasted, + }, + }, + { + name: "local force close", + // Mark the channel's commitment as broadcast with + // local initiator. + updateChannel: func(c *OpenChannel) error { + return c.MarkCommitmentBroadcasted( + &wire.MsgTx{}, true, + ) + }, + expectedStatuses: []ChannelStatus{ + ChanStatusLocalCloseInitiator, + ChanStatusCommitBroadcasted, + }, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", + err) + } + defer cleanUp() + + // Create an open channel. + channel := createTestChannel( + t, cdb, openChannelOption(), + ) + + err = test.updateChannel(channel) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Lookup open channels in the database. + dbChans, err := fetchChannels( + cdb, pendingChannelFilter(false), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(dbChans) != 1 { + t.Fatalf("expected 1 channel, got: %v", + len(dbChans)) + } + + // Check that the statuses that we expect were written + // to disk. + for _, status := range test.expectedStatuses { + if !dbChans[0].HasChanStatus(status) { + t.Fatalf("expected channel to have "+ + "status: %v, has status: %v", + status, dbChans[0].chanStatus) + } + } + }) + } +} diff --git a/channeldb/db_test.go b/channeldb/db_test.go index 4d678303..e678d2a5 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -647,7 +647,7 @@ func TestFetchChannels(t *testing.T) { channelIDOption(pendingWaitingChan), ) - err = pendingClosing.MarkCoopBroadcasted(nil) + err = pendingClosing.MarkCoopBroadcasted(nil, true) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -667,7 +667,7 @@ func TestFetchChannels(t *testing.T) { channelIDOption(openWaitingChan), openChannelOption(), ) - err = openClosing.MarkCoopBroadcasted(nil) + err = openClosing.MarkCoopBroadcasted(nil, true) if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/contractcourt/chain_arbitrator_test.go b/contractcourt/chain_arbitrator_test.go index 5710b14a..0ac7484a 100644 --- a/contractcourt/chain_arbitrator_test.go +++ b/contractcourt/chain_arbitrator_test.go @@ -62,12 +62,12 @@ func TestChainArbitratorRepublishCloses(t *testing.T) { for i := 0; i < numChans/2; i++ { closeTx := channels[i].FundingTxn.Copy() closeTx.TxIn[0].PreviousOutPoint = channels[i].FundingOutpoint - err := channels[i].MarkCommitmentBroadcasted(closeTx) + err := channels[i].MarkCommitmentBroadcasted(closeTx, true) if err != nil { t.Fatal(err) } - err = channels[i].MarkCoopBroadcasted(closeTx) + err = channels[i].MarkCoopBroadcasted(closeTx, true) if err != nil { t.Fatal(err) } diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index b11c81aa..6db666c9 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -98,7 +98,7 @@ type ChannelArbitratorConfig struct { // MarkCommitmentBroadcasted should mark the channel as the commitment // being broadcast, and we are waiting for the commitment to confirm. - MarkCommitmentBroadcasted func(*wire.MsgTx) error + MarkCommitmentBroadcasted func(*wire.MsgTx, bool) error // MarkChannelClosed marks the channel closed in the database, with the // passed close summary. After this method successfully returns we can @@ -797,8 +797,10 @@ func (c *ChannelArbitrator) stateStep( // Before publishing the transaction, we store it to the // database, such that we can re-publish later in case it - // didn't propagate. - if err := c.cfg.MarkCommitmentBroadcasted(closeTx); err != nil { + // didn't propagate. We initiated the force close, so we + // mark broadcast with local initiator set to true. + err = c.cfg.MarkCommitmentBroadcasted(closeTx, true) + if err != nil { log.Errorf("ChannelArbitrator(%v): unable to "+ "mark commitment broadcasted: %v", c.cfg.ChanPoint, err) diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index 9375d895..0abfb381 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -339,7 +339,7 @@ func createTestChannelArbitrator(t *testing.T, log ArbitratorLog) (*chanArbTestC } return summary, nil }, - MarkCommitmentBroadcasted: func(_ *wire.MsgTx) error { + MarkCommitmentBroadcasted: func(_ *wire.MsgTx, _ bool) error { return nil }, MarkChannelClosed: func(*channeldb.ChannelCloseSummary) error { diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 08ec5ebc..fa96f8c4 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -6433,22 +6433,28 @@ func (lc *LightningChannel) MarkBorked() error { // MarkCommitmentBroadcasted marks the channel as a commitment transaction has // been broadcast, either our own or the remote, and we should watch the chain -// for it to confirm before taking any further action. -func (lc *LightningChannel) MarkCommitmentBroadcasted(tx *wire.MsgTx) error { +// for it to confirm before taking any further action. It takes a boolean which +// indicates whether we initiated the close. +func (lc *LightningChannel) MarkCommitmentBroadcasted(tx *wire.MsgTx, + locallyInitiated bool) error { + lc.Lock() defer lc.Unlock() - return lc.channelState.MarkCommitmentBroadcasted(tx) + return lc.channelState.MarkCommitmentBroadcasted(tx, locallyInitiated) } // MarkCoopBroadcasted marks the channel as a cooperative close transaction has // been broadcast, and that we should watch the chain for it to confirm before -// taking any further action. -func (lc *LightningChannel) MarkCoopBroadcasted(tx *wire.MsgTx) error { +// taking any further action. It takes a locally initiated bool which is true +// if we initiated the cooperative close. +func (lc *LightningChannel) MarkCoopBroadcasted(tx *wire.MsgTx, + localInitiated bool) error { + lc.Lock() defer lc.Unlock() - return lc.channelState.MarkCoopBroadcasted(tx) + return lc.channelState.MarkCoopBroadcasted(tx, localInitiated) } // MarkDataLoss marks sets the channel status to LocalDataLoss and stores the diff --git a/peer.go b/peer.go index ea482d1f..02544fe4 100644 --- a/peer.go +++ b/peer.go @@ -2119,6 +2119,7 @@ func (p *peer) fetchActiveChanCloser(chanID lnwire.ChannelID) (*channelCloser, e feePerKw, uint32(startingHeight), nil, + false, ) p.activeChanCloses[chanID] = chanCloser } @@ -2231,6 +2232,7 @@ func (p *peer) handleLocalCloseReq(req *htlcswitch.ChanClose) { req.TargetFeePerKw, uint32(startingHeight), req, + true, ) p.activeChanCloses[chanID] = chanCloser