channeldb: return full payment for inflight payments

We might as well return all info, and we'll need the individual HTLCs
in later commits.
This commit is contained in:
Johan T. Halseth 2021-03-30 12:10:30 +02:00
parent 8af00ab0cf
commit 7795353e9f
No known key found for this signature in database
GPG Key ID: 15BAADA29DA20D26
4 changed files with 21 additions and 32 deletions

View File

@ -676,16 +676,9 @@ func ensureInFlight(payment *MPPayment) error {
} }
} }
// InFlightPayment is a wrapper around the info for a payment that has status
// InFlight.
type InFlightPayment struct {
// Info is the PaymentCreationInfo of the in-flight payment.
Info *PaymentCreationInfo
}
// FetchInFlightPayments returns all payments with status InFlight. // FetchInFlightPayments returns all payments with status InFlight.
func (p *PaymentControl) FetchInFlightPayments() ([]*InFlightPayment, error) { func (p *PaymentControl) FetchInFlightPayments() ([]*MPPayment, error) {
var inFlights []*InFlightPayment var inFlights []*MPPayment
err := kvdb.View(p.db, func(tx kvdb.RTx) error { err := kvdb.View(p.db, func(tx kvdb.RTx) error {
payments := tx.ReadBucket(paymentsRootBucket) payments := tx.ReadBucket(paymentsRootBucket)
if payments == nil { if payments == nil {
@ -708,15 +701,12 @@ func (p *PaymentControl) FetchInFlightPayments() ([]*InFlightPayment, error) {
return nil return nil
} }
inFlight := &InFlightPayment{} p, err := fetchPayment(bucket)
// Get the CreationInfo.
inFlight.Info, err = fetchCreationInfo(bucket)
if err != nil { if err != nil {
return err return err
} }
inFlights = append(inFlights, inFlight) inFlights = append(inFlights, p)
return nil return nil
}) })
}, func() { }, func() {

View File

@ -50,7 +50,7 @@ type ControlTower interface {
Fail(lntypes.Hash, channeldb.FailureReason) error Fail(lntypes.Hash, channeldb.FailureReason) error
// FetchInFlightPayments returns all payments with status InFlight. // FetchInFlightPayments returns all payments with status InFlight.
FetchInFlightPayments() ([]*channeldb.InFlightPayment, error) FetchInFlightPayments() ([]*channeldb.MPPayment, error)
// SubscribePayment subscribes to updates for the payment with the given // SubscribePayment subscribes to updates for the payment with the given
// hash. A first update with the current state of the payment is always // hash. A first update with the current state of the payment is always
@ -213,7 +213,7 @@ func (p *controlTower) Fail(paymentHash lntypes.Hash,
} }
// FetchInFlightPayments returns all payments with status InFlight. // FetchInFlightPayments returns all payments with status InFlight.
func (p *controlTower) FetchInFlightPayments() ([]*channeldb.InFlightPayment, error) { func (p *controlTower) FetchInFlightPayments() ([]*channeldb.MPPayment, error) {
return p.db.FetchInFlightPayments() return p.db.FetchInFlightPayments()
} }

View File

@ -452,6 +452,12 @@ func (m *mockControlTower) FetchPayment(phash lntypes.Hash) (
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
return m.fetchPayment(phash)
}
func (m *mockControlTower) fetchPayment(phash lntypes.Hash) (
*channeldb.MPPayment, error) {
p, ok := m.payments[phash] p, ok := m.payments[phash]
if !ok { if !ok {
return nil, channeldb.ErrPaymentNotInitiated return nil, channeldb.ErrPaymentNotInitiated
@ -468,12 +474,11 @@ func (m *mockControlTower) FetchPayment(phash lntypes.Hash) (
// Return a copy of the current attempts. // Return a copy of the current attempts.
mp.HTLCs = append(mp.HTLCs, p.attempts...) mp.HTLCs = append(mp.HTLCs, p.attempts...)
return mp, nil return mp, nil
} }
func (m *mockControlTower) FetchInFlightPayments() ( func (m *mockControlTower) FetchInFlightPayments() (
[]*channeldb.InFlightPayment, error) { []*channeldb.MPPayment, error) {
if m.fetchInFlight != nil { if m.fetchInFlight != nil {
m.fetchInFlight <- struct{}{} m.fetchInFlight <- struct{}{}
@ -483,8 +488,8 @@ func (m *mockControlTower) FetchInFlightPayments() (
defer m.Unlock() defer m.Unlock()
// In flight are all payments not successful or failed. // In flight are all payments not successful or failed.
var fl []*channeldb.InFlightPayment var fl []*channeldb.MPPayment
for hash, p := range m.payments { for hash := range m.payments {
if _, ok := m.successful[hash]; ok { if _, ok := m.successful[hash]; ok {
continue continue
} }
@ -492,11 +497,12 @@ func (m *mockControlTower) FetchInFlightPayments() (
continue continue
} }
ifl := channeldb.InFlightPayment{ mp, err := m.fetchPayment(hash)
Info: &p.info, if err != nil {
return nil, err
} }
fl = append(fl, &ifl) fl = append(fl, mp)
} }
return fl, nil return fl, nil

View File

@ -583,14 +583,7 @@ func (r *ChannelRouter) Start() error {
// until the cleaning has finished. // until the cleaning has finished.
toKeep := make(map[uint64]struct{}) toKeep := make(map[uint64]struct{})
for _, p := range payments { for _, p := range payments {
payment, err := r.cfg.Control.FetchPayment( for _, a := range p.HTLCs {
p.Info.PaymentHash,
)
if err != nil {
return err
}
for _, a := range payment.HTLCs {
toKeep[a.AttemptID] = struct{}{} toKeep[a.AttemptID] = struct{}{}
} }
} }
@ -603,7 +596,7 @@ func (r *ChannelRouter) Start() error {
for _, payment := range payments { for _, payment := range payments {
log.Infof("Resuming payment with hash %v", payment.Info.PaymentHash) log.Infof("Resuming payment with hash %v", payment.Info.PaymentHash)
r.wg.Add(1) r.wg.Add(1)
go func(payment *channeldb.InFlightPayment) { go func(payment *channeldb.MPPayment) {
defer r.wg.Done() defer r.wg.Done()
// We create a dummy, empty payment session such that // We create a dummy, empty payment session such that