Merge pull request #4261 from carlaKC/4164-indexpayments

channeldb: Index payments by sequence number
This commit is contained in:
Conner Fromknecht 2020-06-10 08:11:45 -07:00 committed by GitHub
commit d47d17b5d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 1276 additions and 139 deletions

@ -192,6 +192,11 @@ func WriteElement(w io.Writer, element interface{}) error {
return err return err
} }
case paymentIndexType:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case lnwire.FundingFlag: case lnwire.FundingFlag:
if err := binary.Write(w, byteOrder, e); err != nil { if err := binary.Write(w, byteOrder, e); err != nil {
return err return err
@ -406,6 +411,11 @@ func ReadElement(r io.Reader, element interface{}) error {
return err return err
} }
case *paymentIndexType:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *lnwire.FundingFlag: case *lnwire.FundingFlag:
if err := binary.Read(r, byteOrder, e); err != nil { if err := binary.Read(r, byteOrder, e); err != nil {
return err return err

@ -16,6 +16,7 @@ import (
mig "github.com/lightningnetwork/lnd/channeldb/migration" mig "github.com/lightningnetwork/lnd/channeldb/migration"
"github.com/lightningnetwork/lnd/channeldb/migration12" "github.com/lightningnetwork/lnd/channeldb/migration12"
"github.com/lightningnetwork/lnd/channeldb/migration13" "github.com/lightningnetwork/lnd/channeldb/migration13"
"github.com/lightningnetwork/lnd/channeldb/migration16"
"github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11"
"github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
@ -144,6 +145,19 @@ var (
number: 14, number: 14,
migration: mig.CreateTLB(payAddrIndexBucket), migration: mig.CreateTLB(payAddrIndexBucket),
}, },
{
// Initialize payment index bucket which will be used
// to index payments by sequence number. This index will
// be used to allow more efficient ListPayments queries.
number: 15,
migration: mig.CreateTLB(paymentsIndexBucket),
},
{
// Add our existing payments to the index bucket created
// in migration 15.
number: 16,
migration: migration16.MigrateSequenceIndex,
},
} }
// Big endian is the preferred byte order, due to cursor scans over // Big endian is the preferred byte order, due to cursor scans over
@ -257,6 +271,7 @@ var topLevelBuckets = [][]byte{
fwdPackagesKey, fwdPackagesKey,
invoiceBucket, invoiceBucket,
payAddrIndexBucket, payAddrIndexBucket,
paymentsIndexBucket,
nodeInfoBucket, nodeInfoBucket,
nodeBucket, nodeBucket,
edgeBucket, edgeBucket,

@ -1007,6 +1007,18 @@ func TestQueryInvoices(t *testing.T) {
// still pending. // still pending.
expected: pendingInvoices[len(pendingInvoices)-15:], expected: pendingInvoices[len(pendingInvoices)-15:],
}, },
// Fetch all invoices paginating backwards, with an index offset
// that is beyond our last offset. We expect all invoices to be
// returned.
{
query: InvoiceQuery{
IndexOffset: numInvoices * 2,
PendingOnly: false,
Reversed: true,
NumMaxInvoices: numInvoices,
},
expected: invoices,
},
} }
for i, testCase := range testCases { for i, testCase := range testCases {

@ -839,85 +839,47 @@ func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) {
if invoices == nil { if invoices == nil {
return ErrNoInvoicesCreated return ErrNoInvoicesCreated
} }
// Get the add index bucket which we will use to iterate through
// our indexed invoices.
invoiceAddIndex := invoices.NestedReadBucket(addIndexBucket) invoiceAddIndex := invoices.NestedReadBucket(addIndexBucket)
if invoiceAddIndex == nil { if invoiceAddIndex == nil {
return ErrNoInvoicesCreated return ErrNoInvoicesCreated
} }
// keyForIndex is a helper closure that retrieves the invoice // Create a paginator which reads from our add index bucket with
// key for the given add index of an invoice. // the parameters provided by the invoice query.
keyForIndex := func(c kvdb.RCursor, index uint64) []byte { paginator := newPaginator(
var keyIndex [8]byte invoiceAddIndex.ReadCursor(), q.Reversed, q.IndexOffset,
byteOrder.PutUint64(keyIndex[:], index) q.NumMaxInvoices,
_, invoiceKey := c.Seek(keyIndex[:]) )
return invoiceKey
}
// nextKey is a helper closure to determine what the next // accumulateInvoices looks up an invoice based on the index we
// invoice key is when iterating over the invoice add index. // are given, adds it to our set of invoices if it has the right
nextKey := func(c kvdb.RCursor) ([]byte, []byte) { // characteristics for our query and returns the number of items
if q.Reversed { // we have added to our set of invoices.
return c.Prev() accumulateInvoices := func(_, indexValue []byte) (bool, error) {
} invoice, err := fetchInvoice(indexValue, invoices)
return c.Next()
}
// We'll be using a cursor to seek into the database and return
// a slice of invoices. We'll need to determine where to start
// our cursor depending on the parameters set within the query.
c := invoiceAddIndex.ReadCursor()
invoiceKey := keyForIndex(c, q.IndexOffset+1)
// If the query is specifying reverse iteration, then we must
// handle a few offset cases.
if q.Reversed {
switch q.IndexOffset {
// This indicates the default case, where no offset was
// specified. In that case we just start from the last
// invoice.
case 0:
_, invoiceKey = c.Last()
// This indicates the offset being set to the very
// first invoice. Since there are no invoices before
// this offset, and the direction is reversed, we can
// return without adding any invoices to the response.
case 1:
return nil
// Otherwise we start iteration at the invoice prior to
// the offset.
default:
invoiceKey = keyForIndex(c, q.IndexOffset-1)
}
}
// If we know that a set of invoices exists, then we'll begin
// our seek through the bucket in order to satisfy the query.
// We'll continue until either we reach the end of the range, or
// reach our max number of invoices.
for ; invoiceKey != nil; _, invoiceKey = nextKey(c) {
// If our current return payload exceeds the max number
// of invoices, then we'll exit now.
if uint64(len(resp.Invoices)) >= q.NumMaxInvoices {
break
}
invoice, err := fetchInvoice(invoiceKey, invoices)
if err != nil { if err != nil {
return err return false, err
} }
// Skip any settled or canceled invoices if the caller is // Skip any settled or canceled invoices if the caller
// only interested in pending ones. // is only interested in pending ones.
if q.PendingOnly && !invoice.IsPending() { if q.PendingOnly && !invoice.IsPending() {
continue return false, nil
} }
// At this point, we've exhausted the offset, so we'll // At this point, we've exhausted the offset, so we'll
// begin collecting invoices found within the range. // begin collecting invoices found within the range.
resp.Invoices = append(resp.Invoices, invoice) resp.Invoices = append(resp.Invoices, invoice)
return true, nil
}
// Query our paginator using accumulateInvoices to build up a
// set of invoices.
if err := paginator.query(accumulateInvoices); err != nil {
return err
} }
// If we iterated through the add index in reverse order, then // If we iterated through the add index in reverse order, then

@ -6,6 +6,7 @@ import (
mig "github.com/lightningnetwork/lnd/channeldb/migration" mig "github.com/lightningnetwork/lnd/channeldb/migration"
"github.com/lightningnetwork/lnd/channeldb/migration12" "github.com/lightningnetwork/lnd/channeldb/migration12"
"github.com/lightningnetwork/lnd/channeldb/migration13" "github.com/lightningnetwork/lnd/channeldb/migration13"
"github.com/lightningnetwork/lnd/channeldb/migration16"
"github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11"
) )
@ -33,4 +34,5 @@ func UseLogger(logger btclog.Logger) {
migration_01_to_11.UseLogger(logger) migration_01_to_11.UseLogger(logger)
migration12.UseLogger(logger) migration12.UseLogger(logger)
migration13.UseLogger(logger) migration13.UseLogger(logger)
migration16.UseLogger(logger)
} }

