From e79e46ed21caf1f1ea80e959e08e32975579f881 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Mon, 24 May 2021 19:40:53 +0800 Subject: [PATCH] routing: add mock structs for testing This commit uses the package mock to create new mock structs, replacing the old ones for better control when writing tests. --- routing/mock_test.go | 180 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 180 insertions(+) diff --git a/routing/mock_test.go b/routing/mock_test.go index 478797bd..9484f6a5 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -11,6 +11,7 @@ import ( "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" + "github.com/stretchr/testify/mock" ) type mockPaymentAttemptDispatcherOld struct { @@ -529,3 +530,182 @@ func (m *mockControlTowerOld) SubscribePayment(paymentHash lntypes.Hash) ( return nil, errors.New("not implemented") } + +type mockPaymentAttemptDispatcher struct { + mock.Mock +} + +var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil) + +func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID, + pid uint64, htlcAdd *lnwire.UpdateAddHTLC) error { + + args := m.Called(firstHop, pid, htlcAdd) + return args.Error(0) +} + +func (m *mockPaymentAttemptDispatcher) GetPaymentResult(attemptID uint64, + paymentHash lntypes.Hash, deobfuscator htlcswitch.ErrorDecrypter) ( + <-chan *htlcswitch.PaymentResult, error) { + + args := m.Called(attemptID, paymentHash, deobfuscator) + return args.Get(0).(<-chan *htlcswitch.PaymentResult), args.Error(1) +} + +func (m *mockPaymentAttemptDispatcher) CleanStore( + keepPids map[uint64]struct{}) error { + + args := m.Called(keepPids) + return args.Error(0) +} + +type mockPaymentSessionSource struct { + mock.Mock +} + +var _ PaymentSessionSource = (*mockPaymentSessionSource)(nil) + +func (m *mockPaymentSessionSource) NewPaymentSession( + payment *LightningPayment) (PaymentSession, error) { + + args := m.Called(m) + return args.Get(0).(PaymentSession), args.Error(1) +} + +func (m *mockPaymentSessionSource) NewPaymentSessionForRoute( + preBuiltRoute *route.Route) PaymentSession { + + args := m.Called(preBuiltRoute) + return args.Get(0).(PaymentSession) +} + +func (m *mockPaymentSessionSource) NewPaymentSessionEmpty() PaymentSession { + args := m.Called() + return args.Get(0).(PaymentSession) +} + +type mockMissionControl struct { + mock.Mock +} + +var _ MissionController = (*mockMissionControl)(nil) + +func (m *mockMissionControl) ReportPaymentFail( + paymentID uint64, rt *route.Route, + failureSourceIdx *int, failure lnwire.FailureMessage) ( + *channeldb.FailureReason, error) { + + args := m.Called(paymentID, rt, failureSourceIdx, failure) + return args.Get(0).(*channeldb.FailureReason), args.Error(1) + +} + +func (m *mockMissionControl) ReportPaymentSuccess(paymentID uint64, + rt *route.Route) error { + + args := m.Called(paymentID, rt) + return args.Error(0) +} + +func (m *mockMissionControl) GetProbability(fromNode, toNode route.Vertex, + amt lnwire.MilliSatoshi) float64 { + + args := m.Called(fromNode, toNode, amt) + return args.Get(0).(float64) +} + +type mockPaymentSession struct { + mock.Mock +} + +var _ PaymentSession = (*mockPaymentSession)(nil) + +func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, + activeShards, height uint32) (*route.Route, error) { + args := m.Called(maxAmt, feeLimit, activeShards, height) + return args.Get(0).(*route.Route), args.Error(1) +} + +func (m *mockPaymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, + pubKey *btcec.PublicKey, policy *channeldb.ChannelEdgePolicy) bool { + + args := m.Called(msg, pubKey, policy) + return args.Bool(0) +} + +func (m *mockPaymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey, + channelID uint64) *channeldb.ChannelEdgePolicy { + + args := m.Called(pubKey, channelID) + return args.Get(0).(*channeldb.ChannelEdgePolicy) +} + +type mockControlTower struct { + mock.Mock +} + +var _ ControlTower = (*mockControlTower)(nil) + +func (m *mockControlTower) InitPayment(phash lntypes.Hash, + c *channeldb.PaymentCreationInfo) error { + + args := m.Called(phash, c) + return args.Error(0) +} + +func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash, + a *channeldb.HTLCAttemptInfo) error { + + args := m.Called(phash, a) + return args.Error(0) +} + +func (m *mockControlTower) SettleAttempt(phash lntypes.Hash, + pid uint64, settleInfo *channeldb.HTLCSettleInfo) ( + *channeldb.HTLCAttempt, error) { + + args := m.Called(phash, pid, settleInfo) + return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1) +} + +func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64, + failInfo *channeldb.HTLCFailInfo) (*channeldb.HTLCAttempt, error) { + + args := m.Called(phash, pid, failInfo) + return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1) +} + +func (m *mockControlTower) Fail(phash lntypes.Hash, + reason channeldb.FailureReason) error { + + args := m.Called(phash, reason) + return args.Error(0) +} + +func (m *mockControlTower) FetchPayment(phash lntypes.Hash) ( + *channeldb.MPPayment, error) { + + args := m.Called(phash) + + // Type assertion on nil will fail, so we check and return here. + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*channeldb.MPPayment), args.Error(1) + +} + +func (m *mockControlTower) FetchInFlightPayments() ( + []*channeldb.MPPayment, error) { + + args := m.Called() + return args.Get(0).([]*channeldb.MPPayment), args.Error(1) +} + +func (m *mockControlTower) SubscribePayment(paymentHash lntypes.Hash) ( + *ControlTowerSubscriber, error) { + + args := m.Called(paymentHash) + return args.Get(0).(*ControlTowerSubscriber), args.Error(1) +}