sweep: create wallet interface

We need access to additional wallet functionality. This commit creates
an interface to prevent passing in multiple function pointers.
This commit is contained in:
Joost Jager 2019-12-10 15:32:57 +01:00
parent b325aae4f2
commit 34c9193bfc
No known key found for this signature in database
GPG Key ID: A61B9D4C393C59C7
5 changed files with 43 additions and 23 deletions

@ -795,7 +795,7 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB,
FeeEstimator: cc.feeEstimator,
GenSweepScript: newSweepPkScriptGen(cc.wallet),
Signer: cc.wallet.Cfg.Signer,
PublishTransaction: cc.wallet.PublishTransaction,
Wallet: cc.wallet,
NewBatchTimer: func() <-chan time.Time {
return time.NewTimer(sweep.DefaultBatchWindowDuration).C
},

@ -2,6 +2,8 @@ package sweep
import (
"sync"
"testing"
"time"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
@ -11,6 +13,8 @@ import (
// mockBackend simulates a chain backend for realistic behaviour in unit tests
// around double spends.
type mockBackend struct {
t *testing.T
lock sync.Mutex
notifier *MockNotifier
@ -19,14 +23,18 @@ type mockBackend struct {
unconfirmedTxes map[chainhash.Hash]*wire.MsgTx
unconfirmedSpendInputs map[wire.OutPoint]struct{}
publishChan chan wire.MsgTx
}
func newMockBackend(notifier *MockNotifier) *mockBackend {
func newMockBackend(t *testing.T, notifier *MockNotifier) *mockBackend {
return &mockBackend{
t: t,
notifier: notifier,
unconfirmedTxes: make(map[chainhash.Hash]*wire.MsgTx),
confirmedSpendInputs: make(map[wire.OutPoint]struct{}),
unconfirmedSpendInputs: make(map[wire.OutPoint]struct{}),
publishChan: make(chan wire.MsgTx, 2),
}
}
@ -65,6 +73,17 @@ func (b *mockBackend) publishTransaction(tx *wire.MsgTx) error {
return nil
}
func (b *mockBackend) PublishTransaction(tx *wire.MsgTx) error {
log.Tracef("Publishing tx %v", tx.TxHash())
err := b.publishTransaction(tx)
select {
case b.publishChan <- *tx:
case <-time.After(defaultTestTimeout):
b.t.Fatalf("unexpected tx published")
}
return err
}
func (b *mockBackend) deleteUnconfirmed(txHash chainhash.Hash) {
b.lock.Lock()
defer b.lock.Unlock()

12
sweep/interface.go Normal file

@ -0,0 +1,12 @@
package sweep
import (
"github.com/btcsuite/btcd/wire"
)
// Wallet contains all wallet related functionality required by sweeper.
type Wallet interface {
// PublishTransaction performs cursory validation (dust checks, etc) and
// broadcasts the passed transaction to the Bitcoin network.
PublishTransaction(tx *wire.MsgTx) error
}

@ -210,9 +210,8 @@ type UtxoSweeperConfig struct {
// transaction.
FeeEstimator chainfee.Estimator
// PublishTransaction facilitates the process of broadcasting a signed
// transaction to the appropriate network.
PublishTransaction func(*wire.MsgTx) error
// Wallet contains the wallet functions that sweeper requires.
Wallet Wallet
// NewBatchTimer creates a channel that will be sent on when a certain
// time window has passed. During this time window, new inputs can still
@ -321,7 +320,7 @@ func (s *UtxoSweeper) Start() error {
// Error can be ignored. Because we are starting up, there are
// no pending inputs to update based on the publish result.
err := s.cfg.PublishTransaction(lastTx)
err := s.cfg.Wallet.PublishTransaction(lastTx)
if err != nil && err != lnwallet.ErrDoubleSpend {
log.Errorf("last tx publish: %v", err)
}
@ -886,7 +885,7 @@ func (s *UtxoSweeper) sweep(inputs inputSet, feeRate chainfee.SatPerKWeight,
}),
)
err = s.cfg.PublishTransaction(tx)
err = s.cfg.Wallet.PublishTransaction(tx)
// In case of an unexpected error, don't try to recover.
if err != nil && err != lnwallet.ErrDoubleSpend {

@ -98,14 +98,13 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext {
store := NewMockSweeperStore()
backend := newMockBackend(notifier)
backend := newMockBackend(t, notifier)
estimator := newMockFeeEstimator(10000, chainfee.FeePerKwFloor)
publishChan := make(chan wire.MsgTx, 2)
ctx := &sweeperTestContext{
notifier: notifier,
publishChan: publishChan,
publishChan: backend.publishChan,
t: t,
estimator: estimator,
backend: backend,
@ -116,16 +115,7 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext {
var outputScriptCount byte
ctx.sweeper = New(&UtxoSweeperConfig{
Notifier: notifier,
PublishTransaction: func(tx *wire.MsgTx) error {
log.Tracef("Publishing tx %v", tx.TxHash())
err := backend.publishTransaction(tx)
select {
case publishChan <- *tx:
case <-time.After(defaultTestTimeout):
t.Fatalf("unexpected tx published")
}
return err
},
Wallet: backend,
NewBatchTimer: func() <-chan time.Time {
c := make(chan time.Time, 1)
ctx.timeoutChan <- c