@ -0,0 +1,14 @@
package migration16
import (
"github.com/btcsuite/btclog"
)
// log is a logger that is initialized as disabled. This means the package will
// not perform any logging by default until a logger is set.
var log = btclog.Disabled
// UseLogger uses a specified Logger to output package logging info.
func UseLogger(logger btclog.Logger) {
log = logger
}

@ -0,0 +1,191 @@
package migration16
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb/kvdb"
)
var (
paymentsRootBucket = []byte("payments-root-bucket")
paymentSequenceKey = []byte("payment-sequence-key")
duplicatePaymentsBucket = []byte("payment-duplicate-bucket")
paymentsIndexBucket = []byte("payments-index-bucket")
byteOrder = binary.BigEndian
)
// paymentIndexType indicates the type of index we have recorded in the payment
// indexes bucket.
type paymentIndexType uint8
// paymentIndexTypeHash is a payment index type which indicates that we have
// created an index of payment sequence number to payment hash.
const paymentIndexTypeHash paymentIndexType = 0
// paymentIndex stores all the information we require to create an index by
// sequence number for a payment.
type paymentIndex struct {
// paymentHash is the hash of the payment, which is its key in the
// payment root bucket.
paymentHash []byte
// sequenceNumbers is the set of sequence numbers associated with this
// payment hash. There will be more than one sequence number in the
// case where duplicate payments are present.
sequenceNumbers [][]byte
}
// MigrateSequenceIndex migrates the payments db to contain a new bucket which
// provides an index from sequence number to payment hash. This is required
// for more efficient sequential lookup of payments, which are keyed by payment
// hash before this migration.
func MigrateSequenceIndex(tx kvdb.RwTx) error {
log.Infof("Migrating payments to add sequence number index")
// Get a list of indices we need to write.
indexList, err := getPaymentIndexList(tx)
if err != nil {
return err
}
// Create the top level bucket that we will use to index payments in.
bucket, err := tx.CreateTopLevelBucket(paymentsIndexBucket)
if err != nil {
return err
}
// Write an index for each of our payments.
for _, index := range indexList {
// Write indexes for each of our sequence numbers.
for _, seqNr := range index.sequenceNumbers {
err := putIndex(bucket, seqNr, index.paymentHash)
if err != nil {
return err
}
}
}
return nil
}
// putIndex performs a sanity check that ensures we are not writing duplicate
// indexes to disk then creates the index provided.
func putIndex(bucket kvdb.RwBucket, sequenceNr, paymentHash []byte) error {
// Add a sanity check that we do not already have an entry with
// this sequence number.
existingEntry := bucket.Get(sequenceNr)
if existingEntry != nil {
return fmt.Errorf("sequence number: %x duplicated",
sequenceNr)
}
bytes, err := serializePaymentIndexEntry(paymentHash)
if err != nil {
return err
}
return bucket.Put(sequenceNr, bytes)
}
// serializePaymentIndexEntry serializes a payment hash typed index. The value
// produced contains a payment index type (which can be used in future to
// signal different payment index types) and the payment hash.
func serializePaymentIndexEntry(hash []byte) ([]byte, error) {
var b bytes.Buffer
err := binary.Write(&b, byteOrder, paymentIndexTypeHash)
if err != nil {
return nil, err
}
if err := wire.WriteVarBytes(&b, 0, hash); err != nil {
return nil, err
}
return b.Bytes(), nil
}
// getPaymentIndexList gets a list of indices we need to write for our current
// set of payments.
func getPaymentIndexList(tx kvdb.RTx) ([]paymentIndex, error) {
// Iterate over all payments and store their indexing keys. This is
// needed, because no modifications are allowed inside a Bucket.ForEach
// loop.
paymentsBucket := tx.ReadBucket(paymentsRootBucket)
if paymentsBucket == nil {
return nil, nil
}
var indexList []paymentIndex
err := paymentsBucket.ForEach(func(k, v []byte) error {
// Get the bucket which contains the payment, fail if the key
// does not have a bucket.
bucket := paymentsBucket.NestedReadBucket(k)
if bucket == nil {
return fmt.Errorf("non bucket element in " +
"payments bucket")
}
seqBytes := bucket.Get(paymentSequenceKey)
if seqBytes == nil {
return fmt.Errorf("nil sequence number bytes")
}
seqNrs, err := fetchSequenceNumbers(bucket)
if err != nil {
return err
}
// Create an index object with our payment hash and sequence
// numbers and append it to our set of indexes.
index := paymentIndex{
paymentHash: k,
sequenceNumbers: seqNrs,
}
indexList = append(indexList, index)
return nil
})
if err != nil {
return nil, err
}
return indexList, nil
}
// fetchSequenceNumbers fetches all the sequence numbers associated with a
// payment, including those belonging to any duplicate payments.
func fetchSequenceNumbers(paymentBucket kvdb.RBucket) ([][]byte, error) {
seqNum := paymentBucket.Get(paymentSequenceKey)
if seqNum == nil {
return nil, errors.New("expected sequence number")
}
sequenceNumbers := [][]byte{seqNum}
// Get the duplicate payments bucket, if it has no duplicates, just
// return early with the payment sequence number.
duplicates := paymentBucket.NestedReadBucket(duplicatePaymentsBucket)
if duplicates == nil {
return sequenceNumbers, nil
}
// If we do have duplicated, they are keyed by sequence number, so we
// iterate through the duplicates bucket and add them to our set of
// sequence numbers.
if err := duplicates.ForEach(func(k, v []byte) error {
sequenceNumbers = append(sequenceNumbers, k)
return nil
}); err != nil {
return nil, err
}
return sequenceNumbers, nil
}

