routing: add mock structs for testing
This commit uses the package mock to create new mock structs, replacing the old ones for better control when writing tests.
This commit is contained in:
parent
f5de56a40d
commit
e79e46ed21
@ -11,6 +11,7 @@ import (
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type mockPaymentAttemptDispatcherOld struct {
|
||||
@ -529,3 +530,182 @@ func (m *mockControlTowerOld) SubscribePayment(paymentHash lntypes.Hash) (
|
||||
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type mockPaymentAttemptDispatcher struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil)
|
||||
|
||||
func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID,
|
||||
pid uint64, htlcAdd *lnwire.UpdateAddHTLC) error {
|
||||
|
||||
args := m.Called(firstHop, pid, htlcAdd)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockPaymentAttemptDispatcher) GetPaymentResult(attemptID uint64,
|
||||
paymentHash lntypes.Hash, deobfuscator htlcswitch.ErrorDecrypter) (
|
||||
<-chan *htlcswitch.PaymentResult, error) {
|
||||
|
||||
args := m.Called(attemptID, paymentHash, deobfuscator)
|
||||
return args.Get(0).(<-chan *htlcswitch.PaymentResult), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockPaymentAttemptDispatcher) CleanStore(
|
||||
keepPids map[uint64]struct{}) error {
|
||||
|
||||
args := m.Called(keepPids)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type mockPaymentSessionSource struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
var _ PaymentSessionSource = (*mockPaymentSessionSource)(nil)
|
||||
|
||||
func (m *mockPaymentSessionSource) NewPaymentSession(
|
||||
payment *LightningPayment) (PaymentSession, error) {
|
||||
|
||||
args := m.Called(m)
|
||||
return args.Get(0).(PaymentSession), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockPaymentSessionSource) NewPaymentSessionForRoute(
|
||||
preBuiltRoute *route.Route) PaymentSession {
|
||||
|
||||
args := m.Called(preBuiltRoute)
|
||||
return args.Get(0).(PaymentSession)
|
||||
}
|
||||
|
||||
func (m *mockPaymentSessionSource) NewPaymentSessionEmpty() PaymentSession {
|
||||
args := m.Called()
|
||||
return args.Get(0).(PaymentSession)
|
||||
}
|
||||
|
||||
type mockMissionControl struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
var _ MissionController = (*mockMissionControl)(nil)
|
||||
|
||||
func (m *mockMissionControl) ReportPaymentFail(
|
||||
paymentID uint64, rt *route.Route,
|
||||
failureSourceIdx *int, failure lnwire.FailureMessage) (
|
||||
*channeldb.FailureReason, error) {
|
||||
|
||||
args := m.Called(paymentID, rt, failureSourceIdx, failure)
|
||||
return args.Get(0).(*channeldb.FailureReason), args.Error(1)
|
||||
|
||||
}
|
||||
|
||||
func (m *mockMissionControl) ReportPaymentSuccess(paymentID uint64,
|
||||
rt *route.Route) error {
|
||||
|
||||
args := m.Called(paymentID, rt)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockMissionControl) GetProbability(fromNode, toNode route.Vertex,
|
||||
amt lnwire.MilliSatoshi) float64 {
|
||||
|
||||
args := m.Called(fromNode, toNode, amt)
|
||||
return args.Get(0).(float64)
|
||||
}
|
||||
|
||||
type mockPaymentSession struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
var _ PaymentSession = (*mockPaymentSession)(nil)
|
||||
|
||||
func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
|
||||
activeShards, height uint32) (*route.Route, error) {
|
||||
args := m.Called(maxAmt, feeLimit, activeShards, height)
|
||||
return args.Get(0).(*route.Route), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockPaymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate,
|
||||
pubKey *btcec.PublicKey, policy *channeldb.ChannelEdgePolicy) bool {
|
||||
|
||||
args := m.Called(msg, pubKey, policy)
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (m *mockPaymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey,
|
||||
channelID uint64) *channeldb.ChannelEdgePolicy {
|
||||
|
||||
args := m.Called(pubKey, channelID)
|
||||
return args.Get(0).(*channeldb.ChannelEdgePolicy)
|
||||
}
|
||||
|
||||
type mockControlTower struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
var _ ControlTower = (*mockControlTower)(nil)
|
||||
|
||||
func (m *mockControlTower) InitPayment(phash lntypes.Hash,
|
||||
c *channeldb.PaymentCreationInfo) error {
|
||||
|
||||
args := m.Called(phash, c)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash,
|
||||
a *channeldb.HTLCAttemptInfo) error {
|
||||
|
||||
args := m.Called(phash, a)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockControlTower) SettleAttempt(phash lntypes.Hash,
|
||||
pid uint64, settleInfo *channeldb.HTLCSettleInfo) (
|
||||
*channeldb.HTLCAttempt, error) {
|
||||
|
||||
args := m.Called(phash, pid, settleInfo)
|
||||
return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64,
|
||||
failInfo *channeldb.HTLCFailInfo) (*channeldb.HTLCAttempt, error) {
|
||||
|
||||
args := m.Called(phash, pid, failInfo)
|
||||
return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockControlTower) Fail(phash lntypes.Hash,
|
||||
reason channeldb.FailureReason) error {
|
||||
|
||||
args := m.Called(phash, reason)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockControlTower) FetchPayment(phash lntypes.Hash) (
|
||||
*channeldb.MPPayment, error) {
|
||||
|
||||
args := m.Called(phash)
|
||||
|
||||
// Type assertion on nil will fail, so we check and return here.
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
return args.Get(0).(*channeldb.MPPayment), args.Error(1)
|
||||
|
||||
}
|
||||
|
||||
func (m *mockControlTower) FetchInFlightPayments() (
|
||||
[]*channeldb.MPPayment, error) {
|
||||
|
||||
args := m.Called()
|
||||
return args.Get(0).([]*channeldb.MPPayment), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockControlTower) SubscribePayment(paymentHash lntypes.Hash) (
|
||||
*ControlTowerSubscriber, error) {
|
||||
|
||||
args := m.Called(paymentHash)
|
||||
return args.Get(0).(*ControlTowerSubscriber), args.Error(1)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user