diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 1785d70b..6fe39855 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -46,6 +46,11 @@ var ( // for the configured max number of attempts. ErrTooManyAttempts = errors.New("sweep failed after max attempts") + // ErrSweeperShuttingDown is an error returned when a client attempts to + // make a request to the UtxoSweeper, but it is unable to handle it as + // it is/has already been stoppepd. + ErrSweeperShuttingDown = errors.New("utxo sweeper shutting down") + // DefaultMaxSweepAttempts specifies the default maximum number of times // an input is included in a publish attempt before giving up and // returning an error to the caller. @@ -96,6 +101,38 @@ type inputCluster struct { inputs pendingInputs } +// pendingSweepsReq is an internal message we'll use to represent an external +// caller's intent to retrieve all of the pending inputs the UtxoSweeper is +// attempting to sweep. +type pendingSweepsReq struct { + respChan chan map[wire.OutPoint]*PendingInput +} + +// PendingInput contains information about an input that is currently being +// swept by the UtxoSweeper. +type PendingInput struct { + // OutPoint is the identify outpoint of the input being swept. + OutPoint wire.OutPoint + + // WitnessType is the witness type of the input being swept. + WitnessType input.WitnessType + + // Amount is the amount of the input being swept. + Amount btcutil.Amount + + // LastFeeRate is the most recent fee rate used for the input being + // swept within a transaction broadcast to the network. + LastFeeRate lnwallet.SatPerKWeight + + // BroadcastAttempts is the number of attempts we've made to sweept the + // input. + BroadcastAttempts int + + // NextBroadcastHeight is the next height of the chain at which we'll + // attempt to broadcast a transaction sweeping the input. + NextBroadcastHeight uint32 +} + // UtxoSweeper is responsible for sweeping outputs back into the wallet type UtxoSweeper struct { started uint32 // To be used atomically. @@ -106,6 +143,11 @@ type UtxoSweeper struct { newInputs chan *sweepInputMessage spendChan chan *chainntnfs.SpendDetail + // pendingSweepsReq is a channel that will be sent requests by external + // callers in order to retrieve the set of pending inputs the + // UtxoSweeper is attempting to sweep. + pendingSweepsReqs chan *pendingSweepsReq + // pendingInputs is the total set of inputs the UtxoSweeper has been // requested to sweep. pendingInputs pendingInputs @@ -212,11 +254,12 @@ type sweepInputMessage struct { // New returns a new Sweeper instance. func New(cfg *UtxoSweeperConfig) *UtxoSweeper { return &UtxoSweeper{ - cfg: cfg, - newInputs: make(chan *sweepInputMessage), - spendChan: make(chan *chainntnfs.SpendDetail), - quit: make(chan struct{}), - pendingInputs: make(pendingInputs), + cfg: cfg, + newInputs: make(chan *sweepInputMessage), + spendChan: make(chan *chainntnfs.SpendDetail), + pendingSweepsReqs: make(chan *pendingSweepsReq), + quit: make(chan struct{}), + pendingInputs: make(pendingInputs), } } @@ -340,7 +383,7 @@ func (s *UtxoSweeper) SweepInput(input input.Input, select { case s.newInputs <- sweeperInput: case <-s.quit: - return nil, fmt.Errorf("sweeper shutting down") + return nil, ErrSweeperShuttingDown } return sweeperInput.resultChan, nil @@ -484,6 +527,11 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch, log.Errorf("schedule sweep: %v", err) } + // A new external request has been received to retrieve all of + // the inputs we're currently attempting to sweep. + case req := <-s.pendingSweepsReqs: + req.respChan <- s.handlePendingSweepsReq(req) + // The timer expires and we are going to (re)sweep. case <-s.timer: log.Debugf("Sweep timer expired") @@ -886,6 +934,51 @@ func (s *UtxoSweeper) waitForSpend(outpoint wire.OutPoint, return spendEvent.Cancel, nil } +// PendingInputs returns the set of inputs that the UtxoSweeper is currently +// attempting to sweep. +func (s *UtxoSweeper) PendingInputs() (map[wire.OutPoint]*PendingInput, error) { + respChan := make(chan map[wire.OutPoint]*PendingInput, 1) + select { + case s.pendingSweepsReqs <- &pendingSweepsReq{ + respChan: respChan, + }: + case <-s.quit: + return nil, ErrSweeperShuttingDown + } + + select { + case pendingSweeps := <-respChan: + return pendingSweeps, nil + case <-s.quit: + return nil, ErrSweeperShuttingDown + } +} + +// handlePendingSweepsReq handles a request to retrieve all pending inputs the +// UtxoSweeper is attempting to sweep. +func (s *UtxoSweeper) handlePendingSweepsReq( + req *pendingSweepsReq) map[wire.OutPoint]*PendingInput { + + pendingInputs := make(map[wire.OutPoint]*PendingInput, len(s.pendingInputs)) + for _, pendingInput := range s.pendingInputs { + // Only the exported fields are set, as we expect the response + // to only be consumed externally. + op := *pendingInput.input.OutPoint() + pendingInputs[op] = &PendingInput{ + OutPoint: op, + WitnessType: pendingInput.input.WitnessType(), + Amount: btcutil.Amount( + pendingInput.input.SignDesc().Output.Value, + ), + LastFeeRate: pendingInput.lastFeeRate, + BroadcastAttempts: pendingInput.publishAttempts, + NextBroadcastHeight: uint32(pendingInput.minPublishHeight), + } + } + + return pendingInputs +} + // CreateSweepTx accepts a list of inputs and signs and generates a txn that // spends from them. This method also makes an accurate fee estimate before // generating the required witnesses. diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index f7ec2173..d494b996 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -263,6 +263,29 @@ func (ctx *sweeperTestContext) expectResult(c chan Result, expected error) { } } +func (ctx *sweeperTestContext) assertPendingInputs(inputs ...input.Input) { + ctx.t.Helper() + + inputSet := make(map[wire.OutPoint]struct{}, len(inputs)) + for _, input := range inputs { + inputSet[*input.OutPoint()] = struct{}{} + } + + pendingInputs, err := ctx.sweeper.PendingInputs() + if err != nil { + ctx.t.Fatal(err) + } + if len(pendingInputs) != len(inputSet) { + ctx.t.Fatalf("expected %d pending inputs, got %d", + len(inputSet), len(pendingInputs)) + } + for input := range pendingInputs { + if _, ok := inputSet[input]; !ok { + ctx.t.Fatalf("found unexpected input %v", input) + } + } +} + // receiveSpendTx receives the transaction sent through the given resultChan. func receiveSpendTx(t *testing.T, resultChan chan Result) *wire.MsgTx { t.Helper() @@ -1032,3 +1055,71 @@ func TestDifferentFeePreferences(t *testing.T) { ctx.finish(1) } + +// TestPendingInputs ensures that the sweeper correctly determines the inputs +// pending to be swept. +func TestPendingInputs(t *testing.T) { + ctx := createSweeperTestContext(t) + + // Throughout this test, we'll be attempting to sweep three inputs, two + // with the higher fee preference, and the last with the lower. We do + // this to ensure the sweeper can return all pending inputs, even those + // with different fee preferences. + const ( + lowFeeRate = 5000 + highFeeRate = 10000 + ) + + lowFeePref := FeePreference{ + ConfTarget: 12, + } + ctx.estimator.blocksToFee[lowFeePref.ConfTarget] = lowFeeRate + + highFeePref := FeePreference{ + ConfTarget: 6, + } + ctx.estimator.blocksToFee[highFeePref.ConfTarget] = highFeeRate + + input1 := spendableInputs[0] + resultChan1, err := ctx.sweeper.SweepInput(input1, highFeePref) + if err != nil { + t.Fatal(err) + } + input2 := spendableInputs[1] + if _, err := ctx.sweeper.SweepInput(input2, highFeePref); err != nil { + t.Fatal(err) + } + input3 := spendableInputs[2] + resultChan3, err := ctx.sweeper.SweepInput(input3, lowFeePref) + if err != nil { + t.Fatal(err) + } + + // We should expect to see all inputs pending. + ctx.assertPendingInputs(input1, input2, input3) + + // We should expect to see both sweep transactions broadcast. The higher + // fee rate sweep should be broadcast first. We'll remove the lower fee + // rate sweep to ensure we can detect pending inputs after a sweep. + // Once the higher fee rate sweep confirms, we should no longer see + // those inputs pending. + ctx.tick() + ctx.receiveTx() + lowFeeRateTx := ctx.receiveTx() + ctx.backend.deleteUnconfirmed(lowFeeRateTx.TxHash()) + ctx.backend.mine() + ctx.expectResult(resultChan1, nil) + ctx.assertPendingInputs(input3) + + // We'll then trigger a new block to rebroadcast the lower fee rate + // sweep. Once again we'll ensure those inputs are no longer pending + // once the sweep transaction confirms. + ctx.backend.notifier.NotifyEpoch(101) + ctx.tick() + ctx.receiveTx() + ctx.backend.mine() + ctx.expectResult(resultChan3, nil) + ctx.assertPendingInputs() + + ctx.finish(1) +}