diff --git a/channeldb/payment_control.go b/channeldb/payment_control.go index c820778c..a5023567 100644 --- a/channeldb/payment_control.go +++ b/channeldb/payment_control.go @@ -676,16 +676,9 @@ func ensureInFlight(payment *MPPayment) error { } } -// InFlightPayment is a wrapper around the info for a payment that has status -// InFlight. -type InFlightPayment struct { - // Info is the PaymentCreationInfo of the in-flight payment. - Info *PaymentCreationInfo -} - // FetchInFlightPayments returns all payments with status InFlight. -func (p *PaymentControl) FetchInFlightPayments() ([]*InFlightPayment, error) { - var inFlights []*InFlightPayment +func (p *PaymentControl) FetchInFlightPayments() ([]*MPPayment, error) { + var inFlights []*MPPayment err := kvdb.View(p.db, func(tx kvdb.RTx) error { payments := tx.ReadBucket(paymentsRootBucket) if payments == nil { @@ -708,15 +701,12 @@ func (p *PaymentControl) FetchInFlightPayments() ([]*InFlightPayment, error) { return nil } - inFlight := &InFlightPayment{} - - // Get the CreationInfo. - inFlight.Info, err = fetchCreationInfo(bucket) + p, err := fetchPayment(bucket) if err != nil { return err } - inFlights = append(inFlights, inFlight) + inFlights = append(inFlights, p) return nil }) }, func() { diff --git a/routing/control_tower.go b/routing/control_tower.go index 3e028c18..950b16f3 100644 --- a/routing/control_tower.go +++ b/routing/control_tower.go @@ -50,7 +50,7 @@ type ControlTower interface { Fail(lntypes.Hash, channeldb.FailureReason) error // FetchInFlightPayments returns all payments with status InFlight. - FetchInFlightPayments() ([]*channeldb.InFlightPayment, error) + FetchInFlightPayments() ([]*channeldb.MPPayment, error) // SubscribePayment subscribes to updates for the payment with the given // hash. A first update with the current state of the payment is always @@ -213,7 +213,7 @@ func (p *controlTower) Fail(paymentHash lntypes.Hash, } // FetchInFlightPayments returns all payments with status InFlight. -func (p *controlTower) FetchInFlightPayments() ([]*channeldb.InFlightPayment, error) { +func (p *controlTower) FetchInFlightPayments() ([]*channeldb.MPPayment, error) { return p.db.FetchInFlightPayments() } diff --git a/routing/mock_test.go b/routing/mock_test.go index a477fb10..8186902f 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -452,6 +452,12 @@ func (m *mockControlTower) FetchPayment(phash lntypes.Hash) ( m.Lock() defer m.Unlock() + return m.fetchPayment(phash) +} + +func (m *mockControlTower) fetchPayment(phash lntypes.Hash) ( + *channeldb.MPPayment, error) { + p, ok := m.payments[phash] if !ok { return nil, channeldb.ErrPaymentNotInitiated @@ -468,12 +474,11 @@ func (m *mockControlTower) FetchPayment(phash lntypes.Hash) ( // Return a copy of the current attempts. mp.HTLCs = append(mp.HTLCs, p.attempts...) - return mp, nil } func (m *mockControlTower) FetchInFlightPayments() ( - []*channeldb.InFlightPayment, error) { + []*channeldb.MPPayment, error) { if m.fetchInFlight != nil { m.fetchInFlight <- struct{}{} @@ -483,8 +488,8 @@ func (m *mockControlTower) FetchInFlightPayments() ( defer m.Unlock() // In flight are all payments not successful or failed. - var fl []*channeldb.InFlightPayment - for hash, p := range m.payments { + var fl []*channeldb.MPPayment + for hash := range m.payments { if _, ok := m.successful[hash]; ok { continue } @@ -492,11 +497,12 @@ func (m *mockControlTower) FetchInFlightPayments() ( continue } - ifl := channeldb.InFlightPayment{ - Info: &p.info, + mp, err := m.fetchPayment(hash) + if err != nil { + return nil, err } - fl = append(fl, &ifl) + fl = append(fl, mp) } return fl, nil diff --git a/routing/router.go b/routing/router.go index a1f50215..1d251e44 100644 --- a/routing/router.go +++ b/routing/router.go @@ -583,14 +583,7 @@ func (r *ChannelRouter) Start() error { // until the cleaning has finished. toKeep := make(map[uint64]struct{}) for _, p := range payments { - payment, err := r.cfg.Control.FetchPayment( - p.Info.PaymentHash, - ) - if err != nil { - return err - } - - for _, a := range payment.HTLCs { + for _, a := range p.HTLCs { toKeep[a.AttemptID] = struct{}{} } } @@ -603,7 +596,7 @@ func (r *ChannelRouter) Start() error { for _, payment := range payments { log.Infof("Resuming payment with hash %v", payment.Info.PaymentHash) r.wg.Add(1) - go func(payment *channeldb.InFlightPayment) { + go func(payment *channeldb.MPPayment) { defer r.wg.Done() // We create a dummy, empty payment session such that