From cd3598156971d9b8b385751b9324b770062040f8 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Fri, 21 May 2021 19:02:49 +0800 Subject: [PATCH] routing: refactor update payment state tests This commit refactors the resumePayment to extract some logics back to paymentState so that the code is more testable. It also adds unit tests for paymentState, and breaks the original MPPayment tests into independent tests so that it's easier to maintain and debug. All the new tests are built using mock so that the control flow is eaiser to setup and change. --- routing/mock_test.go | 40 +- routing/payment_lifecycle.go | 161 +++--- routing/payment_lifecycle_test.go | 622 ++++++++++++----------- routing/router_test.go | 797 +++++++++++++++++++++++++++++- 4 files changed, 1262 insertions(+), 358 deletions(-) diff --git a/routing/mock_test.go b/routing/mock_test.go index 9484f6a5..430c7fad 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -533,6 +533,8 @@ func (m *mockControlTowerOld) SubscribePayment(paymentHash lntypes.Hash) ( type mockPaymentAttemptDispatcher struct { mock.Mock + + resultChan chan *htlcswitch.PaymentResult } var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil) @@ -548,8 +550,11 @@ func (m *mockPaymentAttemptDispatcher) GetPaymentResult(attemptID uint64, paymentHash lntypes.Hash, deobfuscator htlcswitch.ErrorDecrypter) ( <-chan *htlcswitch.PaymentResult, error) { - args := m.Called(attemptID, paymentHash, deobfuscator) - return args.Get(0).(<-chan *htlcswitch.PaymentResult), args.Error(1) + m.Called(attemptID, paymentHash, deobfuscator) + + // Instead of returning the mocked returned values, we need to return + // the chan resultChan so it can be converted into a read-only chan. + return m.resultChan, nil } func (m *mockPaymentAttemptDispatcher) CleanStore( @@ -568,7 +573,7 @@ var _ PaymentSessionSource = (*mockPaymentSessionSource)(nil) func (m *mockPaymentSessionSource) NewPaymentSession( payment *LightningPayment) (PaymentSession, error) { - args := m.Called(m) + args := m.Called(payment) return args.Get(0).(PaymentSession), args.Error(1) } @@ -586,6 +591,8 @@ func (m *mockPaymentSessionSource) NewPaymentSessionEmpty() PaymentSession { type mockMissionControl struct { mock.Mock + + failReason *channeldb.FailureReason } var _ MissionController = (*mockMissionControl)(nil) @@ -596,8 +603,7 @@ func (m *mockMissionControl) ReportPaymentFail( *channeldb.FailureReason, error) { args := m.Called(paymentID, rt, failureSourceIdx, failure) - return args.Get(0).(*channeldb.FailureReason), args.Error(1) - + return m.failReason, args.Error(1) } func (m *mockMissionControl) ReportPaymentSuccess(paymentID uint64, @@ -642,6 +648,7 @@ func (m *mockPaymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey, type mockControlTower struct { mock.Mock + sync.Mutex } var _ ControlTower = (*mockControlTower)(nil) @@ -656,6 +663,9 @@ func (m *mockControlTower) InitPayment(phash lntypes.Hash, func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash, a *channeldb.HTLCAttemptInfo) error { + m.Lock() + defer m.Unlock() + args := m.Called(phash, a) return args.Error(0) } @@ -664,6 +674,9 @@ func (m *mockControlTower) SettleAttempt(phash lntypes.Hash, pid uint64, settleInfo *channeldb.HTLCSettleInfo) ( *channeldb.HTLCAttempt, error) { + m.Lock() + defer m.Unlock() + args := m.Called(phash, pid, settleInfo) return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1) } @@ -671,6 +684,9 @@ func (m *mockControlTower) SettleAttempt(phash lntypes.Hash, func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64, failInfo *channeldb.HTLCFailInfo) (*channeldb.HTLCAttempt, error) { + m.Lock() + defer m.Unlock() + args := m.Called(phash, pid, failInfo) return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1) } @@ -678,6 +694,9 @@ func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64, func (m *mockControlTower) Fail(phash lntypes.Hash, reason channeldb.FailureReason) error { + m.Lock() + defer m.Unlock() + args := m.Called(phash, reason) return args.Error(0) } @@ -685,6 +704,8 @@ func (m *mockControlTower) Fail(phash lntypes.Hash, func (m *mockControlTower) FetchPayment(phash lntypes.Hash) ( *channeldb.MPPayment, error) { + m.Lock() + defer m.Unlock() args := m.Called(phash) // Type assertion on nil will fail, so we check and return here. @@ -692,8 +713,15 @@ func (m *mockControlTower) FetchPayment(phash lntypes.Hash) ( return nil, args.Error(1) } - return args.Get(0).(*channeldb.MPPayment), args.Error(1) + // Make a copy of the payment here to avoid data race. + p := args.Get(0).(*channeldb.MPPayment) + payment := &channeldb.MPPayment{ + FailureReason: p.FailureReason, + } + payment.HTLCs = make([]channeldb.HTLCAttempt, len(p.HTLCs)) + copy(payment.HTLCs, p.HTLCs) + return payment, args.Error(1) } func (m *mockControlTower) FetchInFlightPayments() ( diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 155b0f37..5699e95f 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -38,21 +38,53 @@ type paymentState struct { numShardsInFlight int remainingAmt lnwire.MilliSatoshi remainingFees lnwire.MilliSatoshi - terminate bool + + // terminate indicates the payment is in its final stage and no more + // shards should be launched. This value is true if we have an HTLC + // settled or the payment has an error. + terminate bool } -// paymentState uses the passed payment to find the latest information we need -// to act on every iteration of the payment loop. -func (p *paymentLifecycle) paymentState(payment *channeldb.MPPayment) ( +// terminated returns a bool to indicate there are no further actions needed +// and we should return what we have, either the payment preimage or the +// payment error. +func (ps paymentState) terminated() bool { + // If the payment is in final stage and we have no in flight shards to + // wait result for, we consider the whole action terminated. + return ps.terminate && ps.numShardsInFlight == 0 +} + +// needWaitForShards returns a bool to specify whether we need to wait for the +// outcome of the shanrdHandler. +func (ps paymentState) needWaitForShards() bool { + // If we have in flight shards and the payment is in final stage, we + // need to wait for the outcomes from the shards. Or if we have no more + // money to be sent, we need to wait for the already launched shards. + if ps.numShardsInFlight == 0 { + return false + } + return ps.terminate || ps.remainingAmt == 0 +} + +// updatePaymentState will fetch db for the payment to find the latest +// information we need to act on every iteration of the payment loop and update +// the paymentState. +func (p *paymentLifecycle) updatePaymentState() (*channeldb.MPPayment, *paymentState, error) { + // Fetch the latest payment from db. + payment, err := p.router.cfg.Control.FetchPayment(p.identifier) + if err != nil { + return nil, nil, err + } + // Fetch the total amount and fees that has already been sent in // settled and still in-flight shards. sentAmt, fees := payment.SentAmt() // Sanity check we haven't sent a value larger than the payment amount. if sentAmt > p.totalAmount { - return nil, fmt.Errorf("amount sent %v exceeds "+ + return nil, nil, fmt.Errorf("amount sent %v exceeds "+ "total amount %v", sentAmt, p.totalAmount) } @@ -74,13 +106,15 @@ func (p *paymentLifecycle) paymentState(payment *channeldb.MPPayment) ( // have returned with a result. terminate := settle != nil || failure != nil - activeShards := payment.InFlightHTLCs() - return &paymentState{ - numShardsInFlight: len(activeShards), + // Update the payment state. + state := &paymentState{ + numShardsInFlight: len(payment.InFlightHTLCs()), remainingAmt: p.totalAmount - sentAmt, remainingFees: feeBudget, terminate: terminate, - }, nil + } + + return payment, state, nil } // resumePayment resumes the paymentLifecycle from the current state. @@ -102,9 +136,7 @@ func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) { // If we had any existing attempts outstanding, we'll start by spinning // up goroutines that'll collect their results and deliver them to the // lifecycle loop below. - payment, err := p.router.cfg.Control.FetchPayment( - p.identifier, - ) + payment, _, err := p.updatePaymentState() if err != nil { return [32]byte{}, nil, err } @@ -128,34 +160,30 @@ lifecycle: return [32]byte{}, nil, err } - // We start every iteration by fetching the lastest state of - // the payment from the ControlTower. This ensures that we will - // act on the latest available information, whether we are - // resuming an existing payment or just sent a new attempt. - payment, err := p.router.cfg.Control.FetchPayment( - p.identifier, - ) - if err != nil { - return [32]byte{}, nil, err - } - - // Using this latest state of the payment, calculate - // information about our active shards and terminal conditions. - state, err := p.paymentState(payment) + // We update the payment state on every iteration. Since the + // payment state is affected by multiple goroutines (ie, + // collectResultAsync), it is NOT guaranteed that we always + // have the latest state here. This is fine as long as the + // state is consistent as a whole. + payment, currentState, err := p.updatePaymentState() if err != nil { return [32]byte{}, nil, err } log.Debugf("Payment %v in state terminate=%v, "+ "active_shards=%v, rem_value=%v, fee_limit=%v", - p.identifier, state.terminate, state.numShardsInFlight, - state.remainingAmt, state.remainingFees) + p.identifier, currentState.terminate, + currentState.numShardsInFlight, + currentState.remainingAmt, currentState.remainingFees, + ) + // TODO(yy): sanity check all the states to make sure + // everything is expected. switch { // We have a terminal condition and no active shards, we are // ready to exit. - case state.terminate && state.numShardsInFlight == 0: + case currentState.terminated(): // Find the first successful shard and return // the preimage and route. for _, a := range payment.HTLCs { @@ -170,7 +198,7 @@ lifecycle: // If we either reached a terminal error condition (but had // active shards still) or there is no remaining value to send, // we'll wait for a shard outcome. - case state.terminate || state.remainingAmt == 0: + case currentState.needWaitForShards(): // We still have outstanding shards, so wait for a new // outcome to be available before re-evaluating our // state. @@ -212,8 +240,9 @@ lifecycle: // Create a new payment attempt from the given payment session. rt, err := p.paySession.RequestRoute( - state.remainingAmt, state.remainingFees, - uint32(state.numShardsInFlight), uint32(p.currentHeight), + currentState.remainingAmt, currentState.remainingFees, + uint32(currentState.numShardsInFlight), + uint32(p.currentHeight), ) if err != nil { log.Warnf("Failed to find route for payment %v: %v", @@ -227,7 +256,7 @@ lifecycle: // There is no route to try, and we have no active // shards. This means that there is no way for us to // send the payment, so mark it failed with no route. - if state.numShardsInFlight == 0 { + if currentState.numShardsInFlight == 0 { failureCode := routeErr.FailureReason() log.Debugf("Marking payment %v permanently "+ "failed with no route: %v", @@ -253,22 +282,11 @@ lifecycle: // If this route will consume the last remeining amount to send // to the receiver, this will be our last shard (for now). - lastShard := rt.ReceiverAmt() == state.remainingAmt + lastShard := rt.ReceiverAmt() == currentState.remainingAmt // We found a route to try, launch a new shard. attempt, outcome, err := shardHandler.launchShard(rt, lastShard) - switch { - // We may get a terminal error if we've processed a shard with - // a terminal state (settled or permanent failure), while we - // were pathfinding. We know we're in a terminal state here, - // so we can continue and wait for our last shards to return. - case err == channeldb.ErrPaymentTerminal: - log.Infof("Payment %v in terminal state, abandoning "+ - "shard", p.identifier) - - continue lifecycle - - case err != nil: + if err != nil { return [32]byte{}, nil, err } @@ -297,6 +315,7 @@ lifecycle: // Now that the shard was successfully sent, launch a go // routine that will handle its result when its back. shardHandler.collectResultAsync(attempt) + } } @@ -437,12 +456,30 @@ type shardResult struct { } // collectResultAsync launches a goroutine that will wait for the result of the -// given HTLC attempt to be available then handle its result. Note that it will -// fail the payment with the control tower if a terminal error is encountered. +// given HTLC attempt to be available then handle its result. It will fail the +// payment with the control tower if a terminal error is encountered. func (p *shardHandler) collectResultAsync(attempt *channeldb.HTLCAttemptInfo) { + + // errToSend is the error to be sent to sh.shardErrors. + var errToSend error + + // handleResultErr is a function closure must be called using defer. It + // finishes collecting result by updating the payment state and send + // the error (or nil) to sh.shardErrors. + handleResultErr := func() { + // Send the error or quit. + select { + case p.shardErrors <- errToSend: + case <-p.router.quit: + case <-p.quit: + } + + p.wg.Done() + } + p.wg.Add(1) go func() { - defer p.wg.Done() + defer handleResultErr() // Block until the result is available. result, err := p.collectResult(attempt) @@ -456,32 +493,18 @@ func (p *shardHandler) collectResultAsync(attempt *channeldb.HTLCAttemptInfo) { attempt.AttemptID, p.identifier, err) } - select { - case p.shardErrors <- err: - case <-p.router.quit: - case <-p.quit: - } + // Overwrite errToSend and return. + errToSend = err return } // If a non-critical error was encountered handle it and mark // the payment failed if the failure was terminal. if result.err != nil { - err := p.handleSendError(attempt, result.err) - if err != nil { - select { - case p.shardErrors <- err: - case <-p.router.quit: - case <-p.quit: - } - return - } - } - - select { - case p.shardErrors <- nil: - case <-p.router.quit: - case <-p.quit: + // Overwrite errToSend and return. Notice that the + // errToSend could be nil here. + errToSend = p.handleSendError(attempt, result.err) + return } }() } diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 7f860284..74f6dfbd 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -195,14 +195,6 @@ func TestRouterPaymentStateMachine(t *testing.T) { t.Fatalf("unable to create route: %v", err) } - halfShard, err := createTestRoute(paymentAmt/2, testGraph.aliasMap) - require.NoError(t, err, "unable to create half route") - - shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap) - if err != nil { - t.Fatalf("unable to create route: %v", err) - } - tests := []paymentLifecycleTestCase{ { // Tests a normal payment flow that succeeds. @@ -425,280 +417,6 @@ func TestRouterPaymentStateMachine(t *testing.T) { routes: []*route.Route{rt}, paymentErr: channeldb.FailureReasonNoRoute, }, - - // ===================================== - // || MPP scenarios || - // ===================================== - { - // Tests a simple successful MP payment of 4 shards. - name: "MP success", - - steps: []string{ - routerInitPayment, - - // shard 0 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 1 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 2 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 3 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // All shards succeed. - getPaymentResultSuccess, - getPaymentResultSuccess, - getPaymentResultSuccess, - getPaymentResultSuccess, - - // Router should settle them all. - routerSettleAttempt, - routerSettleAttempt, - routerSettleAttempt, - routerSettleAttempt, - - // And the final result is obviously - // successful. - paymentSuccess, - }, - routes: []*route.Route{shard, shard, shard, shard}, - }, - { - // An MP payment scenario where we need several extra - // attempts before the payment finally settle. - name: "MP failed shards", - - steps: []string{ - routerInitPayment, - - // shard 0 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 1 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 2 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 3 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // First two shards fail, two new ones are sent. - getPaymentResultTempFailure, - getPaymentResultTempFailure, - routerFailAttempt, - routerFailAttempt, - - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // The four shards settle. - getPaymentResultSuccess, - getPaymentResultSuccess, - getPaymentResultSuccess, - getPaymentResultSuccess, - routerSettleAttempt, - routerSettleAttempt, - routerSettleAttempt, - routerSettleAttempt, - - // Overall payment succeeds. - paymentSuccess, - }, - routes: []*route.Route{ - shard, shard, shard, shard, shard, shard, - }, - }, - { - // An MP payment scenario where one of the shards fails, - // but we still receive a single success shard. - name: "MP one shard success", - - steps: []string{ - routerInitPayment, - - // shard 0 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 1 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 0 fails, and should be failed by the - // router. - getPaymentResultTempFailure, - routerFailAttempt, - - // We will try one more shard because we haven't - // sent the full payment amount. - routeRelease, - - // The second shard succeed against all odds, - // making the overall payment succeed. - getPaymentResultSuccess, - routerSettleAttempt, - paymentSuccess, - }, - routes: []*route.Route{halfShard, halfShard}, - }, - { - // An MP payment scenario a shard fail with a terminal - // error, causing the router to stop attempting. - name: "MP terminal", - - steps: []string{ - routerInitPayment, - - // shard 0 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 1 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 2 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 3 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // The first shard fail with a terminal error. - getPaymentResultTerminalFailure, - routerFailAttempt, - routerFailPayment, - - // Remaining 3 shards fail. - getPaymentResultTempFailure, - getPaymentResultTempFailure, - getPaymentResultTempFailure, - routerFailAttempt, - routerFailAttempt, - routerFailAttempt, - - // Payment fails. - paymentError, - }, - routes: []*route.Route{ - shard, shard, shard, shard, shard, shard, - }, - paymentErr: channeldb.FailureReasonPaymentDetails, - }, - { - // A MP payment scenario when our path finding returns - // after we've just received a terminal failure, and - // attempts to dispatch a new shard. Testing that we - // correctly abandon the shard and conclude the payment. - name: "MP path found after failure", - - steps: []string{ - routerInitPayment, - - // shard 0 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // The first shard fail with a terminal error. - getPaymentResultTerminalFailure, - routerFailAttempt, - routerFailPayment, - - // shard 1 fails because we've had a terminal - // failure. - routeRelease, - routerRegisterAttempt, - - // Payment fails. - paymentError, - }, - routes: []*route.Route{ - shard, shard, - }, - paymentErr: channeldb.FailureReasonPaymentDetails, - }, - { - // A MP payment scenario when our path finding returns - // after we've just received a terminal failure, and - // we have another shard still in flight. - name: "MP shard in flight after terminal", - - steps: []string{ - routerInitPayment, - - // shard 0 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 1 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 2 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // We find a path for another shard. - routeRelease, - - // shard 0 fails with a terminal error. - getPaymentResultTerminalFailure, - routerFailAttempt, - routerFailPayment, - - // We try to register our final shard after - // processing a terminal failure. - routerRegisterAttempt, - - // Our in-flight shards fail. - getPaymentResultTempFailure, - getPaymentResultTempFailure, - routerFailAttempt, - routerFailAttempt, - - // Payment fails. - paymentError, - }, - routes: []*route.Route{ - shard, shard, shard, shard, - }, - paymentErr: channeldb.FailureReasonPaymentDetails, - }, } for _, test := range tests { @@ -1080,3 +798,343 @@ func testPaymentLifecycle(t *testing.T, test paymentLifecycleTestCase, t.Fatalf("SendPayment didn't exit") } } + +// TestPaymentState tests that the logics implemented on paymentState struct +// are as expected. In particular, that the method terminated and +// needWaitForShards return the right values. +func TestPaymentState(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + + // Use the following three params, each is equivalent to a bool + // statement, to construct 8 test cases so that we can + // exhaustively catch all possible states. + numShardsInFlight int + remainingAmt lnwire.MilliSatoshi + terminate bool + + expectedTerminated bool + expectedNeedWaitForShards bool + }{ + { + // If we have active shards and terminate is marked + // false, the state is not terminated. Since the + // remaining amount is zero, we need to wait for shards + // to be finished and launch no more shards. + name: "state 100", + numShardsInFlight: 1, + remainingAmt: lnwire.MilliSatoshi(0), + terminate: false, + expectedTerminated: false, + expectedNeedWaitForShards: true, + }, + { + // If we have active shards while terminate is marked + // true, the state is not terminated, and we need to + // wait for shards to be finished and launch no more + // shards. + name: "state 101", + numShardsInFlight: 1, + remainingAmt: lnwire.MilliSatoshi(0), + terminate: true, + expectedTerminated: false, + expectedNeedWaitForShards: true, + }, + + { + // If we have active shards and terminate is marked + // false, the state is not terminated. Since the + // remaining amount is not zero, we don't need to wait + // for shards outcomes and should launch more shards. + name: "state 110", + numShardsInFlight: 1, + remainingAmt: lnwire.MilliSatoshi(1), + terminate: false, + expectedTerminated: false, + expectedNeedWaitForShards: false, + }, + { + // If we have active shards and terminate is marked + // true, the state is not terminated. Even the + // remaining amount is not zero, we need to wait for + // shards outcomes because state is terminated. + name: "state 111", + numShardsInFlight: 1, + remainingAmt: lnwire.MilliSatoshi(1), + terminate: true, + expectedTerminated: false, + expectedNeedWaitForShards: true, + }, + { + // If we have no active shards while terminate is marked + // false, the state is not terminated, and we don't + // need to wait for more shard outcomes because there + // are no active shards. + name: "state 000", + numShardsInFlight: 0, + remainingAmt: lnwire.MilliSatoshi(0), + terminate: false, + expectedTerminated: false, + expectedNeedWaitForShards: false, + }, + { + // If we have no active shards while terminate is marked + // true, the state is terminated, and we don't need to + // wait for shards to be finished. + name: "state 001", + numShardsInFlight: 0, + remainingAmt: lnwire.MilliSatoshi(0), + terminate: true, + expectedTerminated: true, + expectedNeedWaitForShards: false, + }, + { + // If we have no active shards while terminate is marked + // false, the state is not terminated. Since the + // remaining amount is not zero, we don't need to wait + // for shards outcomes and should launch more shards. + name: "state 010", + numShardsInFlight: 0, + remainingAmt: lnwire.MilliSatoshi(1), + terminate: false, + expectedTerminated: false, + expectedNeedWaitForShards: false, + }, + { + // If we have no active shards while terminate is marked + // true, the state is terminated, and we don't need to + // wait for shards outcomes. + name: "state 011", + numShardsInFlight: 0, + remainingAmt: lnwire.MilliSatoshi(1), + terminate: true, + expectedTerminated: true, + expectedNeedWaitForShards: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ps := &paymentState{ + numShardsInFlight: tc.numShardsInFlight, + remainingAmt: tc.remainingAmt, + terminate: tc.terminate, + } + + require.Equal( + t, tc.expectedTerminated, ps.terminated(), + "terminated returned wrong value", + ) + require.Equal( + t, tc.expectedNeedWaitForShards, + ps.needWaitForShards(), + "needWaitForShards returned wrong value", + ) + }) + } + +} + +// TestUpdatePaymentState checks that the method updatePaymentState updates the +// paymentState as expected. +func TestUpdatePaymentState(t *testing.T) { + t.Parallel() + + // paymentHash is the identifier on paymentLifecycle. + paymentHash := lntypes.Hash{} + + // TODO(yy): make MPPayment into an interface so we can mock it. The + // current design implicitly tests the methods SendAmt, TerminalInfo, + // and InFlightHTLCs on channeldb.MPPayment, which is not good. Once + // MPPayment becomes an interface, we can then mock these methods here. + + // SentAmt returns 90, 10 + // TerminalInfo returns non-nil, nil + // InFlightHTLCs returns 0 + var preimage lntypes.Preimage + paymentSettled := &channeldb.MPPayment{ + HTLCs: []channeldb.HTLCAttempt{ + makeSettledAttempt(100, 10, preimage), + }, + } + + // SentAmt returns 0, 0 + // TerminalInfo returns nil, non-nil + // InFlightHTLCs returns 0 + reason := channeldb.FailureReasonError + paymentFailed := &channeldb.MPPayment{ + FailureReason: &reason, + } + + // SentAmt returns 90, 10 + // TerminalInfo returns nil, nil + // InFlightHTLCs returns 1 + paymentActive := &channeldb.MPPayment{ + HTLCs: []channeldb.HTLCAttempt{ + makeActiveAttempt(100, 10), + makeFailedAttempt(100, 10), + }, + } + + testCases := []struct { + name string + payment *channeldb.MPPayment + totalAmt int + feeLimit int + + expectedState *paymentState + shouldReturnError bool + }{ + { + // Test that the error returned from FetchPayment is + // handled properly. We use a nil payment to indicate + // we want to return an error. + name: "fetch payment error", + payment: nil, + shouldReturnError: true, + }, + { + // Test that when the sentAmt exceeds totalAmount, the + // error is returned. + name: "amount exceeded error", + payment: paymentSettled, + totalAmt: 1, + shouldReturnError: true, + }, + { + // Test that when the fee budget is reached, the + // remaining fee should be zero. + name: "fee budget reached", + payment: paymentActive, + totalAmt: 1000, + feeLimit: 1, + expectedState: &paymentState{ + numShardsInFlight: 1, + remainingAmt: 1000 - 90, + remainingFees: 0, + terminate: false, + }, + }, + { + // Test when the payment is settled, the state should + // be marked as terminated. + name: "payment settled", + payment: paymentSettled, + totalAmt: 1000, + feeLimit: 100, + expectedState: &paymentState{ + numShardsInFlight: 0, + remainingAmt: 1000 - 90, + remainingFees: 100 - 10, + terminate: true, + }, + }, + { + // Test when the payment is failed, the state should be + // marked as terminated. + name: "payment failed", + payment: paymentFailed, + totalAmt: 1000, + feeLimit: 100, + expectedState: &paymentState{ + numShardsInFlight: 0, + remainingAmt: 1000, + remainingFees: 100, + terminate: true, + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // Create mock control tower and assign it to router. + // We will then use the router and the paymentHash + // above to create our paymentLifecycle for this test. + ct := &mockControlTower{} + rt := &ChannelRouter{cfg: &Config{Control: ct}} + pl := &paymentLifecycle{ + router: rt, + identifier: paymentHash, + totalAmount: lnwire.MilliSatoshi(tc.totalAmt), + feeLimit: lnwire.MilliSatoshi(tc.feeLimit), + } + + if tc.payment == nil { + // A nil payment indicates we want to test an + // error returned from FetchPayment. + dummyErr := errors.New("dummy") + ct.On("FetchPayment", paymentHash).Return( + nil, dummyErr, + ) + + } else { + // Otherwise we will return the payment. + ct.On("FetchPayment", paymentHash).Return( + tc.payment, nil, + ) + } + + // Call the method that updates the payment state. + _, state, err := pl.updatePaymentState() + + // Assert that the mock method is called as + // intended. + ct.AssertExpectations(t) + + if tc.shouldReturnError { + require.Error(t, err, "expect an error") + return + } + + require.NoError(t, err, "unexpected error") + require.Equal( + t, tc.expectedState, state, + "state not updated as expected", + ) + + }) + } + +} + +func makeActiveAttempt(total, fee int) channeldb.HTLCAttempt { + return channeldb.HTLCAttempt{ + HTLCAttemptInfo: makeAttemptInfo(total, total-fee), + } +} + +func makeSettledAttempt(total, fee int, + preimage lntypes.Preimage) channeldb.HTLCAttempt { + + return channeldb.HTLCAttempt{ + HTLCAttemptInfo: makeAttemptInfo(total, total-fee), + Settle: &channeldb.HTLCSettleInfo{Preimage: preimage}, + } +} + +func makeFailedAttempt(total, fee int) channeldb.HTLCAttempt { + return channeldb.HTLCAttempt{ + HTLCAttemptInfo: makeAttemptInfo(total, total-fee), + Failure: &channeldb.HTLCFailInfo{ + Reason: channeldb.HTLCFailInternal, + }, + } +} + +func makeAttemptInfo(total, amtForwarded int) channeldb.HTLCAttemptInfo { + hop := &route.Hop{AmtToForward: lnwire.MilliSatoshi(amtForwarded)} + return channeldb.HTLCAttemptInfo{ + Route: route.Route{ + TotalAmount: lnwire.MilliSatoshi(total), + Hops: []*route.Hop{hop}, + }, + } +} diff --git a/routing/router_test.go b/routing/router_test.go index ab59739c..29ad342d 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -15,6 +15,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/davecgh/go-spew/spew" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/lightningnetwork/lnd/channeldb" @@ -1069,7 +1070,8 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { _, ok = msg.(*lnwire.FailUnknownNextPeer) require.True(t, ok, "unexpected fail message") - ctx.router.cfg.MissionControl.(*MissionControl).ResetHistory() + err = ctx.router.cfg.MissionControl.(*MissionControl).ResetHistory() + require.NoError(t, err, "reset history failed") // Next, we'll modify the SendToSwitch method to indicate that the // connection between songoku and isn't up. @@ -3436,3 +3438,796 @@ func TestChannelOnChainRejectionZombie(t *testing.T) { require.Nil(t, err) assertChanChainRejection(t, ctx, edge, ErrInvalidFundingOutput) } + +func createDummyTestGraph(t *testing.T) *testGraphInstance { + // Setup two simple channels such that we can mock sending along this + // route. + chanCapSat := btcutil.Amount(100000) + testChannels := []*testChannel{ + symmetricTestChannel("a", "b", chanCapSat, &testChannelPolicy{ + Expiry: 144, + FeeRate: 400, + MinHTLC: 1, + MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat), + }, 1), + symmetricTestChannel("b", "c", chanCapSat, &testChannelPolicy{ + Expiry: 144, + FeeRate: 400, + MinHTLC: 1, + MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat), + }, 2), + } + + testGraph, err := createTestGraphFromChannels(testChannels, "a") + require.NoError(t, err, "failed to create graph") + return testGraph +} + +func createDummyLightningPayment(t *testing.T, + target route.Vertex, amt lnwire.MilliSatoshi) *LightningPayment { + + var preImage lntypes.Preimage + _, err := rand.Read(preImage[:]) + require.NoError(t, err, "unable to generate preimage") + + payHash := preImage.Hash() + + return &LightningPayment{ + Target: target, + Amount: amt, + FeeLimit: noFeeLimit, + paymentHash: &payHash, + } +} + +// TestSendMPPaymentSucceed tests that we can successfully send a MPPayment via +// router.SendPayment. This test mainly focuses on testing the logic of the +// method resumePayment is implemented as expected. +func TestSendMPPaymentSucceed(t *testing.T) { + const startingBlockHeight = 101 + + // Create mockers to initialize the router. + controlTower := &mockControlTower{} + sessionSource := &mockPaymentSessionSource{} + missionControl := &mockMissionControl{} + payer := &mockPaymentAttemptDispatcher{} + chain := newMockChain(startingBlockHeight) + chainView := newMockChainView(chain) + testGraph := createDummyTestGraph(t) + + // Define the behavior of the mockers to the point where we can + // successfully start the router. + controlTower.On("FetchInFlightPayments").Return( + []*channeldb.MPPayment{}, nil, + ) + payer.On("CleanStore", mock.Anything).Return(nil) + + // Create and start the router. + router, err := New(Config{ + Control: controlTower, + SessionSource: sessionSource, + MissionControl: missionControl, + Payer: payer, + + // TODO(yy): create new mocks for the chain and chainview. + Chain: chain, + ChainView: chainView, + + // TODO(yy): mock the graph once it's changed into interface. + Graph: testGraph.graph, + + Clock: clock.NewTestClock(time.Unix(1, 0)), + GraphPruneInterval: time.Hour * 2, + NextPaymentID: func() (uint64, error) { + next := atomic.AddUint64(&uniquePaymentID, 1) + return next, nil + }, + }) + require.NoError(t, err, "failed to create router") + + // Make sure the router can start and stop without error. + require.NoError(t, router.Start(), "router failed to start") + defer func() { + require.NoError(t, router.Stop(), "router failed to stop") + }() + + // Once the router is started, check that the mocked methods are called + // as expected. + controlTower.AssertExpectations(t) + payer.AssertExpectations(t) + + // Mock the methods to the point where we are inside the function + // resumePayment. + paymentAmt := lnwire.MilliSatoshi(10000) + req := createDummyLightningPayment( + t, testGraph.aliasMap["c"], paymentAmt, + ) + identifier := lntypes.Hash(req.Identifier()) + session := &mockPaymentSession{} + sessionSource.On("NewPaymentSession", req).Return(session, nil) + controlTower.On("InitPayment", identifier, mock.Anything).Return(nil) + + // The following mocked methods are called inside resumePayment. Note + // that the payment object below will determine the state of the + // paymentLifecycle. + payment := &channeldb.MPPayment{} + controlTower.On("FetchPayment", identifier).Return(payment, nil) + + // Create a route that can send 1/4 of the total amount. This value + // will be returned by calling RequestRoute. + shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap) + require.NoError(t, err, "failed to create route") + session.On("RequestRoute", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ).Return(shard, nil) + + // Make a new htlc attempt with zero fee and append it to the payment's + // HTLCs when calling RegisterAttempt. + activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0) + controlTower.On("RegisterAttempt", + identifier, mock.Anything, + ).Return(nil).Run(func(args mock.Arguments) { + payment.HTLCs = append(payment.HTLCs, activeAttempt) + }) + + // Create a buffered chan and it will be returned by GetPaymentResult. + payer.resultChan = make(chan *htlcswitch.PaymentResult, 10) + payer.On("GetPaymentResult", + mock.Anything, identifier, mock.Anything, + ).Run(func(args mock.Arguments) { + // Before the mock method is returned, we send the result to + // the read-only chan. + payer.resultChan <- &htlcswitch.PaymentResult{} + }) + + // Simple mocking the rest. + payer.On("SendHTLC", + mock.Anything, mock.Anything, mock.Anything, + ).Return(nil) + missionControl.On("ReportPaymentSuccess", + mock.Anything, mock.Anything, + ).Return(nil) + + // Mock SettleAttempt by changing one of the HTLCs to be settled. + preimage := lntypes.Preimage{1, 2, 3} + settledAttempt := makeSettledAttempt( + int(paymentAmt/4), 0, preimage, + ) + controlTower.On("SettleAttempt", + identifier, mock.Anything, mock.Anything, + ).Return(&settledAttempt, nil).Run(func(args mock.Arguments) { + // Whenever this method is invoked, we will mark the first + // active attempt settled and exit. + for i, attempt := range payment.HTLCs { + if attempt.Settle == nil { + attempt.Settle = &channeldb.HTLCSettleInfo{ + Preimage: preimage, + } + payment.HTLCs[i] = attempt + return + } + } + }) + + // Call the actual method SendPayment on router. This is place inside a + // goroutine so we can set a timeout for the whole test, in case + // anything goes wrong and the test never finishes. + done := make(chan struct{}) + var p lntypes.Hash + go func() { + p, _, err = router.SendPayment(req) + close(done) + }() + + select { + case <-done: + case <-time.After(testTimeout): + t.Fatalf("SendPayment didn't exit") + } + + // Finally, validate the returned values and check that the mock + // methods are called as expected. + require.NoError(t, err, "send payment failed") + require.EqualValues(t, preimage, p, "preimage not match") + + // Note that we also implicitly check the methods such as FailAttempt, + // ReportPaymentFail, etc, are not called because we never mocked them + // in this test. If any of the unexpected methods was called, the test + // would fail. + controlTower.AssertExpectations(t) + payer.AssertExpectations(t) + sessionSource.AssertExpectations(t) + session.AssertExpectations(t) + missionControl.AssertExpectations(t) +} + +// TestSendMPPaymentSucceedOnExtraShards tests that we need extra attempts if +// there are failed ones,so that a payment is successfully sent. This test +// mainly focuses on testing the logic of the method resumePayment is +// implemented as expected. +func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) { + const startingBlockHeight = 101 + + // Create mockers to initialize the router. + controlTower := &mockControlTower{} + sessionSource := &mockPaymentSessionSource{} + missionControl := &mockMissionControl{} + payer := &mockPaymentAttemptDispatcher{} + chain := newMockChain(startingBlockHeight) + chainView := newMockChainView(chain) + testGraph := createDummyTestGraph(t) + + // Define the behavior of the mockers to the point where we can + // successfully start the router. + controlTower.On("FetchInFlightPayments").Return( + []*channeldb.MPPayment{}, nil, + ) + payer.On("CleanStore", mock.Anything).Return(nil) + + // Create and start the router. + router, err := New(Config{ + Control: controlTower, + SessionSource: sessionSource, + MissionControl: missionControl, + Payer: payer, + + // TODO(yy): create new mocks for the chain and chainview. + Chain: chain, + ChainView: chainView, + + // TODO(yy): mock the graph once it's changed into interface. + Graph: testGraph.graph, + + Clock: clock.NewTestClock(time.Unix(1, 0)), + GraphPruneInterval: time.Hour * 2, + NextPaymentID: func() (uint64, error) { + next := atomic.AddUint64(&uniquePaymentID, 1) + return next, nil + }, + }) + require.NoError(t, err, "failed to create router") + + // Make sure the router can start and stop without error. + require.NoError(t, router.Start(), "router failed to start") + defer func() { + require.NoError(t, router.Stop(), "router failed to stop") + }() + + // Once the router is started, check that the mocked methods are called + // as expected. + controlTower.AssertExpectations(t) + payer.AssertExpectations(t) + + // Mock the methods to the point where we are inside the function + // resumePayment. + paymentAmt := lnwire.MilliSatoshi(20000) + req := createDummyLightningPayment( + t, testGraph.aliasMap["c"], paymentAmt, + ) + identifier := lntypes.Hash(req.Identifier()) + session := &mockPaymentSession{} + sessionSource.On("NewPaymentSession", req).Return(session, nil) + controlTower.On("InitPayment", identifier, mock.Anything).Return(nil) + + // The following mocked methods are called inside resumePayment. Note + // that the payment object below will determine the state of the + // paymentLifecycle. + payment := &channeldb.MPPayment{} + controlTower.On("FetchPayment", identifier).Return(payment, nil) + + // Create a route that can send 1/4 of the total amount. This value + // will be returned by calling RequestRoute. + shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap) + require.NoError(t, err, "failed to create route") + session.On("RequestRoute", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ).Return(shard, nil) + + // Make a new htlc attempt with zero fee and append it to the payment's + // HTLCs when calling RegisterAttempt. + activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0) + controlTower.On("RegisterAttempt", + identifier, mock.Anything, + ).Return(nil).Run(func(args mock.Arguments) { + payment.HTLCs = append(payment.HTLCs, activeAttempt) + }) + + // Create a buffered chan and it will be returned by GetPaymentResult. + payer.resultChan = make(chan *htlcswitch.PaymentResult, 10) + + // We use the failAttemptCount to track how many attempts we want to + // fail. Each time the following mock method is called, the count gets + // updated. + failAttemptCount := 0 + payer.On("GetPaymentResult", + mock.Anything, identifier, mock.Anything, + ).Run(func(args mock.Arguments) { + // Before the mock method is returned, we send the result to + // the read-only chan. + + // Update the counter. + failAttemptCount++ + + // We will make the first two attempts failed with temporary + // error. + if failAttemptCount <= 2 { + payer.resultChan <- &htlcswitch.PaymentResult{ + Error: htlcswitch.NewForwardingError( + &lnwire.FailTemporaryChannelFailure{}, + 1, + ), + } + return + } + + // Otherwise we will mark the attempt succeeded. + payer.resultChan <- &htlcswitch.PaymentResult{} + }) + + // Mock the FailAttempt method to fail one of the attempts. + var failedAttempt channeldb.HTLCAttempt + controlTower.On("FailAttempt", + identifier, mock.Anything, mock.Anything, + ).Return(&failedAttempt, nil).Run(func(args mock.Arguments) { + // Whenever this method is invoked, we will mark the first + // active attempt as failed and exit. + for i, attempt := range payment.HTLCs { + if attempt.Settle != nil || attempt.Failure != nil { + continue + } + + attempt.Failure = &channeldb.HTLCFailInfo{} + failedAttempt = attempt + payment.HTLCs[i] = attempt + return + } + + }) + + // Setup ReportPaymentFail to return nil reason and error so the + // payment won't fail. + missionControl.On("ReportPaymentFail", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ).Return(nil, nil) + + // Simple mocking the rest. + payer.On("SendHTLC", + mock.Anything, mock.Anything, mock.Anything, + ).Return(nil) + missionControl.On("ReportPaymentSuccess", + mock.Anything, mock.Anything, + ).Return(nil) + + // Mock SettleAttempt by changing one of the HTLCs to be settled. + preimage := lntypes.Preimage{1, 2, 3} + settledAttempt := makeSettledAttempt( + int(paymentAmt/4), 0, preimage, + ) + controlTower.On("SettleAttempt", + identifier, mock.Anything, mock.Anything, + ).Return(&settledAttempt, nil).Run(func(args mock.Arguments) { + // Whenever this method is invoked, we will mark the first + // active attempt settled and exit. + for i, attempt := range payment.HTLCs { + if attempt.Settle != nil || attempt.Failure != nil { + continue + } + + attempt.Settle = &channeldb.HTLCSettleInfo{ + Preimage: preimage, + } + payment.HTLCs[i] = attempt + return + } + }) + + // Call the actual method SendPayment on router. This is place inside a + // goroutine so we can set a timeout for the whole test, in case + // anything goes wrong and the test never finishes. + done := make(chan struct{}) + var p lntypes.Hash + go func() { + p, _, err = router.SendPayment(req) + close(done) + }() + + select { + case <-done: + case <-time.After(testTimeout): + t.Fatalf("SendPayment didn't exit") + } + + // Finally, validate the returned values and check that the mock + // methods are called as expected. + require.NoError(t, err, "send payment failed") + require.EqualValues(t, preimage, p, "preimage not match") + + controlTower.AssertExpectations(t) + payer.AssertExpectations(t) + sessionSource.AssertExpectations(t) + session.AssertExpectations(t) + missionControl.AssertExpectations(t) +} + +// TestSendMPPaymentFailed tests that when one of the shard fails with a +// terminal error, the router will stop attempting and the payment will fail. +// This test mainly focuses on testing the logic of the method resumePayment +// is implemented as expected. +func TestSendMPPaymentFailed(t *testing.T) { + const startingBlockHeight = 101 + + // Create mockers to initialize the router. + controlTower := &mockControlTower{} + sessionSource := &mockPaymentSessionSource{} + missionControl := &mockMissionControl{} + payer := &mockPaymentAttemptDispatcher{} + chain := newMockChain(startingBlockHeight) + chainView := newMockChainView(chain) + testGraph := createDummyTestGraph(t) + + // Define the behavior of the mockers to the point where we can + // successfully start the router. + controlTower.On("FetchInFlightPayments").Return( + []*channeldb.MPPayment{}, nil, + ) + payer.On("CleanStore", mock.Anything).Return(nil) + + // Create and start the router. + router, err := New(Config{ + Control: controlTower, + SessionSource: sessionSource, + MissionControl: missionControl, + Payer: payer, + + // TODO(yy): create new mocks for the chain and chainview. + Chain: chain, + ChainView: chainView, + + // TODO(yy): mock the graph once it's changed into interface. + Graph: testGraph.graph, + + Clock: clock.NewTestClock(time.Unix(1, 0)), + GraphPruneInterval: time.Hour * 2, + NextPaymentID: func() (uint64, error) { + next := atomic.AddUint64(&uniquePaymentID, 1) + return next, nil + }, + }) + require.NoError(t, err, "failed to create router") + + // Make sure the router can start and stop without error. + require.NoError(t, router.Start(), "router failed to start") + defer func() { + require.NoError(t, router.Stop(), "router failed to stop") + }() + + // Once the router is started, check that the mocked methods are called + // as expected. + controlTower.AssertExpectations(t) + payer.AssertExpectations(t) + + // Mock the methods to the point where we are inside the function + // resumePayment. + paymentAmt := lnwire.MilliSatoshi(10000) + req := createDummyLightningPayment( + t, testGraph.aliasMap["c"], paymentAmt, + ) + identifier := lntypes.Hash(req.Identifier()) + session := &mockPaymentSession{} + sessionSource.On("NewPaymentSession", req).Return(session, nil) + controlTower.On("InitPayment", identifier, mock.Anything).Return(nil) + + // The following mocked methods are called inside resumePayment. Note + // that the payment object below will determine the state of the + // paymentLifecycle. + payment := &channeldb.MPPayment{} + controlTower.On("FetchPayment", identifier).Return(payment, nil) + + // Create a route that can send 1/4 of the total amount. This value + // will be returned by calling RequestRoute. + shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap) + require.NoError(t, err, "failed to create route") + session.On("RequestRoute", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ).Return(shard, nil) + + // Make a new htlc attempt with zero fee and append it to the payment's + // HTLCs when calling RegisterAttempt. + activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0) + controlTower.On("RegisterAttempt", + identifier, mock.Anything, + ).Return(nil).Run(func(args mock.Arguments) { + payment.HTLCs = append(payment.HTLCs, activeAttempt) + }) + + // Create a buffered chan and it will be returned by GetPaymentResult. + payer.resultChan = make(chan *htlcswitch.PaymentResult, 10) + + // We use the failAttemptCount to track how many attempts we want to + // fail. Each time the following mock method is called, the count gets + // updated. + failAttemptCount := 0 + payer.On("GetPaymentResult", + mock.Anything, identifier, mock.Anything, + ).Run(func(args mock.Arguments) { + // Before the mock method is returned, we send the result to + // the read-only chan. + + // Update the counter. + failAttemptCount++ + + // We fail the first attempt with terminal error. + if failAttemptCount == 1 { + payer.resultChan <- &htlcswitch.PaymentResult{ + Error: htlcswitch.NewForwardingError( + &lnwire.FailIncorrectDetails{}, + 1, + ), + } + return + + } + + // We will make the rest attempts failed with temporary error. + payer.resultChan <- &htlcswitch.PaymentResult{ + Error: htlcswitch.NewForwardingError( + &lnwire.FailTemporaryChannelFailure{}, + 1, + ), + } + }) + + // Mock the FailAttempt method to fail one of the attempts. + var failedAttempt channeldb.HTLCAttempt + controlTower.On("FailAttempt", + identifier, mock.Anything, mock.Anything, + ).Return(&failedAttempt, nil).Run(func(args mock.Arguments) { + // Whenever this method is invoked, we will mark the first + // active attempt as failed and exit. + for i, attempt := range payment.HTLCs { + if attempt.Settle != nil || attempt.Failure != nil { + continue + } + + attempt.Failure = &channeldb.HTLCFailInfo{} + failedAttempt = attempt + payment.HTLCs[i] = attempt + return + } + + }) + + // Setup ReportPaymentFail to return nil reason and error so the + // payment won't fail. + var called bool + failureReason := channeldb.FailureReasonPaymentDetails + missionControl.On("ReportPaymentFail", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ).Return(nil, nil).Run(func(args mock.Arguments) { + // We only return the terminal error once, thus when the method + // is called, we will return it with a nil error. + if called { + missionControl.failReason = nil + return + } + + // If it's the first time calling this method, we will return a + // terminal error. + missionControl.failReason = &failureReason + payment.FailureReason = &failureReason + called = true + }) + + // Simple mocking the rest. + controlTower.On("Fail", identifier, failureReason).Return(nil) + payer.On("SendHTLC", + mock.Anything, mock.Anything, mock.Anything, + ).Return(nil) + + // Call the actual method SendPayment on router. This is place inside a + // goroutine so we can set a timeout for the whole test, in case + // anything goes wrong and the test never finishes. + done := make(chan struct{}) + var p lntypes.Hash + go func() { + p, _, err = router.SendPayment(req) + close(done) + }() + + select { + case <-done: + case <-time.After(testTimeout): + t.Fatalf("SendPayment didn't exit") + } + + // Finally, validate the returned values and check that the mock + // methods are called as expected. + require.Error(t, err, "expected send payment error") + require.EqualValues(t, [32]byte{}, p, "preimage not match") + + controlTower.AssertExpectations(t) + payer.AssertExpectations(t) + sessionSource.AssertExpectations(t) + session.AssertExpectations(t) + missionControl.AssertExpectations(t) +} + +// TestSendMPPaymentFailedWithShardsInFlight tests that when the payment is in +// terminal state, even if we have shards in flight, we still fail the payment +// and exit. This test mainly focuses on testing the logic of the method +// resumePayment is implemented as expected. +func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) { + const startingBlockHeight = 101 + + // Create mockers to initialize the router. + controlTower := &mockControlTower{} + sessionSource := &mockPaymentSessionSource{} + missionControl := &mockMissionControl{} + payer := &mockPaymentAttemptDispatcher{} + chain := newMockChain(startingBlockHeight) + chainView := newMockChainView(chain) + testGraph := createDummyTestGraph(t) + + // Define the behavior of the mockers to the point where we can + // successfully start the router. + controlTower.On("FetchInFlightPayments").Return( + []*channeldb.MPPayment{}, nil, + ) + payer.On("CleanStore", mock.Anything).Return(nil) + + // Create and start the router. + router, err := New(Config{ + Control: controlTower, + SessionSource: sessionSource, + MissionControl: missionControl, + Payer: payer, + + // TODO(yy): create new mocks for the chain and chainview. + Chain: chain, + ChainView: chainView, + + // TODO(yy): mock the graph once it's changed into interface. + Graph: testGraph.graph, + + Clock: clock.NewTestClock(time.Unix(1, 0)), + GraphPruneInterval: time.Hour * 2, + NextPaymentID: func() (uint64, error) { + next := atomic.AddUint64(&uniquePaymentID, 1) + return next, nil + }, + }) + require.NoError(t, err, "failed to create router") + + // Make sure the router can start and stop without error. + require.NoError(t, router.Start(), "router failed to start") + defer func() { + require.NoError(t, router.Stop(), "router failed to stop") + }() + + // Once the router is started, check that the mocked methods are called + // as expected. + controlTower.AssertExpectations(t) + payer.AssertExpectations(t) + + // Mock the methods to the point where we are inside the function + // resumePayment. + paymentAmt := lnwire.MilliSatoshi(10000) + req := createDummyLightningPayment( + t, testGraph.aliasMap["c"], paymentAmt, + ) + identifier := lntypes.Hash(req.Identifier()) + session := &mockPaymentSession{} + sessionSource.On("NewPaymentSession", req).Return(session, nil) + controlTower.On("InitPayment", identifier, mock.Anything).Return(nil) + + // The following mocked methods are called inside resumePayment. Note + // that the payment object below will determine the state of the + // paymentLifecycle. + payment := &channeldb.MPPayment{} + controlTower.On("FetchPayment", identifier).Return(payment, nil) + + // Create a route that can send 1/4 of the total amount. This value + // will be returned by calling RequestRoute. + shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap) + require.NoError(t, err, "failed to create route") + session.On("RequestRoute", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ).Return(shard, nil) + + // Make a new htlc attempt with zero fee and append it to the payment's + // HTLCs when calling RegisterAttempt. + activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0) + controlTower.On("RegisterAttempt", + identifier, mock.Anything, + ).Return(nil).Run(func(args mock.Arguments) { + payment.HTLCs = append(payment.HTLCs, activeAttempt) + }) + + // Create a buffered chan and it will be returned by GetPaymentResult. + payer.resultChan = make(chan *htlcswitch.PaymentResult, 10) + + // We use the failAttemptCount to track how many attempts we want to + // fail. Each time the following mock method is called, the count gets + // updated. + failAttemptCount := 0 + payer.On("GetPaymentResult", + mock.Anything, identifier, mock.Anything, + ).Run(func(args mock.Arguments) { + // Before the mock method is returned, we send the result to + // the read-only chan. + + // Update the counter. + failAttemptCount++ + + // We fail the first attempt with terminal error. + if failAttemptCount == 1 { + payer.resultChan <- &htlcswitch.PaymentResult{ + Error: htlcswitch.NewForwardingError( + &lnwire.FailIncorrectDetails{}, + 1, + ), + } + return + + } + + // For the rest attempts we will NOT send anything to the + // resultChan, thus making all the shards in active state, + // neither settled or failed. + }) + + // Mock the FailAttempt method to fail EXACTLY once. + var failedAttempt channeldb.HTLCAttempt + controlTower.On("FailAttempt", + identifier, mock.Anything, mock.Anything, + ).Return(&failedAttempt, nil).Run(func(args mock.Arguments) { + // Whenever this method is invoked, we will mark the first + // active attempt as failed and exit. + failedAttempt = payment.HTLCs[0] + failedAttempt.Failure = &channeldb.HTLCFailInfo{} + payment.HTLCs[0] = failedAttempt + }).Once() + + // Setup ReportPaymentFail to return nil reason and error so the + // payment won't fail. + failureReason := channeldb.FailureReasonPaymentDetails + missionControl.On("ReportPaymentFail", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ).Return(failureReason, nil).Run(func(args mock.Arguments) { + missionControl.failReason = &failureReason + payment.FailureReason = &failureReason + }).Once() + + // Simple mocking the rest. + controlTower.On("Fail", identifier, failureReason).Return(nil).Once() + payer.On("SendHTLC", + mock.Anything, mock.Anything, mock.Anything, + ).Return(nil) + + // Call the actual method SendPayment on router. This is place inside a + // goroutine so we can set a timeout for the whole test, in case + // anything goes wrong and the test never finishes. + done := make(chan struct{}) + var p lntypes.Hash + go func() { + p, _, err = router.SendPayment(req) + close(done) + }() + + select { + case <-done: + case <-time.After(testTimeout): + t.Fatalf("SendPayment didn't exit") + } + + // Finally, validate the returned values and check that the mock + // methods are called as expected. + require.Error(t, err, "expected send payment error") + require.EqualValues(t, [32]byte{}, p, "preimage not match") + + controlTower.AssertExpectations(t) + payer.AssertExpectations(t) + sessionSource.AssertExpectations(t) + session.AssertExpectations(t) + missionControl.AssertExpectations(t) +}