diff --git a/channeldb/payment_control.go b/channeldb/payment_control.go index c11bc5c9..454e2831 100644 --- a/channeldb/payment_control.go +++ b/channeldb/payment_control.go @@ -6,11 +6,18 @@ import ( "errors" "fmt" "io" + "sync" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntypes" ) +const ( + // paymentSeqBlockSize is the block size used when we batch allocate + // payment sequences for future payments. + paymentSeqBlockSize = 1000 +) + var ( // ErrAlreadyPaid signals we have already paid this payment hash. ErrAlreadyPaid = errors.New("invoice is already paid") @@ -84,7 +91,10 @@ var ( // PaymentControl implements persistence for payments and payment attempts. type PaymentControl struct { - db *DB + paymentSeqMx sync.Mutex + currPaymentSeq uint64 + storedPaymentSeq uint64 + db *DB } // NewPaymentControl creates a new instance of the PaymentControl. @@ -101,6 +111,14 @@ func NewPaymentControl(db *DB) *PaymentControl { func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash, info *PaymentCreationInfo) error { + // Obtain a new sequence number for this payment. This is used + // to sort the payments in order of creation, and also acts as + // a unique identifier for each payment. + sequenceNum, err := p.nextPaymentSequence() + if err != nil { + return err + } + var b bytes.Buffer if err := serializePaymentCreationInfo(&b, info); err != nil { return err @@ -108,7 +126,7 @@ func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash, infoBytes := b.Bytes() var updateErr error - err := kvdb.Batch(p.db.Backend, func(tx kvdb.RwTx) error { + err = kvdb.Batch(p.db.Backend, func(tx kvdb.RwTx) error { // Reset the update error, to avoid carrying over an error // from a previous execution of the batched db transaction. updateErr = nil @@ -150,14 +168,6 @@ func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash, return nil } - // Obtain a new sequence number for this payment. This is used - // to sort the payments in order of creation, and also acts as - // a unique identifier for each payment. - sequenceNum, err := nextPaymentSequence(tx) - if err != nil { - return err - } - // Before we set our new sequence number, we check whether this // payment has a previously set sequence number and remove its // index entry if it exists. This happens in the case where we @@ -615,19 +625,45 @@ func fetchPaymentBucketUpdate(tx kvdb.RwTx, paymentHash lntypes.Hash) ( // nextPaymentSequence returns the next sequence number to store for a new // payment. -func nextPaymentSequence(tx kvdb.RwTx) ([]byte, error) { - payments, err := tx.CreateTopLevelBucket(paymentsRootBucket) - if err != nil { - return nil, err - } +func (p *PaymentControl) nextPaymentSequence() ([]byte, error) { + p.paymentSeqMx.Lock() + defer p.paymentSeqMx.Unlock() + + // Set a new upper bound in the DB every 1000 payments to avoid + // conflicts on the sequence when using etcd. + if p.currPaymentSeq == p.storedPaymentSeq { + var currPaymentSeq, newUpperBound uint64 + if err := kvdb.Update(p.db.Backend, func(tx kvdb.RwTx) error { + paymentsBucket, err := tx.CreateTopLevelBucket( + paymentsRootBucket, + ) + if err != nil { + return err + } - seq, err := payments.NextSequence() - if err != nil { - return nil, err + currPaymentSeq = paymentsBucket.Sequence() + newUpperBound = currPaymentSeq + paymentSeqBlockSize + return paymentsBucket.SetSequence(newUpperBound) + }, func() {}); err != nil { + return nil, err + } + + // We lazy initialize the cached currPaymentSeq here using the + // first nextPaymentSequence() call. This if statement will auto + // initialize our stored currPaymentSeq, since by default both + // this variable and storedPaymentSeq are zero which in turn + // will have us fetch the current values from the DB. + if p.currPaymentSeq == 0 { + p.currPaymentSeq = currPaymentSeq + } + + p.storedPaymentSeq = newUpperBound } + p.currPaymentSeq++ b := make([]byte, 8) - binary.BigEndian.PutUint64(b, seq) + binary.BigEndian.PutUint64(b, p.currPaymentSeq) + return b, nil }