@ -0,0 +1,144 @@
package migration16
import (
"encoding/hex"
"testing"
"github.com/lightningnetwork/lnd/channeldb/kvdb"
"github.com/lightningnetwork/lnd/channeldb/migtest"
)
var (
hexStr = migtest.Hex
hash1Str = "02acee76ebd53d00824410cf6adecad4f50334dac702bd5a2d3ba01b91709f0e"
hash1 = hexStr(hash1Str)
paymentID1 = hexStr("0000000000000001")
hash2Str = "62eb3f0a48f954e495d0c14ac63df04a67cefa59dafdbcd3d5046d1f5647840c"
hash2 = hexStr(hash2Str)
paymentID2 = hexStr("0000000000000002")
paymentID3 = hexStr("0000000000000003")
// pre is the data in the payments root bucket in database version 13 format.
pre = map[string]interface{}{
// A payment without duplicates.
hash1: map[string]interface{}{
"payment-sequence-key": paymentID1,
},
// A payment with a duplicate.
hash2: map[string]interface{}{
"payment-sequence-key": paymentID2,
"payment-duplicate-bucket": map[string]interface{}{
paymentID3: map[string]interface{}{
"payment-sequence-key": paymentID3,
},
},
},
}
preFails = map[string]interface{}{
// A payment without duplicates.
hash1: map[string]interface{}{
"payment-sequence-key": paymentID1,
"payment-duplicate-bucket": map[string]interface{}{
paymentID1: map[string]interface{}{
"payment-sequence-key": paymentID1,
},
},
},
}
// post is the expected data after migration.
post = map[string]interface{}{
paymentID1: paymentHashIndex(hash1Str),
paymentID2: paymentHashIndex(hash2Str),
paymentID3: paymentHashIndex(hash2Str),
}
)
// paymentHashIndex produces a string that represents the value we expect for
// our payment indexes from a hex encoded payment hash string.
func paymentHashIndex(hashStr string) string {
hash, err := hex.DecodeString(hashStr)
if err != nil {
panic(err)
}
bytes, err := serializePaymentIndexEntry(hash)
if err != nil {
panic(err)
}
return string(bytes)
}
// MigrateSequenceIndex asserts that the database is properly migrated to
// contain a payments index.
func TestMigrateSequenceIndex(t *testing.T) {
tests := []struct {
name string
shouldFail bool
pre map[string]interface{}
post map[string]interface{}
}{
{
name: "migration ok",
shouldFail: false,
pre: pre,
post: post,
},
{
name: "duplicate sequence number",
shouldFail: true,
pre: preFails,
post: post,
},
{
name: "no payments",
shouldFail: false,
pre: nil,
post: nil,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
// Before the migration we have a payments bucket.
before := func(tx kvdb.RwTx) error {
return migtest.RestoreDB(
tx, paymentsRootBucket, test.pre,
)
}
// After the migration, we should have an untouched
// payments bucket and a new index bucket.
after := func(tx kvdb.RwTx) error {
if err := migtest.VerifyDB(
tx, paymentsRootBucket, test.pre,
); err != nil {
return err
}
// If we expect our migration to fail, we don't
// expect an index bucket.
if test.shouldFail {
return nil
}
return migtest.VerifyDB(
tx, paymentsIndexBucket, test.post,
)
}
migtest.ApplyMigration(
t, before, after, MigrateSequenceIndex,
test.shouldFail,
)
})
}
}

140
channeldb/paginate.go Normal file

