channeldb: modify new payments module to match project code-style

This commit modifies the new payment module within the database to
match the coding style of the rest of the package and the project as a
hole. Additionally, a few fields have been renamed, and the extra
timestamp added to the OutgoingPayment struct has been removed as
there’s already a CreationTime field within the Invoice struct that’s
embedded within the OutgoingPayment struct.
This commit is contained in:
Olaoluwa Osuntokun 2016-12-30 16:32:20 -08:00
parent 276c384455
commit f7510cf1fc
No known key found for this signature in database
GPG Key ID: 9CC5B105D03521A2
2 changed files with 123 additions and 163 deletions

@ -3,117 +3,122 @@ package channeldb
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"github.com/boltdb/bolt"
"github.com/roasbeef/btcd/wire"
"github.com/roasbeef/btcutil"
"io" "io"
"time"
"github.com/boltdb/bolt"
"github.com/roasbeef/btcutil"
) )
var ( var (
// invoiceBucket is the name of the bucket within // paymentBucket is the name of the bucket within the database that
// the database that stores all data related to payments. // stores all data related to payments.
// Within the payments bucket, each invoice is keyed //
// by its invoice ID // Within the payments bucket, each invoice is keyed by its invoice ID
// which is a monotonically increasing uint64. // which is a monotonically increasing uint64. BoltDB's sequence
// BoltDB sequence feature is used for generating // feature is used for generating monotonically increasing id.
// monotonically increasing id.
paymentBucket = []byte("payments") paymentBucket = []byte("payments")
) )
// OutgoingPayment represents payment from given node. // OutgoingPayment represents a successful payment between the daemon and a
// remote node. Details such as the total fee paid, and the time of the payment
// are stored.
type OutgoingPayment struct { type OutgoingPayment struct {
Invoice Invoice
// Total fee paid. // Fee is the total fee paid for the payment in satoshis.
Fee btcutil.Amount Fee btcutil.Amount
// Path including starting and ending nodes. // TotalTimeLock is the total cumulative time-lock in the HTLC extended
Path [][33]byte // from the second-to-last hop to the destination.
// Timelock length.
TimeLockLength uint32 TimeLockLength uint32
// RHash value used for payment. // Path encodes the path the payment took throuhg the network. The path
// We need RHash because we start payment knowing only RHash // excludes the outgoing node and consists of the hex-encoded
RHash [32]byte // compressed public key of each of the nodes involved in the payment.
Path [][33]byte
// Timestamp is time when payment was created. // PaymentHash is the payment hash (r-hash) used to send the payment.
Timestamp time.Time //
// TODO(roasbeef): weave through preimage on payment success to can
// store only supplemental info the embedded Invoice
PaymentHash [32]byte
} }
// AddPayment adds payment to DB. // AddPayment saves a successful payment to the database. It is assumed that
// There is no checking that payment with the same hash already exist. // all payment are sent using unique payment hashes.
func (db *DB) AddPayment(p *OutgoingPayment) error { func (db *DB) AddPayment(payment *OutgoingPayment) error {
err := validateInvoice(&p.Invoice) // Validate the field of the inner voice within the outgoing payment,
if err != nil { // these must also adhere to the same constraints as regular invoices.
if err := validateInvoice(&payment.Invoice); err != nil {
return err return err
} }
// We serialize before writing to database // We first serialize the payment before starting the database
// so no db access in the case of serialization errors // transaction so we can avoid creating a DB payment in the case of a
b := new(bytes.Buffer) // serialization error.
err = serializeOutgoingPayment(b, p) var b bytes.Buffer
if err != nil { if err := serializeOutgoingPayment(&b, payment); err != nil {
return err return err
} }
paymentBytes := b.Bytes() paymentBytes := b.Bytes()
return db.Update(func(tx *bolt.Tx) error { return db.Update(func(tx *bolt.Tx) error {
payments, err := tx.CreateBucketIfNotExists(paymentBucket) payments, err := tx.CreateBucketIfNotExists(paymentBucket)
if err != nil { if err != nil {
return err return err
} }
// Obtain the new unique sequence number for this payment.
paymentId, err := payments.NextSequence() paymentId, err := payments.NextSequence()
if err != nil { if err != nil {
return err return err
} }
// We use BigEndian for keys because // We use BigEndian for keys as it orders keys in
// it orders keys in ascending order // ascending order. This allows bucket scans to order payments
// in the order in which they were created.
paymentIdBytes := make([]byte, 8) paymentIdBytes := make([]byte, 8)
binary.BigEndian.PutUint64(paymentIdBytes, paymentId) binary.BigEndian.PutUint64(paymentIdBytes, paymentId)
err = payments.Put(paymentIdBytes, paymentBytes)
if err != nil { return payments.Put(paymentIdBytes, paymentBytes)
return err
}
return nil
}) })
} }
// FetchAllPayments returns all outgoing payments in DB. // FetchAllPayments returns all outgoing payments in DB.
func (db *DB) FetchAllPayments() ([]*OutgoingPayment, error) { func (db *DB) FetchAllPayments() ([]*OutgoingPayment, error) {
var payments []*OutgoingPayment var payments []*OutgoingPayment
err := db.View(func(tx *bolt.Tx) error { err := db.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket(paymentBucket) bucket := tx.Bucket(paymentBucket)
if bucket == nil { if bucket == nil {
return ErrNoPaymentsCreated return ErrNoPaymentsCreated
} }
err := bucket.ForEach(func(k, v []byte) error {
// Value can be nil if it is a sub-backet return bucket.ForEach(func(k, v []byte) error {
// so simply ignore it. // If the value is nil, then we ignore it as it may be
// a sub-bucket.
if v == nil { if v == nil {
return nil return nil
} }
r := bytes.NewReader(v) r := bytes.NewReader(v)
payment, err := deserializeOutgoingPayment(r) payment, err := deserializeOutgoingPayment(r)
if err != nil { if err != nil {
return err return err
} }
payments = append(payments, payment) payments = append(payments, payment)
return nil return nil
}) })
return err
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
return payments, nil return payments, nil
} }
// DeleteAllPayments deletes all payments from DB. // DeleteAllPayments deletes all payments from DB.
// If payments bucket does not exist it will create
// new bucket without error.
func (db *DB) DeleteAllPayments() error { func (db *DB) DeleteAllPayments() error {
return db.Update(func(tx *bolt.Tx) error { return db.Update(func(tx *bolt.Tx) error {
err := tx.DeleteBucket(paymentBucket) err := tx.DeleteBucket(paymentBucket)
@ -125,124 +130,87 @@ func (db *DB) DeleteAllPayments() error {
if err != nil { if err != nil {
return err return err
} }
return err
return nil
}) })
} }
func serializeOutgoingPayment(w io.Writer, p *OutgoingPayment) error { func serializeOutgoingPayment(w io.Writer, p *OutgoingPayment) error {
err := serializeInvoice(w, &p.Invoice) var scratch [8]byte
if err != nil {
if err := serializeInvoice(w, &p.Invoice); err != nil {
return err return err
} }
// Serialize fee. byteOrder.PutUint64(scratch[:], uint64(p.Fee))
feeBytes := make([]byte, 8) if _, err := w.Write(scratch[:]); err != nil {
byteOrder.PutUint64(feeBytes, uint64(p.Fee))
_, err = w.Write(feeBytes)
if err != nil {
return err return err
} }
// Serialize path. // First write out the length of the bytes to prefix the value.
pathLen := uint32(len(p.Path)) pathLen := uint32(len(p.Path))
pathLenBytes := make([]byte, 4) byteOrder.PutUint32(scratch[:4], pathLen)
// Write length of the path if _, err := w.Write(scratch[:4]); err != nil {
byteOrder.PutUint32(pathLenBytes, pathLen)
_, err = w.Write(pathLenBytes)
if err != nil {
return err return err
} }
// Serialize each element of the path
for i := uint32(0); i < pathLen; i++ { // Then with the path written, we write out the series of public keys
_, err := w.Write(p.Path[i][:]) // involved in the path.
if err != nil { for _, hop := range p.Path {
if _, err := w.Write(hop[:]); err != nil {
return err return err
} }
} }
// Serialize TimeLockLength byteOrder.PutUint32(scratch[:4], p.TimeLockLength)
timeLockLengthBytes := make([]byte, 4) if _, err := w.Write(scratch[:4]); err != nil {
byteOrder.PutUint32(timeLockLengthBytes, p.TimeLockLength)
_, err = w.Write(timeLockLengthBytes)
if err != nil {
return err return err
} }
// Serialize RHash if _, err := w.Write(p.PaymentHash[:]); err != nil {
_, err = w.Write(p.RHash[:])
if err != nil {
return err return err
} }
// Serialize Timestamp.
tBytes, err := p.Timestamp.MarshalBinary()
if err != nil {
return err
}
err = wire.WriteVarBytes(w, 0, tBytes)
if err != nil {
return err
}
return nil return nil
} }
func deserializeOutgoingPayment(r io.Reader) (*OutgoingPayment, error) { func deserializeOutgoingPayment(r io.Reader) (*OutgoingPayment, error) {
var scratch [8]byte
p := &OutgoingPayment{} p := &OutgoingPayment{}
// Deserialize invoice
inv, err := deserializeInvoice(r) inv, err := deserializeInvoice(r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
p.Invoice = *inv p.Invoice = *inv
// Deserialize fee if _, err := r.Read(scratch[:]); err != nil {
feeBytes := make([]byte, 8)
_, err = r.Read(feeBytes)
if err != nil {
return nil, err return nil, err
} }
p.Fee = btcutil.Amount(byteOrder.Uint64(feeBytes)) p.Fee = btcutil.Amount(byteOrder.Uint64(scratch[:]))
// Deserialize path if _, err = r.Read(scratch[:4]); err != nil {
pathLenBytes := make([]byte, 4)
_, err = r.Read(pathLenBytes)
if err != nil {
return nil, err return nil, err
} }
pathLen := byteOrder.Uint32(pathLenBytes) pathLen := byteOrder.Uint32(scratch[:4])
path := make([][33]byte, pathLen) path := make([][33]byte, pathLen)
for i := uint32(0); i < pathLen; i++ { for i := uint32(0); i < pathLen; i++ {
_, err := r.Read(path[i][:]) if _, err := r.Read(path[i][:]); err != nil {
if err != nil {
return nil, err return nil, err
} }
} }
p.Path = path p.Path = path
// Deserialize TimeLockLength if _, err = r.Read(scratch[:4]); err != nil {
timeLockLengthBytes := make([]byte, 4)
_, err = r.Read(timeLockLengthBytes)
if err != nil {
return nil, err return nil, err
} }
p.TimeLockLength = byteOrder.Uint32(timeLockLengthBytes) p.TimeLockLength = byteOrder.Uint32(scratch[:4])
// Deserialize RHash if _, err := r.Read(p.PaymentHash[:]); err != nil {
_, err = r.Read(p.RHash[:])
if err != nil {
return nil, err return nil, err
} }
// Deserialize Timestamp
tBytes, err := wire.ReadVarBytes(r, 0, 100,
"OutgoingPayment.Timestamp")
if err != nil {
return nil, err
}
err = p.Timestamp.UnmarshalBinary(tBytes)
if err != nil {
return nil, err
}
return p, nil return p, nil
} }

