diff --git a/routing/mock_test.go b/routing/mock_test.go index 9484f6a5..430c7fad 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -533,6 +533,8 @@ func (m *mockControlTowerOld) SubscribePayment(paymentHash lntypes.Hash) ( type mockPaymentAttemptDispatcher struct { mock.Mock + + resultChan chan *htlcswitch.PaymentResult } var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil) @@ -548,8 +550,11 @@ func (m *mockPaymentAttemptDispatcher) GetPaymentResult(attemptID uint64, paymentHash lntypes.Hash, deobfuscator htlcswitch.ErrorDecrypter) ( <-chan *htlcswitch.PaymentResult, error) { - args := m.Called(attemptID, paymentHash, deobfuscator) - return args.Get(0).(<-chan *htlcswitch.PaymentResult), args.Error(1) + m.Called(attemptID, paymentHash, deobfuscator) + + // Instead of returning the mocked returned values, we need to return + // the chan resultChan so it can be converted into a read-only chan. + return m.resultChan, nil } func (m *mockPaymentAttemptDispatcher) CleanStore( @@ -568,7 +573,7 @@ var _ PaymentSessionSource = (*mockPaymentSessionSource)(nil) func (m *mockPaymentSessionSource) NewPaymentSession( payment *LightningPayment) (PaymentSession, error) { - args := m.Called(m) + args := m.Called(payment) return args.Get(0).(PaymentSession), args.Error(1) } @@ -586,6 +591,8 @@ func (m *mockPaymentSessionSource) NewPaymentSessionEmpty() PaymentSession { type mockMissionControl struct { mock.Mock + + failReason *channeldb.FailureReason } var _ MissionController = (*mockMissionControl)(nil) @@ -596,8 +603,7 @@ func (m *mockMissionControl) ReportPaymentFail( *channeldb.FailureReason, error) { args := m.Called(paymentID, rt, failureSourceIdx, failure) - return args.Get(0).(*channeldb.FailureReason), args.Error(1) - + return m.failReason, args.Error(1) } func (m *mockMissionControl) ReportPaymentSuccess(paymentID uint64, @@ -642,6 +648,7 @@ func (m *mockPaymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey, type mockControlTower struct { mock.Mock + sync.Mutex } var _ ControlTower = (*mockControlTower)(nil) @@ -656,6 +663,9 @@ func (m *mockControlTower) InitPayment(phash lntypes.Hash, func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash, a *channeldb.HTLCAttemptInfo) error { + m.Lock() + defer m.Unlock() + args := m.Called(phash, a) return args.Error(0) } @@ -664,6 +674,9 @@ func (m *mockControlTower) SettleAttempt(phash lntypes.Hash, pid uint64, settleInfo *channeldb.HTLCSettleInfo) ( *channeldb.HTLCAttempt, error) { + m.Lock() + defer m.Unlock() + args := m.Called(phash, pid, settleInfo) return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1) } @@ -671,6 +684,9 @@ func (m *mockControlTower) SettleAttempt(phash lntypes.Hash, func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64, failInfo *channeldb.HTLCFailInfo) (*channeldb.HTLCAttempt, error) { + m.Lock() + defer m.Unlock() + args := m.Called(phash, pid, failInfo) return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1) } @@ -678,6 +694,9 @@ func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64, func (m *mockControlTower) Fail(phash lntypes.Hash, reason channeldb.FailureReason) error { + m.Lock() + defer m.Unlock() + args := m.Called(phash, reason) return args.Error(0) } @@ -685,6 +704,8 @@ func (m *mockControlTower) Fail(phash lntypes.Hash, func (m *mockControlTower) FetchPayment(phash lntypes.Hash) ( *channeldb.MPPayment, error) { + m.Lock() + defer m.Unlock() args := m.Called(phash) // Type assertion on nil will fail, so we check and return here. @@ -692,8 +713,15 @@ func (m *mockControlTower) FetchPayment(phash lntypes.Hash) ( return nil, args.Error(1) } - return args.Get(0).(*channeldb.MPPayment), args.Error(1) + // Make a copy of the payment here to avoid data race. + p := args.Get(0).(*channeldb.MPPayment) + payment := &channeldb.MPPayment{ + FailureReason: p.FailureReason, + } + payment.HTLCs = make([]channeldb.HTLCAttempt, len(p.HTLCs)) + copy(payment.HTLCs, p.HTLCs) + return payment, args.Error(1) } func (m *mockControlTower) FetchInFlightPayments() ( diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 155b0f37..5699e95f 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -38,21 +38,53 @@ type paymentState struct { numShardsInFlight int remainingAmt lnwire.MilliSatoshi remainingFees lnwire.MilliSatoshi - terminate bool + + // terminate indicates the payment is in its final stage and no more + // shards should be launched. This value is true if we have an HTLC + // settled or the payment has an error. + terminate bool } -// paymentState uses the passed payment to find the latest information we need -// to act on every iteration of the payment loop. -func (p *paymentLifecycle) paymentState(payment *channeldb.MPPayment) ( +// terminated returns a bool to indicate there are no further actions needed +// and we should return what we have, either the payment preimage or the +// payment error. +func (ps paymentState) terminated() bool { + // If the payment is in final stage and we have no in flight shards to + // wait result for, we consider the whole action terminated. + return ps.terminate && ps.numShardsInFlight == 0 +} + +// needWaitForShards returns a bool to specify whether we need to wait for the +// outcome of the shanrdHandler. +func (ps paymentState) needWaitForShards() bool { + // If we have in flight shards and the payment is in final stage, we + // need to wait for the outcomes from the shards. Or if we have no more + // money to be sent, we need to wait for the already launched shards. + if ps.numShardsInFlight == 0 { + return false + } + return ps.terminate || ps.remainingAmt == 0 +} + +// updatePaymentState will fetch db for the payment to find the latest +// information we need to act on every iteration of the payment loop and update +// the paymentState. +func (p *paymentLifecycle) updatePaymentState() (*channeldb.MPPayment, *paymentState, error) { + // Fetch the latest payment from db. + payment, err := p.router.cfg.Control.FetchPayment(p.identifier) + if err != nil { + return nil, nil, err + } + // Fetch the total amount and fees that has already been sent in // settled and still in-flight shards. sentAmt, fees := payment.SentAmt() // Sanity check we haven't sent a value larger than the payment amount. if sentAmt > p.totalAmount { - return nil, fmt.Errorf("amount sent %v exceeds "+ + return nil, nil, fmt.Errorf("amount sent %v exceeds "+ "total amount %v", sentAmt, p.totalAmount) } @@ -74,13 +106,15 @@ func (p *paymentLifecycle) paymentState(payment *channeldb.MPPayment) ( // have returned with a result. terminate := settle != nil || failure != nil - activeShards := payment.InFlightHTLCs() - return &paymentState{ - numShardsInFlight: len(activeShards), + // Update the payment state. + state := &paymentState{ + numShardsInFlight: len(payment.InFlightHTLCs()), remainingAmt: p.totalAmount - sentAmt, remainingFees: feeBudget, terminate: terminate, - }, nil + } + + return payment, state, nil } // resumePayment resumes the paymentLifecycle from the current state. @@ -102,9 +136,7 @@ func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) { // If we had any existing attempts outstanding, we'll start by spinning // up goroutines that'll collect their results and deliver them to the // lifecycle loop below. - payment, err := p.router.cfg.Control.FetchPayment( - p.identifier, - ) + payment, _, err := p.updatePaymentState() if err != nil { return [32]byte{}, nil, err } @@ -128,34 +160,30 @@ lifecycle: return [32]byte{}, nil, err } - // We start every iteration by fetching the lastest state of - // the payment from the ControlTower. This ensures that we will - // act on the latest available information, whether we are - // resuming an existing payment or just sent a new attempt. - payment, err := p.router.cfg.Control.FetchPayment( - p.identifier, - ) - if err != nil { - return [32]byte{}, nil, err - } - - // Using this latest state of the payment, calculate - // information about our active shards and terminal conditions. - state, err := p.paymentState(payment) + // We update the payment state on every iteration. Since the + // payment state is affected by multiple goroutines (ie, + // collectResultAsync), it is NOT guaranteed that we always + // have the latest state here. This is fine as long as the + // state is consistent as a whole. + payment, currentState, err := p.updatePaymentState() if err != nil { return [32]byte{}, nil, err } log.Debugf("Payment %v in state terminate=%v, "+ "active_shards=%v, rem_value=%v, fee_limit=%v", - p.identifier, state.terminate, state.numShardsInFlight, - state.remainingAmt, state.remainingFees) + p.identifier, currentState.terminate, + currentState.numShardsInFlight, + currentState.remainingAmt, currentState.remainingFees, + ) + // TODO(yy): sanity check all the states to make sure + // everything is expected. switch { // We have a terminal condition and no active shards, we are // ready to exit. - case state.terminate && state.numShardsInFlight == 0: + case currentState.terminated(): // Find the first successful shard and return // the preimage and route. for _, a := range payment.HTLCs { @@ -170,7 +198,7 @@ lifecycle: // If we either reached a terminal error condition (but had // active shards still) or there is no remaining value to send, // we'll wait for a shard outcome. - case state.terminate || state.remainingAmt == 0: + case currentState.needWaitForShards(): // We still have outstanding shards, so wait for a new // outcome to be available before re-evaluating our // state. @@ -212,8 +240,9 @@ lifecycle: // Create a new payment attempt from the given payment session. rt, err := p.paySession.RequestRoute( - state.remainingAmt, state.remainingFees, - uint32(state.numShardsInFlight), uint32(p.currentHeight), + currentState.remainingAmt, currentState.remainingFees, + uint32(currentState.numShardsInFlight), + uint32(p.currentHeight), ) if err != nil { log.Warnf("Failed to find route for payment %v: %v", @@ -227,7 +256,7 @@ lifecycle: // There is no route to try, and we have no active // shards. This means that there is no way for us to // send the payment, so mark it failed with no route. - if state.numShardsInFlight == 0 { + if currentState.numShardsInFlight == 0 { failureCode := routeErr.FailureReason() log.Debugf("Marking payment %v permanently "+ "failed with no route: %v", @@ -253,22 +282,11 @@ lifecycle: // If this route will consume the last remeining amount to send // to the receiver, this will be our last shard (for now). - lastShard := rt.ReceiverAmt() == state.remainingAmt + lastShard := rt.ReceiverAmt() == currentState.remainingAmt // We found a route to try, launch a new shard. attempt, outcome, err := shardHandler.launchShard(rt, lastShard) - switch { - // We may get a terminal error if we've processed a shard with - // a terminal state (settled or permanent failure), while we - // were pathfinding. We know we're in a terminal state here, - // so we can continue and wait for our last shards to return. - case err == channeldb.ErrPaymentTerminal: - log.Infof("Payment %v in terminal state, abandoning "+ - "shard", p.identifier) - - continue lifecycle - - case err != nil: + if err != nil { return [32]byte{}, nil, err } @@ -297,6 +315,7 @@ lifecycle: // Now that the shard was successfully sent, launch a go // routine that will handle its result when its back. shardHandler.collectResultAsync(attempt) + } } @@ -437,12 +456,30 @@ type shardResult struct { } // collectResultAsync launches a goroutine that will wait for the result of the -// given HTLC attempt to be available then handle its result. Note that it will -// fail the payment with the control tower if a terminal error is encountered. +// given HTLC attempt to be available then handle its result. It will fail the +// payment with the control tower if a terminal error is encountered. func (p *shardHandler) collectResultAsync(attempt *channeldb.HTLCAttemptInfo) { + + // errToSend is the error to be sent to sh.shardErrors. + var errToSend error + + // handleResultErr is a function closure must be called using defer. It + // finishes collecting result by updating the payment state and send + // the error (or nil) to sh.shardErrors. + handleResultErr := func() { + // Send the error or quit. + select { + case p.shardErrors <- errToSend: + case <-p.router.quit: + case <-p.quit: + } + + p.wg.Done() + } + p.wg.Add(1) go func() { - defer p.wg.Done() + defer handleResultErr() // Block until the result is available. result, err := p.collectResult(attempt) @@ -456,32 +493,18 @@ func (p *shardHandler) collectResultAsync(attempt *channeldb.HTLCAttemptInfo) { attempt.AttemptID, p.identifier, err) } - select { - case p.shardErrors <- err: - case <-p.router.quit: - case <-p.quit: - } + // Overwrite errToSend and return. + errToSend = err return } // If a non-critical error was encountered handle it and mark // the payment failed if the failure was terminal. if result.err != nil { - err := p.handleSendError(attempt, result.err) - if err != nil { - select { - case p.shardErrors <- err: - case <-p.router.quit: - case <-p.quit: - } - return - } - } - - select { - case p.shardErrors <- nil: - case <-p.router.quit: - case <-p.quit: + // Overwrite errToSend and return. Notice that the + // errToSend could be nil here. + errToSend = p.handleSendError(attempt, result.err) + return } }() } diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 7f860284..74f6dfbd 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -195,14 +195,6 @@ func TestRouterPaymentStateMachine(t *testing.T) { t.Fatalf("unable to create route: %v", err) } - halfShard, err := createTestRoute(paymentAmt/2, testGraph.aliasMap) - require.NoError(t, err, "unable to create half route") - - shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap) - if err != nil { - t.Fatalf("unable to create route: %v", err) - } - tests := []paymentLifecycleTestCase{ { // Tests a normal payment flow that succeeds. @@ -425,280 +417,6 @@ func TestRouterPaymentStateMachine(t *testing.T) { routes: []*route.Route{rt}, paymentErr: channeldb.FailureReasonNoRoute, }, - - // ===================================== - // || MPP scenarios || - // ===================================== - { - // Tests a simple successful MP payment of 4 shards. - name: "MP success", - - steps: []string{ - routerInitPayment, - - // shard 0 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 1 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 2 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 3 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // All shards succeed. - getPaymentResultSuccess, - getPaymentResultSuccess, - getPaymentResultSuccess, - getPaymentResultSuccess, - - // Router should settle them all. - routerSettleAttempt, - routerSettleAttempt, - routerSettleAttempt, - routerSettleAttempt, - - // And the final result is obviously - // successful. - paymentSuccess, - }, - routes: []*route.Route{shard, shard, shard, shard}, - }, - { - // An MP payment scenario where we need several extra - // attempts before the payment finally settle. - name: "MP failed shards", - - steps: []string{ - routerInitPayment, - - // shard 0 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 1 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 2 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 3 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // First two shards fail, two new ones are sent. - getPaymentResultTempFailure, - getPaymentResultTempFailure, - routerFailAttempt, - routerFailAttempt, - - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // The four shards settle. - getPaymentResultSuccess, - getPaymentResultSuccess, - getPaymentResultSuccess, - getPaymentResultSuccess, - routerSettleAttempt, - routerSettleAttempt, - routerSettleAttempt, - routerSettleAttempt, - - // Overall payment succeeds. - paymentSuccess, - }, - routes: []*route.Route{ - shard, shard, shard, shard, shard, shard, - }, - }, - { - // An MP payment scenario where one of the shards fails, - // but we still receive a single success shard. - name: "MP one shard success", - - steps: []string{ - routerInitPayment, - - // shard 0 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 1 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 0 fails, and should be failed by the - // router. - getPaymentResultTempFailure, - routerFailAttempt, - - // We will try one more shard because we haven't - // sent the full payment amount. - routeRelease, - - // The second shard succeed against all odds, - // making the overall payment succeed. - getPaymentResultSuccess, - routerSettleAttempt, - paymentSuccess, - }, - routes: []*route.Route{halfShard, halfShard}, - }, - { - // An MP payment scenario a shard fail with a terminal - // error, causing the router to stop attempting. - name: "MP terminal", - - steps: []string{ - routerInitPayment, - - // shard 0 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 1 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 2 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 3 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // The first shard fail with a terminal error. - getPaymentResultTerminalFailure, - routerFailAttempt, - routerFailPayment, - - // Remaining 3 shards fail. - getPaymentResultTempFailure, - getPaymentResultTempFailure, - getPaymentResultTempFailure, - routerFailAttempt, - routerFailAttempt, - routerFailAttempt, - - // Payment fails. - paymentError, - }, - routes: []*route.Route{ - shard, shard, shard, shard, shard, shard, - }, - paymentErr: channeldb.FailureReasonPaymentDetails, - }, - { - // A MP payment scenario when our path finding returns - // after we've just received a terminal failure, and - // attempts to dispatch a new shard. Testing that we - // correctly abandon the shard and conclude the payment. - name: "MP path found after failure", - - steps: []string{ - routerInitPayment, - - // shard 0 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // The first shard fail with a terminal error. - getPaymentResultTerminalFailure, - routerFailAttempt, - routerFailPayment, - - // shard 1 fails because we've had a terminal - // failure. - routeRelease, - routerRegisterAttempt, - - // Payment fails. - paymentError, - }, - routes: []*route.Route{ - shard, shard, - }, - paymentErr: channeldb.FailureReasonPaymentDetails, - }, - { - // A MP payment scenario when our path finding returns - // after we've just received a terminal failure, and - // we have another shard still in flight. - name: "MP shard in flight after terminal", - - steps: []string{ - routerInitPayment, - - // shard 0 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 1 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // shard 2 - routeRelease, - routerRegisterAttempt, - sendToSwitchSuccess, - - // We find a path for another shard. - routeRelease, - - // shard 0 fails with a terminal error. - getPaymentResultTerminalFailure, - routerFailAttempt, - routerFailPayment, - - // We try to register our final shard after - // processing a terminal failure. - routerRegisterAttempt, - - // Our in-flight shards fail. - getPaymentResultTempFailure, - getPaymentResultTempFailure, - routerFailAttempt, - routerFailAttempt, - - // Payment fails. - paymentError, - }, - routes: []*route.Route{ - shard, shard, shard, shard, - }, - paymentErr: channeldb.FailureReasonPaymentDetails, - }, } for _, test := range tests { @@ -1080,3 +798,343 @@ func testPaymentLifecycle(t *testing.T, test paymentLifecycleTestCase, t.Fatalf("SendPayment didn't exit") } } + +// TestPaymentState tests that the logics implemented on paymentState struct +// are as expected. In particular, that the method terminated and +// needWaitForShards return the right values. +func TestPaymentState(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + + // Use the following three params, each is equivalent to a bool + // statement, to construct 8 test cases so that we can + // exhaustively catch all possible states. + numShardsInFlight int + remainingAmt lnwire.MilliSatoshi + terminate bool + + expectedTerminated bool + expectedNeedWaitForShards bool + }{ + { + // If we have active shards and terminate is marked + // false, the state is not terminated. Since the + // remaining amount is zero, we need to wait for shards + // to be finished and launch no more shards. + name: "state 100", + numShardsInFlight: 1, + remainingAmt: lnwire.MilliSatoshi(0), + terminate: false, + expectedTerminated: false, + expectedNeedWaitForShards: true, + }, + { + // If we have active shards while terminate is marked + // true, the state is not terminated, and we need to + // wait for shards to be finished and launch no more + // shards. + name: "state 101", + numShardsInFlight: 1, + remainingAmt: lnwire.MilliSatoshi(0), + terminate: true, + expectedTerminated: false, + expectedNeedWaitForShards: true, + }, + + { + // If we have active shards and terminate is marked + // false, the state is not terminated. Since the + // remaining amount is not zero, we don't need to wait + // for shards outcomes and should launch more shards. + name: "state 110", + numShardsInFlight: 1, + remainingAmt: lnwire.MilliSatoshi(1), + terminate: false, + expectedTerminated: false, + expectedNeedWaitForShards: false, + }, + { + // If we have active shards and terminate is marked + // true, the state is not terminated. Even the + // remaining amount is not zero, we need to wait for + // shards outcomes because state is terminated. + name: "state 111", + numShardsInFlight: 1, + remainingAmt: lnwire.MilliSatoshi(1), + terminate: true, + expectedTerminated: false, + expectedNeedWaitForShards: true, + }, + { + // If we have no active shards while terminate is marked + // false, the state is not terminated, and we don't + // need to wait for more shard outcomes because there + // are no active shards. + name: "state 000", + numShardsInFlight: 0, + remainingAmt: lnwire.MilliSatoshi(0), + terminate: false, + expectedTerminated: false, + expectedNeedWaitForShards: false, + }, + { + // If we have no active shards while terminate is marked + // true, the state is terminated, and we don't need to + // wait for shards to be finished. + name: "state 001", + numShardsInFlight: 0, + remainingAmt: lnwire.MilliSatoshi(0), + terminate: true, + expectedTerminated: true, + expectedNeedWaitForShards: false, + }, + { + // If we have no active shards while terminate is marked + // false, the state is not terminated. Since the + // remaining amount is not zero, we don't need to wait + // for shards outcomes and should launch more shards. + name: "state 010", + numShardsInFlight: 0, + remainingAmt: lnwire.MilliSatoshi(1), + terminate: false, + expectedTerminated: false, + expectedNeedWaitForShards: false, + }, + { + // If we have no active shards while terminate is marked + // true, the state is terminated, and we don't need to + // wait for shards outcomes. + name: "state 011", + numShardsInFlight: 0, + remainingAmt: lnwire.MilliSatoshi(1), + terminate: true, + expectedTerminated: true, + expectedNeedWaitForShards: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ps := &paymentState{ + numShardsInFlight: tc.numShardsInFlight, + remainingAmt: tc.remainingAmt, + terminate: tc.terminate, + } + + require.Equal( + t, tc.expectedTerminated, ps.terminated(), + "terminated returned wrong value", + ) + require.Equal( + t, tc.expectedNeedWaitForShards, + ps.needWaitForShards(), + "needWaitForShards returned wrong value", + ) + }) + } + +} + +// TestUpdatePaymentState checks that the method updatePaymentState updates the +// paymentState as expected. +func TestUpdatePaymentState(t *testing.T) { + t.Parallel() + + // paymentHash is the identifier on paymentLifecycle. + paymentHash := lntypes.Hash{} + + // TODO(yy): make MPPayment into an interface so we can mock it. The + // current design implicitly tests the methods SendAmt, TerminalInfo, + // and InFlightHTLCs on channeldb.MPPayment, which is not good. Once + // MPPayment becomes an interface, we can then mock these methods here. + + // SentAmt returns 90, 10 + // TerminalInfo returns non-nil, nil + // InFlightHTLCs returns 0 + var preimage lntypes.Preimage + paymentSettled := &channeldb.MPPayment{ + HTLCs: []channeldb.HTLCAttempt{ + makeSettledAttempt(100, 10, preimage), + }, + } + + // SentAmt returns 0, 0 + // TerminalInfo returns nil, non-nil + // InFlightHTLCs returns 0 + reason := channeldb.FailureReasonError + paymentFailed := &channeldb.MPPayment{ + FailureReason: &reason, + } + + // SentAmt returns 90, 10 + // TerminalInfo returns nil, nil + // InFlightHTLCs returns 1 + paymentActive := &channeldb.MPPayment{ + HTLCs: []channeldb.HTLCAttempt{ + makeActiveAttempt(100, 10), + makeFailedAttempt(100, 10), + }, + } + + testCases := []struct { + name string + payment *channeldb.MPPayment + totalAmt int + feeLimit int + + expectedState *paymentState + shouldReturnError bool + }{ + { + // Test that the error returned from FetchPayment is + // handled properly. We use a nil payment to indicate + // we want to return an error. + name: "fetch payment error", + payment: nil, + shouldReturnError: true, + }, + { + // Test that when the sentAmt exceeds totalAmount, the + // error is returned. + name: "amount exceeded error", + payment: paymentSettled, + totalAmt: 1, + shouldReturnError: true, + }, + { + // Test that when the fee budget is reached, the + // remaining fee should be zero. + name: "fee budget reached", + payment: paymentActive, + totalAmt: 1000, + feeLimit: 1, + expectedState: &paymentState{ + numShardsInFlight: 1, + remainingAmt: 1000 - 90, + remainingFees: 0, + terminate: false, + }, + }, + { + // Test when the payment is settled, the state should + // be marked as terminated. + name: "payment settled", + payment: paymentSettled, + totalAmt: 1000, + feeLimit: 100, + expectedState: &paymentState{ + numShardsInFlight: 0, + remainingAmt: 1000 - 90, + remainingFees: 100 - 10, + terminate: true, + }, + }, + { + // Test when the payment is failed, the state should be + // marked as terminated. + name: "payment failed", + payment: paymentFailed, + totalAmt: 1000, + feeLimit: 100, + expectedState: &paymentState{ + numShardsInFlight: 0, + remainingAmt: 1000, + remainingFees: 100, + terminate: true, + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // Create mock control tower and assign it to router. + // We will then use the router and the paymentHash + // above to create our paymentLifecycle for this test. + ct := &mockControlTower{} + rt := &ChannelRouter{cfg: &Config{Control: ct}} + pl := &paymentLifecycle{ + router: rt, + identifier: paymentHash, + totalAmount: lnwire.MilliSatoshi(tc.totalAmt), + feeLimit: lnwire.MilliSatoshi(tc.feeLimit), + } + + if tc.payment == nil { + // A nil payment indicates we want to test an + // error returned from FetchPayment. + dummyErr := errors.New("dummy") + ct.On("FetchPayment", paymentHash).Return( + nil, dummyErr, + ) + + } else { + // Otherwise we will return the payment. + ct.On("FetchPayment", paymentHash).Return( + tc.payment, nil, + ) + } + + // Call the method that updates the payment state. + _, state, err := pl.updatePaymentState() + + // Assert that the mock method is called as + // intended. + ct.AssertExpectations(t) + + if tc.shouldReturnError { + require.Error(t, err, "expect an error") + return + } + + require.NoError(t, err, "unexpected error") + require.Equal( + t, tc.expectedState, state, + "state not updated as expected", + ) + + }) + } + +} + +func makeActiveAttempt(total, fee int) channeldb.HTLCAttempt { + return channeldb.HTLCAttempt{ + HTLCAttemptInfo: makeAttemptInfo(total, total-fee), + } +} + +func makeSettledAttempt(total, fee int, + preimage lntypes.Preimage) channeldb.HTLCAttempt { + + return channeldb.HTLCAttempt{ + HTLCAttemptInfo: makeAttemptInfo(total, total-fee), + Settle: &channeldb.HTLCSettleInfo{Preimage: preimage}, + } +} + +func makeFailedAttempt(total, fee int) channeldb.HTLCAttempt { + return channeldb.HTLCAttempt{ + HTLCAttemptInfo: makeAttemptInfo(total, total-fee), + Failure: &channeldb.HTLCFailInfo{ + Reason: channeldb.HTLCFailInternal, + }, + } +} + +func makeAttemptInfo(total, amtForwarded int) channeldb.HTLCAttemptInfo { + hop := &route.Hop{AmtToForward: lnwire.MilliSatoshi(amtForwarded)} + return channeldb.HTLCAttemptInfo{ + Route: route.Route{ + TotalAmount: lnwire.MilliSatoshi(total), + Hops: []*route.Hop{hop}, + }, + } +} diff --git a/routing/router_test.go b/routing/router_test.go index ab59739c..29ad342d 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -15,6 +15,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/davecgh/go-spew/spew" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/lightningnetwork/lnd/channeldb" @@ -1069,7 +1070,8 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { _, ok = msg.(*lnwire.FailUnknownNextPeer) require.True(t, ok, "unexpected fail message") - ctx.router.cfg.MissionControl.(*MissionControl).ResetHistory() + err = ctx.router.cfg.MissionControl.(*MissionControl).ResetHistory() + require.NoError(t, err, "reset history failed") // Next, we'll modify the SendToSwitch method to indicate that the // connection between songoku and isn't up. @@ -3436,3 +3438,796 @@ func TestChannelOnChainRejectionZombie(t *testing.T) { require.Nil(t, err) assertChanChainRejection(t, ctx, edge, ErrInvalidFundingOutput) } + +func createDummyTestGraph(t *testing.T) *testGraphInstance { + // Setup two simple channels such that we can mock sending along this + // route. + chanCapSat := btcutil.Amount(100000) + testChannels := []*testChannel{ + symmetricTestChannel("a", "b", chanCapSat, &testChannelPolicy{ + Expiry: 144, + FeeRate: 400, + MinHTLC: 1, + MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat), + }, 1), + symmetricTestChannel("b", "c", chanCapSat, &testChannelPolicy{ + Expiry: 144, + FeeRate: 400, + MinHTLC: 1, + MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat), + }, 2), + } + + testGraph, err := createTestGraphFromChannels(testChannels, "a") + require.NoError(t, err, "failed to create graph") + return testGraph +} + +func createDummyLightningPayment(t *testing.T, + target route.Vertex, amt lnwire.MilliSatoshi) *LightningPayment { + + var preImage lntypes.Preimage + _, err := rand.Read(preImage[:]) + require.NoError(t, err, "unable to generate preimage") + + payHash := preImage.Hash() + + return &LightningPayment{ + Target: target, + Amount: amt, + FeeLimit: noFeeLimit, + paymentHash: &payHash, + } +} + +// TestSendMPPaymentSucceed tests that we can successfully send a MPPayment via +// router.SendPayment. This test mainly focuses on testing the logic of the +// method resumePayment is implemented as expected. +func TestSendMPPaymentSucceed(t *testing.T) { + const startingBlockHeight = 101 + + // Create mockers to initialize the router. + controlTower := &mockControlTower{} + sessionSource := &mockPaymentSessionSource{} + missionControl := &mockMissionControl{} + payer := &mockPaymentAttemptDispatcher{} + chain := newMockChain(startingBlockHeight) + chainView := newMockChainView(chain) + testGraph := createDummyTestGraph(t) + + // Define the behavior of the mockers to the point where we can + // successfully start the router. + controlTower.On("FetchInFlightPayments").Return( + []*channeldb.MPPayment{}, nil, + ) + payer.On("CleanStore", mock.Anything).Return(nil) + + // Create and start the router. + router, err := New(Config{ + Control: controlTower, + SessionSource: sessionSource, + MissionControl: missionControl, + Payer: payer, + + // TODO(yy): create new mocks for the chain and chainview. + Chain: chain, + ChainView: chainView, + + // TODO(yy): mock the graph once it's changed into interface. + Graph: testGraph.graph, + + Clock: clock.NewTestClock(time.Unix(1, 0)), + GraphPruneInterval: time.Hour * 2, + NextPaymentID: func() (uint64, error) { + next := atomic.AddUint64(&uniquePaymentID, 1) + return next, nil + }, + }) + require.NoError(t, err, "failed to create router") + + // Make sure the router can start and stop without error. + require.NoError(t, router.Start(), "router failed to start") + defer func() { + require.NoError(t, router.Stop(), "router failed to stop") + }() + + // Once the router is started, check that the mocked methods are called + // as expected. + controlTower.AssertExpectations(t) + payer.AssertExpectations(t) + + // Mock the methods to the point where we are inside the function + // resumePayment. + paymentAmt := lnwire.MilliSatoshi(10000) + req := createDummyLightningPayment( + t, testGraph.aliasMap["c"], paymentAmt, + ) + identifier := lntypes.Hash(req.Identifier()) + session := &mockPaymentSession{} + sessionSource.On("NewPaymentSession", req).Return(session, nil) + controlTower.On("InitPayment", identifier, mock.Anything).Return(nil) + + // The following mocked methods are called inside resumePayment. Note + // that the payment object below will determine the state of the + // paymentLifecycle. + payment := &channeldb.MPPayment{} + controlTower.On("FetchPayment", identifier).Return(payment, nil) + + // Create a route that can send 1/4 of the total amount. This value + // will be returned by calling RequestRoute. + shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap) + require.NoError(t, err, "failed to create route") + session.On("RequestRoute", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ).Return(shard, nil) + + // Make a new htlc attempt with zero fee and append it to the payment's + // HTLCs when calling RegisterAttempt. + activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0) + controlTower.On("RegisterAttempt", + identifier, mock.Anything, + ).Return(nil).Run(func(args mock.Arguments) { + payment.HTLCs = append(payment.HTLCs, activeAttempt) + }) + + // Create a buffered chan and it will be returned by GetPaymentResult. + payer.resultChan = make(chan *htlcswitch.PaymentResult, 10) + payer.On("GetPaymentResult", + mock.Anything, identifier, mock.Anything, + ).Run(func(args mock.Arguments) { + // Before the mock method is returned, we send the result to + // the read-only chan. + payer.resultChan <- &htlcswitch.PaymentResult{} + }) + + // Simple mocking the rest. + payer.On("SendHTLC", + mock.Anything, mock.Anything, mock.Anything, + ).Return(nil) + missionControl.On("ReportPaymentSuccess", + mock.Anything, mock.Anything, + ).Return(nil) + + // Mock SettleAttempt by changing one of the HTLCs to be settled. + preimage := lntypes.Preimage{1, 2, 3} + settledAttempt := makeSettledAttempt( + int(paymentAmt/4), 0, preimage, + ) + controlTower.On("SettleAttempt", + identifier, mock.Anything, mock.Anything, + ).Return(&settledAttempt, nil).Run(func(args mock.Arguments) { + // Whenever this method is invoked, we will mark the first + // active attempt settled and exit. + for i, attempt := range payment.HTLCs { + if attempt.Settle == nil { + attempt.Settle = &channeldb.HTLCSettleInfo{ + Preimage: preimage, + } + payment.HTLCs[i] = attempt + return + } + } + }) + + // Call the actual method SendPayment on router. This is place inside a + // goroutine so we can set a timeout for the whole test, in case + // anything goes wrong and the test never finishes. + done := make(chan struct{}) + var p lntypes.Hash + go func() { + p, _, err = router.SendPayment(req) + close(done) + }() + + select { + case <-done: + case <-time.After(testTimeout): + t.Fatalf("SendPayment didn't exit") + } + + // Finally, validate the returned values and check that the mock + // methods are called as expected. + require.NoError(t, err, "send payment failed") + require.EqualValues(t, preimage, p, "preimage not match") + + // Note that we also implicitly check the methods such as FailAttempt, + // ReportPaymentFail, etc, are not called because we never mocked them + // in this test. If any of the unexpected methods was called, the test + // would fail. + controlTower.AssertExpectations(t) + payer.AssertExpectations(t) + sessionSource.AssertExpectations(t) + session.AssertExpectations(t) + missionControl.AssertExpectations(t) +} + +// TestSendMPPaymentSucceedOnExtraShards tests that we need extra attempts if +// there are failed ones,so that a payment is successfully sent. This test +// mainly focuses on testing the logic of the method resumePayment is +// implemented as expected. +func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) { + const startingBlockHeight = 101 + + // Create mockers to initialize the router. + controlTower := &mockControlTower{} + sessionSource := &mockPaymentSessionSource{} + missionControl := &mockMissionControl{} + payer := &mockPaymentAttemptDispatcher{} + chain := newMockChain(startingBlockHeight) + chainView := newMockChainView(chain) + testGraph := createDummyTestGraph(t) + + // Define the behavior of the mockers to the point where we can + // successfully start the router. + controlTower.On("FetchInFlightPayments").Return( + []*channeldb.MPPayment{}, nil, + ) + payer.On("CleanStore", mock.Anything).Return(nil) + + // Create and start the router. + router, err := New(Config{ + Control: controlTower, + SessionSource: sessionSource, + MissionControl: missionControl, + Payer: payer, + + // TODO(yy): create new mocks for the chain and chainview. + Chain: chain, + ChainView: chainView, + + // TODO(yy): mock the graph once it's changed into interface. + Graph: testGraph.graph, + + Clock: clock.NewTestClock(time.Unix(1, 0)), + GraphPruneInterval: time.Hour * 2, + NextPaymentID: func() (uint64, error) { + next := atomic.AddUint64(&uniquePaymentID, 1) + return next, nil + }, + }) + require.NoError(t, err, "failed to create router") + + // Make sure the router can start and stop without error. + require.NoError(t, router.Start(), "router failed to start") + defer func() { + require.NoError(t, router.Stop(), "router failed to stop") + }() + + // Once the router is started, check that the mocked methods are called + // as expected. + controlTower.AssertExpectations(t) + payer.AssertExpectations(t) + + // Mock the methods to the point where we are inside the function + // resumePayment. + paymentAmt := lnwire.MilliSatoshi(20000) + req := createDummyLightningPayment( + t, testGraph.aliasMap["c"], paymentAmt, + ) + identifier := lntypes.Hash(req.Identifier()) + session := &mockPaymentSession{} + sessionSource.On("NewPaymentSession", req).Return(session, nil) + controlTower.On("InitPayment", identifier, mock.Anything).Return(nil) + + // The following mocked methods are called inside resumePayment. Note + // that the payment object below will determine the state of the + // paymentLifecycle. + payment := &channeldb.MPPayment{} + controlTower.On("FetchPayment", identifier).Return(payment, nil) + + // Create a route that can send 1/4 of the total amount. This value + // will be returned by calling RequestRoute. + shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap) + require.NoError(t, err, "failed to create route") + session.On("RequestRoute", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ).Return(shard, nil) + + // Make a new htlc attempt with zero fee and append it to the payment's + // HTLCs when calling RegisterAttempt. + activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0) + controlTower.On("RegisterAttempt", + identifier, mock.Anything, + ).Return(nil).Run(func(args mock.Arguments) { + payment.HTLCs = append(payment.HTLCs, activeAttempt) + }) + + // Create a buffered chan and it will be returned by GetPaymentResult. + payer.resultChan = make(chan *htlcswitch.PaymentResult, 10) + + // We use the failAttemptCount to track how many attempts we want to + // fail. Each time the following mock method is called, the count gets + // updated. + failAttemptCount := 0 + payer.On("GetPaymentResult", + mock.Anything, identifier, mock.Anything, + ).Run(func(args mock.Arguments) { + // Before the mock method is returned, we send the result to + // the read-only chan. + + // Update the counter. + failAttemptCount++ + + // We will make the first two attempts failed with temporary + // error. + if failAttemptCount <= 2 { + payer.resultChan <- &htlcswitch.PaymentResult{ + Error: htlcswitch.NewForwardingError( + &lnwire.FailTemporaryChannelFailure{}, + 1, + ), + } + return + } + + // Otherwise we will mark the attempt succeeded. + payer.resultChan <- &htlcswitch.PaymentResult{} + }) + + // Mock the FailAttempt method to fail one of the attempts. + var failedAttempt channeldb.HTLCAttempt + controlTower.On("FailAttempt", + identifier, mock.Anything, mock.Anything, + ).Return(&failedAttempt, nil).Run(func(args mock.Arguments) { + // Whenever this method is invoked, we will mark the first + // active attempt as failed and exit. + for i, attempt := range payment.HTLCs { + if attempt.Settle != nil || attempt.Failure != nil { + continue + } + + attempt.Failure = &channeldb.HTLCFailInfo{} + failedAttempt = attempt + payment.HTLCs[i] = attempt + return + } + + }) + + // Setup ReportPaymentFail to return nil reason and error so the + // payment won't fail. + missionControl.On("ReportPaymentFail", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ).Return(nil, nil) + + // Simple mocking the rest. + payer.On("SendHTLC", + mock.Anything, mock.Anything, mock.Anything, + ).Return(nil) + missionControl.On("ReportPaymentSuccess", + mock.Anything, mock.Anything, + ).Return(nil) + + // Mock SettleAttempt by changing one of the HTLCs to be settled. + preimage := lntypes.Preimage{1, 2, 3} + settledAttempt := makeSettledAttempt( + int(paymentAmt/4), 0, preimage, + ) + controlTower.On("SettleAttempt", + identifier, mock.Anything, mock.Anything, + ).Return(&settledAttempt, nil).Run(func(args mock.Arguments) { + // Whenever this method is invoked, we will mark the first + // active attempt settled and exit. + for i, attempt := range payment.HTLCs { + if attempt.Settle != nil || attempt.Failure != nil { + continue + } + + attempt.Settle = &channeldb.HTLCSettleInfo{ + Preimage: preimage, + } + payment.HTLCs[i] = attempt + return + } + }) + + // Call the actual method SendPayment on router. This is place inside a + // goroutine so we can set a timeout for the whole test, in case + // anything goes wrong and the test never finishes. + done := make(chan struct{}) + var p lntypes.Hash + go func() { + p, _, err = router.SendPayment(req) + close(done) + }() + + select { + case <-done: + case <-time.After(testTimeout): + t.Fatalf("SendPayment didn't exit") + } + + // Finally, validate the returned values and check that the mock + // methods are called as expected. + require.NoError(t, err, "send payment failed") + require.EqualValues(t, preimage, p, "preimage not match") + + controlTower.AssertExpectations(t) + payer.AssertExpectations(t) + sessionSource.AssertExpectations(t) + session.AssertExpectations(t) + missionControl.AssertExpectations(t) +} + +// TestSendMPPaymentFailed tests that when one of the shard fails with a +// terminal error, the router will stop attempting and the payment will fail. +// This test mainly focuses on testing the logic of the method resumePayment +// is implemented as expected. +func TestSendMPPaymentFailed(t *testing.T) { + const startingBlockHeight = 101 + + // Create mockers to initialize the router. + controlTower := &mockControlTower{} + sessionSource := &mockPaymentSessionSource{} + missionControl := &mockMissionControl{} + payer := &mockPaymentAttemptDispatcher{} + chain := newMockChain(startingBlockHeight) + chainView := newMockChainView(chain) + testGraph := createDummyTestGraph(t) + + // Define the behavior of the mockers to the point where we can + // successfully start the router. + controlTower.On("FetchInFlightPayments").Return( + []*channeldb.MPPayment{}, nil, + ) + payer.On("CleanStore", mock.Anything).Return(nil) + + // Create and start the router. + router, err := New(Config{ + Control: controlTower, + SessionSource: sessionSource, + MissionControl: missionControl, + Payer: payer, + + // TODO(yy): create new mocks for the chain and chainview. + Chain: chain, + ChainView: chainView, + + // TODO(yy): mock the graph once it's changed into interface. + Graph: testGraph.graph, + + Clock: clock.NewTestClock(time.Unix(1, 0)), + GraphPruneInterval: time.Hour * 2, + NextPaymentID: func() (uint64, error) { + next := atomic.AddUint64(&uniquePaymentID, 1) + return next, nil + }, + }) + require.NoError(t, err, "failed to create router") + + // Make sure the router can start and stop without error. + require.NoError(t, router.Start(), "router failed to start") + defer func() { + require.NoError(t, router.Stop(), "router failed to stop") + }() + + // Once the router is started, check that the mocked methods are called + // as expected. + controlTower.AssertExpectations(t) + payer.AssertExpectations(t) + + // Mock the methods to the point where we are inside the function + // resumePayment. + paymentAmt := lnwire.MilliSatoshi(10000) + req := createDummyLightningPayment( + t, testGraph.aliasMap["c"], paymentAmt, + ) + identifier := lntypes.Hash(req.Identifier()) + session := &mockPaymentSession{} + sessionSource.On("NewPaymentSession", req).Return(session, nil) + controlTower.On("InitPayment", identifier, mock.Anything).Return(nil) + + // The following mocked methods are called inside resumePayment. Note + // that the payment object below will determine the state of the + // paymentLifecycle. + payment := &channeldb.MPPayment{} + controlTower.On("FetchPayment", identifier).Return(payment, nil) + + // Create a route that can send 1/4 of the total amount. This value + // will be returned by calling RequestRoute. + shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap) + require.NoError(t, err, "failed to create route") + session.On("RequestRoute", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ).Return(shard, nil) + + // Make a new htlc attempt with zero fee and append it to the payment's + // HTLCs when calling RegisterAttempt. + activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0) + controlTower.On("RegisterAttempt", + identifier, mock.Anything, + ).Return(nil).Run(func(args mock.Arguments) { + payment.HTLCs = append(payment.HTLCs, activeAttempt) + }) + + // Create a buffered chan and it will be returned by GetPaymentResult. + payer.resultChan = make(chan *htlcswitch.PaymentResult, 10) + + // We use the failAttemptCount to track how many attempts we want to + // fail. Each time the following mock method is called, the count gets + // updated. + failAttemptCount := 0 + payer.On("GetPaymentResult", + mock.Anything, identifier, mock.Anything, + ).Run(func(args mock.Arguments) { + // Before the mock method is returned, we send the result to + // the read-only chan. + + // Update the counter. + failAttemptCount++ + + // We fail the first attempt with terminal error. + if failAttemptCount == 1 { + payer.resultChan <- &htlcswitch.PaymentResult{ + Error: htlcswitch.NewForwardingError( + &lnwire.FailIncorrectDetails{}, + 1, + ), + } + return + + } + + // We will make the rest attempts failed with temporary error. + payer.resultChan <- &htlcswitch.PaymentResult{ + Error: htlcswitch.NewForwardingError( + &lnwire.FailTemporaryChannelFailure{}, + 1, + ), + } + }) + + // Mock the FailAttempt method to fail one of the attempts. + var failedAttempt channeldb.HTLCAttempt + controlTower.On("FailAttempt", + identifier, mock.Anything, mock.Anything, + ).Return(&failedAttempt, nil).Run(func(args mock.Arguments) { + // Whenever this method is invoked, we will mark the first + // active attempt as failed and exit. + for i, attempt := range payment.HTLCs { + if attempt.Settle != nil || attempt.Failure != nil { + continue + } + + attempt.Failure = &channeldb.HTLCFailInfo{} + failedAttempt = attempt + payment.HTLCs[i] = attempt + return + } + + }) + + // Setup ReportPaymentFail to return nil reason and error so the + // payment won't fail. + var called bool + failureReason := channeldb.FailureReasonPaymentDetails + missionControl.On("ReportPaymentFail", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ).Return(nil, nil).Run(func(args mock.Arguments) { + // We only return the terminal error once, thus when the method + // is called, we will return it with a nil error. + if called { + missionControl.failReason = nil + return + } + + // If it's the first time calling this method, we will return a + // terminal error. + missionControl.failReason = &failureReason + payment.FailureReason = &failureReason + called = true + }) + + // Simple mocking the rest. + controlTower.On("Fail", identifier, failureReason).Return(nil) + payer.On("SendHTLC", + mock.Anything, mock.Anything, mock.Anything, + ).Return(nil) + + // Call the actual method SendPayment on router. This is place inside a + // goroutine so we can set a timeout for the whole test, in case + // anything goes wrong and the test never finishes. + done := make(chan struct{}) + var p lntypes.Hash + go func() { + p, _, err = router.SendPayment(req) + close(done) + }() + + select { + case <-done: + case <-time.After(testTimeout): + t.Fatalf("SendPayment didn't exit") + } + + // Finally, validate the returned values and check that the mock + // methods are called as expected. + require.Error(t, err, "expected send payment error") + require.EqualValues(t, [32]byte{}, p, "preimage not match") + + controlTower.AssertExpectations(t) + payer.AssertExpectations(t) + sessionSource.AssertExpectations(t) + session.AssertExpectations(t) + missionControl.AssertExpectations(t) +} + +// TestSendMPPaymentFailedWithShardsInFlight tests that when the payment is in +// terminal state, even if we have shards in flight, we still fail the payment +// and exit. This test mainly focuses on testing the logic of the method +// resumePayment is implemented as expected. +func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) { + const startingBlockHeight = 101 + + // Create mockers to initialize the router. + controlTower := &mockControlTower{} + sessionSource := &mockPaymentSessionSource{} + missionControl := &mockMissionControl{} + payer := &mockPaymentAttemptDispatcher{} + chain := newMockChain(startingBlockHeight) + chainView := newMockChainView(chain) + testGraph := createDummyTestGraph(t) + + // Define the behavior of the mockers to the point where we can + // successfully start the router. + controlTower.On("FetchInFlightPayments").Return( + []*channeldb.MPPayment{}, nil, + ) + payer.On("CleanStore", mock.Anything).Return(nil) + + // Create and start the router. + router, err := New(Config{ + Control: controlTower, + SessionSource: sessionSource, + MissionControl: missionControl, + Payer: payer, + + // TODO(yy): create new mocks for the chain and chainview. + Chain: chain, + ChainView: chainView, + + // TODO(yy): mock the graph once it's changed into interface. + Graph: testGraph.graph, + + Clock: clock.NewTestClock(time.Unix(1, 0)), + GraphPruneInterval: time.Hour * 2, + NextPaymentID: func() (uint64, error) { + next := atomic.AddUint64(&uniquePaymentID, 1) + return next, nil + }, + }) + require.NoError(t, err, "failed to create router") + + // Make sure the router can start and stop without error. + require.NoError(t, router.Start(), "router failed to start") + defer func() { + require.NoError(t, router.Stop(), "router failed to stop") + }() + + // Once the router is started, check that the mocked methods are called + // as expected. + controlTower.AssertExpectations(t) + payer.AssertExpectations(t) + + // Mock the methods to the point where we are inside the function + // resumePayment. + paymentAmt := lnwire.MilliSatoshi(10000) + req := createDummyLightningPayment( + t, testGraph.aliasMap["c"], paymentAmt, + ) + identifier := lntypes.Hash(req.Identifier()) + session := &mockPaymentSession{} + sessionSource.On("NewPaymentSession", req).Return(session, nil) + controlTower.On("InitPayment", identifier, mock.Anything).Return(nil) + + // The following mocked methods are called inside resumePayment. Note + // that the payment object below will determine the state of the + // paymentLifecycle. + payment := &channeldb.MPPayment{} + controlTower.On("FetchPayment", identifier).Return(payment, nil) + + // Create a route that can send 1/4 of the total amount. This value + // will be returned by calling RequestRoute. + shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap) + require.NoError(t, err, "failed to create route") + session.On("RequestRoute", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ).Return(shard, nil) + + // Make a new htlc attempt with zero fee and append it to the payment's + // HTLCs when calling RegisterAttempt. + activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0) + controlTower.On("RegisterAttempt", + identifier, mock.Anything, + ).Return(nil).Run(func(args mock.Arguments) { + payment.HTLCs = append(payment.HTLCs, activeAttempt) + }) + + // Create a buffered chan and it will be returned by GetPaymentResult. + payer.resultChan = make(chan *htlcswitch.PaymentResult, 10) + + // We use the failAttemptCount to track how many attempts we want to + // fail. Each time the following mock method is called, the count gets + // updated. + failAttemptCount := 0 + payer.On("GetPaymentResult", + mock.Anything, identifier, mock.Anything, + ).Run(func(args mock.Arguments) { + // Before the mock method is returned, we send the result to + // the read-only chan. + + // Update the counter. + failAttemptCount++ + + // We fail the first attempt with terminal error. + if failAttemptCount == 1 { + payer.resultChan <- &htlcswitch.PaymentResult{ + Error: htlcswitch.NewForwardingError( + &lnwire.FailIncorrectDetails{}, + 1, + ), + } + return + + } + + // For the rest attempts we will NOT send anything to the + // resultChan, thus making all the shards in active state, + // neither settled or failed. + }) + + // Mock the FailAttempt method to fail EXACTLY once. + var failedAttempt channeldb.HTLCAttempt + controlTower.On("FailAttempt", + identifier, mock.Anything, mock.Anything, + ).Return(&failedAttempt, nil).Run(func(args mock.Arguments) { + // Whenever this method is invoked, we will mark the first + // active attempt as failed and exit. + failedAttempt = payment.HTLCs[0] + failedAttempt.Failure = &channeldb.HTLCFailInfo{} + payment.HTLCs[0] = failedAttempt + }).Once() + + // Setup ReportPaymentFail to return nil reason and error so the + // payment won't fail. + failureReason := channeldb.FailureReasonPaymentDetails + missionControl.On("ReportPaymentFail", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ).Return(failureReason, nil).Run(func(args mock.Arguments) { + missionControl.failReason = &failureReason + payment.FailureReason = &failureReason + }).Once() + + // Simple mocking the rest. + controlTower.On("Fail", identifier, failureReason).Return(nil).Once() + payer.On("SendHTLC", + mock.Anything, mock.Anything, mock.Anything, + ).Return(nil) + + // Call the actual method SendPayment on router. This is place inside a + // goroutine so we can set a timeout for the whole test, in case + // anything goes wrong and the test never finishes. + done := make(chan struct{}) + var p lntypes.Hash + go func() { + p, _, err = router.SendPayment(req) + close(done) + }() + + select { + case <-done: + case <-time.After(testTimeout): + t.Fatalf("SendPayment didn't exit") + } + + // Finally, validate the returned values and check that the mock + // methods are called as expected. + require.Error(t, err, "expected send payment error") + require.EqualValues(t, [32]byte{}, p, "preimage not match") + + controlTower.AssertExpectations(t) + payer.AssertExpectations(t) + sessionSource.AssertExpectations(t) + session.AssertExpectations(t) + missionControl.AssertExpectations(t) +}