@ -0,0 +1,140 @@
package channeldb
import "github.com/lightningnetwork/lnd/channeldb/kvdb"
type paginator struct {
// cursor is the cursor which we are using to iterate through a bucket.
cursor kvdb.RCursor
// reversed indicates whether we are paginating forwards or backwards.
reversed bool
// indexOffset is the index from which we will begin querying.
indexOffset uint64
// totalItems is the total number of items we allow in our response.
totalItems uint64
}
// newPaginator returns a struct which can be used to query an indexed bucket
// in pages.
func newPaginator(c kvdb.RCursor, reversed bool,
indexOffset, totalItems uint64) paginator {
return paginator{
cursor: c,
reversed: reversed,
indexOffset: indexOffset,
totalItems: totalItems,
}
}
// keyValueForIndex seeks our cursor to a given index and returns the key and
// value at that position.
func (p paginator) keyValueForIndex(index uint64) ([]byte, []byte) {
var keyIndex [8]byte
byteOrder.PutUint64(keyIndex[:], index)
return p.cursor.Seek(keyIndex[:])
}
// lastIndex returns the last value in our index, if our index is empty it
// returns 0.
func (p paginator) lastIndex() uint64 {
keyIndex, _ := p.cursor.Last()
if keyIndex == nil {
return 0
}
return byteOrder.Uint64(keyIndex)
}
// nextKey is a helper closure to determine what key we should use next when
// we are iterating, depending on whether we are iterating forwards or in
// reverse.
func (p paginator) nextKey() ([]byte, []byte) {
if p.reversed {
return p.cursor.Prev()
}
return p.cursor.Next()
}
// cursorStart gets the index key and value for the first item we are looking
// up, taking into account that we may be paginating in reverse. The index
// offset provided is *excusive* so we will start with the item after the offset
// for forwards queries, and the item before the index for backwards queries.
func (p paginator) cursorStart() ([]byte, []byte) {
indexKey, indexValue := p.keyValueForIndex(p.indexOffset + 1)
// If the query is specifying reverse iteration, then we must
// handle a few offset cases.
if p.reversed {
switch {
// This indicates the default case, where no offset was
// specified. In that case we just start from the last
// entry.
case p.indexOffset == 0:
indexKey, indexValue = p.cursor.Last()
// This indicates the offset being set to the very
// first entry. Since there are no entries before
// this offset, and the direction is reversed, we can
// return without adding any invoices to the response.
case p.indexOffset == 1:
return nil, nil
// If we have been given an index offset that is beyond our last
// index value, we just return the last indexed value in our set
// since we are querying in reverse. We do not cover the case
// where our index offset equals our last index value, because
// index offset is exclusive, so we would want to start at the
// value before our last index.
case p.indexOffset > p.lastIndex():
return p.cursor.Last()
// Otherwise we have an index offset which is within our set of
// indexed keys, and we want to start at the item before our
// offset. We seek to our index offset, then return the element
// before it. We do this rather than p.indexOffset-1 to account
// for indexes that have gaps.
default:
p.keyValueForIndex(p.indexOffset)
indexKey, indexValue = p.cursor.Prev()
}
}
return indexKey, indexValue
}
// query gets the start point for our index offset and iterates through keys
// in our index until we reach the total number of items required for the query
// or we run out of cursor values. This function takes a fetchAndAppend function
// which is responsible for looking up the entry at that index, adding the entry
// to its set of return items (if desired) and return a boolean which indicates
// whether the item was added. This is required to allow the paginator to
// determine when the response has the maximum number of required items.
func (p paginator) query(fetchAndAppend func(k, v []byte) (bool, error)) error {
indexKey, indexValue := p.cursorStart()
var totalItems int
for ; indexKey != nil; indexKey, indexValue = p.nextKey() {
// If our current return payload exceeds the max number
// of invoices, then we'll exit now.
if uint64(totalItems) >= p.totalItems {
break
}
added, err := fetchAndAppend(indexKey, indexValue)
if err != nil {
return err
}
// If we added an item to our set in the latest fetch and append
// we increment our total count.
if added {
totalItems++
}
}
return nil
}

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io"
"github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/channeldb/kvdb"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
@ -74,6 +75,11 @@ var (
// errNoAttemptInfo is returned when no attempt info is stored yet. // errNoAttemptInfo is returned when no attempt info is stored yet.
errNoAttemptInfo = errors.New("unable to find attempt info for " + errNoAttemptInfo = errors.New("unable to find attempt info for " +
"inflight payment") "inflight payment")
// errNoSequenceNrIndex is returned when an attempt to lookup a payment
// index is made for a sequence number that is not indexed.
errNoSequenceNrIndex = errors.New("payment sequence number index " +
"does not exist")
) )
// PaymentControl implements persistence for payments and payment attempts. // PaymentControl implements persistence for payments and payment attempts.
@ -152,6 +158,27 @@ func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash,
return err return err
} }
// Before we set our new sequence number, we check whether this
// payment has a previously set sequence number and remove its
// index entry if it exists. This happens in the case where we
// have a previously attempted payment which was left in a state
// where we can retry.
seqBytes := bucket.Get(paymentSequenceKey)
if seqBytes != nil {
indexBucket := tx.ReadWriteBucket(paymentsIndexBucket)
if err := indexBucket.Delete(seqBytes); err != nil {
return err
}
}
// Once we have obtained a sequence number, we add an entry
// to our index bucket which will map the sequence number to
// our payment hash.
err = createPaymentIndexEntry(tx, sequenceNum, info.PaymentHash)
if err != nil {
return err
}
err = bucket.Put(paymentSequenceKey, sequenceNum) err = bucket.Put(paymentSequenceKey, sequenceNum)
if err != nil { if err != nil {
return err return err
@ -183,6 +210,58 @@ func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash,
return updateErr return updateErr
} }
// paymentIndexTypeHash is a payment index type which indicates that we have
// created an index of payment sequence number to payment hash.
type paymentIndexType uint8
// paymentIndexTypeHash is a payment index type which indicates that we have
// created an index of payment sequence number to payment hash.
const paymentIndexTypeHash paymentIndexType = 0
// createPaymentIndexEntry creates a payment hash typed index for a payment. The
// index produced contains a payment index type (which can be used in future to
// signal different payment index types) and the payment hash.
func createPaymentIndexEntry(tx kvdb.RwTx, sequenceNumber []byte,
hash lntypes.Hash) error {
var b bytes.Buffer
if err := WriteElements(&b, paymentIndexTypeHash, hash[:]); err != nil {
return err
}
indexes := tx.ReadWriteBucket(paymentsIndexBucket)
return indexes.Put(sequenceNumber, b.Bytes())
}
// deserializePaymentIndex deserializes a payment index entry. This function
// currently only supports deserialization of payment hash indexes, and will
// fail for other types.
func deserializePaymentIndex(r io.Reader) (lntypes.Hash, error) {
var (
indexType paymentIndexType
paymentHash []byte
)
if err := ReadElements(r, &indexType, &paymentHash); err != nil {
return lntypes.Hash{}, err
}
// While we only have on payment index type, we do not need to use our
// index type to deserialize the index. However, we sanity check that
// this type is as expected, since we had to read it out anyway.
if indexType != paymentIndexTypeHash {
return lntypes.Hash{}, fmt.Errorf("unknown payment index "+
"type: %v", indexType)
}
hash, err := lntypes.MakeHash(paymentHash)
if err != nil {
return lntypes.Hash{}, err
}
return hash, nil
}
// RegisterAttempt atomically records the provided HTLCAttemptInfo to the // RegisterAttempt atomically records the provided HTLCAttemptInfo to the
// DB. // DB.
func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash, func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash,

