diff --git a/channeldb/payments.go b/channeldb/payments.go index 73624902..5c2475bd 100644 --- a/channeldb/payments.go +++ b/channeldb/payments.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "math" "sort" "time" @@ -511,64 +510,72 @@ type PaymentsResponse struct { func (db *DB) QueryPayments(query PaymentsQuery) (PaymentsResponse, error) { var resp PaymentsResponse - allPayments, err := db.FetchPayments() - if err != nil { + if err := kvdb.View(db, func(tx kvdb.RTx) error { + // Get the root payments bucket. + paymentsBucket := tx.ReadBucket(paymentsRootBucket) + if paymentsBucket == nil { + return nil + } + + // Get the index bucket which maps sequence number -> payment + // hash and duplicate bool. If we have a payments bucket, we + // should have an indexes bucket as well. + indexes := tx.ReadBucket(paymentsIndexBucket) + if indexes == nil { + return fmt.Errorf("index bucket does not exist") + } + + // accumulatePayments gets payments with the sequence number + // and hash provided and adds them to our list of payments if + // they meet the criteria of our query. It returns the number + // of payments that were added. + accumulatePayments := func(sequenceKey, hash []byte) (bool, + error) { + + r := bytes.NewReader(hash) + paymentHash, err := deserializePaymentIndex(r) + if err != nil { + return false, err + } + + payment, err := fetchPaymentWithSequenceNumber( + tx, paymentHash, sequenceKey, + ) + if err != nil { + return false, err + } + + // To keep compatibility with the old API, we only + // return non-succeeded payments if requested. + if payment.Status != StatusSucceeded && + !query.IncludeIncomplete { + + return false, err + } + + // At this point, we've exhausted the offset, so we'll + // begin collecting invoices found within the range. + resp.Payments = append(resp.Payments, payment) + return true, nil + } + + // Create a paginator which reads from our sequence index bucket + // with the parameters provided by the payments query. + paginator := newPaginator( + indexes.ReadCursor(), query.Reversed, query.IndexOffset, + query.MaxPayments, + ) + + // Run a paginated query, adding payments to our response. + if err := paginator.query(accumulatePayments); err != nil { + return err + } + + return nil + }); err != nil { return resp, err } - if len(allPayments) == 0 { - return resp, nil - } - - indexExclusiveLimit := query.IndexOffset - // In backward pagination, if the index limit is the default 0 value, - // we set our limit to maxint to include all payments from the highest - // sequence number on. - if query.Reversed && indexExclusiveLimit == 0 { - indexExclusiveLimit = math.MaxInt64 - } - - for i := range allPayments { - var payment *MPPayment - - // If we have the max number of payments we want, exit. - if uint64(len(resp.Payments)) == query.MaxPayments { - break - } - - if query.Reversed { - payment = allPayments[len(allPayments)-1-i] - - // In the reversed direction, skip over all payments - // that have sequence numbers greater than or equal to - // the index offset. We skip payments with equal index - // because the offset is exclusive. - if payment.SequenceNum >= indexExclusiveLimit { - continue - } - } else { - payment = allPayments[i] - - // In the forward direction, skip over all payments that - // have sequence numbers less than or equal to the index - // offset. We skip payments with equal indexes because - // the index offset is exclusive. - if payment.SequenceNum <= indexExclusiveLimit { - continue - } - } - - // To keep compatibility with the old API, we only return - // non-succeeded payments if requested. - if payment.Status != StatusSucceeded && - !query.IncludeIncomplete { - - continue - } - - resp.Payments = append(resp.Payments, payment) - } - // Need to swap the payments slice order if reversed order. if query.Reversed { for l, r := 0, len(resp.Payments)-1; l < r; l, r = l+1, r-1 { @@ -585,7 +592,7 @@ func (db *DB) QueryPayments(query PaymentsQuery) (PaymentsResponse, error) { resp.Payments[len(resp.Payments)-1].SequenceNum } - return resp, err + return resp, nil } // fetchPaymentWithSequenceNumber get the payment which matches the payment hash diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index 532003bb..9e790c3e 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -165,18 +165,24 @@ func TestRouteSerialization(t *testing.T) { } // deletePayment removes a payment with paymentHash from the payments database. -func deletePayment(t *testing.T, db *DB, paymentHash lntypes.Hash) { +func deletePayment(t *testing.T, db *DB, paymentHash lntypes.Hash, seqNr uint64) { t.Helper() err := kvdb.Update(db, func(tx kvdb.RwTx) error { payments := tx.ReadWriteBucket(paymentsRootBucket) + // Delete the payment bucket. err := payments.DeleteNestedBucket(paymentHash[:]) if err != nil { return err } - return nil + key := make([]byte, 8) + byteOrder.PutUint64(key, seqNr) + + // Delete the index that references this payment. + indexes := tx.ReadWriteBucket(paymentsIndexBucket) + return indexes.Delete(key) }) if err != nil { @@ -350,6 +356,42 @@ func TestQueryPayments(t *testing.T) { lastIndex: 7, expectedSeqNrs: []uint64{3, 4, 5, 6, 7}, }, + { + name: "query payments reverse before index gap", + query: PaymentsQuery{ + IndexOffset: 3, + MaxPayments: 7, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 1, + expectedSeqNrs: []uint64{1}, + }, + { + name: "query payments reverse on index gap", + query: PaymentsQuery{ + IndexOffset: 2, + MaxPayments: 7, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 1, + expectedSeqNrs: []uint64{1}, + }, + { + name: "query payments forward on index gap", + query: PaymentsQuery{ + IndexOffset: 2, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 3, + lastIndex: 4, + expectedSeqNrs: []uint64{3, 4}, + }, } for _, tt := range tests { @@ -358,11 +400,16 @@ func TestQueryPayments(t *testing.T) { t.Parallel() db, cleanup, err := makeTestDB() - defer cleanup() - if err != nil { t.Fatalf("unable to init db: %v", err) } + defer cleanup() + + // Make a preliminary query to make sure it's ok to + // query when we have no payments. + resp, err := db.QueryPayments(tt.query) + require.NoError(t, err) + require.Len(t, resp.Payments, 0) // Populate the database with a set of test payments. // We create 6 original payments, deleting the payment @@ -391,7 +438,13 @@ func TestQueryPayments(t *testing.T) { // Immediately delete the payment with index 2. if i == 1 { - deletePayment(t, db, info.PaymentHash) + pmt, err := pControl.FetchPayment( + info.PaymentHash, + ) + require.NoError(t, err) + + deletePayment(t, db, info.PaymentHash, + pmt.SequenceNum) } // If we are on the last payment entry, add a