diff --git a/routing/control_tower.go b/routing/control_tower.go index 5ade611c..da7d4d2b 100644 --- a/routing/control_tower.go +++ b/routing/control_tower.go @@ -34,6 +34,10 @@ type ControlTower interface { // FailAttempt marks the given payment attempt failed. FailAttempt(lntypes.Hash, uint64, *channeldb.HTLCFailInfo) error + // FetchPayment fetches the payment corresponding to the given payment + // hash. + FetchPayment(paymentHash lntypes.Hash) (*channeldb.MPPayment, error) + // Fail transitions a payment into the Failed state, and records the // ultimate reason the payment failed. Note that this should only be // called when all active active attempts are already failed. After @@ -132,6 +136,13 @@ func (p *controlTower) FailAttempt(paymentHash lntypes.Hash, return p.db.FailAttempt(paymentHash, attemptID, failInfo) } +// FetchPayment fetches the payment corresponding to the given payment hash. +func (p *controlTower) FetchPayment(paymentHash lntypes.Hash) ( + *channeldb.MPPayment, error) { + + return p.db.FetchPayment(paymentHash) +} + // createSuccessResult creates a success result to send to subscribers. func createSuccessResult(htlcs []channeldb.HTLCAttempt) *PaymentResult { // Extract any preimage from the list of HTLCs. diff --git a/routing/mock_test.go b/routing/mock_test.go index c6a780bb..c66c70b1 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -188,9 +188,15 @@ type failArgs struct { reason channeldb.FailureReason } +type testPayment struct { + info channeldb.PaymentCreationInfo + attempts []channeldb.HTLCAttempt +} + type mockControlTower struct { - inflights map[lntypes.Hash]channeldb.InFlightPayment + payments map[lntypes.Hash]*testPayment successful map[lntypes.Hash]struct{} + failed map[lntypes.Hash]channeldb.FailureReason init chan initArgs register chan registerArgs @@ -205,8 +211,9 @@ var _ ControlTower = (*mockControlTower)(nil) func makeMockControlTower() *mockControlTower { return &mockControlTower{ - inflights: make(map[lntypes.Hash]channeldb.InFlightPayment), + payments: make(map[lntypes.Hash]*testPayment), successful: make(map[lntypes.Hash]struct{}), + failed: make(map[lntypes.Hash]channeldb.FailureReason), } } @@ -220,18 +227,22 @@ func (m *mockControlTower) InitPayment(phash lntypes.Hash, m.init <- initArgs{c} } + // Don't allow re-init a successful payment. if _, ok := m.successful[phash]; ok { - return fmt.Errorf("already successful") + return channeldb.ErrAlreadyPaid } - _, ok := m.inflights[phash] - if ok { - return fmt.Errorf("in flight") + _, failed := m.failed[phash] + _, ok := m.payments[phash] + + // If the payment is known, only allow re-init if failed. + if ok && !failed { + return channeldb.ErrPaymentInFlight } - m.inflights[phash] = channeldb.InFlightPayment{ - Info: c, - Attempts: make([]channeldb.HTLCAttemptInfo, 0), + delete(m.failed, phash) + m.payments[phash] = &testPayment{ + info: *c, } return nil @@ -247,13 +258,24 @@ func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash, m.register <- registerArgs{a} } - p, ok := m.inflights[phash] - if !ok { - return fmt.Errorf("not in flight") + // Cannot register attempts for successful or failed payments. + if _, ok := m.successful[phash]; ok { + return channeldb.ErrPaymentAlreadySucceeded } - p.Attempts = append(p.Attempts, *a) - m.inflights[phash] = p + if _, ok := m.failed[phash]; ok { + return channeldb.ErrPaymentAlreadyFailed + } + + p, ok := m.payments[phash] + if !ok { + return channeldb.ErrPaymentNotInitiated + } + + p.attempts = append(p.attempts, channeldb.HTLCAttempt{ + HTLCAttemptInfo: *a, + }) + m.payments[phash] = p return nil } @@ -268,9 +290,69 @@ func (m *mockControlTower) SettleAttempt(phash lntypes.Hash, m.success <- successArgs{settleInfo.Preimage} } - delete(m.inflights, phash) - m.successful[phash] = struct{}{} - return nil + // Only allow setting attempts for payments not yet succeeded or + // failed. + if _, ok := m.successful[phash]; ok { + return channeldb.ErrPaymentAlreadySucceeded + } + + if _, ok := m.failed[phash]; ok { + return channeldb.ErrPaymentAlreadyFailed + } + + p, ok := m.payments[phash] + if !ok { + return channeldb.ErrPaymentNotInitiated + } + + // Find the attempt with this pid, and set the settle info. + for i, a := range p.attempts { + if a.AttemptID != pid { + continue + } + + p.attempts[i].Settle = settleInfo + + // Mark the payment successful on first settled attempt. + m.successful[phash] = struct{}{} + return nil + } + + return fmt.Errorf("pid not found") +} + +func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64, + failInfo *channeldb.HTLCFailInfo) error { + + m.Lock() + defer m.Unlock() + + // Only allow failing attempts for payments not yet succeeded or + // failed. + if _, ok := m.successful[phash]; ok { + return channeldb.ErrPaymentAlreadySucceeded + } + + if _, ok := m.failed[phash]; ok { + return channeldb.ErrPaymentAlreadyFailed + } + + p, ok := m.payments[phash] + if !ok { + return channeldb.ErrPaymentNotInitiated + } + + // Find the attempt with this pid, and set the failure info. + for i, a := range p.attempts { + if a.AttemptID != pid { + continue + } + + p.attempts[i].Failure = failInfo + return nil + } + + return fmt.Errorf("pid not found") } func (m *mockControlTower) Fail(phash lntypes.Hash, @@ -283,10 +365,50 @@ func (m *mockControlTower) Fail(phash lntypes.Hash, m.fail <- failArgs{reason} } - delete(m.inflights, phash) + // Cannot fail already successful or failed payments. + if _, ok := m.successful[phash]; ok { + return channeldb.ErrPaymentAlreadySucceeded + } + + if _, ok := m.failed[phash]; ok { + return channeldb.ErrPaymentAlreadyFailed + } + + if _, ok := m.payments[phash]; !ok { + return channeldb.ErrPaymentNotInitiated + } + + m.failed[phash] = reason + return nil } +func (m *mockControlTower) FetchPayment(phash lntypes.Hash) ( + *channeldb.MPPayment, error) { + + m.Lock() + defer m.Unlock() + + p, ok := m.payments[phash] + if !ok { + return nil, channeldb.ErrPaymentNotInitiated + } + + mp := &channeldb.MPPayment{ + Info: &p.info, + } + + reason, ok := m.failed[phash] + if ok { + mp.FailureReason = &reason + } + + // Return a copy of the current attempts. + mp.HTLCs = append(mp.HTLCs, p.attempts...) + + return mp, nil +} + func (m *mockControlTower) FetchInFlightPayments() ( []*channeldb.InFlightPayment, error) { @@ -297,8 +419,25 @@ func (m *mockControlTower) FetchInFlightPayments() ( m.fetchInFlight <- struct{}{} } + // In flight are all payments not successful or failed. var fl []*channeldb.InFlightPayment - for _, ifl := range m.inflights { + for hash, p := range m.payments { + if _, ok := m.successful[hash]; ok { + continue + } + if _, ok := m.failed[hash]; ok { + continue + } + + var attempts []channeldb.HTLCAttemptInfo + for _, a := range p.attempts { + attempts = append(attempts, a.HTLCAttemptInfo) + } + ifl := channeldb.InFlightPayment{ + Info: &p.info, + Attempts: attempts, + } + fl = append(fl, &ifl) } @@ -310,9 +449,3 @@ func (m *mockControlTower) SubscribePayment(paymentHash lntypes.Hash) ( return false, nil, errors.New("not implemented") } - -func (m *mockControlTower) FailAttempt(hash lntypes.Hash, pid uint64, - failInfo *channeldb.HTLCFailInfo) error { - - return nil -}