@ -1,6 +1,7 @@
package channeldb package channeldb
import ( import (
"bytes"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"fmt" "fmt"
@ -9,9 +10,13 @@ import (
"testing" "testing"
"time" "time"
"github.com/btcsuite/btcwallet/walletdb"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb/kvdb"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/record"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func genPreimage() ([32]byte, error) { func genPreimage() ([32]byte, error) {
@ -70,6 +75,7 @@ func TestPaymentControlSwitchFail(t *testing.T) {
t.Fatalf("unable to send htlc message: %v", err) t.Fatalf("unable to send htlc message: %v", err)
} }
assertPaymentIndex(t, pControl, info.PaymentHash)
assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight)
assertPaymentInfo( assertPaymentInfo(
t, pControl, info.PaymentHash, info, nil, nil, t, pControl, info.PaymentHash, info, nil, nil,
@ -88,6 +94,11 @@ func TestPaymentControlSwitchFail(t *testing.T) {
t, pControl, info.PaymentHash, info, &failReason, nil, t, pControl, info.PaymentHash, info, &failReason, nil,
) )
// Lookup the payment so we can get its old sequence number before it is
// overwritten.
payment, err := pControl.FetchPayment(info.PaymentHash)
assert.NoError(t, err)
// Sends the htlc again, which should succeed since the prior payment // Sends the htlc again, which should succeed since the prior payment
// failed. // failed.
err = pControl.InitPayment(info.PaymentHash, info) err = pControl.InitPayment(info.PaymentHash, info)
@ -95,6 +106,11 @@ func TestPaymentControlSwitchFail(t *testing.T) {
t.Fatalf("unable to send htlc message: %v", err) t.Fatalf("unable to send htlc message: %v", err)
} }
// Check that our index has been updated, and the old index has been
// removed.
assertPaymentIndex(t, pControl, info.PaymentHash)
assertNoIndex(t, pControl, payment.SequenceNum)
assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight)
assertPaymentInfo( assertPaymentInfo(
t, pControl, info.PaymentHash, info, nil, nil, t, pControl, info.PaymentHash, info, nil, nil,
@ -145,7 +161,6 @@ func TestPaymentControlSwitchFail(t *testing.T) {
// Settle the attempt and verify that status was changed to // Settle the attempt and verify that status was changed to
// StatusSucceeded. // StatusSucceeded.
var payment *MPPayment
payment, err = pControl.SettleAttempt( payment, err = pControl.SettleAttempt(
info.PaymentHash, attempt.AttemptID, info.PaymentHash, attempt.AttemptID,
&HTLCSettleInfo{ &HTLCSettleInfo{
@ -209,6 +224,7 @@ func TestPaymentControlSwitchDoubleSend(t *testing.T) {
t.Fatalf("unable to send htlc message: %v", err) t.Fatalf("unable to send htlc message: %v", err)
} }
assertPaymentIndex(t, pControl, info.PaymentHash)
assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight)
assertPaymentInfo( assertPaymentInfo(
t, pControl, info.PaymentHash, info, nil, nil, t, pControl, info.PaymentHash, info, nil, nil,
@ -326,7 +342,7 @@ func TestPaymentControlFailsWithoutInFlight(t *testing.T) {
assertPaymentStatus(t, pControl, info.PaymentHash, StatusUnknown) assertPaymentStatus(t, pControl, info.PaymentHash, StatusUnknown)
} }
// TestPaymentControlDeleteNonInFlight checks that calling DeletaPayments only // TestPaymentControlDeleteNonInFlight checks that calling DeletePayments only
// deletes payments from the database that are not in-flight. // deletes payments from the database that are not in-flight.
func TestPaymentControlDeleteNonInFligt(t *testing.T) { func TestPaymentControlDeleteNonInFligt(t *testing.T) {
t.Parallel() t.Parallel()
@ -338,23 +354,37 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) {
t.Fatalf("unable to init db: %v", err) t.Fatalf("unable to init db: %v", err)
} }
// Create a sequence number for duplicate payments that will not collide
// with the sequence numbers for the payments we create. These values
// start at 1, so 9999 is a safe bet for this test.
var duplicateSeqNr = 9999
pControl := NewPaymentControl(db) pControl := NewPaymentControl(db)
payments := []struct { payments := []struct {
failed bool failed bool
success bool success bool
hasDuplicate bool
}{ }{
{ {
failed: true, failed: true,
success: false, success: false,
hasDuplicate: false,
}, },
{ {
failed: false, failed: false,
success: true, success: true,
hasDuplicate: false,
}, },
{ {
failed: false, failed: false,
success: false, success: false,
hasDuplicate: false,
},
{
failed: false,
success: true,
hasDuplicate: true,
}, },
} }
@ -430,6 +460,16 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) {
t, pControl, info.PaymentHash, info, nil, htlc, t, pControl, info.PaymentHash, info, nil, htlc,
) )
} }
// If the payment is intended to have a duplicate payment, we
// add one.
if p.hasDuplicate {
appendDuplicatePayment(
t, pControl.db, info.PaymentHash,
uint64(duplicateSeqNr),
)
duplicateSeqNr++
}
} }
// Delete payments. // Delete payments.
@ -451,6 +491,21 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) {
if status != StatusInFlight { if status != StatusInFlight {
t.Fatalf("expected in-fligth status, got %v", status) t.Fatalf("expected in-fligth status, got %v", status)
} }
// Finally, check that we only have a single index left in the payment
// index bucket.
var indexCount int
err = kvdb.View(db, func(tx walletdb.ReadTx) error {
index := tx.ReadBucket(paymentsIndexBucket)
return index.ForEach(func(k, v []byte) error {
indexCount++
return nil
})
})
require.NoError(t, err)
require.Equal(t, 1, indexCount)
} }
// TestPaymentControlMultiShard checks the ability of payment control to // TestPaymentControlMultiShard checks the ability of payment control to
@ -495,6 +550,7 @@ func TestPaymentControlMultiShard(t *testing.T) {
t.Fatalf("unable to send htlc message: %v", err) t.Fatalf("unable to send htlc message: %v", err)
} }
assertPaymentIndex(t, pControl, info.PaymentHash)
assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight)
assertPaymentInfo( assertPaymentInfo(
t, pControl, info.PaymentHash, info, nil, nil, t, pControl, info.PaymentHash, info, nil, nil,
@ -910,3 +966,55 @@ func assertPaymentInfo(t *testing.T, p *PaymentControl, hash lntypes.Hash,
t.Fatal("expected no settle info") t.Fatal("expected no settle info")
} }
} }
// fetchPaymentIndexEntry gets the payment hash for the sequence number provided
// from our payment indexes bucket.
func fetchPaymentIndexEntry(_ *testing.T, p *PaymentControl,
sequenceNumber uint64) (*lntypes.Hash, error) {
var hash lntypes.Hash
if err := kvdb.View(p.db, func(tx walletdb.ReadTx) error {
indexBucket := tx.ReadBucket(paymentsIndexBucket)
key := make([]byte, 8)
byteOrder.PutUint64(key, sequenceNumber)
indexValue := indexBucket.Get(key)
if indexValue == nil {
return errNoSequenceNrIndex
}
r := bytes.NewReader(indexValue)
var err error
hash, err = deserializePaymentIndex(r)
return err
}); err != nil {
return nil, err
}
return &hash, nil
}
// assertPaymentIndex looks up the index for a payment in the db and checks
// that its payment hash matches the expected hash passed in.
func assertPaymentIndex(t *testing.T, p *PaymentControl,
expectedHash lntypes.Hash) {
// Lookup the payment so that we have its sequence number and check
// that is has correctly been indexed in the payment indexes bucket.
pmt, err := p.FetchPayment(expectedHash)
require.NoError(t, err)
hash, err := fetchPaymentIndexEntry(t, p, pmt.SequenceNum)
require.NoError(t, err)
assert.Equal(t, expectedHash, *hash)
}
// assertNoIndex checks that an index for the sequence number provided does not
// exist.
func assertNoIndex(t *testing.T, p *PaymentControl, seqNr uint64) {
_, err := fetchPaymentIndexEntry(t, p, seqNr)
require.Equal(t, errNoSequenceNrIndex, err)
}

