260 lines
6.1 KiB
Go
260 lines
6.1 KiB
Go
package sweep
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
|
"github.com/btcsuite/btcd/wire"
|
|
"github.com/lightningnetwork/lnd/chainntnfs"
|
|
"github.com/lightningnetwork/lnd/input"
|
|
"github.com/lightningnetwork/lnd/lnwallet"
|
|
)
|
|
|
|
var (
|
|
defaultTestTimeout = 5 * time.Second
|
|
mockChainIOHeight = int32(100)
|
|
)
|
|
|
|
type mockSigner struct {
|
|
}
|
|
|
|
func (m *mockSigner) SignOutputRaw(tx *wire.MsgTx,
|
|
signDesc *input.SignDescriptor) ([]byte, error) {
|
|
|
|
return []byte{}, nil
|
|
}
|
|
|
|
func (m *mockSigner) ComputeInputScript(tx *wire.MsgTx,
|
|
signDesc *input.SignDescriptor) (*input.Script, error) {
|
|
|
|
return &input.Script{}, nil
|
|
}
|
|
|
|
// MockNotifier simulates the chain notifier for test purposes. This type is
|
|
// exported because it is used in nursery tests.
|
|
type MockNotifier struct {
|
|
confChannel map[chainhash.Hash]chan *chainntnfs.TxConfirmation
|
|
epochChan map[chan *chainntnfs.BlockEpoch]int32
|
|
spendChan map[wire.OutPoint][]chan *chainntnfs.SpendDetail
|
|
spends map[wire.OutPoint]*wire.MsgTx
|
|
mutex sync.RWMutex
|
|
t *testing.T
|
|
}
|
|
|
|
// NewMockNotifier instantiates a new mock notifier.
|
|
func NewMockNotifier(t *testing.T) *MockNotifier {
|
|
return &MockNotifier{
|
|
confChannel: make(map[chainhash.Hash]chan *chainntnfs.TxConfirmation),
|
|
epochChan: make(map[chan *chainntnfs.BlockEpoch]int32),
|
|
spendChan: make(map[wire.OutPoint][]chan *chainntnfs.SpendDetail),
|
|
spends: make(map[wire.OutPoint]*wire.MsgTx),
|
|
t: t,
|
|
}
|
|
}
|
|
|
|
// NotifyEpoch simulates a new epoch arriving.
|
|
func (m *MockNotifier) NotifyEpoch(height int32) {
|
|
m.t.Helper()
|
|
|
|
for epochChan, chanHeight := range m.epochChan {
|
|
// Only send notifications if the height is greater than the
|
|
// height the caller passed into the register call.
|
|
if chanHeight >= height {
|
|
continue
|
|
}
|
|
|
|
log.Debugf("Notifying height %v to listener", height)
|
|
|
|
select {
|
|
case epochChan <- &chainntnfs.BlockEpoch{
|
|
Height: height,
|
|
}:
|
|
case <-time.After(defaultTestTimeout):
|
|
m.t.Fatal("epoch event not consumed")
|
|
}
|
|
}
|
|
}
|
|
|
|
// ConfirmTx simulates a tx confirming.
|
|
func (m *MockNotifier) ConfirmTx(txid *chainhash.Hash, height uint32) error {
|
|
confirm := &chainntnfs.TxConfirmation{
|
|
BlockHeight: height,
|
|
}
|
|
select {
|
|
case m.getConfChannel(txid) <- confirm:
|
|
case <-time.After(defaultTestTimeout):
|
|
return fmt.Errorf("confirmation not consumed")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SpendOutpoint simulates a utxo being spent.
|
|
func (m *MockNotifier) SpendOutpoint(outpoint wire.OutPoint,
|
|
spendingTx wire.MsgTx) {
|
|
|
|
log.Debugf("Spending outpoint %v", outpoint)
|
|
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
|
|
channels, ok := m.spendChan[outpoint]
|
|
if ok {
|
|
for _, channel := range channels {
|
|
m.sendSpend(channel, &outpoint, &spendingTx)
|
|
}
|
|
}
|
|
|
|
m.spends[outpoint] = &spendingTx
|
|
}
|
|
|
|
func (m *MockNotifier) sendSpend(channel chan *chainntnfs.SpendDetail,
|
|
outpoint *wire.OutPoint,
|
|
spendingTx *wire.MsgTx) {
|
|
|
|
spenderTxHash := spendingTx.TxHash()
|
|
channel <- &chainntnfs.SpendDetail{
|
|
SpenderTxHash: &spenderTxHash,
|
|
SpendingTx: spendingTx,
|
|
SpentOutPoint: outpoint,
|
|
}
|
|
}
|
|
|
|
// RegisterConfirmationsNtfn registers for tx confirm notifications.
|
|
func (m *MockNotifier) RegisterConfirmationsNtfn(txid *chainhash.Hash,
|
|
_ []byte, numConfs, heightHint uint32) (*chainntnfs.ConfirmationEvent,
|
|
error) {
|
|
|
|
return &chainntnfs.ConfirmationEvent{
|
|
Confirmed: m.getConfChannel(txid),
|
|
}, nil
|
|
}
|
|
|
|
func (m *MockNotifier) getConfChannel(
|
|
txid *chainhash.Hash) chan *chainntnfs.TxConfirmation {
|
|
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
|
|
channel, ok := m.confChannel[*txid]
|
|
if ok {
|
|
return channel
|
|
}
|
|
channel = make(chan *chainntnfs.TxConfirmation)
|
|
m.confChannel[*txid] = channel
|
|
|
|
return channel
|
|
}
|
|
|
|
// RegisterBlockEpochNtfn registers a block notification.
|
|
func (m *MockNotifier) RegisterBlockEpochNtfn(
|
|
bestBlock *chainntnfs.BlockEpoch) (*chainntnfs.BlockEpochEvent, error) {
|
|
|
|
log.Tracef("Mock block ntfn registered")
|
|
|
|
m.mutex.Lock()
|
|
epochChan := make(chan *chainntnfs.BlockEpoch, 0)
|
|
bestHeight := int32(0)
|
|
if bestBlock != nil {
|
|
bestHeight = bestBlock.Height
|
|
}
|
|
m.epochChan[epochChan] = bestHeight
|
|
m.mutex.Unlock()
|
|
|
|
return &chainntnfs.BlockEpochEvent{
|
|
Epochs: epochChan,
|
|
Cancel: func() {
|
|
log.Tracef("Mock block ntfn canceled")
|
|
m.mutex.Lock()
|
|
delete(m.epochChan, epochChan)
|
|
m.mutex.Unlock()
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// Start the notifier.
|
|
func (m *MockNotifier) Start() error {
|
|
return nil
|
|
}
|
|
|
|
// Stop the notifier.
|
|
func (m *MockNotifier) Stop() error {
|
|
return nil
|
|
}
|
|
|
|
// RegisterSpendNtfn registers for spend notifications.
|
|
func (m *MockNotifier) RegisterSpendNtfn(outpoint *wire.OutPoint,
|
|
_ []byte, heightHint uint32) (*chainntnfs.SpendEvent, error) {
|
|
|
|
// Add channel to global spend ntfn map.
|
|
m.mutex.Lock()
|
|
|
|
channels, ok := m.spendChan[*outpoint]
|
|
if !ok {
|
|
channels = make([]chan *chainntnfs.SpendDetail, 0)
|
|
}
|
|
|
|
channel := make(chan *chainntnfs.SpendDetail, 1)
|
|
channels = append(channels, channel)
|
|
m.spendChan[*outpoint] = channels
|
|
|
|
// Check if this output has already been spent.
|
|
spendingTx, spent := m.spends[*outpoint]
|
|
|
|
m.mutex.Unlock()
|
|
|
|
// If output has been spent already, signal now. Do this outside the
|
|
// lock to prevent a dead lock.
|
|
if spent {
|
|
m.sendSpend(channel, outpoint, spendingTx)
|
|
}
|
|
|
|
return &chainntnfs.SpendEvent{
|
|
Spend: channel,
|
|
Cancel: func() {
|
|
log.Infof("Cancelling RegisterSpendNtfn for %v",
|
|
outpoint)
|
|
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
channels := m.spendChan[*outpoint]
|
|
for i, c := range channels {
|
|
if c == channel {
|
|
channels[i] = channels[len(channels)-1]
|
|
m.spendChan[*outpoint] =
|
|
channels[:len(channels)-1]
|
|
}
|
|
}
|
|
|
|
close(channel)
|
|
|
|
log.Infof("Spend ntfn channel closed for %v",
|
|
outpoint)
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
type mockChainIO struct{}
|
|
|
|
var _ lnwallet.BlockChainIO = (*mockChainIO)(nil)
|
|
|
|
func (m *mockChainIO) GetBestBlock() (*chainhash.Hash, int32, error) {
|
|
return nil, mockChainIOHeight, nil
|
|
}
|
|
|
|
func (m *mockChainIO) GetUtxo(op *wire.OutPoint, pkScript []byte,
|
|
heightHint uint32, _ <-chan struct{}) (*wire.TxOut, error) {
|
|
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *mockChainIO) GetBlockHash(blockHeight int64) (*chainhash.Hash, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *mockChainIO) GetBlock(blockHash *chainhash.Hash) (*wire.MsgBlock, error) {
|
|
return nil, nil
|
|
}
|