diff --git a/server.go b/server.go index ccf71ad3..5aa49662 100644 --- a/server.go +++ b/server.go @@ -780,7 +780,6 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, return time.NewTimer(sweep.DefaultBatchWindowDuration).C }, Notifier: cc.chainNotifier, - ChainIO: cc.chainIO, Store: sweeperStore, MaxInputsPerTx: sweep.DefaultMaxInputsPerTx, MaxSweepAttempts: sweep.DefaultMaxSweepAttempts, diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 03fcf6ca..2e3a1e1c 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -212,9 +212,6 @@ type UtxoSweeperConfig struct { // certain on-chain events. Notifier chainntnfs.ChainNotifier - // ChainIO is used to determine the current block height. - ChainIO lnwallet.BlockChainIO - // Store stores the published sweeper txes. Store SweeperStore @@ -323,20 +320,10 @@ func (s *UtxoSweeper) Start() error { // not change from here on. s.relayFeeRate = s.cfg.FeeEstimator.RelayFeePerKW() - // Register for block epochs to retry sweeping every block. - bestHash, bestHeight, err := s.cfg.ChainIO.GetBestBlock() - if err != nil { - return fmt.Errorf("get best block: %v", err) - } - - log.Debugf("Best height: %v", bestHeight) - - blockEpochs, err := s.cfg.Notifier.RegisterBlockEpochNtfn( - &chainntnfs.BlockEpoch{ - Height: bestHeight, - Hash: bestHash, - }, - ) + // We need to register for block epochs and retry sweeping every block. + // We should get a notification with the current best block immediately + // if we don't provide any epoch. We'll wait for that in the collector. + blockEpochs, err := s.cfg.Notifier.RegisterBlockEpochNtfn(nil) if err != nil { return fmt.Errorf("register block epoch ntfn: %v", err) } @@ -347,10 +334,7 @@ func (s *UtxoSweeper) Start() error { defer blockEpochs.Cancel() defer s.wg.Done() - err := s.collector(blockEpochs.Epochs, bestHeight) - if err != nil { - log.Errorf("sweeper stopped: %v", err) - } + s.collector(blockEpochs.Epochs) }() return nil @@ -445,8 +429,18 @@ func (s *UtxoSweeper) feeRateForPreference( // collector is the sweeper main loop. It processes new inputs, spend // notifications and counts down to publication of the sweep tx. -func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch, - bestHeight int32) error { +func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { + // We registered for the block epochs with a nil request. The notifier + // should send us the current best block immediately. So we need to wait + // for it here because we need to know the current best height. + var bestHeight int32 + select { + case bestBlock := <-blockEpochs: + bestHeight = bestBlock.Height + + case <-s.quit: + return + } for { select { @@ -622,7 +616,7 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch, // sweep. case epoch, ok := <-blockEpochs: if !ok { - return nil + return } bestHeight = epoch.Height @@ -635,7 +629,7 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch, } case <-s.quit: - return nil + return } } } diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index b4139fba..501ec3f8 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -130,9 +130,8 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext { ctx.timeoutChan <- c return c }, - Store: store, - Signer: &mockSigner{}, - ChainIO: &mockChainIO{}, + Store: store, + Signer: &mockSigner{}, GenSweepScript: func() ([]byte, error) { script := []byte{outputScriptCount} outputScriptCount++ diff --git a/sweep/test_utils.go b/sweep/test_utils.go index 46ee6dc3..df44cd10 100644 --- a/sweep/test_utils.go +++ b/sweep/test_utils.go @@ -10,12 +10,12 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/input" - "github.com/lightningnetwork/lnd/lnwallet" ) var ( defaultTestTimeout = 5 * time.Second - mockChainIOHeight = int32(100) + mockChainHash, _ = chainhash.NewHashFromStr("00aabbccddeeff") + mockChainHeight = int32(100) ) type mockSigner struct { @@ -155,12 +155,22 @@ func (m *MockNotifier) RegisterBlockEpochNtfn( log.Tracef("Mock block ntfn registered") m.mutex.Lock() - epochChan := make(chan *chainntnfs.BlockEpoch, 0) - bestHeight := int32(0) - if bestBlock != nil { - bestHeight = bestBlock.Height + epochChan := make(chan *chainntnfs.BlockEpoch, 1) + + // The real notifier returns a notification with the current block hash + // and height immediately if no best block hash or height is specified + // in the request. We want to emulate this behaviour as well for the + // mock. + switch { + case bestBlock == nil: + epochChan <- &chainntnfs.BlockEpoch{ + Hash: mockChainHash, + Height: mockChainHeight, + } + m.epochChan[epochChan] = mockChainHeight + default: + m.epochChan[epochChan] = bestBlock.Height } - m.epochChan[epochChan] = bestHeight m.mutex.Unlock() return &chainntnfs.BlockEpochEvent{ @@ -235,25 +245,3 @@ func (m *MockNotifier) RegisterSpendNtfn(outpoint *wire.OutPoint, }, }, nil } - -type mockChainIO struct{} - -var _ lnwallet.BlockChainIO = (*mockChainIO)(nil) - -func (m *mockChainIO) GetBestBlock() (*chainhash.Hash, int32, error) { - return nil, mockChainIOHeight, nil -} - -func (m *mockChainIO) GetUtxo(op *wire.OutPoint, pkScript []byte, - heightHint uint32, _ <-chan struct{}) (*wire.TxOut, error) { - - return nil, nil -} - -func (m *mockChainIO) GetBlockHash(blockHeight int64) (*chainhash.Hash, error) { - return nil, nil -} - -func (m *mockChainIO) GetBlock(blockHash *chainhash.Hash) (*wire.MsgBlock, error) { - return nil, nil -}