@ -3,9 +3,9 @@ package channeldb
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"io" "io"
"math"
"sort" "sort"
"time" "time"
@ -92,6 +92,35 @@ var (
// paymentFailInfoKey is a key used in the payment's sub-bucket to // paymentFailInfoKey is a key used in the payment's sub-bucket to
// store information about the reason a payment failed. // store information about the reason a payment failed.
paymentFailInfoKey = []byte("payment-fail-info") paymentFailInfoKey = []byte("payment-fail-info")
// paymentsIndexBucket is the name of the top-level bucket within the
// database that stores an index of payment sequence numbers to its
// payment hash.
// payments-sequence-index-bucket
// |--<sequence-number>: <payment hash>
// |--...
// |--<sequence-number>: <payment hash>
paymentsIndexBucket = []byte("payments-index-bucket")
)
var (
// ErrNoSequenceNumber is returned if we lookup a payment which does
// not have a sequence number.
ErrNoSequenceNumber = errors.New("sequence number not found")
// ErrDuplicateNotFound is returned when we lookup a payment by its
// index and cannot find a payment with a matching sequence number.
ErrDuplicateNotFound = errors.New("duplicate payment not found")
// ErrNoDuplicateBucket is returned when we expect to find duplicates
// when looking up a payment from its index, but the payment does not
// have any.
ErrNoDuplicateBucket = errors.New("expected duplicate bucket")
// ErrNoDuplicateNestedBucket is returned if we do not find duplicate
// payments in their own sub-bucket.
ErrNoDuplicateNestedBucket = errors.New("nested duplicate bucket not " +
"found")
) )
// FailureReason encodes the reason a payment ultimately failed. // FailureReason encodes the reason a payment ultimately failed.
@ -481,62 +510,70 @@ type PaymentsResponse struct {
func (db *DB) QueryPayments(query PaymentsQuery) (PaymentsResponse, error) { func (db *DB) QueryPayments(query PaymentsQuery) (PaymentsResponse, error) {
var resp PaymentsResponse var resp PaymentsResponse
allPayments, err := db.FetchPayments() if err := kvdb.View(db, func(tx kvdb.RTx) error {
// Get the root payments bucket.
paymentsBucket := tx.ReadBucket(paymentsRootBucket)
if paymentsBucket == nil {
return nil
}
// Get the index bucket which maps sequence number -> payment
// hash and duplicate bool. If we have a payments bucket, we
// should have an indexes bucket as well.
indexes := tx.ReadBucket(paymentsIndexBucket)
if indexes == nil {
return fmt.Errorf("index bucket does not exist")
}
// accumulatePayments gets payments with the sequence number
// and hash provided and adds them to our list of payments if
// they meet the criteria of our query. It returns the number
// of payments that were added.
accumulatePayments := func(sequenceKey, hash []byte) (bool,
error) {
r := bytes.NewReader(hash)
paymentHash, err := deserializePaymentIndex(r)
if err != nil { if err != nil {
return resp, err return false, err
} }
if len(allPayments) == 0 { payment, err := fetchPaymentWithSequenceNumber(
return resp, nil tx, paymentHash, sequenceKey,
)
if err != nil {
return false, err
} }
indexExclusiveLimit := query.IndexOffset // To keep compatibility with the old API, we only
// In backward pagination, if the index limit is the default 0 value, // return non-succeeded payments if requested.
// we set our limit to maxint to include all payments from the highest
// sequence number on.
if query.Reversed && indexExclusiveLimit == 0 {
indexExclusiveLimit = math.MaxInt64
}
for i := range allPayments {
var payment *MPPayment
// If we have the max number of payments we want, exit.
if uint64(len(resp.Payments)) == query.MaxPayments {
break
}
if query.Reversed {
payment = allPayments[len(allPayments)-1-i]
// In the reversed direction, skip over all payments
// that have sequence numbers greater than or equal to
// the index offset. We skip payments with equal index
// because the offset is exclusive.
if payment.SequenceNum >= indexExclusiveLimit {
continue
}
} else {
payment = allPayments[i]
// In the forward direction, skip over all payments that
// have sequence numbers less than or equal to the index
// offset. We skip payments with equal indexes because
// the index offset is exclusive.
if payment.SequenceNum <= indexExclusiveLimit {
continue
}
}
// To keep compatibility with the old API, we only return
// non-succeeded payments if requested.
if payment.Status != StatusSucceeded && if payment.Status != StatusSucceeded &&
!query.IncludeIncomplete { !query.IncludeIncomplete {
continue return false, err
} }
// At this point, we've exhausted the offset, so we'll
// begin collecting invoices found within the range.
resp.Payments = append(resp.Payments, payment) resp.Payments = append(resp.Payments, payment)
return true, nil
}
// Create a paginator which reads from our sequence index bucket
// with the parameters provided by the payments query.
paginator := newPaginator(
indexes.ReadCursor(), query.Reversed, query.IndexOffset,
query.MaxPayments,
)
// Run a paginated query, adding payments to our response.
if err := paginator.query(accumulatePayments); err != nil {
return err
}
return nil
}); err != nil {
return resp, err
} }
// Need to swap the payments slice order if reversed order. // Need to swap the payments slice order if reversed order.
@ -555,7 +592,84 @@ func (db *DB) QueryPayments(query PaymentsQuery) (PaymentsResponse, error) {
resp.Payments[len(resp.Payments)-1].SequenceNum resp.Payments[len(resp.Payments)-1].SequenceNum
} }
return resp, err return resp, nil
}
// fetchPaymentWithSequenceNumber get the payment which matches the payment hash
// *and* sequence number provided from the database. This is required because
// we previously had more than one payment per hash, so we have multiple indexes
// pointing to a single payment; we want to retrieve the correct one.
func fetchPaymentWithSequenceNumber(tx kvdb.RTx, paymentHash lntypes.Hash,
sequenceNumber []byte) (*MPPayment, error) {
// We can now lookup the payment keyed by its hash in
// the payments root bucket.
bucket, err := fetchPaymentBucket(tx, paymentHash)
if err != nil {
return nil, err
}
// A single payment hash can have multiple payments associated with it.
// We lookup our sequence number first, to determine whether this is
// the payment we are actually looking for.
seqBytes := bucket.Get(paymentSequenceKey)
if seqBytes == nil {
return nil, ErrNoSequenceNumber
}
// If this top level payment has the sequence number we are looking for,
// return it.
if bytes.Equal(seqBytes, sequenceNumber) {
return fetchPayment(bucket)
}
// If we were not looking for the top level payment, we are looking for
// one of our duplicate payments. We need to iterate through the seq
// numbers in this bucket to find the correct payments. If we do not
// find a duplicate payments bucket here, something is wrong.
dup := bucket.NestedReadBucket(duplicatePaymentsBucket)
if dup == nil {
return nil, ErrNoDuplicateBucket
}
var duplicatePayment *MPPayment
err = dup.ForEach(func(k, v []byte) error {
subBucket := dup.NestedReadBucket(k)
if subBucket == nil {
// We one bucket for each duplicate to be found.
return ErrNoDuplicateNestedBucket
}
seqBytes := subBucket.Get(duplicatePaymentSequenceKey)
if seqBytes == nil {
return err
}
// If this duplicate payment is not the sequence number we are
// looking for, we can continue.
if !bytes.Equal(seqBytes, sequenceNumber) {
return nil
}
duplicatePayment, err = fetchDuplicatePayment(subBucket)
if err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
// If none of the duplicate payments matched our sequence number, we
// failed to find the payment with this sequence number; something is
// wrong.
if duplicatePayment == nil {
return nil, ErrDuplicateNotFound
}
return duplicatePayment, nil
} }
// DeletePayments deletes all completed and failed payments from the DB. // DeletePayments deletes all completed and failed payments from the DB.
@ -566,7 +680,15 @@ func (db *DB) DeletePayments() error {
return nil return nil
} }
var deleteBuckets [][]byte var (
// deleteBuckets is the set of payment buckets we need
// to delete.
deleteBuckets [][]byte
// deleteIndexes is the set of indexes pointing to these
// payments that need to be deleted.
deleteIndexes [][]byte
)
err := payments.ForEach(func(k, _ []byte) error { err := payments.ForEach(func(k, _ []byte) error {
bucket := payments.NestedReadWriteBucket(k) bucket := payments.NestedReadWriteBucket(k)
if bucket == nil { if bucket == nil {
@ -589,7 +711,18 @@ func (db *DB) DeletePayments() error {
return nil return nil
} }
// Add the bucket to the set of buckets we can delete.
deleteBuckets = append(deleteBuckets, k) deleteBuckets = append(deleteBuckets, k)
// Get all the sequence number associated with the
// payment, including duplicates.
seqNrs, err := fetchSequenceNumbers(bucket)
if err != nil {
return err
}
deleteIndexes = append(deleteIndexes, seqNrs...)
return nil return nil
}) })
if err != nil { if err != nil {
@ -602,10 +735,49 @@ func (db *DB) DeletePayments() error {
} }
} }
// Get our index bucket and delete all indexes pointing to the
// payments we are deleting.
indexBucket := tx.ReadWriteBucket(paymentsIndexBucket)
for _, k := range deleteIndexes {
if err := indexBucket.Delete(k); err != nil {
return err
}
}
return nil return nil
}) })
} }
// fetchSequenceNumbers fetches all the sequence numbers associated with a
// payment, including those belonging to any duplicate payments.
func fetchSequenceNumbers(paymentBucket kvdb.RBucket) ([][]byte, error) {
seqNum := paymentBucket.Get(paymentSequenceKey)
if seqNum == nil {
return nil, errors.New("expected sequence number")
}
sequenceNumbers := [][]byte{seqNum}
// Get the duplicate payments bucket, if it has no duplicates, just
// return early with the payment sequence number.
duplicates := paymentBucket.NestedReadBucket(duplicatePaymentsBucket)
if duplicates == nil {
return sequenceNumbers, nil
}
// If we do have duplicated, they are keyed by sequence number, so we
// iterate through the duplicates bucket and add them to our set of
// sequence numbers.
if err := duplicates.ForEach(func(k, v []byte) error {
sequenceNumbers = append(sequenceNumbers, k)
return nil
}); err != nil {
return nil, err
}
return sequenceNumbers, nil
}
// nolint: dupl // nolint: dupl
func serializePaymentCreationInfo(w io.Writer, c *PaymentCreationInfo) error { func serializePaymentCreationInfo(w io.Writer, c *PaymentCreationInfo) error {
var scratch [8]byte var scratch [8]byte

@ -9,11 +9,13 @@ import (
"time" "time"
"github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcwallet/walletdb"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/channeldb/kvdb"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
"github.com/stretchr/testify/require"
) )
var ( var (
@ -163,18 +165,24 @@ func TestRouteSerialization(t *testing.T) {
} }
// deletePayment removes a payment with paymentHash from the payments database. // deletePayment removes a payment with paymentHash from the payments database.
func deletePayment(t *testing.T, db *DB, paymentHash lntypes.Hash) { func deletePayment(t *testing.T, db *DB, paymentHash lntypes.Hash, seqNr uint64) {
t.Helper() t.Helper()
err := kvdb.Update(db, func(tx kvdb.RwTx) error { err := kvdb.Update(db, func(tx kvdb.RwTx) error {
payments := tx.ReadWriteBucket(paymentsRootBucket) payments := tx.ReadWriteBucket(paymentsRootBucket)
// Delete the payment bucket.
err := payments.DeleteNestedBucket(paymentHash[:]) err := payments.DeleteNestedBucket(paymentHash[:])
if err != nil { if err != nil {
return err return err
} }
return nil key := make([]byte, 8)
byteOrder.PutUint64(key, seqNr)
// Delete the index that references this payment.
indexes := tx.ReadWriteBucket(paymentsIndexBucket)
return indexes.Delete(key)
}) })
if err != nil { if err != nil {
@ -188,6 +196,10 @@ func deletePayment(t *testing.T, db *DB, paymentHash lntypes.Hash) {
func TestQueryPayments(t *testing.T) { func TestQueryPayments(t *testing.T) {
// Define table driven test for QueryPayments. // Define table driven test for QueryPayments.
// Test payments have sequence indices [1, 3, 4, 5, 6, 7]. // Test payments have sequence indices [1, 3, 4, 5, 6, 7].
// Note that the payment with index 7 has the same payment hash as 6,
// and is stored in a nested bucket within payment 6 rather than being
// its own entry in the payments bucket. We do this to test retrieval
// of legacy payments.
tests := []struct { tests := []struct {
name string name string
query PaymentsQuery query PaymentsQuery
@ -344,6 +356,42 @@ func TestQueryPayments(t *testing.T) {
lastIndex: 7, lastIndex: 7,
expectedSeqNrs: []uint64{3, 4, 5, 6, 7}, expectedSeqNrs: []uint64{3, 4, 5, 6, 7},
}, },
{
name: "query payments reverse before index gap",
query: PaymentsQuery{
IndexOffset: 3,
MaxPayments: 7,
Reversed: true,
IncludeIncomplete: true,
},
firstIndex: 1,
lastIndex: 1,
expectedSeqNrs: []uint64{1},
},
{
name: "query payments reverse on index gap",
query: PaymentsQuery{
IndexOffset: 2,
MaxPayments: 7,
Reversed: true,
IncludeIncomplete: true,
},
firstIndex: 1,
lastIndex: 1,
expectedSeqNrs: []uint64{1},
},
{
name: "query payments forward on index gap",
query: PaymentsQuery{
IndexOffset: 2,
MaxPayments: 2,
Reversed: false,
IncludeIncomplete: true,
},
firstIndex: 3,
lastIndex: 4,
expectedSeqNrs: []uint64{3, 4},
},
} }
for _, tt := range tests { for _, tt := range tests {
@ -352,17 +400,28 @@ func TestQueryPayments(t *testing.T) {
t.Parallel() t.Parallel()
db, cleanup, err := makeTestDB() db, cleanup, err := makeTestDB()
defer cleanup()
if err != nil { if err != nil {
t.Fatalf("unable to init db: %v", err) t.Fatalf("unable to init db: %v", err)
} }
defer cleanup()
// Make a preliminary query to make sure it's ok to
// query when we have no payments.
resp, err := db.QueryPayments(tt.query)
require.NoError(t, err)
require.Len(t, resp.Payments, 0)
// Populate the database with a set of test payments. // Populate the database with a set of test payments.
numberOfPayments := 7 // We create 6 original payments, deleting the payment
// at index 2 so that we cover the case where sequence
// numbers are missing. We also add a duplicate payment
// to the last payment added to test the legacy case
// where we have duplicates in the nested duplicates
// bucket.
nonDuplicatePayments := 6
pControl := NewPaymentControl(db) pControl := NewPaymentControl(db)
for i := 0; i < numberOfPayments; i++ { for i := 0; i < nonDuplicatePayments; i++ {
// Generate a test payment. // Generate a test payment.
info, _, _, err := genInfo() info, _, _, err := genInfo()
if err != nil { if err != nil {
@ -379,7 +438,29 @@ func TestQueryPayments(t *testing.T) {
// Immediately delete the payment with index 2. // Immediately delete the payment with index 2.
if i == 1 { if i == 1 {
deletePayment(t, db, info.PaymentHash) pmt, err := pControl.FetchPayment(
info.PaymentHash,
)
require.NoError(t, err)
deletePayment(t, db, info.PaymentHash,
pmt.SequenceNum)
}
// If we are on the last payment entry, add a
// duplicate payment with sequence number equal
// to the parent payment + 1.
if i == (nonDuplicatePayments - 1) {
pmt, err := pControl.FetchPayment(
info.PaymentHash,
)
require.NoError(t, err)
appendDuplicatePayment(
t, pControl.db,
info.PaymentHash,
pmt.SequenceNum+1,
)
} }
} }
@ -424,3 +505,210 @@ func TestQueryPayments(t *testing.T) {
}) })
} }
} }
// TestFetchPaymentWithSequenceNumber tests lookup of payments with their
// sequence number. It sets up one payment with no duplicates, and another with
// two duplicates in its duplicates bucket then uses these payments to test the
// case where a specific duplicate is not found and the duplicates bucket is not
// present when we expect it to be.
func TestFetchPaymentWithSequenceNumber(t *testing.T) {
db, cleanup, err := makeTestDB()
require.NoError(t, err)
defer cleanup()
pControl := NewPaymentControl(db)
// Generate a test payment which does not have duplicates.
noDuplicates, _, _, err := genInfo()
require.NoError(t, err)
// Create a new payment entry in the database.
err = pControl.InitPayment(noDuplicates.PaymentHash, noDuplicates)
require.NoError(t, err)
// Fetch the payment so we can get its sequence nr.
noDuplicatesPayment, err := pControl.FetchPayment(
noDuplicates.PaymentHash,
)
require.NoError(t, err)
// Generate a test payment which we will add duplicates to.
hasDuplicates, _, _, err := genInfo()
require.NoError(t, err)
// Create a new payment entry in the database.
err = pControl.InitPayment(hasDuplicates.PaymentHash, hasDuplicates)
require.NoError(t, err)
// Fetch the payment so we can get its sequence nr.
hasDuplicatesPayment, err := pControl.FetchPayment(
hasDuplicates.PaymentHash,
)
require.NoError(t, err)
// We declare the sequence numbers used here so that we can reference
// them in tests.
var (
duplicateOneSeqNr = hasDuplicatesPayment.SequenceNum + 1
duplicateTwoSeqNr = hasDuplicatesPayment.SequenceNum + 2
)
// Add two duplicates to our second payment.
appendDuplicatePayment(
t, db, hasDuplicates.PaymentHash, duplicateOneSeqNr,
)
appendDuplicatePayment(
t, db, hasDuplicates.PaymentHash, duplicateTwoSeqNr,
)
tests := []struct {
name string
paymentHash lntypes.Hash
sequenceNumber uint64
expectedErr error
}{
{
name: "lookup payment without duplicates",
paymentHash: noDuplicates.PaymentHash,
sequenceNumber: noDuplicatesPayment.SequenceNum,
expectedErr: nil,
},
{
name: "lookup payment with duplicates",
paymentHash: hasDuplicates.PaymentHash,
sequenceNumber: hasDuplicatesPayment.SequenceNum,
expectedErr: nil,
},
{
name: "lookup first duplicate",
paymentHash: hasDuplicates.PaymentHash,
sequenceNumber: duplicateOneSeqNr,
expectedErr: nil,
},
{
name: "lookup second duplicate",
paymentHash: hasDuplicates.PaymentHash,
sequenceNumber: duplicateTwoSeqNr,
expectedErr: nil,
},
{
name: "lookup non-existent duplicate",
paymentHash: hasDuplicates.PaymentHash,
sequenceNumber: 999999,
expectedErr: ErrDuplicateNotFound,
},
{
name: "lookup duplicate, no duplicates bucket",
paymentHash: noDuplicates.PaymentHash,
sequenceNumber: duplicateTwoSeqNr,
expectedErr: ErrNoDuplicateBucket,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
err := kvdb.Update(db,
func(tx walletdb.ReadWriteTx) error {
var seqNrBytes [8]byte
byteOrder.PutUint64(
seqNrBytes[:], test.sequenceNumber,
)
_, err := fetchPaymentWithSequenceNumber(
tx, test.paymentHash, seqNrBytes[:],
)
return err
})
require.Equal(t, test.expectedErr, err)
})
}
}
// appendDuplicatePayment adds a duplicate payment to an existing payment. Note
// that this function requires a unique sequence number.
//
// This code is *only* intended to replicate legacy duplicate payments in lnd,
// our current schema does not allow duplicates.
func appendDuplicatePayment(t *testing.T, db *DB, paymentHash lntypes.Hash,
seqNr uint64) {
err := kvdb.Update(db, func(tx walletdb.ReadWriteTx) error {
bucket, err := fetchPaymentBucketUpdate(
tx, paymentHash,
)
if err != nil {
return err
}
// Create the duplicates bucket if it is not
// present.
dup, err := bucket.CreateBucketIfNotExists(
duplicatePaymentsBucket,
)
if err != nil {
return err
}
var sequenceKey [8]byte
byteOrder.PutUint64(sequenceKey[:], seqNr)
// Create duplicate payments for the two dup
// sequence numbers we've setup.
putDuplicatePayment(t, dup, sequenceKey[:], paymentHash)
// Finally, once we have created our entry we add an index for
// it.
err = createPaymentIndexEntry(tx, sequenceKey[:], paymentHash)
require.NoError(t, err)
return nil
})
if err != nil {
t.Fatalf("could not create payment: %v", err)
}
}
// putDuplicatePayment creates a duplicate payment in the duplicates bucket
// provided with the minimal information required for successful reading.
func putDuplicatePayment(t *testing.T, duplicateBucket kvdb.RwBucket,
sequenceKey []byte, paymentHash lntypes.Hash) {
paymentBucket, err := duplicateBucket.CreateBucketIfNotExists(
sequenceKey,
)
require.NoError(t, err)
err = paymentBucket.Put(duplicatePaymentSequenceKey, sequenceKey)
require.NoError(t, err)
// Generate fake information for the duplicate payment.
info, _, _, err := genInfo()
require.NoError(t, err)
// Write the payment info to disk under the creation info key. This code
// is copied rather than using serializePaymentCreationInfo to ensure
// we always write in the legacy format used by duplicate payments.
var b bytes.Buffer
var scratch [8]byte
_, err = b.Write(paymentHash[:])
require.NoError(t, err)
byteOrder.PutUint64(scratch[:], uint64(info.Value))
_, err = b.Write(scratch[:])
require.NoError(t, err)
err = serializeTime(&b, info.CreationTime)
require.NoError(t, err)
byteOrder.PutUint32(scratch[:4], 0)
_, err = b.Write(scratch[:4])
require.NoError(t, err)
// Get the PaymentCreationInfo.
err = paymentBucket.Put(duplicatePaymentCreationInfoKey, b.Bytes())
require.NoError(t, err)
}