@ -3,55 +3,50 @@ package channeldb
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"github.com/btcsuite/fastsha256"
"github.com/davecgh/go-spew/spew"
"github.com/roasbeef/btcutil"
"math/rand" "math/rand"
"reflect" "reflect"
"testing" "testing"
"time" "time"
"github.com/btcsuite/fastsha256"
"github.com/davecgh/go-spew/spew"
"github.com/roasbeef/btcutil"
) )
func makeFakePayment() *OutgoingPayment { func makeFakePayment() *OutgoingPayment {
// Create a fake invoice which
// we'll use several times in the tests below.
fakeInvoice := &Invoice{ fakeInvoice := &Invoice{
CreationDate: time.Now(), CreationDate: time.Now(),
Memo: []byte("fake memo"),
Receipt: []byte("fake receipt"),
} }
fakeInvoice.Memo = []byte("memo")
fakeInvoice.Receipt = []byte("recipt")
copy(fakeInvoice.Terms.PaymentPreimage[:], rev[:]) copy(fakeInvoice.Terms.PaymentPreimage[:], rev[:])
fakeInvoice.Terms.Value = btcutil.Amount(10000) fakeInvoice.Terms.Value = btcutil.Amount(10000)
// Make fake path
fakePath := make([][33]byte, 3) fakePath := make([][33]byte, 3)
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
for j := 0; j < 33; j++ { copy(fakePath[i][:], bytes.Repeat([]byte{byte(i)}, 33))
fakePath[i][j] = byte(i)
} }
}
var rHash [32]byte = fastsha256.Sum256(rev[:]) return &OutgoingPayment{
fakePayment := &OutgoingPayment{
Invoice: *fakeInvoice, Invoice: *fakeInvoice,
Fee: 101, Fee: 101,
Path: fakePath, Path: fakePath,
TimeLockLength: 1000, TimeLockLength: 1000,
RHash: rHash, PaymentHash: fastsha256.Sum256(rev[:]),
Timestamp: time.Unix(100000, 0),
} }
return fakePayment
} }
// randomBytes creates random []byte with length // randomBytes creates random []byte with length in range [minLen, maxLen)
// in range [minLen, maxLen)
func randomBytes(minLen, maxLen int) ([]byte, error) { func randomBytes(minLen, maxLen int) ([]byte, error) {
l := minLen + rand.Intn(maxLen-minLen) randBuf := make([]byte, minLen+rand.Intn(maxLen-minLen))
b := make([]byte, l)
_, err := rand.Read(b) if _, err := rand.Read(randBuf); err != nil {
if err != nil {
return nil, fmt.Errorf("Internal error. "+ return nil, fmt.Errorf("Internal error. "+
"Cannot generate random string: %v", err) "Cannot generate random string: %v", err)
} }
return b, nil
return randBuf, nil
} }
func makeRandomFakePayment() (*OutgoingPayment, error) { func makeRandomFakePayment() (*OutgoingPayment, error) {
@ -78,7 +73,6 @@ func makeRandomFakePayment() (*OutgoingPayment, error) {
fakeInvoice.Terms.Value = btcutil.Amount(rand.Intn(10000)) fakeInvoice.Terms.Value = btcutil.Amount(rand.Intn(10000))
// Make fake path
fakePathLen := 1 + rand.Intn(5) fakePathLen := 1 + rand.Intn(5)
fakePath := make([][33]byte, fakePathLen) fakePath := make([][33]byte, fakePathLen)
for i := 0; i < fakePathLen; i++ { for i := 0; i < fakePathLen; i++ {
@ -89,31 +83,29 @@ func makeRandomFakePayment() (*OutgoingPayment, error) {
copy(fakePath[i][:], b) copy(fakePath[i][:], b)
} }
var rHash [32]byte = fastsha256.Sum256( rHash := fastsha256.Sum256(fakeInvoice.Terms.PaymentPreimage[:])
fakeInvoice.Terms.PaymentPreimage[:],
)
fakePayment := &OutgoingPayment{ fakePayment := &OutgoingPayment{
Invoice: *fakeInvoice, Invoice: *fakeInvoice,
Fee: btcutil.Amount(rand.Intn(1001)), Fee: btcutil.Amount(rand.Intn(1001)),
Path: fakePath, Path: fakePath,
TimeLockLength: uint32(rand.Intn(10000)), TimeLockLength: uint32(rand.Intn(10000)),
RHash: rHash, PaymentHash: rHash,
Timestamp: time.Unix(rand.Int63n(10000), 0),
} }
return fakePayment, nil return fakePayment, nil
} }
func TestOutgoingPaymentSerialization(t *testing.T) { func TestOutgoingPaymentSerialization(t *testing.T) {
fakePayment := makeFakePayment() fakePayment := makeFakePayment()
b := new(bytes.Buffer)
err := serializeOutgoingPayment(b, fakePayment) var b bytes.Buffer
if err != nil { if err := serializeOutgoingPayment(&b, fakePayment); err != nil {
t.Fatalf("Can't serialize outgoing payment: %v", err) t.Fatalf("unable to serialize outgoing payment: %v", err)
} }
newPayment, err := deserializeOutgoingPayment(b) newPayment, err := deserializeOutgoingPayment(&b)
if err != nil { if err != nil {
t.Fatalf("Can't deserialize outgoing payment: %v", err) t.Fatalf("unable to deserialize outgoing payment: %v", err)
} }
if !reflect.DeepEqual(fakePayment, newPayment) { if !reflect.DeepEqual(fakePayment, newPayment) {
@ -127,27 +119,27 @@ func TestOutgoingPaymentSerialization(t *testing.T) {
func TestOutgoingPaymentWorkflow(t *testing.T) { func TestOutgoingPaymentWorkflow(t *testing.T) {
db, cleanUp, err := makeTestDB() db, cleanUp, err := makeTestDB()
defer cleanUp()
if err != nil { if err != nil {
t.Fatalf("unable to make test db: %v", err) t.Fatalf("unable to make test db: %v", err)
} }
defer cleanUp()
fakePayment := makeFakePayment() fakePayment := makeFakePayment()
err = db.AddPayment(fakePayment) if err = db.AddPayment(fakePayment); err != nil {
if err != nil { t.Fatalf("unable to put payment in DB: %v", err)
t.Fatalf("Can't put payment in DB: %v", err)
} }
payments, err := db.FetchAllPayments() payments, err := db.FetchAllPayments()
if err != nil { if err != nil {
t.Fatalf("Can't get payments from DB: %v", err) t.Fatalf("unable to fetch payments from DB: %v", err)
} }
correctPayments := []*OutgoingPayment{fakePayment}
if !reflect.DeepEqual(payments, correctPayments) { expectedPayments := []*OutgoingPayment{fakePayment}
if !reflect.DeepEqual(payments, expectedPayments) {
t.Fatalf("Wrong payments after reading from DB."+ t.Fatalf("Wrong payments after reading from DB."+
"Got %v, want %v", "Got %v, want %v",
spew.Sdump(payments), spew.Sdump(payments),
spew.Sdump(correctPayments), spew.Sdump(expectedPayments),
) )
} }
@ -157,11 +149,12 @@ func TestOutgoingPaymentWorkflow(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Internal error in tests: %v", err) t.Fatalf("Internal error in tests: %v", err)
} }
err = db.AddPayment(randomPayment)
if err != nil { if err = db.AddPayment(randomPayment); err != nil {
t.Fatalf("Can't put payment in DB: %v", err) t.Fatalf("unable to put payment in DB: %v", err)
} }
correctPayments = append(correctPayments, randomPayment)
expectedPayments = append(expectedPayments, randomPayment)
} }
payments, err = db.FetchAllPayments() payments, err = db.FetchAllPayments()
@ -169,18 +162,17 @@ func TestOutgoingPaymentWorkflow(t *testing.T) {
t.Fatalf("Can't get payments from DB: %v", err) t.Fatalf("Can't get payments from DB: %v", err)
} }
if !reflect.DeepEqual(payments, correctPayments) { if !reflect.DeepEqual(payments, expectedPayments) {
t.Fatalf("Wrong payments after reading from DB."+ t.Fatalf("Wrong payments after reading from DB."+
"Got %v, want %v", "Got %v, want %v",
spew.Sdump(payments), spew.Sdump(payments),
spew.Sdump(correctPayments), spew.Sdump(expectedPayments),
) )
} }
// Delete all payments. // Delete all payments.
err = db.DeleteAllPayments() if err = db.DeleteAllPayments(); err != nil {
if err != nil { t.Fatalf("unable to delete payments from DB: %v", err)
t.Fatalf("Can't delete payments from DB: %v", err)
} }
// Check that there is no payments after deletion // Check that there is no payments after deletion