routing: refactor update payment state tests

This commit refactors the resumePayment to extract some logics back to
paymentState so that the code is more testable. It also adds unit tests
for paymentState, and breaks the original MPPayment tests into independent tests
so that it's easier to maintain and debug. All the new tests are built
using mock so that the control flow is eaiser to setup and change.
This commit is contained in:
yyforyongyu 2021-05-21 19:02:49 +08:00
parent e79e46ed21
commit cd35981569
No known key found for this signature in database
GPG Key ID: 9BCD95C4FF296868
4 changed files with 1262 additions and 358 deletions

@ -533,6 +533,8 @@ func (m *mockControlTowerOld) SubscribePayment(paymentHash lntypes.Hash) (
type mockPaymentAttemptDispatcher struct {
mock.Mock
resultChan chan *htlcswitch.PaymentResult
}
var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil)
@ -548,8 +550,11 @@ func (m *mockPaymentAttemptDispatcher) GetPaymentResult(attemptID uint64,
paymentHash lntypes.Hash, deobfuscator htlcswitch.ErrorDecrypter) (
<-chan *htlcswitch.PaymentResult, error) {
args := m.Called(attemptID, paymentHash, deobfuscator)
return args.Get(0).(<-chan *htlcswitch.PaymentResult), args.Error(1)
m.Called(attemptID, paymentHash, deobfuscator)
// Instead of returning the mocked returned values, we need to return
// the chan resultChan so it can be converted into a read-only chan.
return m.resultChan, nil
}
func (m *mockPaymentAttemptDispatcher) CleanStore(
@ -568,7 +573,7 @@ var _ PaymentSessionSource = (*mockPaymentSessionSource)(nil)
func (m *mockPaymentSessionSource) NewPaymentSession(
payment *LightningPayment) (PaymentSession, error) {
args := m.Called(m)
args := m.Called(payment)
return args.Get(0).(PaymentSession), args.Error(1)
}
@ -586,6 +591,8 @@ func (m *mockPaymentSessionSource) NewPaymentSessionEmpty() PaymentSession {
type mockMissionControl struct {
mock.Mock
failReason *channeldb.FailureReason
}
var _ MissionController = (*mockMissionControl)(nil)
@ -596,8 +603,7 @@ func (m *mockMissionControl) ReportPaymentFail(
*channeldb.FailureReason, error) {
args := m.Called(paymentID, rt, failureSourceIdx, failure)
return args.Get(0).(*channeldb.FailureReason), args.Error(1)
return m.failReason, args.Error(1)
}
func (m *mockMissionControl) ReportPaymentSuccess(paymentID uint64,
@ -642,6 +648,7 @@ func (m *mockPaymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey,
type mockControlTower struct {
mock.Mock
sync.Mutex
}
var _ ControlTower = (*mockControlTower)(nil)
@ -656,6 +663,9 @@ func (m *mockControlTower) InitPayment(phash lntypes.Hash,
func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash,
a *channeldb.HTLCAttemptInfo) error {
m.Lock()
defer m.Unlock()
args := m.Called(phash, a)
return args.Error(0)
}
@ -664,6 +674,9 @@ func (m *mockControlTower) SettleAttempt(phash lntypes.Hash,
pid uint64, settleInfo *channeldb.HTLCSettleInfo) (
*channeldb.HTLCAttempt, error) {
m.Lock()
defer m.Unlock()
args := m.Called(phash, pid, settleInfo)
return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1)
}
@ -671,6 +684,9 @@ func (m *mockControlTower) SettleAttempt(phash lntypes.Hash,
func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64,
failInfo *channeldb.HTLCFailInfo) (*channeldb.HTLCAttempt, error) {
m.Lock()
defer m.Unlock()
args := m.Called(phash, pid, failInfo)
return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1)
}
@ -678,6 +694,9 @@ func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64,
func (m *mockControlTower) Fail(phash lntypes.Hash,
reason channeldb.FailureReason) error {
m.Lock()
defer m.Unlock()
args := m.Called(phash, reason)
return args.Error(0)
}
@ -685,6 +704,8 @@ func (m *mockControlTower) Fail(phash lntypes.Hash,
func (m *mockControlTower) FetchPayment(phash lntypes.Hash) (
*channeldb.MPPayment, error) {
m.Lock()
defer m.Unlock()
args := m.Called(phash)
// Type assertion on nil will fail, so we check and return here.
@ -692,8 +713,15 @@ func (m *mockControlTower) FetchPayment(phash lntypes.Hash) (
return nil, args.Error(1)
}
return args.Get(0).(*channeldb.MPPayment), args.Error(1)
// Make a copy of the payment here to avoid data race.
p := args.Get(0).(*channeldb.MPPayment)
payment := &channeldb.MPPayment{
FailureReason: p.FailureReason,
}
payment.HTLCs = make([]channeldb.HTLCAttempt, len(p.HTLCs))
copy(payment.HTLCs, p.HTLCs)
return payment, args.Error(1)
}
func (m *mockControlTower) FetchInFlightPayments() (

@ -38,21 +38,53 @@ type paymentState struct {
numShardsInFlight int
remainingAmt lnwire.MilliSatoshi
remainingFees lnwire.MilliSatoshi
terminate bool
// terminate indicates the payment is in its final stage and no more
// shards should be launched. This value is true if we have an HTLC
// settled or the payment has an error.
terminate bool
}
// paymentState uses the passed payment to find the latest information we need
// to act on every iteration of the payment loop.
func (p *paymentLifecycle) paymentState(payment *channeldb.MPPayment) (
// terminated returns a bool to indicate there are no further actions needed
// and we should return what we have, either the payment preimage or the
// payment error.
func (ps paymentState) terminated() bool {
// If the payment is in final stage and we have no in flight shards to
// wait result for, we consider the whole action terminated.
return ps.terminate && ps.numShardsInFlight == 0
}
// needWaitForShards returns a bool to specify whether we need to wait for the
// outcome of the shanrdHandler.
func (ps paymentState) needWaitForShards() bool {
// If we have in flight shards and the payment is in final stage, we
// need to wait for the outcomes from the shards. Or if we have no more
// money to be sent, we need to wait for the already launched shards.
if ps.numShardsInFlight == 0 {
return false
}
return ps.terminate || ps.remainingAmt == 0
}
// updatePaymentState will fetch db for the payment to find the latest
// information we need to act on every iteration of the payment loop and update
// the paymentState.
func (p *paymentLifecycle) updatePaymentState() (*channeldb.MPPayment,
*paymentState, error) {
// Fetch the latest payment from db.
payment, err := p.router.cfg.Control.FetchPayment(p.identifier)
if err != nil {
return nil, nil, err
}
// Fetch the total amount and fees that has already been sent in
// settled and still in-flight shards.
sentAmt, fees := payment.SentAmt()
// Sanity check we haven't sent a value larger than the payment amount.
if sentAmt > p.totalAmount {
return nil, fmt.Errorf("amount sent %v exceeds "+
return nil, nil, fmt.Errorf("amount sent %v exceeds "+
"total amount %v", sentAmt, p.totalAmount)
}
@ -74,13 +106,15 @@ func (p *paymentLifecycle) paymentState(payment *channeldb.MPPayment) (
// have returned with a result.
terminate := settle != nil || failure != nil
activeShards := payment.InFlightHTLCs()
return &paymentState{
numShardsInFlight: len(activeShards),
// Update the payment state.
state := &paymentState{
numShardsInFlight: len(payment.InFlightHTLCs()),
remainingAmt: p.totalAmount - sentAmt,
remainingFees: feeBudget,
terminate: terminate,
}, nil
}
return payment, state, nil
}
// resumePayment resumes the paymentLifecycle from the current state.
@ -102,9 +136,7 @@ func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) {
// If we had any existing attempts outstanding, we'll start by spinning
// up goroutines that'll collect their results and deliver them to the
// lifecycle loop below.
payment, err := p.router.cfg.Control.FetchPayment(
p.identifier,
)
payment, _, err := p.updatePaymentState()
if err != nil {
return [32]byte{}, nil, err
}
@ -128,34 +160,30 @@ lifecycle:
return [32]byte{}, nil, err
}
// We start every iteration by fetching the lastest state of
// the payment from the ControlTower. This ensures that we will
// act on the latest available information, whether we are
// resuming an existing payment or just sent a new attempt.
payment, err := p.router.cfg.Control.FetchPayment(
p.identifier,
)
if err != nil {
return [32]byte{}, nil, err
}
// Using this latest state of the payment, calculate
// information about our active shards and terminal conditions.
state, err := p.paymentState(payment)
// We update the payment state on every iteration. Since the
// payment state is affected by multiple goroutines (ie,
// collectResultAsync), it is NOT guaranteed that we always
// have the latest state here. This is fine as long as the
// state is consistent as a whole.
payment, currentState, err := p.updatePaymentState()
if err != nil {
return [32]byte{}, nil, err
}
log.Debugf("Payment %v in state terminate=%v, "+
"active_shards=%v, rem_value=%v, fee_limit=%v",
p.identifier, state.terminate, state.numShardsInFlight,
state.remainingAmt, state.remainingFees)
p.identifier, currentState.terminate,
currentState.numShardsInFlight,
currentState.remainingAmt, currentState.remainingFees,
)
// TODO(yy): sanity check all the states to make sure
// everything is expected.
switch {
// We have a terminal condition and no active shards, we are
// ready to exit.
case state.terminate && state.numShardsInFlight == 0:
case currentState.terminated():
// Find the first successful shard and return
// the preimage and route.
for _, a := range payment.HTLCs {
@ -170,7 +198,7 @@ lifecycle:
// If we either reached a terminal error condition (but had
// active shards still) or there is no remaining value to send,
// we'll wait for a shard outcome.
case state.terminate || state.remainingAmt == 0:
case currentState.needWaitForShards():
// We still have outstanding shards, so wait for a new
// outcome to be available before re-evaluating our
// state.
@ -212,8 +240,9 @@ lifecycle:
// Create a new payment attempt from the given payment session.
rt, err := p.paySession.RequestRoute(
state.remainingAmt, state.remainingFees,
uint32(state.numShardsInFlight), uint32(p.currentHeight),
currentState.remainingAmt, currentState.remainingFees,
uint32(currentState.numShardsInFlight),
uint32(p.currentHeight),
)
if err != nil {
log.Warnf("Failed to find route for payment %v: %v",
@ -227,7 +256,7 @@ lifecycle:
// There is no route to try, and we have no active
// shards. This means that there is no way for us to
// send the payment, so mark it failed with no route.
if state.numShardsInFlight == 0 {
if currentState.numShardsInFlight == 0 {
failureCode := routeErr.FailureReason()
log.Debugf("Marking payment %v permanently "+
"failed with no route: %v",
@ -253,22 +282,11 @@ lifecycle:
// If this route will consume the last remeining amount to send
// to the receiver, this will be our last shard (for now).
lastShard := rt.ReceiverAmt() == state.remainingAmt
lastShard := rt.ReceiverAmt() == currentState.remainingAmt
// We found a route to try, launch a new shard.
attempt, outcome, err := shardHandler.launchShard(rt, lastShard)
switch {
// We may get a terminal error if we've processed a shard with
// a terminal state (settled or permanent failure), while we
// were pathfinding. We know we're in a terminal state here,
// so we can continue and wait for our last shards to return.
case err == channeldb.ErrPaymentTerminal:
log.Infof("Payment %v in terminal state, abandoning "+
"shard", p.identifier)
continue lifecycle
case err != nil:
if err != nil {
return [32]byte{}, nil, err
}
@ -297,6 +315,7 @@ lifecycle:
// Now that the shard was successfully sent, launch a go
// routine that will handle its result when its back.
shardHandler.collectResultAsync(attempt)
}
}
@ -437,12 +456,30 @@ type shardResult struct {
}
// collectResultAsync launches a goroutine that will wait for the result of the
// given HTLC attempt to be available then handle its result. Note that it will
// fail the payment with the control tower if a terminal error is encountered.
// given HTLC attempt to be available then handle its result. It will fail the
// payment with the control tower if a terminal error is encountered.
func (p *shardHandler) collectResultAsync(attempt *channeldb.HTLCAttemptInfo) {
// errToSend is the error to be sent to sh.shardErrors.
var errToSend error
// handleResultErr is a function closure must be called using defer. It
// finishes collecting result by updating the payment state and send
// the error (or nil) to sh.shardErrors.
handleResultErr := func() {
// Send the error or quit.
select {
case p.shardErrors <- errToSend:
case <-p.router.quit:
case <-p.quit:
}
p.wg.Done()
}
p.wg.Add(1)
go func() {
defer p.wg.Done()
defer handleResultErr()
// Block until the result is available.
result, err := p.collectResult(attempt)
@ -456,32 +493,18 @@ func (p *shardHandler) collectResultAsync(attempt *channeldb.HTLCAttemptInfo) {
attempt.AttemptID, p.identifier, err)
}
select {
case p.shardErrors <- err:
case <-p.router.quit:
case <-p.quit:
}
// Overwrite errToSend and return.
errToSend = err
return
}
// If a non-critical error was encountered handle it and mark
// the payment failed if the failure was terminal.
if result.err != nil {
err := p.handleSendError(attempt, result.err)
if err != nil {
select {
case p.shardErrors <- err:
case <-p.router.quit:
case <-p.quit:
}
return
}
}
select {
case p.shardErrors <- nil:
case <-p.router.quit:
case <-p.quit:
// Overwrite errToSend and return. Notice that the
// errToSend could be nil here.
errToSend = p.handleSendError(attempt, result.err)
return
}
}()
}

@ -195,14 +195,6 @@ func TestRouterPaymentStateMachine(t *testing.T) {
t.Fatalf("unable to create route: %v", err)
}
halfShard, err := createTestRoute(paymentAmt/2, testGraph.aliasMap)
require.NoError(t, err, "unable to create half route")
shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap)
if err != nil {
t.Fatalf("unable to create route: %v", err)
}
tests := []paymentLifecycleTestCase{
{
// Tests a normal payment flow that succeeds.
@ -425,280 +417,6 @@ func TestRouterPaymentStateMachine(t *testing.T) {
routes: []*route.Route{rt},
paymentErr: channeldb.FailureReasonNoRoute,
},
// =====================================
// || MPP scenarios ||
// =====================================
{
// Tests a simple successful MP payment of 4 shards.
name: "MP success",
steps: []string{
routerInitPayment,
// shard 0
routeRelease,
routerRegisterAttempt,
sendToSwitchSuccess,
// shard 1
routeRelease,
routerRegisterAttempt,
sendToSwitchSuccess,
// shard 2
routeRelease,
routerRegisterAttempt,
sendToSwitchSuccess,
// shard 3
routeRelease,
routerRegisterAttempt,
sendToSwitchSuccess,
// All shards succeed.
getPaymentResultSuccess,
getPaymentResultSuccess,
getPaymentResultSuccess,
getPaymentResultSuccess,
// Router should settle them all.
routerSettleAttempt,
routerSettleAttempt,
routerSettleAttempt,
routerSettleAttempt,
// And the final result is obviously
// successful.
paymentSuccess,
},
routes: []*route.Route{shard, shard, shard, shard},
},
{
// An MP payment scenario where we need several extra
// attempts before the payment finally settle.
name: "MP failed shards",
steps: []string{
routerInitPayment,
// shard 0
routeRelease,
routerRegisterAttempt,
sendToSwitchSuccess,
// shard 1
routeRelease,
routerRegisterAttempt,
sendToSwitchSuccess,
// shard 2
routeRelease,
routerRegisterAttempt,
sendToSwitchSuccess,
// shard 3
routeRelease,
routerRegisterAttempt,
sendToSwitchSuccess,
// First two shards fail, two new ones are sent.
getPaymentResultTempFailure,
getPaymentResultTempFailure,
routerFailAttempt,
routerFailAttempt,
routeRelease,
routerRegisterAttempt,
sendToSwitchSuccess,
routeRelease,
routerRegisterAttempt,
sendToSwitchSuccess,
// The four shards settle.
getPaymentResultSuccess,
getPaymentResultSuccess,
getPaymentResultSuccess,
getPaymentResultSuccess,
routerSettleAttempt,
routerSettleAttempt,
routerSettleAttempt,
routerSettleAttempt,
// Overall payment succeeds.
paymentSuccess,
},
routes: []*route.Route{
shard, shard, shard, shard, shard, shard,
},
},
{
// An MP payment scenario where one of the shards fails,
// but we still receive a single success shard.
name: "MP one shard success",
steps: []string{
routerInitPayment,
// shard 0
routeRelease,
routerRegisterAttempt,
sendToSwitchSuccess,
// shard 1
routeRelease,
routerRegisterAttempt,
sendToSwitchSuccess,
// shard 0 fails, and should be failed by the
// router.
getPaymentResultTempFailure,
routerFailAttempt,
// We will try one more shard because we haven't
// sent the full payment amount.
routeRelease,
// The second shard succeed against all odds,
// making the overall payment succeed.
getPaymentResultSuccess,
routerSettleAttempt,
paymentSuccess,
},
routes: []*route.Route{halfShard, halfShard},
},
{
// An MP payment scenario a shard fail with a terminal
// error, causing the router to stop attempting.
name: "MP terminal",
steps: []string{
routerInitPayment,
// shard 0
routeRelease,
routerRegisterAttempt,
sendToSwitchSuccess,
// shard 1
routeRelease,
routerRegisterAttempt,
sendToSwitchSuccess,
// shard 2
routeRelease,
routerRegisterAttempt,
sendToSwitchSuccess,
// shard 3
routeRelease,
routerRegisterAttempt,
sendToSwitchSuccess,
// The first shard fail with a terminal error.
getPaymentResultTerminalFailure,
routerFailAttempt,
routerFailPayment,
// Remaining 3 shards fail.
getPaymentResultTempFailure,
getPaymentResultTempFailure,
getPaymentResultTempFailure,
routerFailAttempt,
routerFailAttempt,
routerFailAttempt,
// Payment fails.
paymentError,
},
routes: []*route.Route{
shard, shard, shard, shard, shard, shard,
},
paymentErr: channeldb.FailureReasonPaymentDetails,
},
{
// A MP payment scenario when our path finding returns
// after we've just received a terminal failure, and
// attempts to dispatch a new shard. Testing that we
// correctly abandon the shard and conclude the payment.
name: "MP path found after failure",
steps: []string{
routerInitPayment,
// shard 0
routeRelease,
routerRegisterAttempt,
sendToSwitchSuccess,
// The first shard fail with a terminal error.
getPaymentResultTerminalFailure,
routerFailAttempt,
routerFailPayment,
// shard 1 fails because we've had a terminal
// failure.
routeRelease,
routerRegisterAttempt,
// Payment fails.
paymentError,
},
routes: []*route.Route{
shard, shard,
},
paymentErr: channeldb.FailureReasonPaymentDetails,
},
{
// A MP payment scenario when our path finding returns
// after we've just received a terminal failure, and
// we have another shard still in flight.
name: "MP shard in flight after terminal",
steps: []string{
routerInitPayment,
// shard 0
routeRelease,
routerRegisterAttempt,
sendToSwitchSuccess,
// shard 1
routeRelease,
routerRegisterAttempt,
sendToSwitchSuccess,
// shard 2
routeRelease,
routerRegisterAttempt,
sendToSwitchSuccess,
// We find a path for another shard.
routeRelease,
// shard 0 fails with a terminal error.
getPaymentResultTerminalFailure,
routerFailAttempt,
routerFailPayment,
// We try to register our final shard after
// processing a terminal failure.
routerRegisterAttempt,
// Our in-flight shards fail.
getPaymentResultTempFailure,
getPaymentResultTempFailure,
routerFailAttempt,
routerFailAttempt,
// Payment fails.
paymentError,
},
routes: []*route.Route{
shard, shard, shard, shard,
},
paymentErr: channeldb.FailureReasonPaymentDetails,
},
}
for _, test := range tests {
@ -1080,3 +798,343 @@ func testPaymentLifecycle(t *testing.T, test paymentLifecycleTestCase,
t.Fatalf("SendPayment didn't exit")
}
}
// TestPaymentState tests that the logics implemented on paymentState struct
// are as expected. In particular, that the method terminated and
// needWaitForShards return the right values.
func TestPaymentState(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
// Use the following three params, each is equivalent to a bool
// statement, to construct 8 test cases so that we can
// exhaustively catch all possible states.
numShardsInFlight int
remainingAmt lnwire.MilliSatoshi
terminate bool
expectedTerminated bool
expectedNeedWaitForShards bool
}{
{
// If we have active shards and terminate is marked
// false, the state is not terminated. Since the
// remaining amount is zero, we need to wait for shards
// to be finished and launch no more shards.
name: "state 100",
numShardsInFlight: 1,
remainingAmt: lnwire.MilliSatoshi(0),
terminate: false,
expectedTerminated: false,
expectedNeedWaitForShards: true,
},
{
// If we have active shards while terminate is marked
// true, the state is not terminated, and we need to
// wait for shards to be finished and launch no more
// shards.
name: "state 101",
numShardsInFlight: 1,
remainingAmt: lnwire.MilliSatoshi(0),
terminate: true,
expectedTerminated: false,
expectedNeedWaitForShards: true,
},
{
// If we have active shards and terminate is marked
// false, the state is not terminated. Since the
// remaining amount is not zero, we don't need to wait
// for shards outcomes and should launch more shards.
name: "state 110",
numShardsInFlight: 1,
remainingAmt: lnwire.MilliSatoshi(1),
terminate: false,
expectedTerminated: false,
expectedNeedWaitForShards: false,
},
{
// If we have active shards and terminate is marked
// true, the state is not terminated. Even the
// remaining amount is not zero, we need to wait for
// shards outcomes because state is terminated.
name: "state 111",
numShardsInFlight: 1,
remainingAmt: lnwire.MilliSatoshi(1),
terminate: true,
expectedTerminated: false,
expectedNeedWaitForShards: true,
},
{
// If we have no active shards while terminate is marked
// false, the state is not terminated, and we don't
// need to wait for more shard outcomes because there
// are no active shards.
name: "state 000",
numShardsInFlight: 0,
remainingAmt: lnwire.MilliSatoshi(0),
terminate: false,
expectedTerminated: false,
expectedNeedWaitForShards: false,
},
{
// If we have no active shards while terminate is marked
// true, the state is terminated, and we don't need to
// wait for shards to be finished.
name: "state 001",
numShardsInFlight: 0,
remainingAmt: lnwire.MilliSatoshi(0),
terminate: true,
expectedTerminated: true,
expectedNeedWaitForShards: false,
},
{
// If we have no active shards while terminate is marked
// false, the state is not terminated. Since the
// remaining amount is not zero, we don't need to wait
// for shards outcomes and should launch more shards.
name: "state 010",
numShardsInFlight: 0,
remainingAmt: lnwire.MilliSatoshi(1),
terminate: false,
expectedTerminated: false,
expectedNeedWaitForShards: false,
},
{
// If we have no active shards while terminate is marked
// true, the state is terminated, and we don't need to
// wait for shards outcomes.
name: "state 011",
numShardsInFlight: 0,
remainingAmt: lnwire.MilliSatoshi(1),
terminate: true,
expectedTerminated: true,
expectedNeedWaitForShards: false,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ps := &paymentState{
numShardsInFlight: tc.numShardsInFlight,
remainingAmt: tc.remainingAmt,
terminate: tc.terminate,
}
require.Equal(
t, tc.expectedTerminated, ps.terminated(),
"terminated returned wrong value",
)
require.Equal(
t, tc.expectedNeedWaitForShards,
ps.needWaitForShards(),
"needWaitForShards returned wrong value",
)
})
}
}
// TestUpdatePaymentState checks that the method updatePaymentState updates the
// paymentState as expected.
func TestUpdatePaymentState(t *testing.T) {
t.Parallel()
// paymentHash is the identifier on paymentLifecycle.
paymentHash := lntypes.Hash{}
// TODO(yy): make MPPayment into an interface so we can mock it. The
// current design implicitly tests the methods SendAmt, TerminalInfo,
// and InFlightHTLCs on channeldb.MPPayment, which is not good. Once
// MPPayment becomes an interface, we can then mock these methods here.
// SentAmt returns 90, 10
// TerminalInfo returns non-nil, nil
// InFlightHTLCs returns 0
var preimage lntypes.Preimage
paymentSettled := &channeldb.MPPayment{
HTLCs: []channeldb.HTLCAttempt{
makeSettledAttempt(100, 10, preimage),
},
}
// SentAmt returns 0, 0
// TerminalInfo returns nil, non-nil
// InFlightHTLCs returns 0
reason := channeldb.FailureReasonError
paymentFailed := &channeldb.MPPayment{
FailureReason: &reason,
}
// SentAmt returns 90, 10
// TerminalInfo returns nil, nil
// InFlightHTLCs returns 1
paymentActive := &channeldb.MPPayment{
HTLCs: []channeldb.HTLCAttempt{
makeActiveAttempt(100, 10),
makeFailedAttempt(100, 10),
},
}
testCases := []struct {
name string
payment *channeldb.MPPayment
totalAmt int
feeLimit int
expectedState *paymentState
shouldReturnError bool
}{
{
// Test that the error returned from FetchPayment is
// handled properly. We use a nil payment to indicate
// we want to return an error.
name: "fetch payment error",
payment: nil,
shouldReturnError: true,
},
{
// Test that when the sentAmt exceeds totalAmount, the
// error is returned.
name: "amount exceeded error",
payment: paymentSettled,
totalAmt: 1,
shouldReturnError: true,
},
{
// Test that when the fee budget is reached, the
// remaining fee should be zero.
name: "fee budget reached",
payment: paymentActive,
totalAmt: 1000,
feeLimit: 1,
expectedState: &paymentState{
numShardsInFlight: 1,
remainingAmt: 1000 - 90,
remainingFees: 0,
terminate: false,
},
},
{
// Test when the payment is settled, the state should
// be marked as terminated.
name: "payment settled",
payment: paymentSettled,
totalAmt: 1000,
feeLimit: 100,
expectedState: &paymentState{
numShardsInFlight: 0,
remainingAmt: 1000 - 90,
remainingFees: 100 - 10,
terminate: true,
},
},
{
// Test when the payment is failed, the state should be
// marked as terminated.
name: "payment failed",
payment: paymentFailed,
totalAmt: 1000,
feeLimit: 100,
expectedState: &paymentState{
numShardsInFlight: 0,
remainingAmt: 1000,
remainingFees: 100,
terminate: true,
},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
// Create mock control tower and assign it to router.
// We will then use the router and the paymentHash
// above to create our paymentLifecycle for this test.
ct := &mockControlTower{}
rt := &ChannelRouter{cfg: &Config{Control: ct}}
pl := &paymentLifecycle{
router: rt,
identifier: paymentHash,
totalAmount: lnwire.MilliSatoshi(tc.totalAmt),
feeLimit: lnwire.MilliSatoshi(tc.feeLimit),
}
if tc.payment == nil {
// A nil payment indicates we want to test an
// error returned from FetchPayment.
dummyErr := errors.New("dummy")
ct.On("FetchPayment", paymentHash).Return(
nil, dummyErr,
)
} else {
// Otherwise we will return the payment.
ct.On("FetchPayment", paymentHash).Return(
tc.payment, nil,
)
}
// Call the method that updates the payment state.
_, state, err := pl.updatePaymentState()
// Assert that the mock method is called as
// intended.
ct.AssertExpectations(t)
if tc.shouldReturnError {
require.Error(t, err, "expect an error")
return
}
require.NoError(t, err, "unexpected error")
require.Equal(
t, tc.expectedState, state,
"state not updated as expected",
)
})
}
}
func makeActiveAttempt(total, fee int) channeldb.HTLCAttempt {
return channeldb.HTLCAttempt{
HTLCAttemptInfo: makeAttemptInfo(total, total-fee),
}
}
func makeSettledAttempt(total, fee int,
preimage lntypes.Preimage) channeldb.HTLCAttempt {
return channeldb.HTLCAttempt{
HTLCAttemptInfo: makeAttemptInfo(total, total-fee),
Settle: &channeldb.HTLCSettleInfo{Preimage: preimage},
}
}
func makeFailedAttempt(total, fee int) channeldb.HTLCAttempt {
return channeldb.HTLCAttempt{
HTLCAttemptInfo: makeAttemptInfo(total, total-fee),
Failure: &channeldb.HTLCFailInfo{
Reason: channeldb.HTLCFailInternal,
},
}
}
func makeAttemptInfo(total, amtForwarded int) channeldb.HTLCAttemptInfo {
hop := &route.Hop{AmtToForward: lnwire.MilliSatoshi(amtForwarded)}
return channeldb.HTLCAttemptInfo{
Route: route.Route{
TotalAmount: lnwire.MilliSatoshi(total),
Hops: []*route.Hop{hop},
},
}
}

@ -15,6 +15,7 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil"
"github.com/davecgh/go-spew/spew"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/lightningnetwork/lnd/channeldb"
@ -1069,7 +1070,8 @@ func TestSendPaymentErrorPathPruning(t *testing.T) {
_, ok = msg.(*lnwire.FailUnknownNextPeer)
require.True(t, ok, "unexpected fail message")
ctx.router.cfg.MissionControl.(*MissionControl).ResetHistory()
err = ctx.router.cfg.MissionControl.(*MissionControl).ResetHistory()
require.NoError(t, err, "reset history failed")
// Next, we'll modify the SendToSwitch method to indicate that the
// connection between songoku and isn't up.
@ -3436,3 +3438,796 @@ func TestChannelOnChainRejectionZombie(t *testing.T) {
require.Nil(t, err)
assertChanChainRejection(t, ctx, edge, ErrInvalidFundingOutput)
}
func createDummyTestGraph(t *testing.T) *testGraphInstance {
// Setup two simple channels such that we can mock sending along this
// route.
chanCapSat := btcutil.Amount(100000)
testChannels := []*testChannel{
symmetricTestChannel("a", "b", chanCapSat, &testChannelPolicy{
Expiry: 144,
FeeRate: 400,
MinHTLC: 1,
MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat),
}, 1),
symmetricTestChannel("b", "c", chanCapSat, &testChannelPolicy{
Expiry: 144,
FeeRate: 400,
MinHTLC: 1,
MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat),
}, 2),
}
testGraph, err := createTestGraphFromChannels(testChannels, "a")
require.NoError(t, err, "failed to create graph")
return testGraph
}
func createDummyLightningPayment(t *testing.T,
target route.Vertex, amt lnwire.MilliSatoshi) *LightningPayment {
var preImage lntypes.Preimage
_, err := rand.Read(preImage[:])
require.NoError(t, err, "unable to generate preimage")
payHash := preImage.Hash()
return &LightningPayment{
Target: target,
Amount: amt,
FeeLimit: noFeeLimit,
paymentHash: &payHash,
}
}
// TestSendMPPaymentSucceed tests that we can successfully send a MPPayment via
// router.SendPayment. This test mainly focuses on testing the logic of the
// method resumePayment is implemented as expected.
func TestSendMPPaymentSucceed(t *testing.T) {
const startingBlockHeight = 101
// Create mockers to initialize the router.
controlTower := &mockControlTower{}
sessionSource := &mockPaymentSessionSource{}
missionControl := &mockMissionControl{}
payer := &mockPaymentAttemptDispatcher{}
chain := newMockChain(startingBlockHeight)
chainView := newMockChainView(chain)
testGraph := createDummyTestGraph(t)
// Define the behavior of the mockers to the point where we can
// successfully start the router.
controlTower.On("FetchInFlightPayments").Return(
[]*channeldb.MPPayment{}, nil,
)
payer.On("CleanStore", mock.Anything).Return(nil)
// Create and start the router.
router, err := New(Config{
Control: controlTower,
SessionSource: sessionSource,
MissionControl: missionControl,
Payer: payer,
// TODO(yy): create new mocks for the chain and chainview.
Chain: chain,
ChainView: chainView,
// TODO(yy): mock the graph once it's changed into interface.
Graph: testGraph.graph,
Clock: clock.NewTestClock(time.Unix(1, 0)),
GraphPruneInterval: time.Hour * 2,
NextPaymentID: func() (uint64, error) {
next := atomic.AddUint64(&uniquePaymentID, 1)
return next, nil
},
})
require.NoError(t, err, "failed to create router")
// Make sure the router can start and stop without error.
require.NoError(t, router.Start(), "router failed to start")
defer func() {
require.NoError(t, router.Stop(), "router failed to stop")
}()
// Once the router is started, check that the mocked methods are called
// as expected.
controlTower.AssertExpectations(t)
payer.AssertExpectations(t)
// Mock the methods to the point where we are inside the function
// resumePayment.
paymentAmt := lnwire.MilliSatoshi(10000)
req := createDummyLightningPayment(
t, testGraph.aliasMap["c"], paymentAmt,
)
identifier := lntypes.Hash(req.Identifier())
session := &mockPaymentSession{}
sessionSource.On("NewPaymentSession", req).Return(session, nil)
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
// The following mocked methods are called inside resumePayment. Note
// that the payment object below will determine the state of the
// paymentLifecycle.
payment := &channeldb.MPPayment{}
controlTower.On("FetchPayment", identifier).Return(payment, nil)
// Create a route that can send 1/4 of the total amount. This value
// will be returned by calling RequestRoute.
shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap)
require.NoError(t, err, "failed to create route")
session.On("RequestRoute",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(shard, nil)
// Make a new htlc attempt with zero fee and append it to the payment's
// HTLCs when calling RegisterAttempt.
activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0)
controlTower.On("RegisterAttempt",
identifier, mock.Anything,
).Return(nil).Run(func(args mock.Arguments) {
payment.HTLCs = append(payment.HTLCs, activeAttempt)
})
// Create a buffered chan and it will be returned by GetPaymentResult.
payer.resultChan = make(chan *htlcswitch.PaymentResult, 10)
payer.On("GetPaymentResult",
mock.Anything, identifier, mock.Anything,
).Run(func(args mock.Arguments) {
// Before the mock method is returned, we send the result to
// the read-only chan.
payer.resultChan <- &htlcswitch.PaymentResult{}
})
// Simple mocking the rest.
payer.On("SendHTLC",
mock.Anything, mock.Anything, mock.Anything,
).Return(nil)
missionControl.On("ReportPaymentSuccess",
mock.Anything, mock.Anything,
).Return(nil)
// Mock SettleAttempt by changing one of the HTLCs to be settled.
preimage := lntypes.Preimage{1, 2, 3}
settledAttempt := makeSettledAttempt(
int(paymentAmt/4), 0, preimage,
)
controlTower.On("SettleAttempt",
identifier, mock.Anything, mock.Anything,
).Return(&settledAttempt, nil).Run(func(args mock.Arguments) {
// Whenever this method is invoked, we will mark the first
// active attempt settled and exit.
for i, attempt := range payment.HTLCs {
if attempt.Settle == nil {
attempt.Settle = &channeldb.HTLCSettleInfo{
Preimage: preimage,
}
payment.HTLCs[i] = attempt
return
}
}
})
// Call the actual method SendPayment on router. This is place inside a
// goroutine so we can set a timeout for the whole test, in case
// anything goes wrong and the test never finishes.
done := make(chan struct{})
var p lntypes.Hash
go func() {
p, _, err = router.SendPayment(req)
close(done)
}()
select {
case <-done:
case <-time.After(testTimeout):
t.Fatalf("SendPayment didn't exit")
}
// Finally, validate the returned values and check that the mock
// methods are called as expected.
require.NoError(t, err, "send payment failed")
require.EqualValues(t, preimage, p, "preimage not match")
// Note that we also implicitly check the methods such as FailAttempt,
// ReportPaymentFail, etc, are not called because we never mocked them
// in this test. If any of the unexpected methods was called, the test
// would fail.
controlTower.AssertExpectations(t)
payer.AssertExpectations(t)
sessionSource.AssertExpectations(t)
session.AssertExpectations(t)
missionControl.AssertExpectations(t)
}
// TestSendMPPaymentSucceedOnExtraShards tests that we need extra attempts if
// there are failed ones,so that a payment is successfully sent. This test
// mainly focuses on testing the logic of the method resumePayment is
// implemented as expected.
func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
const startingBlockHeight = 101
// Create mockers to initialize the router.
controlTower := &mockControlTower{}
sessionSource := &mockPaymentSessionSource{}
missionControl := &mockMissionControl{}
payer := &mockPaymentAttemptDispatcher{}
chain := newMockChain(startingBlockHeight)
chainView := newMockChainView(chain)
testGraph := createDummyTestGraph(t)
// Define the behavior of the mockers to the point where we can
// successfully start the router.
controlTower.On("FetchInFlightPayments").Return(
[]*channeldb.MPPayment{}, nil,
)
payer.On("CleanStore", mock.Anything).Return(nil)
// Create and start the router.
router, err := New(Config{
Control: controlTower,
SessionSource: sessionSource,
MissionControl: missionControl,
Payer: payer,
// TODO(yy): create new mocks for the chain and chainview.
Chain: chain,
ChainView: chainView,
// TODO(yy): mock the graph once it's changed into interface.
Graph: testGraph.graph,
Clock: clock.NewTestClock(time.Unix(1, 0)),
GraphPruneInterval: time.Hour * 2,
NextPaymentID: func() (uint64, error) {
next := atomic.AddUint64(&uniquePaymentID, 1)
return next, nil
},
})
require.NoError(t, err, "failed to create router")
// Make sure the router can start and stop without error.
require.NoError(t, router.Start(), "router failed to start")
defer func() {
require.NoError(t, router.Stop(), "router failed to stop")
}()
// Once the router is started, check that the mocked methods are called
// as expected.
controlTower.AssertExpectations(t)
payer.AssertExpectations(t)
// Mock the methods to the point where we are inside the function
// resumePayment.
paymentAmt := lnwire.MilliSatoshi(20000)
req := createDummyLightningPayment(
t, testGraph.aliasMap["c"], paymentAmt,
)
identifier := lntypes.Hash(req.Identifier())
session := &mockPaymentSession{}
sessionSource.On("NewPaymentSession", req).Return(session, nil)
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
// The following mocked methods are called inside resumePayment. Note
// that the payment object below will determine the state of the
// paymentLifecycle.
payment := &channeldb.MPPayment{}
controlTower.On("FetchPayment", identifier).Return(payment, nil)
// Create a route that can send 1/4 of the total amount. This value
// will be returned by calling RequestRoute.
shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap)
require.NoError(t, err, "failed to create route")
session.On("RequestRoute",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(shard, nil)
// Make a new htlc attempt with zero fee and append it to the payment's
// HTLCs when calling RegisterAttempt.
activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0)
controlTower.On("RegisterAttempt",
identifier, mock.Anything,
).Return(nil).Run(func(args mock.Arguments) {
payment.HTLCs = append(payment.HTLCs, activeAttempt)
})
// Create a buffered chan and it will be returned by GetPaymentResult.
payer.resultChan = make(chan *htlcswitch.PaymentResult, 10)
// We use the failAttemptCount to track how many attempts we want to
// fail. Each time the following mock method is called, the count gets
// updated.
failAttemptCount := 0
payer.On("GetPaymentResult",
mock.Anything, identifier, mock.Anything,
).Run(func(args mock.Arguments) {
// Before the mock method is returned, we send the result to
// the read-only chan.
// Update the counter.
failAttemptCount++
// We will make the first two attempts failed with temporary
// error.
if failAttemptCount <= 2 {
payer.resultChan <- &htlcswitch.PaymentResult{
Error: htlcswitch.NewForwardingError(
&lnwire.FailTemporaryChannelFailure{},
1,
),
}
return
}
// Otherwise we will mark the attempt succeeded.
payer.resultChan <- &htlcswitch.PaymentResult{}
})
// Mock the FailAttempt method to fail one of the attempts.
var failedAttempt channeldb.HTLCAttempt
controlTower.On("FailAttempt",
identifier, mock.Anything, mock.Anything,
).Return(&failedAttempt, nil).Run(func(args mock.Arguments) {
// Whenever this method is invoked, we will mark the first
// active attempt as failed and exit.
for i, attempt := range payment.HTLCs {
if attempt.Settle != nil || attempt.Failure != nil {
continue
}
attempt.Failure = &channeldb.HTLCFailInfo{}
failedAttempt = attempt
payment.HTLCs[i] = attempt
return
}
})
// Setup ReportPaymentFail to return nil reason and error so the
// payment won't fail.
missionControl.On("ReportPaymentFail",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(nil, nil)
// Simple mocking the rest.
payer.On("SendHTLC",
mock.Anything, mock.Anything, mock.Anything,
).Return(nil)
missionControl.On("ReportPaymentSuccess",
mock.Anything, mock.Anything,
).Return(nil)
// Mock SettleAttempt by changing one of the HTLCs to be settled.
preimage := lntypes.Preimage{1, 2, 3}
settledAttempt := makeSettledAttempt(
int(paymentAmt/4), 0, preimage,
)
controlTower.On("SettleAttempt",
identifier, mock.Anything, mock.Anything,
).Return(&settledAttempt, nil).Run(func(args mock.Arguments) {
// Whenever this method is invoked, we will mark the first
// active attempt settled and exit.
for i, attempt := range payment.HTLCs {
if attempt.Settle != nil || attempt.Failure != nil {
continue
}
attempt.Settle = &channeldb.HTLCSettleInfo{
Preimage: preimage,
}
payment.HTLCs[i] = attempt
return
}
})
// Call the actual method SendPayment on router. This is place inside a
// goroutine so we can set a timeout for the whole test, in case
// anything goes wrong and the test never finishes.
done := make(chan struct{})
var p lntypes.Hash
go func() {
p, _, err = router.SendPayment(req)
close(done)
}()
select {
case <-done:
case <-time.After(testTimeout):
t.Fatalf("SendPayment didn't exit")
}
// Finally, validate the returned values and check that the mock
// methods are called as expected.
require.NoError(t, err, "send payment failed")
require.EqualValues(t, preimage, p, "preimage not match")
controlTower.AssertExpectations(t)
payer.AssertExpectations(t)
sessionSource.AssertExpectations(t)
session.AssertExpectations(t)
missionControl.AssertExpectations(t)
}
// TestSendMPPaymentFailed tests that when one of the shard fails with a
// terminal error, the router will stop attempting and the payment will fail.
// This test mainly focuses on testing the logic of the method resumePayment
// is implemented as expected.
func TestSendMPPaymentFailed(t *testing.T) {
const startingBlockHeight = 101
// Create mockers to initialize the router.
controlTower := &mockControlTower{}
sessionSource := &mockPaymentSessionSource{}
missionControl := &mockMissionControl{}
payer := &mockPaymentAttemptDispatcher{}
chain := newMockChain(startingBlockHeight)
chainView := newMockChainView(chain)
testGraph := createDummyTestGraph(t)
// Define the behavior of the mockers to the point where we can
// successfully start the router.
controlTower.On("FetchInFlightPayments").Return(
[]*channeldb.MPPayment{}, nil,
)
payer.On("CleanStore", mock.Anything).Return(nil)
// Create and start the router.
router, err := New(Config{
Control: controlTower,
SessionSource: sessionSource,
MissionControl: missionControl,
Payer: payer,
// TODO(yy): create new mocks for the chain and chainview.
Chain: chain,
ChainView: chainView,
// TODO(yy): mock the graph once it's changed into interface.
Graph: testGraph.graph,
Clock: clock.NewTestClock(time.Unix(1, 0)),
GraphPruneInterval: time.Hour * 2,
NextPaymentID: func() (uint64, error) {
next := atomic.AddUint64(&uniquePaymentID, 1)
return next, nil
},
})
require.NoError(t, err, "failed to create router")
// Make sure the router can start and stop without error.
require.NoError(t, router.Start(), "router failed to start")
defer func() {
require.NoError(t, router.Stop(), "router failed to stop")
}()
// Once the router is started, check that the mocked methods are called
// as expected.
controlTower.AssertExpectations(t)
payer.AssertExpectations(t)
// Mock the methods to the point where we are inside the function
// resumePayment.
paymentAmt := lnwire.MilliSatoshi(10000)
req := createDummyLightningPayment(
t, testGraph.aliasMap["c"], paymentAmt,
)
identifier := lntypes.Hash(req.Identifier())
session := &mockPaymentSession{}
sessionSource.On("NewPaymentSession", req).Return(session, nil)
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
// The following mocked methods are called inside resumePayment. Note
// that the payment object below will determine the state of the
// paymentLifecycle.
payment := &channeldb.MPPayment{}
controlTower.On("FetchPayment", identifier).Return(payment, nil)
// Create a route that can send 1/4 of the total amount. This value
// will be returned by calling RequestRoute.
shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap)
require.NoError(t, err, "failed to create route")
session.On("RequestRoute",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(shard, nil)
// Make a new htlc attempt with zero fee and append it to the payment's
// HTLCs when calling RegisterAttempt.
activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0)
controlTower.On("RegisterAttempt",
identifier, mock.Anything,
).Return(nil).Run(func(args mock.Arguments) {
payment.HTLCs = append(payment.HTLCs, activeAttempt)
})
// Create a buffered chan and it will be returned by GetPaymentResult.
payer.resultChan = make(chan *htlcswitch.PaymentResult, 10)
// We use the failAttemptCount to track how many attempts we want to
// fail. Each time the following mock method is called, the count gets
// updated.
failAttemptCount := 0
payer.On("GetPaymentResult",
mock.Anything, identifier, mock.Anything,
).Run(func(args mock.Arguments) {
// Before the mock method is returned, we send the result to
// the read-only chan.
// Update the counter.
failAttemptCount++
// We fail the first attempt with terminal error.
if failAttemptCount == 1 {
payer.resultChan <- &htlcswitch.PaymentResult{
Error: htlcswitch.NewForwardingError(
&lnwire.FailIncorrectDetails{},
1,
),
}
return
}
// We will make the rest attempts failed with temporary error.
payer.resultChan <- &htlcswitch.PaymentResult{
Error: htlcswitch.NewForwardingError(
&lnwire.FailTemporaryChannelFailure{},
1,
),
}
})
// Mock the FailAttempt method to fail one of the attempts.
var failedAttempt channeldb.HTLCAttempt
controlTower.On("FailAttempt",
identifier, mock.Anything, mock.Anything,
).Return(&failedAttempt, nil).Run(func(args mock.Arguments) {
// Whenever this method is invoked, we will mark the first
// active attempt as failed and exit.
for i, attempt := range payment.HTLCs {
if attempt.Settle != nil || attempt.Failure != nil {
continue
}
attempt.Failure = &channeldb.HTLCFailInfo{}
failedAttempt = attempt
payment.HTLCs[i] = attempt
return
}
})
// Setup ReportPaymentFail to return nil reason and error so the
// payment won't fail.
var called bool
failureReason := channeldb.FailureReasonPaymentDetails
missionControl.On("ReportPaymentFail",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(nil, nil).Run(func(args mock.Arguments) {
// We only return the terminal error once, thus when the method
// is called, we will return it with a nil error.
if called {
missionControl.failReason = nil
return
}
// If it's the first time calling this method, we will return a
// terminal error.
missionControl.failReason = &failureReason
payment.FailureReason = &failureReason
called = true
})
// Simple mocking the rest.
controlTower.On("Fail", identifier, failureReason).Return(nil)
payer.On("SendHTLC",
mock.Anything, mock.Anything, mock.Anything,
).Return(nil)
// Call the actual method SendPayment on router. This is place inside a
// goroutine so we can set a timeout for the whole test, in case
// anything goes wrong and the test never finishes.
done := make(chan struct{})
var p lntypes.Hash
go func() {
p, _, err = router.SendPayment(req)
close(done)
}()
select {
case <-done:
case <-time.After(testTimeout):
t.Fatalf("SendPayment didn't exit")
}
// Finally, validate the returned values and check that the mock
// methods are called as expected.
require.Error(t, err, "expected send payment error")
require.EqualValues(t, [32]byte{}, p, "preimage not match")
controlTower.AssertExpectations(t)
payer.AssertExpectations(t)
sessionSource.AssertExpectations(t)
session.AssertExpectations(t)
missionControl.AssertExpectations(t)
}
// TestSendMPPaymentFailedWithShardsInFlight tests that when the payment is in
// terminal state, even if we have shards in flight, we still fail the payment
// and exit. This test mainly focuses on testing the logic of the method
// resumePayment is implemented as expected.
func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
const startingBlockHeight = 101
// Create mockers to initialize the router.
controlTower := &mockControlTower{}
sessionSource := &mockPaymentSessionSource{}
missionControl := &mockMissionControl{}
payer := &mockPaymentAttemptDispatcher{}
chain := newMockChain(startingBlockHeight)
chainView := newMockChainView(chain)
testGraph := createDummyTestGraph(t)
// Define the behavior of the mockers to the point where we can
// successfully start the router.
controlTower.On("FetchInFlightPayments").Return(
[]*channeldb.MPPayment{}, nil,
)
payer.On("CleanStore", mock.Anything).Return(nil)
// Create and start the router.
router, err := New(Config{
Control: controlTower,
SessionSource: sessionSource,
MissionControl: missionControl,
Payer: payer,
// TODO(yy): create new mocks for the chain and chainview.
Chain: chain,
ChainView: chainView,
// TODO(yy): mock the graph once it's changed into interface.
Graph: testGraph.graph,
Clock: clock.NewTestClock(time.Unix(1, 0)),
GraphPruneInterval: time.Hour * 2,
NextPaymentID: func() (uint64, error) {
next := atomic.AddUint64(&uniquePaymentID, 1)
return next, nil
},
})
require.NoError(t, err, "failed to create router")
// Make sure the router can start and stop without error.
require.NoError(t, router.Start(), "router failed to start")
defer func() {
require.NoError(t, router.Stop(), "router failed to stop")
}()
// Once the router is started, check that the mocked methods are called
// as expected.
controlTower.AssertExpectations(t)
payer.AssertExpectations(t)
// Mock the methods to the point where we are inside the function
// resumePayment.
paymentAmt := lnwire.MilliSatoshi(10000)
req := createDummyLightningPayment(
t, testGraph.aliasMap["c"], paymentAmt,
)
identifier := lntypes.Hash(req.Identifier())
session := &mockPaymentSession{}
sessionSource.On("NewPaymentSession", req).Return(session, nil)
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
// The following mocked methods are called inside resumePayment. Note
// that the payment object below will determine the state of the
// paymentLifecycle.
payment := &channeldb.MPPayment{}
controlTower.On("FetchPayment", identifier).Return(payment, nil)
// Create a route that can send 1/4 of the total amount. This value
// will be returned by calling RequestRoute.
shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap)
require.NoError(t, err, "failed to create route")
session.On("RequestRoute",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(shard, nil)
// Make a new htlc attempt with zero fee and append it to the payment's
// HTLCs when calling RegisterAttempt.
activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0)
controlTower.On("RegisterAttempt",
identifier, mock.Anything,
).Return(nil).Run(func(args mock.Arguments) {
payment.HTLCs = append(payment.HTLCs, activeAttempt)
})
// Create a buffered chan and it will be returned by GetPaymentResult.
payer.resultChan = make(chan *htlcswitch.PaymentResult, 10)
// We use the failAttemptCount to track how many attempts we want to
// fail. Each time the following mock method is called, the count gets
// updated.
failAttemptCount := 0
payer.On("GetPaymentResult",
mock.Anything, identifier, mock.Anything,
).Run(func(args mock.Arguments) {
// Before the mock method is returned, we send the result to
// the read-only chan.
// Update the counter.
failAttemptCount++
// We fail the first attempt with terminal error.
if failAttemptCount == 1 {
payer.resultChan <- &htlcswitch.PaymentResult{
Error: htlcswitch.NewForwardingError(
&lnwire.FailIncorrectDetails{},
1,
),
}
return
}
// For the rest attempts we will NOT send anything to the
// resultChan, thus making all the shards in active state,
// neither settled or failed.
})
// Mock the FailAttempt method to fail EXACTLY once.
var failedAttempt channeldb.HTLCAttempt
controlTower.On("FailAttempt",
identifier, mock.Anything, mock.Anything,
).Return(&failedAttempt, nil).Run(func(args mock.Arguments) {
// Whenever this method is invoked, we will mark the first
// active attempt as failed and exit.
failedAttempt = payment.HTLCs[0]
failedAttempt.Failure = &channeldb.HTLCFailInfo{}
payment.HTLCs[0] = failedAttempt
}).Once()
// Setup ReportPaymentFail to return nil reason and error so the
// payment won't fail.
failureReason := channeldb.FailureReasonPaymentDetails
missionControl.On("ReportPaymentFail",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(failureReason, nil).Run(func(args mock.Arguments) {
missionControl.failReason = &failureReason
payment.FailureReason = &failureReason
}).Once()
// Simple mocking the rest.
controlTower.On("Fail", identifier, failureReason).Return(nil).Once()
payer.On("SendHTLC",
mock.Anything, mock.Anything, mock.Anything,
).Return(nil)
// Call the actual method SendPayment on router. This is place inside a
// goroutine so we can set a timeout for the whole test, in case
// anything goes wrong and the test never finishes.
done := make(chan struct{})
var p lntypes.Hash
go func() {
p, _, err = router.SendPayment(req)
close(done)
}()
select {
case <-done:
case <-time.After(testTimeout):
t.Fatalf("SendPayment didn't exit")
}
// Finally, validate the returned values and check that the mock
// methods are called as expected.
require.Error(t, err, "expected send payment error")
require.EqualValues(t, [32]byte{}, p, "preimage not match")
controlTower.AssertExpectations(t)
payer.AssertExpectations(t)
sessionSource.AssertExpectations(t)
session.AssertExpectations(t)
missionControl.AssertExpectations(t)
}