Merge pull request #4261 from carlaKC/4164-indexpayments
channeldb: Index payments by sequence number
This commit is contained in:
commit
d47d17b5d4
@ -192,6 +192,11 @@ func WriteElement(w io.Writer, element interface{}) error {
|
||||
return err
|
||||
}
|
||||
|
||||
case paymentIndexType:
|
||||
if err := binary.Write(w, byteOrder, e); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case lnwire.FundingFlag:
|
||||
if err := binary.Write(w, byteOrder, e); err != nil {
|
||||
return err
|
||||
@ -406,6 +411,11 @@ func ReadElement(r io.Reader, element interface{}) error {
|
||||
return err
|
||||
}
|
||||
|
||||
case *paymentIndexType:
|
||||
if err := binary.Read(r, byteOrder, e); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case *lnwire.FundingFlag:
|
||||
if err := binary.Read(r, byteOrder, e); err != nil {
|
||||
return err
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
mig "github.com/lightningnetwork/lnd/channeldb/migration"
|
||||
"github.com/lightningnetwork/lnd/channeldb/migration12"
|
||||
"github.com/lightningnetwork/lnd/channeldb/migration13"
|
||||
"github.com/lightningnetwork/lnd/channeldb/migration16"
|
||||
"github.com/lightningnetwork/lnd/channeldb/migration_01_to_11"
|
||||
"github.com/lightningnetwork/lnd/clock"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
@ -144,6 +145,19 @@ var (
|
||||
number: 14,
|
||||
migration: mig.CreateTLB(payAddrIndexBucket),
|
||||
},
|
||||
{
|
||||
// Initialize payment index bucket which will be used
|
||||
// to index payments by sequence number. This index will
|
||||
// be used to allow more efficient ListPayments queries.
|
||||
number: 15,
|
||||
migration: mig.CreateTLB(paymentsIndexBucket),
|
||||
},
|
||||
{
|
||||
// Add our existing payments to the index bucket created
|
||||
// in migration 15.
|
||||
number: 16,
|
||||
migration: migration16.MigrateSequenceIndex,
|
||||
},
|
||||
}
|
||||
|
||||
// Big endian is the preferred byte order, due to cursor scans over
|
||||
@ -257,6 +271,7 @@ var topLevelBuckets = [][]byte{
|
||||
fwdPackagesKey,
|
||||
invoiceBucket,
|
||||
payAddrIndexBucket,
|
||||
paymentsIndexBucket,
|
||||
nodeInfoBucket,
|
||||
nodeBucket,
|
||||
edgeBucket,
|
||||
|
@ -1007,6 +1007,18 @@ func TestQueryInvoices(t *testing.T) {
|
||||
// still pending.
|
||||
expected: pendingInvoices[len(pendingInvoices)-15:],
|
||||
},
|
||||
// Fetch all invoices paginating backwards, with an index offset
|
||||
// that is beyond our last offset. We expect all invoices to be
|
||||
// returned.
|
||||
{
|
||||
query: InvoiceQuery{
|
||||
IndexOffset: numInvoices * 2,
|
||||
PendingOnly: false,
|
||||
Reversed: true,
|
||||
NumMaxInvoices: numInvoices,
|
||||
},
|
||||
expected: invoices,
|
||||
},
|
||||
}
|
||||
|
||||
for i, testCase := range testCases {
|
||||
|
@ -839,85 +839,47 @@ func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) {
|
||||
if invoices == nil {
|
||||
return ErrNoInvoicesCreated
|
||||
}
|
||||
|
||||
// Get the add index bucket which we will use to iterate through
|
||||
// our indexed invoices.
|
||||
invoiceAddIndex := invoices.NestedReadBucket(addIndexBucket)
|
||||
if invoiceAddIndex == nil {
|
||||
return ErrNoInvoicesCreated
|
||||
}
|
||||
|
||||
// keyForIndex is a helper closure that retrieves the invoice
|
||||
// key for the given add index of an invoice.
|
||||
keyForIndex := func(c kvdb.RCursor, index uint64) []byte {
|
||||
var keyIndex [8]byte
|
||||
byteOrder.PutUint64(keyIndex[:], index)
|
||||
_, invoiceKey := c.Seek(keyIndex[:])
|
||||
return invoiceKey
|
||||
}
|
||||
// Create a paginator which reads from our add index bucket with
|
||||
// the parameters provided by the invoice query.
|
||||
paginator := newPaginator(
|
||||
invoiceAddIndex.ReadCursor(), q.Reversed, q.IndexOffset,
|
||||
q.NumMaxInvoices,
|
||||
)
|
||||
|
||||
// nextKey is a helper closure to determine what the next
|
||||
// invoice key is when iterating over the invoice add index.
|
||||
nextKey := func(c kvdb.RCursor) ([]byte, []byte) {
|
||||
if q.Reversed {
|
||||
return c.Prev()
|
||||
}
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// We'll be using a cursor to seek into the database and return
|
||||
// a slice of invoices. We'll need to determine where to start
|
||||
// our cursor depending on the parameters set within the query.
|
||||
c := invoiceAddIndex.ReadCursor()
|
||||
invoiceKey := keyForIndex(c, q.IndexOffset+1)
|
||||
|
||||
// If the query is specifying reverse iteration, then we must
|
||||
// handle a few offset cases.
|
||||
if q.Reversed {
|
||||
switch q.IndexOffset {
|
||||
|
||||
// This indicates the default case, where no offset was
|
||||
// specified. In that case we just start from the last
|
||||
// invoice.
|
||||
case 0:
|
||||
_, invoiceKey = c.Last()
|
||||
|
||||
// This indicates the offset being set to the very
|
||||
// first invoice. Since there are no invoices before
|
||||
// this offset, and the direction is reversed, we can
|
||||
// return without adding any invoices to the response.
|
||||
case 1:
|
||||
return nil
|
||||
|
||||
// Otherwise we start iteration at the invoice prior to
|
||||
// the offset.
|
||||
default:
|
||||
invoiceKey = keyForIndex(c, q.IndexOffset-1)
|
||||
}
|
||||
}
|
||||
|
||||
// If we know that a set of invoices exists, then we'll begin
|
||||
// our seek through the bucket in order to satisfy the query.
|
||||
// We'll continue until either we reach the end of the range, or
|
||||
// reach our max number of invoices.
|
||||
for ; invoiceKey != nil; _, invoiceKey = nextKey(c) {
|
||||
// If our current return payload exceeds the max number
|
||||
// of invoices, then we'll exit now.
|
||||
if uint64(len(resp.Invoices)) >= q.NumMaxInvoices {
|
||||
break
|
||||
}
|
||||
|
||||
invoice, err := fetchInvoice(invoiceKey, invoices)
|
||||
// accumulateInvoices looks up an invoice based on the index we
|
||||
// are given, adds it to our set of invoices if it has the right
|
||||
// characteristics for our query and returns the number of items
|
||||
// we have added to our set of invoices.
|
||||
accumulateInvoices := func(_, indexValue []byte) (bool, error) {
|
||||
invoice, err := fetchInvoice(indexValue, invoices)
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Skip any settled or canceled invoices if the caller is
|
||||
// only interested in pending ones.
|
||||
// Skip any settled or canceled invoices if the caller
|
||||
// is only interested in pending ones.
|
||||
if q.PendingOnly && !invoice.IsPending() {
|
||||
continue
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// At this point, we've exhausted the offset, so we'll
|
||||
// begin collecting invoices found within the range.
|
||||
resp.Invoices = append(resp.Invoices, invoice)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Query our paginator using accumulateInvoices to build up a
|
||||
// set of invoices.
|
||||
if err := paginator.query(accumulateInvoices); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If we iterated through the add index in reverse order, then
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
mig "github.com/lightningnetwork/lnd/channeldb/migration"
|
||||
"github.com/lightningnetwork/lnd/channeldb/migration12"
|
||||
"github.com/lightningnetwork/lnd/channeldb/migration13"
|
||||
"github.com/lightningnetwork/lnd/channeldb/migration16"
|
||||
"github.com/lightningnetwork/lnd/channeldb/migration_01_to_11"
|
||||
)
|
||||
|
||||
@ -33,4 +34,5 @@ func UseLogger(logger btclog.Logger) {
|
||||
migration_01_to_11.UseLogger(logger)
|
||||
migration12.UseLogger(logger)
|
||||
migration13.UseLogger(logger)
|
||||
migration16.UseLogger(logger)
|
||||
}
|
||||
|
14
channeldb/migration16/log.go
Normal file
14
channeldb/migration16/log.go
Normal file
@ -0,0 +1,14 @@
|
||||
package migration16
|
||||
|
||||
import (
|
||||
"github.com/btcsuite/btclog"
|
||||
)
|
||||
|
||||
// log is a logger that is initialized as disabled. This means the package will
|
||||
// not perform any logging by default until a logger is set.
|
||||
var log = btclog.Disabled
|
||||
|
||||
// UseLogger uses a specified Logger to output package logging info.
|
||||
func UseLogger(logger btclog.Logger) {
|
||||
log = logger
|
||||
}
|
191
channeldb/migration16/migration.go
Normal file
191
channeldb/migration16/migration.go
Normal file
@ -0,0 +1,191 @@
|
||||
package migration16
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/lightningnetwork/lnd/channeldb/kvdb"
|
||||
)
|
||||
|
||||
var (
|
||||
paymentsRootBucket = []byte("payments-root-bucket")
|
||||
|
||||
paymentSequenceKey = []byte("payment-sequence-key")
|
||||
|
||||
duplicatePaymentsBucket = []byte("payment-duplicate-bucket")
|
||||
|
||||
paymentsIndexBucket = []byte("payments-index-bucket")
|
||||
|
||||
byteOrder = binary.BigEndian
|
||||
)
|
||||
|
||||
// paymentIndexType indicates the type of index we have recorded in the payment
|
||||
// indexes bucket.
|
||||
type paymentIndexType uint8
|
||||
|
||||
// paymentIndexTypeHash is a payment index type which indicates that we have
|
||||
// created an index of payment sequence number to payment hash.
|
||||
const paymentIndexTypeHash paymentIndexType = 0
|
||||
|
||||
// paymentIndex stores all the information we require to create an index by
|
||||
// sequence number for a payment.
|
||||
type paymentIndex struct {
|
||||
// paymentHash is the hash of the payment, which is its key in the
|
||||
// payment root bucket.
|
||||
paymentHash []byte
|
||||
|
||||
// sequenceNumbers is the set of sequence numbers associated with this
|
||||
// payment hash. There will be more than one sequence number in the
|
||||
// case where duplicate payments are present.
|
||||
sequenceNumbers [][]byte
|
||||
}
|
||||
|
||||
// MigrateSequenceIndex migrates the payments db to contain a new bucket which
|
||||
// provides an index from sequence number to payment hash. This is required
|
||||
// for more efficient sequential lookup of payments, which are keyed by payment
|
||||
// hash before this migration.
|
||||
func MigrateSequenceIndex(tx kvdb.RwTx) error {
|
||||
log.Infof("Migrating payments to add sequence number index")
|
||||
|
||||
// Get a list of indices we need to write.
|
||||
indexList, err := getPaymentIndexList(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create the top level bucket that we will use to index payments in.
|
||||
bucket, err := tx.CreateTopLevelBucket(paymentsIndexBucket)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write an index for each of our payments.
|
||||
for _, index := range indexList {
|
||||
// Write indexes for each of our sequence numbers.
|
||||
for _, seqNr := range index.sequenceNumbers {
|
||||
err := putIndex(bucket, seqNr, index.paymentHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// putIndex performs a sanity check that ensures we are not writing duplicate
|
||||
// indexes to disk then creates the index provided.
|
||||
func putIndex(bucket kvdb.RwBucket, sequenceNr, paymentHash []byte) error {
|
||||
// Add a sanity check that we do not already have an entry with
|
||||
// this sequence number.
|
||||
existingEntry := bucket.Get(sequenceNr)
|
||||
if existingEntry != nil {
|
||||
return fmt.Errorf("sequence number: %x duplicated",
|
||||
sequenceNr)
|
||||
}
|
||||
|
||||
bytes, err := serializePaymentIndexEntry(paymentHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bucket.Put(sequenceNr, bytes)
|
||||
}
|
||||
|
||||
// serializePaymentIndexEntry serializes a payment hash typed index. The value
|
||||
// produced contains a payment index type (which can be used in future to
|
||||
// signal different payment index types) and the payment hash.
|
||||
func serializePaymentIndexEntry(hash []byte) ([]byte, error) {
|
||||
var b bytes.Buffer
|
||||
|
||||
err := binary.Write(&b, byteOrder, paymentIndexTypeHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := wire.WriteVarBytes(&b, 0, hash); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b.Bytes(), nil
|
||||
}
|
||||
|
||||
// getPaymentIndexList gets a list of indices we need to write for our current
|
||||
// set of payments.
|
||||
func getPaymentIndexList(tx kvdb.RTx) ([]paymentIndex, error) {
|
||||
// Iterate over all payments and store their indexing keys. This is
|
||||
// needed, because no modifications are allowed inside a Bucket.ForEach
|
||||
// loop.
|
||||
paymentsBucket := tx.ReadBucket(paymentsRootBucket)
|
||||
if paymentsBucket == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var indexList []paymentIndex
|
||||
err := paymentsBucket.ForEach(func(k, v []byte) error {
|
||||
// Get the bucket which contains the payment, fail if the key
|
||||
// does not have a bucket.
|
||||
bucket := paymentsBucket.NestedReadBucket(k)
|
||||
if bucket == nil {
|
||||
return fmt.Errorf("non bucket element in " +
|
||||
"payments bucket")
|
||||
}
|
||||
seqBytes := bucket.Get(paymentSequenceKey)
|
||||
if seqBytes == nil {
|
||||
return fmt.Errorf("nil sequence number bytes")
|
||||
}
|
||||
|
||||
seqNrs, err := fetchSequenceNumbers(bucket)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create an index object with our payment hash and sequence
|
||||
// numbers and append it to our set of indexes.
|
||||
index := paymentIndex{
|
||||
paymentHash: k,
|
||||
sequenceNumbers: seqNrs,
|
||||
}
|
||||
|
||||
indexList = append(indexList, index)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return indexList, nil
|
||||
}
|
||||
|
||||
// fetchSequenceNumbers fetches all the sequence numbers associated with a
|
||||
// payment, including those belonging to any duplicate payments.
|
||||
func fetchSequenceNumbers(paymentBucket kvdb.RBucket) ([][]byte, error) {
|
||||
seqNum := paymentBucket.Get(paymentSequenceKey)
|
||||
if seqNum == nil {
|
||||
return nil, errors.New("expected sequence number")
|
||||
}
|
||||
|
||||
sequenceNumbers := [][]byte{seqNum}
|
||||
|
||||
// Get the duplicate payments bucket, if it has no duplicates, just
|
||||
// return early with the payment sequence number.
|
||||
duplicates := paymentBucket.NestedReadBucket(duplicatePaymentsBucket)
|
||||
if duplicates == nil {
|
||||
return sequenceNumbers, nil
|
||||
}
|
||||
|
||||
// If we do have duplicated, they are keyed by sequence number, so we
|
||||
// iterate through the duplicates bucket and add them to our set of
|
||||
// sequence numbers.
|
||||
if err := duplicates.ForEach(func(k, v []byte) error {
|
||||
sequenceNumbers = append(sequenceNumbers, k)
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sequenceNumbers, nil
|
||||
}
|
144
channeldb/migration16/migration_test.go
Normal file
144
channeldb/migration16/migration_test.go
Normal file
@ -0,0 +1,144 @@
|
||||
package migration16
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/channeldb/kvdb"
|
||||
"github.com/lightningnetwork/lnd/channeldb/migtest"
|
||||
)
|
||||
|
||||
var (
|
||||
hexStr = migtest.Hex
|
||||
|
||||
hash1Str = "02acee76ebd53d00824410cf6adecad4f50334dac702bd5a2d3ba01b91709f0e"
|
||||
hash1 = hexStr(hash1Str)
|
||||
paymentID1 = hexStr("0000000000000001")
|
||||
|
||||
hash2Str = "62eb3f0a48f954e495d0c14ac63df04a67cefa59dafdbcd3d5046d1f5647840c"
|
||||
hash2 = hexStr(hash2Str)
|
||||
paymentID2 = hexStr("0000000000000002")
|
||||
|
||||
paymentID3 = hexStr("0000000000000003")
|
||||
|
||||
// pre is the data in the payments root bucket in database version 13 format.
|
||||
pre = map[string]interface{}{
|
||||
// A payment without duplicates.
|
||||
hash1: map[string]interface{}{
|
||||
"payment-sequence-key": paymentID1,
|
||||
},
|
||||
|
||||
// A payment with a duplicate.
|
||||
hash2: map[string]interface{}{
|
||||
"payment-sequence-key": paymentID2,
|
||||
"payment-duplicate-bucket": map[string]interface{}{
|
||||
paymentID3: map[string]interface{}{
|
||||
"payment-sequence-key": paymentID3,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
preFails = map[string]interface{}{
|
||||
// A payment without duplicates.
|
||||
hash1: map[string]interface{}{
|
||||
"payment-sequence-key": paymentID1,
|
||||
"payment-duplicate-bucket": map[string]interface{}{
|
||||
paymentID1: map[string]interface{}{
|
||||
"payment-sequence-key": paymentID1,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// post is the expected data after migration.
|
||||
post = map[string]interface{}{
|
||||
paymentID1: paymentHashIndex(hash1Str),
|
||||
paymentID2: paymentHashIndex(hash2Str),
|
||||
paymentID3: paymentHashIndex(hash2Str),
|
||||
}
|
||||
)
|
||||
|
||||
// paymentHashIndex produces a string that represents the value we expect for
|
||||
// our payment indexes from a hex encoded payment hash string.
|
||||
func paymentHashIndex(hashStr string) string {
|
||||
hash, err := hex.DecodeString(hashStr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
bytes, err := serializePaymentIndexEntry(hash)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
// MigrateSequenceIndex asserts that the database is properly migrated to
|
||||
// contain a payments index.
|
||||
func TestMigrateSequenceIndex(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
shouldFail bool
|
||||
pre map[string]interface{}
|
||||
post map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "migration ok",
|
||||
shouldFail: false,
|
||||
pre: pre,
|
||||
post: post,
|
||||
},
|
||||
{
|
||||
name: "duplicate sequence number",
|
||||
shouldFail: true,
|
||||
pre: preFails,
|
||||
post: post,
|
||||
},
|
||||
{
|
||||
name: "no payments",
|
||||
shouldFail: false,
|
||||
pre: nil,
|
||||
post: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Before the migration we have a payments bucket.
|
||||
before := func(tx kvdb.RwTx) error {
|
||||
return migtest.RestoreDB(
|
||||
tx, paymentsRootBucket, test.pre,
|
||||
)
|
||||
}
|
||||
|
||||
// After the migration, we should have an untouched
|
||||
// payments bucket and a new index bucket.
|
||||
after := func(tx kvdb.RwTx) error {
|
||||
if err := migtest.VerifyDB(
|
||||
tx, paymentsRootBucket, test.pre,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If we expect our migration to fail, we don't
|
||||
// expect an index bucket.
|
||||
if test.shouldFail {
|
||||
return nil
|
||||
}
|
||||
|
||||
return migtest.VerifyDB(
|
||||
tx, paymentsIndexBucket, test.post,
|
||||
)
|
||||
}
|
||||
|
||||
migtest.ApplyMigration(
|
||||
t, before, after, MigrateSequenceIndex,
|
||||
test.shouldFail,
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
140
channeldb/paginate.go
Normal file
140
channeldb/paginate.go
Normal file
@ -0,0 +1,140 @@
|
||||
package channeldb
|
||||
|
||||
import "github.com/lightningnetwork/lnd/channeldb/kvdb"
|
||||
|
||||
type paginator struct {
|
||||
// cursor is the cursor which we are using to iterate through a bucket.
|
||||
cursor kvdb.RCursor
|
||||
|
||||
// reversed indicates whether we are paginating forwards or backwards.
|
||||
reversed bool
|
||||
|
||||
// indexOffset is the index from which we will begin querying.
|
||||
indexOffset uint64
|
||||
|
||||
// totalItems is the total number of items we allow in our response.
|
||||
totalItems uint64
|
||||
}
|
||||
|
||||
// newPaginator returns a struct which can be used to query an indexed bucket
|
||||
// in pages.
|
||||
func newPaginator(c kvdb.RCursor, reversed bool,
|
||||
indexOffset, totalItems uint64) paginator {
|
||||
|
||||
return paginator{
|
||||
cursor: c,
|
||||
reversed: reversed,
|
||||
indexOffset: indexOffset,
|
||||
totalItems: totalItems,
|
||||
}
|
||||
}
|
||||
|
||||
// keyValueForIndex seeks our cursor to a given index and returns the key and
|
||||
// value at that position.
|
||||
func (p paginator) keyValueForIndex(index uint64) ([]byte, []byte) {
|
||||
var keyIndex [8]byte
|
||||
byteOrder.PutUint64(keyIndex[:], index)
|
||||
return p.cursor.Seek(keyIndex[:])
|
||||
}
|
||||
|
||||
// lastIndex returns the last value in our index, if our index is empty it
|
||||
// returns 0.
|
||||
func (p paginator) lastIndex() uint64 {
|
||||
keyIndex, _ := p.cursor.Last()
|
||||
if keyIndex == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return byteOrder.Uint64(keyIndex)
|
||||
}
|
||||
|
||||
// nextKey is a helper closure to determine what key we should use next when
|
||||
// we are iterating, depending on whether we are iterating forwards or in
|
||||
// reverse.
|
||||
func (p paginator) nextKey() ([]byte, []byte) {
|
||||
if p.reversed {
|
||||
return p.cursor.Prev()
|
||||
}
|
||||
return p.cursor.Next()
|
||||
}
|
||||
|
||||
// cursorStart gets the index key and value for the first item we are looking
|
||||
// up, taking into account that we may be paginating in reverse. The index
|
||||
// offset provided is *excusive* so we will start with the item after the offset
|
||||
// for forwards queries, and the item before the index for backwards queries.
|
||||
func (p paginator) cursorStart() ([]byte, []byte) {
|
||||
indexKey, indexValue := p.keyValueForIndex(p.indexOffset + 1)
|
||||
|
||||
// If the query is specifying reverse iteration, then we must
|
||||
// handle a few offset cases.
|
||||
if p.reversed {
|
||||
switch {
|
||||
|
||||
// This indicates the default case, where no offset was
|
||||
// specified. In that case we just start from the last
|
||||
// entry.
|
||||
case p.indexOffset == 0:
|
||||
indexKey, indexValue = p.cursor.Last()
|
||||
|
||||
// This indicates the offset being set to the very
|
||||
// first entry. Since there are no entries before
|
||||
// this offset, and the direction is reversed, we can
|
||||
// return without adding any invoices to the response.
|
||||
case p.indexOffset == 1:
|
||||
return nil, nil
|
||||
|
||||
// If we have been given an index offset that is beyond our last
|
||||
// index value, we just return the last indexed value in our set
|
||||
// since we are querying in reverse. We do not cover the case
|
||||
// where our index offset equals our last index value, because
|
||||
// index offset is exclusive, so we would want to start at the
|
||||
// value before our last index.
|
||||
case p.indexOffset > p.lastIndex():
|
||||
return p.cursor.Last()
|
||||
|
||||
// Otherwise we have an index offset which is within our set of
|
||||
// indexed keys, and we want to start at the item before our
|
||||
// offset. We seek to our index offset, then return the element
|
||||
// before it. We do this rather than p.indexOffset-1 to account
|
||||
// for indexes that have gaps.
|
||||
default:
|
||||
p.keyValueForIndex(p.indexOffset)
|
||||
indexKey, indexValue = p.cursor.Prev()
|
||||
}
|
||||
}
|
||||
|
||||
return indexKey, indexValue
|
||||
}
|
||||
|
||||
// query gets the start point for our index offset and iterates through keys
|
||||
// in our index until we reach the total number of items required for the query
|
||||
// or we run out of cursor values. This function takes a fetchAndAppend function
|
||||
// which is responsible for looking up the entry at that index, adding the entry
|
||||
// to its set of return items (if desired) and return a boolean which indicates
|
||||
// whether the item was added. This is required to allow the paginator to
|
||||
// determine when the response has the maximum number of required items.
|
||||
func (p paginator) query(fetchAndAppend func(k, v []byte) (bool, error)) error {
|
||||
indexKey, indexValue := p.cursorStart()
|
||||
|
||||
var totalItems int
|
||||
for ; indexKey != nil; indexKey, indexValue = p.nextKey() {
|
||||
// If our current return payload exceeds the max number
|
||||
// of invoices, then we'll exit now.
|
||||
if uint64(totalItems) >= p.totalItems {
|
||||
break
|
||||
}
|
||||
|
||||
added, err := fetchAndAppend(indexKey, indexValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If we added an item to our set in the latest fetch and append
|
||||
// we increment our total count.
|
||||
if added {
|
||||
totalItems++
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -5,6 +5,7 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/lightningnetwork/lnd/channeldb/kvdb"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
@ -74,6 +75,11 @@ var (
|
||||
// errNoAttemptInfo is returned when no attempt info is stored yet.
|
||||
errNoAttemptInfo = errors.New("unable to find attempt info for " +
|
||||
"inflight payment")
|
||||
|
||||
// errNoSequenceNrIndex is returned when an attempt to lookup a payment
|
||||
// index is made for a sequence number that is not indexed.
|
||||
errNoSequenceNrIndex = errors.New("payment sequence number index " +
|
||||
"does not exist")
|
||||
)
|
||||
|
||||
// PaymentControl implements persistence for payments and payment attempts.
|
||||
@ -152,6 +158,27 @@ func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash,
|
||||
return err
|
||||
}
|
||||
|
||||
// Before we set our new sequence number, we check whether this
|
||||
// payment has a previously set sequence number and remove its
|
||||
// index entry if it exists. This happens in the case where we
|
||||
// have a previously attempted payment which was left in a state
|
||||
// where we can retry.
|
||||
seqBytes := bucket.Get(paymentSequenceKey)
|
||||
if seqBytes != nil {
|
||||
indexBucket := tx.ReadWriteBucket(paymentsIndexBucket)
|
||||
if err := indexBucket.Delete(seqBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Once we have obtained a sequence number, we add an entry
|
||||
// to our index bucket which will map the sequence number to
|
||||
// our payment hash.
|
||||
err = createPaymentIndexEntry(tx, sequenceNum, info.PaymentHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = bucket.Put(paymentSequenceKey, sequenceNum)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -183,6 +210,58 @@ func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash,
|
||||
return updateErr
|
||||
}
|
||||
|
||||
// paymentIndexTypeHash is a payment index type which indicates that we have
|
||||
// created an index of payment sequence number to payment hash.
|
||||
type paymentIndexType uint8
|
||||
|
||||
// paymentIndexTypeHash is a payment index type which indicates that we have
|
||||
// created an index of payment sequence number to payment hash.
|
||||
const paymentIndexTypeHash paymentIndexType = 0
|
||||
|
||||
// createPaymentIndexEntry creates a payment hash typed index for a payment. The
|
||||
// index produced contains a payment index type (which can be used in future to
|
||||
// signal different payment index types) and the payment hash.
|
||||
func createPaymentIndexEntry(tx kvdb.RwTx, sequenceNumber []byte,
|
||||
hash lntypes.Hash) error {
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := WriteElements(&b, paymentIndexTypeHash, hash[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
indexes := tx.ReadWriteBucket(paymentsIndexBucket)
|
||||
return indexes.Put(sequenceNumber, b.Bytes())
|
||||
}
|
||||
|
||||
// deserializePaymentIndex deserializes a payment index entry. This function
|
||||
// currently only supports deserialization of payment hash indexes, and will
|
||||
// fail for other types.
|
||||
func deserializePaymentIndex(r io.Reader) (lntypes.Hash, error) {
|
||||
var (
|
||||
indexType paymentIndexType
|
||||
paymentHash []byte
|
||||
)
|
||||
|
||||
if err := ReadElements(r, &indexType, &paymentHash); err != nil {
|
||||
return lntypes.Hash{}, err
|
||||
}
|
||||
|
||||
// While we only have on payment index type, we do not need to use our
|
||||
// index type to deserialize the index. However, we sanity check that
|
||||
// this type is as expected, since we had to read it out anyway.
|
||||
if indexType != paymentIndexTypeHash {
|
||||
return lntypes.Hash{}, fmt.Errorf("unknown payment index "+
|
||||
"type: %v", indexType)
|
||||
}
|
||||
|
||||
hash, err := lntypes.MakeHash(paymentHash)
|
||||
if err != nil {
|
||||
return lntypes.Hash{}, err
|
||||
}
|
||||
|
||||
return hash, nil
|
||||
}
|
||||
|
||||
// RegisterAttempt atomically records the provided HTLCAttemptInfo to the
|
||||
// DB.
|
||||
func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash,
|
||||
|
@ -1,6 +1,7 @@
|
||||
package channeldb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
@ -9,9 +10,13 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcwallet/walletdb"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/channeldb/kvdb"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func genPreimage() ([32]byte, error) {
|
||||
@ -70,6 +75,7 @@ func TestPaymentControlSwitchFail(t *testing.T) {
|
||||
t.Fatalf("unable to send htlc message: %v", err)
|
||||
}
|
||||
|
||||
assertPaymentIndex(t, pControl, info.PaymentHash)
|
||||
assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight)
|
||||
assertPaymentInfo(
|
||||
t, pControl, info.PaymentHash, info, nil, nil,
|
||||
@ -88,6 +94,11 @@ func TestPaymentControlSwitchFail(t *testing.T) {
|
||||
t, pControl, info.PaymentHash, info, &failReason, nil,
|
||||
)
|
||||
|
||||
// Lookup the payment so we can get its old sequence number before it is
|
||||
// overwritten.
|
||||
payment, err := pControl.FetchPayment(info.PaymentHash)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Sends the htlc again, which should succeed since the prior payment
|
||||
// failed.
|
||||
err = pControl.InitPayment(info.PaymentHash, info)
|
||||
@ -95,6 +106,11 @@ func TestPaymentControlSwitchFail(t *testing.T) {
|
||||
t.Fatalf("unable to send htlc message: %v", err)
|
||||
}
|
||||
|
||||
// Check that our index has been updated, and the old index has been
|
||||
// removed.
|
||||
assertPaymentIndex(t, pControl, info.PaymentHash)
|
||||
assertNoIndex(t, pControl, payment.SequenceNum)
|
||||
|
||||
assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight)
|
||||
assertPaymentInfo(
|
||||
t, pControl, info.PaymentHash, info, nil, nil,
|
||||
@ -145,7 +161,6 @@ func TestPaymentControlSwitchFail(t *testing.T) {
|
||||
|
||||
// Settle the attempt and verify that status was changed to
|
||||
// StatusSucceeded.
|
||||
var payment *MPPayment
|
||||
payment, err = pControl.SettleAttempt(
|
||||
info.PaymentHash, attempt.AttemptID,
|
||||
&HTLCSettleInfo{
|
||||
@ -209,6 +224,7 @@ func TestPaymentControlSwitchDoubleSend(t *testing.T) {
|
||||
t.Fatalf("unable to send htlc message: %v", err)
|
||||
}
|
||||
|
||||
assertPaymentIndex(t, pControl, info.PaymentHash)
|
||||
assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight)
|
||||
assertPaymentInfo(
|
||||
t, pControl, info.PaymentHash, info, nil, nil,
|
||||
@ -326,7 +342,7 @@ func TestPaymentControlFailsWithoutInFlight(t *testing.T) {
|
||||
assertPaymentStatus(t, pControl, info.PaymentHash, StatusUnknown)
|
||||
}
|
||||
|
||||
// TestPaymentControlDeleteNonInFlight checks that calling DeletaPayments only
|
||||
// TestPaymentControlDeleteNonInFlight checks that calling DeletePayments only
|
||||
// deletes payments from the database that are not in-flight.
|
||||
func TestPaymentControlDeleteNonInFligt(t *testing.T) {
|
||||
t.Parallel()
|
||||
@ -338,23 +354,37 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) {
|
||||
t.Fatalf("unable to init db: %v", err)
|
||||
}
|
||||
|
||||
// Create a sequence number for duplicate payments that will not collide
|
||||
// with the sequence numbers for the payments we create. These values
|
||||
// start at 1, so 9999 is a safe bet for this test.
|
||||
var duplicateSeqNr = 9999
|
||||
|
||||
pControl := NewPaymentControl(db)
|
||||
|
||||
payments := []struct {
|
||||
failed bool
|
||||
success bool
|
||||
hasDuplicate bool
|
||||
}{
|
||||
{
|
||||
failed: true,
|
||||
success: false,
|
||||
hasDuplicate: false,
|
||||
},
|
||||
{
|
||||
failed: false,
|
||||
success: true,
|
||||
hasDuplicate: false,
|
||||
},
|
||||
{
|
||||
failed: false,
|
||||
success: false,
|
||||
hasDuplicate: false,
|
||||
},
|
||||
{
|
||||
failed: false,
|
||||
success: true,
|
||||
hasDuplicate: true,
|
||||
},
|
||||
}
|
||||
|
||||
@ -430,6 +460,16 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) {
|
||||
t, pControl, info.PaymentHash, info, nil, htlc,
|
||||
)
|
||||
}
|
||||
|
||||
// If the payment is intended to have a duplicate payment, we
|
||||
// add one.
|
||||
if p.hasDuplicate {
|
||||
appendDuplicatePayment(
|
||||
t, pControl.db, info.PaymentHash,
|
||||
uint64(duplicateSeqNr),
|
||||
)
|
||||
duplicateSeqNr++
|
||||
}
|
||||
}
|
||||
|
||||
// Delete payments.
|
||||
@ -451,6 +491,21 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) {
|
||||
if status != StatusInFlight {
|
||||
t.Fatalf("expected in-fligth status, got %v", status)
|
||||
}
|
||||
|
||||
// Finally, check that we only have a single index left in the payment
|
||||
// index bucket.
|
||||
var indexCount int
|
||||
err = kvdb.View(db, func(tx walletdb.ReadTx) error {
|
||||
index := tx.ReadBucket(paymentsIndexBucket)
|
||||
|
||||
return index.ForEach(func(k, v []byte) error {
|
||||
indexCount++
|
||||
return nil
|
||||
})
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, 1, indexCount)
|
||||
}
|
||||
|
||||
// TestPaymentControlMultiShard checks the ability of payment control to
|
||||
@ -495,6 +550,7 @@ func TestPaymentControlMultiShard(t *testing.T) {
|
||||
t.Fatalf("unable to send htlc message: %v", err)
|
||||
}
|
||||
|
||||
assertPaymentIndex(t, pControl, info.PaymentHash)
|
||||
assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight)
|
||||
assertPaymentInfo(
|
||||
t, pControl, info.PaymentHash, info, nil, nil,
|
||||
@ -910,3 +966,55 @@ func assertPaymentInfo(t *testing.T, p *PaymentControl, hash lntypes.Hash,
|
||||
t.Fatal("expected no settle info")
|
||||
}
|
||||
}
|
||||
|
||||
// fetchPaymentIndexEntry gets the payment hash for the sequence number provided
|
||||
// from our payment indexes bucket.
|
||||
func fetchPaymentIndexEntry(_ *testing.T, p *PaymentControl,
|
||||
sequenceNumber uint64) (*lntypes.Hash, error) {
|
||||
|
||||
var hash lntypes.Hash
|
||||
|
||||
if err := kvdb.View(p.db, func(tx walletdb.ReadTx) error {
|
||||
indexBucket := tx.ReadBucket(paymentsIndexBucket)
|
||||
key := make([]byte, 8)
|
||||
byteOrder.PutUint64(key, sequenceNumber)
|
||||
|
||||
indexValue := indexBucket.Get(key)
|
||||
if indexValue == nil {
|
||||
return errNoSequenceNrIndex
|
||||
}
|
||||
|
||||
r := bytes.NewReader(indexValue)
|
||||
|
||||
var err error
|
||||
hash, err = deserializePaymentIndex(r)
|
||||
return err
|
||||
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &hash, nil
|
||||
}
|
||||
|
||||
// assertPaymentIndex looks up the index for a payment in the db and checks
|
||||
// that its payment hash matches the expected hash passed in.
|
||||
func assertPaymentIndex(t *testing.T, p *PaymentControl,
|
||||
expectedHash lntypes.Hash) {
|
||||
|
||||
// Lookup the payment so that we have its sequence number and check
|
||||
// that is has correctly been indexed in the payment indexes bucket.
|
||||
pmt, err := p.FetchPayment(expectedHash)
|
||||
require.NoError(t, err)
|
||||
|
||||
hash, err := fetchPaymentIndexEntry(t, p, pmt.SequenceNum)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedHash, *hash)
|
||||
}
|
||||
|
||||
// assertNoIndex checks that an index for the sequence number provided does not
|
||||
// exist.
|
||||
func assertNoIndex(t *testing.T, p *PaymentControl, seqNr uint64) {
|
||||
_, err := fetchPaymentIndexEntry(t, p, seqNr)
|
||||
require.Equal(t, errNoSequenceNrIndex, err)
|
||||
}
|
||||
|
@ -3,9 +3,9 @@ package channeldb
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
@ -92,6 +92,35 @@ var (
|
||||
// paymentFailInfoKey is a key used in the payment's sub-bucket to
|
||||
// store information about the reason a payment failed.
|
||||
paymentFailInfoKey = []byte("payment-fail-info")
|
||||
|
||||
// paymentsIndexBucket is the name of the top-level bucket within the
|
||||
// database that stores an index of payment sequence numbers to its
|
||||
// payment hash.
|
||||
// payments-sequence-index-bucket
|
||||
// |--<sequence-number>: <payment hash>
|
||||
// |--...
|
||||
// |--<sequence-number>: <payment hash>
|
||||
paymentsIndexBucket = []byte("payments-index-bucket")
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNoSequenceNumber is returned if we lookup a payment which does
|
||||
// not have a sequence number.
|
||||
ErrNoSequenceNumber = errors.New("sequence number not found")
|
||||
|
||||
// ErrDuplicateNotFound is returned when we lookup a payment by its
|
||||
// index and cannot find a payment with a matching sequence number.
|
||||
ErrDuplicateNotFound = errors.New("duplicate payment not found")
|
||||
|
||||
// ErrNoDuplicateBucket is returned when we expect to find duplicates
|
||||
// when looking up a payment from its index, but the payment does not
|
||||
// have any.
|
||||
ErrNoDuplicateBucket = errors.New("expected duplicate bucket")
|
||||
|
||||
// ErrNoDuplicateNestedBucket is returned if we do not find duplicate
|
||||
// payments in their own sub-bucket.
|
||||
ErrNoDuplicateNestedBucket = errors.New("nested duplicate bucket not " +
|
||||
"found")
|
||||
)
|
||||
|
||||
// FailureReason encodes the reason a payment ultimately failed.
|
||||
@ -481,62 +510,70 @@ type PaymentsResponse struct {
|
||||
func (db *DB) QueryPayments(query PaymentsQuery) (PaymentsResponse, error) {
|
||||
var resp PaymentsResponse
|
||||
|
||||
allPayments, err := db.FetchPayments()
|
||||
if err := kvdb.View(db, func(tx kvdb.RTx) error {
|
||||
// Get the root payments bucket.
|
||||
paymentsBucket := tx.ReadBucket(paymentsRootBucket)
|
||||
if paymentsBucket == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the index bucket which maps sequence number -> payment
|
||||
// hash and duplicate bool. If we have a payments bucket, we
|
||||
// should have an indexes bucket as well.
|
||||
indexes := tx.ReadBucket(paymentsIndexBucket)
|
||||
if indexes == nil {
|
||||
return fmt.Errorf("index bucket does not exist")
|
||||
}
|
||||
|
||||
// accumulatePayments gets payments with the sequence number
|
||||
// and hash provided and adds them to our list of payments if
|
||||
// they meet the criteria of our query. It returns the number
|
||||
// of payments that were added.
|
||||
accumulatePayments := func(sequenceKey, hash []byte) (bool,
|
||||
error) {
|
||||
|
||||
r := bytes.NewReader(hash)
|
||||
paymentHash, err := deserializePaymentIndex(r)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
return false, err
|
||||
}
|
||||
|
||||
if len(allPayments) == 0 {
|
||||
return resp, nil
|
||||
payment, err := fetchPaymentWithSequenceNumber(
|
||||
tx, paymentHash, sequenceKey,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
indexExclusiveLimit := query.IndexOffset
|
||||
// In backward pagination, if the index limit is the default 0 value,
|
||||
// we set our limit to maxint to include all payments from the highest
|
||||
// sequence number on.
|
||||
if query.Reversed && indexExclusiveLimit == 0 {
|
||||
indexExclusiveLimit = math.MaxInt64
|
||||
}
|
||||
|
||||
for i := range allPayments {
|
||||
var payment *MPPayment
|
||||
|
||||
// If we have the max number of payments we want, exit.
|
||||
if uint64(len(resp.Payments)) == query.MaxPayments {
|
||||
break
|
||||
}
|
||||
|
||||
if query.Reversed {
|
||||
payment = allPayments[len(allPayments)-1-i]
|
||||
|
||||
// In the reversed direction, skip over all payments
|
||||
// that have sequence numbers greater than or equal to
|
||||
// the index offset. We skip payments with equal index
|
||||
// because the offset is exclusive.
|
||||
if payment.SequenceNum >= indexExclusiveLimit {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
payment = allPayments[i]
|
||||
|
||||
// In the forward direction, skip over all payments that
|
||||
// have sequence numbers less than or equal to the index
|
||||
// offset. We skip payments with equal indexes because
|
||||
// the index offset is exclusive.
|
||||
if payment.SequenceNum <= indexExclusiveLimit {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// To keep compatibility with the old API, we only return
|
||||
// non-succeeded payments if requested.
|
||||
// To keep compatibility with the old API, we only
|
||||
// return non-succeeded payments if requested.
|
||||
if payment.Status != StatusSucceeded &&
|
||||
!query.IncludeIncomplete {
|
||||
|
||||
continue
|
||||
return false, err
|
||||
}
|
||||
|
||||
// At this point, we've exhausted the offset, so we'll
|
||||
// begin collecting invoices found within the range.
|
||||
resp.Payments = append(resp.Payments, payment)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Create a paginator which reads from our sequence index bucket
|
||||
// with the parameters provided by the payments query.
|
||||
paginator := newPaginator(
|
||||
indexes.ReadCursor(), query.Reversed, query.IndexOffset,
|
||||
query.MaxPayments,
|
||||
)
|
||||
|
||||
// Run a paginated query, adding payments to our response.
|
||||
if err := paginator.query(accumulatePayments); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Need to swap the payments slice order if reversed order.
|
||||
@ -555,7 +592,84 @@ func (db *DB) QueryPayments(query PaymentsQuery) (PaymentsResponse, error) {
|
||||
resp.Payments[len(resp.Payments)-1].SequenceNum
|
||||
}
|
||||
|
||||
return resp, err
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// fetchPaymentWithSequenceNumber get the payment which matches the payment hash
|
||||
// *and* sequence number provided from the database. This is required because
|
||||
// we previously had more than one payment per hash, so we have multiple indexes
|
||||
// pointing to a single payment; we want to retrieve the correct one.
|
||||
func fetchPaymentWithSequenceNumber(tx kvdb.RTx, paymentHash lntypes.Hash,
|
||||
sequenceNumber []byte) (*MPPayment, error) {
|
||||
|
||||
// We can now lookup the payment keyed by its hash in
|
||||
// the payments root bucket.
|
||||
bucket, err := fetchPaymentBucket(tx, paymentHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// A single payment hash can have multiple payments associated with it.
|
||||
// We lookup our sequence number first, to determine whether this is
|
||||
// the payment we are actually looking for.
|
||||
seqBytes := bucket.Get(paymentSequenceKey)
|
||||
if seqBytes == nil {
|
||||
return nil, ErrNoSequenceNumber
|
||||
}
|
||||
|
||||
// If this top level payment has the sequence number we are looking for,
|
||||
// return it.
|
||||
if bytes.Equal(seqBytes, sequenceNumber) {
|
||||
return fetchPayment(bucket)
|
||||
}
|
||||
|
||||
// If we were not looking for the top level payment, we are looking for
|
||||
// one of our duplicate payments. We need to iterate through the seq
|
||||
// numbers in this bucket to find the correct payments. If we do not
|
||||
// find a duplicate payments bucket here, something is wrong.
|
||||
dup := bucket.NestedReadBucket(duplicatePaymentsBucket)
|
||||
if dup == nil {
|
||||
return nil, ErrNoDuplicateBucket
|
||||
}
|
||||
|
||||
var duplicatePayment *MPPayment
|
||||
err = dup.ForEach(func(k, v []byte) error {
|
||||
subBucket := dup.NestedReadBucket(k)
|
||||
if subBucket == nil {
|
||||
// We one bucket for each duplicate to be found.
|
||||
return ErrNoDuplicateNestedBucket
|
||||
}
|
||||
|
||||
seqBytes := subBucket.Get(duplicatePaymentSequenceKey)
|
||||
if seqBytes == nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If this duplicate payment is not the sequence number we are
|
||||
// looking for, we can continue.
|
||||
if !bytes.Equal(seqBytes, sequenceNumber) {
|
||||
return nil
|
||||
}
|
||||
|
||||
duplicatePayment, err = fetchDuplicatePayment(subBucket)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If none of the duplicate payments matched our sequence number, we
|
||||
// failed to find the payment with this sequence number; something is
|
||||
// wrong.
|
||||
if duplicatePayment == nil {
|
||||
return nil, ErrDuplicateNotFound
|
||||
}
|
||||
|
||||
return duplicatePayment, nil
|
||||
}
|
||||
|
||||
// DeletePayments deletes all completed and failed payments from the DB.
|
||||
@ -566,7 +680,15 @@ func (db *DB) DeletePayments() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var deleteBuckets [][]byte
|
||||
var (
|
||||
// deleteBuckets is the set of payment buckets we need
|
||||
// to delete.
|
||||
deleteBuckets [][]byte
|
||||
|
||||
// deleteIndexes is the set of indexes pointing to these
|
||||
// payments that need to be deleted.
|
||||
deleteIndexes [][]byte
|
||||
)
|
||||
err := payments.ForEach(func(k, _ []byte) error {
|
||||
bucket := payments.NestedReadWriteBucket(k)
|
||||
if bucket == nil {
|
||||
@ -589,7 +711,18 @@ func (db *DB) DeletePayments() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add the bucket to the set of buckets we can delete.
|
||||
deleteBuckets = append(deleteBuckets, k)
|
||||
|
||||
// Get all the sequence number associated with the
|
||||
// payment, including duplicates.
|
||||
seqNrs, err := fetchSequenceNumbers(bucket)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
deleteIndexes = append(deleteIndexes, seqNrs...)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
@ -602,10 +735,49 @@ func (db *DB) DeletePayments() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Get our index bucket and delete all indexes pointing to the
|
||||
// payments we are deleting.
|
||||
indexBucket := tx.ReadWriteBucket(paymentsIndexBucket)
|
||||
for _, k := range deleteIndexes {
|
||||
if err := indexBucket.Delete(k); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// fetchSequenceNumbers fetches all the sequence numbers associated with a
|
||||
// payment, including those belonging to any duplicate payments.
|
||||
func fetchSequenceNumbers(paymentBucket kvdb.RBucket) ([][]byte, error) {
|
||||
seqNum := paymentBucket.Get(paymentSequenceKey)
|
||||
if seqNum == nil {
|
||||
return nil, errors.New("expected sequence number")
|
||||
}
|
||||
|
||||
sequenceNumbers := [][]byte{seqNum}
|
||||
|
||||
// Get the duplicate payments bucket, if it has no duplicates, just
|
||||
// return early with the payment sequence number.
|
||||
duplicates := paymentBucket.NestedReadBucket(duplicatePaymentsBucket)
|
||||
if duplicates == nil {
|
||||
return sequenceNumbers, nil
|
||||
}
|
||||
|
||||
// If we do have duplicated, they are keyed by sequence number, so we
|
||||
// iterate through the duplicates bucket and add them to our set of
|
||||
// sequence numbers.
|
||||
if err := duplicates.ForEach(func(k, v []byte) error {
|
||||
sequenceNumbers = append(sequenceNumbers, k)
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sequenceNumbers, nil
|
||||
}
|
||||
|
||||
// nolint: dupl
|
||||
func serializePaymentCreationInfo(w io.Writer, c *PaymentCreationInfo) error {
|
||||
var scratch [8]byte
|
||||
|
@ -9,11 +9,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcwallet/walletdb"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/channeldb/kvdb"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -163,18 +165,24 @@ func TestRouteSerialization(t *testing.T) {
|
||||
}
|
||||
|
||||
// deletePayment removes a payment with paymentHash from the payments database.
|
||||
func deletePayment(t *testing.T, db *DB, paymentHash lntypes.Hash) {
|
||||
func deletePayment(t *testing.T, db *DB, paymentHash lntypes.Hash, seqNr uint64) {
|
||||
t.Helper()
|
||||
|
||||
err := kvdb.Update(db, func(tx kvdb.RwTx) error {
|
||||
payments := tx.ReadWriteBucket(paymentsRootBucket)
|
||||
|
||||
// Delete the payment bucket.
|
||||
err := payments.DeleteNestedBucket(paymentHash[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
key := make([]byte, 8)
|
||||
byteOrder.PutUint64(key, seqNr)
|
||||
|
||||
// Delete the index that references this payment.
|
||||
indexes := tx.ReadWriteBucket(paymentsIndexBucket)
|
||||
return indexes.Delete(key)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
@ -188,6 +196,10 @@ func deletePayment(t *testing.T, db *DB, paymentHash lntypes.Hash) {
|
||||
func TestQueryPayments(t *testing.T) {
|
||||
// Define table driven test for QueryPayments.
|
||||
// Test payments have sequence indices [1, 3, 4, 5, 6, 7].
|
||||
// Note that the payment with index 7 has the same payment hash as 6,
|
||||
// and is stored in a nested bucket within payment 6 rather than being
|
||||
// its own entry in the payments bucket. We do this to test retrieval
|
||||
// of legacy payments.
|
||||
tests := []struct {
|
||||
name string
|
||||
query PaymentsQuery
|
||||
@ -344,6 +356,42 @@ func TestQueryPayments(t *testing.T) {
|
||||
lastIndex: 7,
|
||||
expectedSeqNrs: []uint64{3, 4, 5, 6, 7},
|
||||
},
|
||||
{
|
||||
name: "query payments reverse before index gap",
|
||||
query: PaymentsQuery{
|
||||
IndexOffset: 3,
|
||||
MaxPayments: 7,
|
||||
Reversed: true,
|
||||
IncludeIncomplete: true,
|
||||
},
|
||||
firstIndex: 1,
|
||||
lastIndex: 1,
|
||||
expectedSeqNrs: []uint64{1},
|
||||
},
|
||||
{
|
||||
name: "query payments reverse on index gap",
|
||||
query: PaymentsQuery{
|
||||
IndexOffset: 2,
|
||||
MaxPayments: 7,
|
||||
Reversed: true,
|
||||
IncludeIncomplete: true,
|
||||
},
|
||||
firstIndex: 1,
|
||||
lastIndex: 1,
|
||||
expectedSeqNrs: []uint64{1},
|
||||
},
|
||||
{
|
||||
name: "query payments forward on index gap",
|
||||
query: PaymentsQuery{
|
||||
IndexOffset: 2,
|
||||
MaxPayments: 2,
|
||||
Reversed: false,
|
||||
IncludeIncomplete: true,
|
||||
},
|
||||
firstIndex: 3,
|
||||
lastIndex: 4,
|
||||
expectedSeqNrs: []uint64{3, 4},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@ -352,17 +400,28 @@ func TestQueryPayments(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, cleanup, err := makeTestDB()
|
||||
defer cleanup()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unable to init db: %v", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
// Make a preliminary query to make sure it's ok to
|
||||
// query when we have no payments.
|
||||
resp, err := db.QueryPayments(tt.query)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Payments, 0)
|
||||
|
||||
// Populate the database with a set of test payments.
|
||||
numberOfPayments := 7
|
||||
// We create 6 original payments, deleting the payment
|
||||
// at index 2 so that we cover the case where sequence
|
||||
// numbers are missing. We also add a duplicate payment
|
||||
// to the last payment added to test the legacy case
|
||||
// where we have duplicates in the nested duplicates
|
||||
// bucket.
|
||||
nonDuplicatePayments := 6
|
||||
pControl := NewPaymentControl(db)
|
||||
|
||||
for i := 0; i < numberOfPayments; i++ {
|
||||
for i := 0; i < nonDuplicatePayments; i++ {
|
||||
// Generate a test payment.
|
||||
info, _, _, err := genInfo()
|
||||
if err != nil {
|
||||
@ -379,7 +438,29 @@ func TestQueryPayments(t *testing.T) {
|
||||
|
||||
// Immediately delete the payment with index 2.
|
||||
if i == 1 {
|
||||
deletePayment(t, db, info.PaymentHash)
|
||||
pmt, err := pControl.FetchPayment(
|
||||
info.PaymentHash,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
deletePayment(t, db, info.PaymentHash,
|
||||
pmt.SequenceNum)
|
||||
}
|
||||
|
||||
// If we are on the last payment entry, add a
|
||||
// duplicate payment with sequence number equal
|
||||
// to the parent payment + 1.
|
||||
if i == (nonDuplicatePayments - 1) {
|
||||
pmt, err := pControl.FetchPayment(
|
||||
info.PaymentHash,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
appendDuplicatePayment(
|
||||
t, pControl.db,
|
||||
info.PaymentHash,
|
||||
pmt.SequenceNum+1,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -424,3 +505,210 @@ func TestQueryPayments(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchPaymentWithSequenceNumber tests lookup of payments with their
|
||||
// sequence number. It sets up one payment with no duplicates, and another with
|
||||
// two duplicates in its duplicates bucket then uses these payments to test the
|
||||
// case where a specific duplicate is not found and the duplicates bucket is not
|
||||
// present when we expect it to be.
|
||||
func TestFetchPaymentWithSequenceNumber(t *testing.T) {
|
||||
db, cleanup, err := makeTestDB()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer cleanup()
|
||||
|
||||
pControl := NewPaymentControl(db)
|
||||
|
||||
// Generate a test payment which does not have duplicates.
|
||||
noDuplicates, _, _, err := genInfo()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a new payment entry in the database.
|
||||
err = pControl.InitPayment(noDuplicates.PaymentHash, noDuplicates)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fetch the payment so we can get its sequence nr.
|
||||
noDuplicatesPayment, err := pControl.FetchPayment(
|
||||
noDuplicates.PaymentHash,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate a test payment which we will add duplicates to.
|
||||
hasDuplicates, _, _, err := genInfo()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a new payment entry in the database.
|
||||
err = pControl.InitPayment(hasDuplicates.PaymentHash, hasDuplicates)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fetch the payment so we can get its sequence nr.
|
||||
hasDuplicatesPayment, err := pControl.FetchPayment(
|
||||
hasDuplicates.PaymentHash,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// We declare the sequence numbers used here so that we can reference
|
||||
// them in tests.
|
||||
var (
|
||||
duplicateOneSeqNr = hasDuplicatesPayment.SequenceNum + 1
|
||||
duplicateTwoSeqNr = hasDuplicatesPayment.SequenceNum + 2
|
||||
)
|
||||
|
||||
// Add two duplicates to our second payment.
|
||||
appendDuplicatePayment(
|
||||
t, db, hasDuplicates.PaymentHash, duplicateOneSeqNr,
|
||||
)
|
||||
appendDuplicatePayment(
|
||||
t, db, hasDuplicates.PaymentHash, duplicateTwoSeqNr,
|
||||
)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
paymentHash lntypes.Hash
|
||||
sequenceNumber uint64
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "lookup payment without duplicates",
|
||||
paymentHash: noDuplicates.PaymentHash,
|
||||
sequenceNumber: noDuplicatesPayment.SequenceNum,
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "lookup payment with duplicates",
|
||||
paymentHash: hasDuplicates.PaymentHash,
|
||||
sequenceNumber: hasDuplicatesPayment.SequenceNum,
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "lookup first duplicate",
|
||||
paymentHash: hasDuplicates.PaymentHash,
|
||||
sequenceNumber: duplicateOneSeqNr,
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "lookup second duplicate",
|
||||
paymentHash: hasDuplicates.PaymentHash,
|
||||
sequenceNumber: duplicateTwoSeqNr,
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "lookup non-existent duplicate",
|
||||
paymentHash: hasDuplicates.PaymentHash,
|
||||
sequenceNumber: 999999,
|
||||
expectedErr: ErrDuplicateNotFound,
|
||||
},
|
||||
{
|
||||
name: "lookup duplicate, no duplicates bucket",
|
||||
paymentHash: noDuplicates.PaymentHash,
|
||||
sequenceNumber: duplicateTwoSeqNr,
|
||||
expectedErr: ErrNoDuplicateBucket,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
err := kvdb.Update(db,
|
||||
func(tx walletdb.ReadWriteTx) error {
|
||||
|
||||
var seqNrBytes [8]byte
|
||||
byteOrder.PutUint64(
|
||||
seqNrBytes[:], test.sequenceNumber,
|
||||
)
|
||||
|
||||
_, err := fetchPaymentWithSequenceNumber(
|
||||
tx, test.paymentHash, seqNrBytes[:],
|
||||
)
|
||||
return err
|
||||
})
|
||||
require.Equal(t, test.expectedErr, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// appendDuplicatePayment adds a duplicate payment to an existing payment. Note
|
||||
// that this function requires a unique sequence number.
|
||||
//
|
||||
// This code is *only* intended to replicate legacy duplicate payments in lnd,
|
||||
// our current schema does not allow duplicates.
|
||||
func appendDuplicatePayment(t *testing.T, db *DB, paymentHash lntypes.Hash,
|
||||
seqNr uint64) {
|
||||
|
||||
err := kvdb.Update(db, func(tx walletdb.ReadWriteTx) error {
|
||||
bucket, err := fetchPaymentBucketUpdate(
|
||||
tx, paymentHash,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create the duplicates bucket if it is not
|
||||
// present.
|
||||
dup, err := bucket.CreateBucketIfNotExists(
|
||||
duplicatePaymentsBucket,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var sequenceKey [8]byte
|
||||
byteOrder.PutUint64(sequenceKey[:], seqNr)
|
||||
|
||||
// Create duplicate payments for the two dup
|
||||
// sequence numbers we've setup.
|
||||
putDuplicatePayment(t, dup, sequenceKey[:], paymentHash)
|
||||
|
||||
// Finally, once we have created our entry we add an index for
|
||||
// it.
|
||||
err = createPaymentIndexEntry(tx, sequenceKey[:], paymentHash)
|
||||
require.NoError(t, err)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("could not create payment: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// putDuplicatePayment creates a duplicate payment in the duplicates bucket
|
||||
// provided with the minimal information required for successful reading.
|
||||
func putDuplicatePayment(t *testing.T, duplicateBucket kvdb.RwBucket,
|
||||
sequenceKey []byte, paymentHash lntypes.Hash) {
|
||||
|
||||
paymentBucket, err := duplicateBucket.CreateBucketIfNotExists(
|
||||
sequenceKey,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = paymentBucket.Put(duplicatePaymentSequenceKey, sequenceKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate fake information for the duplicate payment.
|
||||
info, _, _, err := genInfo()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write the payment info to disk under the creation info key. This code
|
||||
// is copied rather than using serializePaymentCreationInfo to ensure
|
||||
// we always write in the legacy format used by duplicate payments.
|
||||
var b bytes.Buffer
|
||||
var scratch [8]byte
|
||||
_, err = b.Write(paymentHash[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
byteOrder.PutUint64(scratch[:], uint64(info.Value))
|
||||
_, err = b.Write(scratch[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
err = serializeTime(&b, info.CreationTime)
|
||||
require.NoError(t, err)
|
||||
|
||||
byteOrder.PutUint32(scratch[:4], 0)
|
||||
_, err = b.Write(scratch[:4])
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get the PaymentCreationInfo.
|
||||
err = paymentBucket.Put(duplicatePaymentCreationInfoKey, b.Bytes())
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user