routing: add mock structs for testing

This commit uses the package mock to create new mock structs, replacing
the old ones for better control when writing tests.
This commit is contained in:
yyforyongyu 2021-05-24 19:40:53 +08:00
parent f5de56a40d
commit e79e46ed21
No known key found for this signature in database
GPG Key ID: 9BCD95C4FF296868

@ -11,6 +11,7 @@ import (
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
"github.com/stretchr/testify/mock"
) )
type mockPaymentAttemptDispatcherOld struct { type mockPaymentAttemptDispatcherOld struct {
@ -529,3 +530,182 @@ func (m *mockControlTowerOld) SubscribePayment(paymentHash lntypes.Hash) (
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
type mockPaymentAttemptDispatcher struct {
mock.Mock
}
var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil)
func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID,
pid uint64, htlcAdd *lnwire.UpdateAddHTLC) error {
args := m.Called(firstHop, pid, htlcAdd)
return args.Error(0)
}
func (m *mockPaymentAttemptDispatcher) GetPaymentResult(attemptID uint64,
paymentHash lntypes.Hash, deobfuscator htlcswitch.ErrorDecrypter) (
<-chan *htlcswitch.PaymentResult, error) {
args := m.Called(attemptID, paymentHash, deobfuscator)
return args.Get(0).(<-chan *htlcswitch.PaymentResult), args.Error(1)
}
func (m *mockPaymentAttemptDispatcher) CleanStore(
keepPids map[uint64]struct{}) error {
args := m.Called(keepPids)
return args.Error(0)
}
type mockPaymentSessionSource struct {
mock.Mock
}
var _ PaymentSessionSource = (*mockPaymentSessionSource)(nil)
func (m *mockPaymentSessionSource) NewPaymentSession(
payment *LightningPayment) (PaymentSession, error) {
args := m.Called(m)
return args.Get(0).(PaymentSession), args.Error(1)
}
func (m *mockPaymentSessionSource) NewPaymentSessionForRoute(
preBuiltRoute *route.Route) PaymentSession {
args := m.Called(preBuiltRoute)
return args.Get(0).(PaymentSession)
}
func (m *mockPaymentSessionSource) NewPaymentSessionEmpty() PaymentSession {
args := m.Called()
return args.Get(0).(PaymentSession)
}
type mockMissionControl struct {
mock.Mock
}
var _ MissionController = (*mockMissionControl)(nil)
func (m *mockMissionControl) ReportPaymentFail(
paymentID uint64, rt *route.Route,
failureSourceIdx *int, failure lnwire.FailureMessage) (
*channeldb.FailureReason, error) {
args := m.Called(paymentID, rt, failureSourceIdx, failure)
return args.Get(0).(*channeldb.FailureReason), args.Error(1)
}
func (m *mockMissionControl) ReportPaymentSuccess(paymentID uint64,
rt *route.Route) error {
args := m.Called(paymentID, rt)
return args.Error(0)
}
func (m *mockMissionControl) GetProbability(fromNode, toNode route.Vertex,
amt lnwire.MilliSatoshi) float64 {
args := m.Called(fromNode, toNode, amt)
return args.Get(0).(float64)
}
type mockPaymentSession struct {
mock.Mock
}
var _ PaymentSession = (*mockPaymentSession)(nil)
func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
activeShards, height uint32) (*route.Route, error) {
args := m.Called(maxAmt, feeLimit, activeShards, height)
return args.Get(0).(*route.Route), args.Error(1)
}
func (m *mockPaymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate,
pubKey *btcec.PublicKey, policy *channeldb.ChannelEdgePolicy) bool {
args := m.Called(msg, pubKey, policy)
return args.Bool(0)
}
func (m *mockPaymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey,
channelID uint64) *channeldb.ChannelEdgePolicy {
args := m.Called(pubKey, channelID)
return args.Get(0).(*channeldb.ChannelEdgePolicy)
}
type mockControlTower struct {
mock.Mock
}
var _ ControlTower = (*mockControlTower)(nil)
func (m *mockControlTower) InitPayment(phash lntypes.Hash,
c *channeldb.PaymentCreationInfo) error {
args := m.Called(phash, c)
return args.Error(0)
}
func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash,
a *channeldb.HTLCAttemptInfo) error {
args := m.Called(phash, a)
return args.Error(0)
}
func (m *mockControlTower) SettleAttempt(phash lntypes.Hash,
pid uint64, settleInfo *channeldb.HTLCSettleInfo) (
*channeldb.HTLCAttempt, error) {
args := m.Called(phash, pid, settleInfo)
return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1)
}
func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64,
failInfo *channeldb.HTLCFailInfo) (*channeldb.HTLCAttempt, error) {
args := m.Called(phash, pid, failInfo)
return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1)
}
func (m *mockControlTower) Fail(phash lntypes.Hash,
reason channeldb.FailureReason) error {
args := m.Called(phash, reason)
return args.Error(0)
}
func (m *mockControlTower) FetchPayment(phash lntypes.Hash) (
*channeldb.MPPayment, error) {
args := m.Called(phash)
// Type assertion on nil will fail, so we check and return here.
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*channeldb.MPPayment), args.Error(1)
}
func (m *mockControlTower) FetchInFlightPayments() (
[]*channeldb.MPPayment, error) {
args := m.Called()
return args.Get(0).([]*channeldb.MPPayment), args.Error(1)
}
func (m *mockControlTower) SubscribePayment(paymentHash lntypes.Hash) (
*ControlTowerSubscriber, error) {
args := m.Called(paymentHash)
return args.Get(0).(*ControlTowerSubscriber), args.Error(1)
}