channeldb: refactor payments code

Go-fmt files. Refactored code according to the guidelines.
Enhanced payment test: add error checking
and individual context for each API call.
Add Timestamp field to payment struct.
This commit is contained in:
BitfuryLightning 2016-12-21 04:19:01 -05:00 committed by Olaoluwa Osuntokun
parent eb4d0e035e
commit 1c7f87c3f1
No known key found for this signature in database
GPG Key ID: 9CC5B105D03521A2
5 changed files with 215 additions and 114 deletions

@ -116,8 +116,7 @@ func validateInvoice(i *Invoice) error {
// insertion will be aborted and rejected due to the strict policy banning any
// duplicate payment hashes.
func (d *DB) AddInvoice(i *Invoice) error {
err := validateInvoice(i)
if err != nil {
if err := validateInvoice(i); err != nil {
return err
}
return d.Update(func(tx *bolt.Tx) error {

@ -1,51 +1,55 @@
package channeldb
import (
"bytes"
"encoding/binary"
"github.com/boltdb/bolt"
"github.com/roasbeef/btcd/wire"
"github.com/roasbeef/btcutil"
"io"
"github.com/roasbeef/btcd/wire"
"github.com/boltdb/bolt"
"encoding/binary"
"bytes"
"time"
)
var (
// invoiceBucket is the name of the bucket within the database that
// stores all data related to payments.
// Within the payments bucket, each invoice is keyed by its invoice ID
// invoiceBucket is the name of the bucket within
// the database that stores all data related to payments.
// Within the payments bucket, each invoice is keyed
// by its invoice ID
// which is a monotonically increasing uint64.
// BoltDB sequence feature is used for generating monotonically increasing
// id.
// BoltDB sequence feature is used for generating
// monotonically increasing id.
paymentBucket = []byte("payments")
)
// OutgoingPayment represents payment from given node.
type OutgoingPayment struct {
Invoice
// Total fee paid
Fee btcutil.Amount
// Path including starting and ending nodes
Path [][]byte
TimeLockLength uint64
// We probably need both RHash and Preimage
// because we start knowing only RHash
RHash [32]byte
}
func validatePayment(p *OutgoingPayment) error {
err := validateInvoice(&p.Invoice)
if err != nil {
return err
}
return nil
// Total fee paid.
Fee btcutil.Amount
// Path including starting and ending nodes.
Path [][33]byte
// Timelock length.
TimeLockLength uint32
// RHash value used for payment.
// We need RHash because we start payment knowing only RHash
RHash [32]byte
// Timestamp is time when payment was created.
Timestamp time.Time
}
// AddPayment adds payment to DB.
// There is no checking that payment with the same hash already exist.
func (db *DB) AddPayment(p *OutgoingPayment) error {
err := validatePayment(p)
err := validateInvoice(&p.Invoice)
if err != nil {
return err
}
// We serialize before writing to database
// so no db access in the case of serialization errors
b := new(bytes.Buffer)
@ -54,15 +58,17 @@ func (db *DB) AddPayment(p *OutgoingPayment) error {
return err
}
paymentBytes := b.Bytes()
return db.Update(func (tx *bolt.Tx) error {
return db.Update(func(tx *bolt.Tx) error {
payments, err := tx.CreateBucketIfNotExists(paymentBucket)
if err != nil {
return err
}
paymentId, err := payments.NextSequence()
if err != nil {
return err
}
// We use BigEndian for keys because
// it orders keys in ascending order
paymentIdBytes := make([]byte, 8)
@ -78,12 +84,12 @@ func (db *DB) AddPayment(p *OutgoingPayment) error {
// FetchAllPayments returns all outgoing payments in DB.
func (db *DB) FetchAllPayments() ([]*OutgoingPayment, error) {
var payments []*OutgoingPayment
err := db.View(func (tx *bolt.Tx) error {
err := db.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket(paymentBucket)
if bucket == nil {
return nil
return ErrNoPaymentsCreated
}
err := bucket.ForEach(func (k, v []byte) error {
err := bucket.ForEach(func(k, v []byte) error {
// Value can be nil if it is a sub-backet
// so simply ignore it.
if v == nil {
@ -109,11 +115,12 @@ func (db *DB) FetchAllPayments() ([]*OutgoingPayment, error) {
// If payments bucket does not exist it will create
// new bucket without error.
func (db *DB) DeleteAllPayments() error {
return db.Update(func (tx *bolt.Tx) error {
return db.Update(func(tx *bolt.Tx) error {
err := tx.DeleteBucket(paymentBucket)
if err != nil && err != bolt.ErrBucketNotFound {
return err
}
_, err = tx.CreateBucket(paymentBucket)
if err != nil {
return err
@ -127,6 +134,7 @@ func serializeOutgoingPayment(w io.Writer, p *OutgoingPayment) error {
if err != nil {
return err
}
// Serialize fee.
feeBytes := make([]byte, 8)
byteOrder.PutUint64(feeBytes, uint64(p.Fee))
@ -134,43 +142,60 @@ func serializeOutgoingPayment(w io.Writer, p *OutgoingPayment) error {
if err != nil {
return err
}
// Serialize path.
pathLen := uint32(len(p.Path))
pathLenBytes := make([]byte, 4)
// Write length of the path
byteOrder.PutUint32(pathLenBytes, pathLen)
_, err = w.Write(pathLenBytes)
if err != nil {
return err
}
// Serialize each element of the path
for i := uint32(0); i < pathLen; i++ {
err := wire.WriteVarBytes(w, 0, p.Path[i])
_, err := w.Write(p.Path[i][:])
if err != nil {
return err
}
}
// Serialize TimeLockLength
timeLockLengthBytes := make([]byte, 8)
byteOrder.PutUint64(timeLockLengthBytes, p.TimeLockLength)
timeLockLengthBytes := make([]byte, 4)
byteOrder.PutUint32(timeLockLengthBytes, p.TimeLockLength)
_, err = w.Write(timeLockLengthBytes)
if err != nil {
return err
}
// Serialize RHash
_, err = w.Write(p.RHash[:])
if err != nil {
return err
}
// Serialize Timestamp.
tBytes, err := p.Timestamp.MarshalBinary()
if err != nil {
return err
}
err = wire.WriteVarBytes(w, 0, tBytes)
if err != nil {
return err
}
return nil
}
func deserializeOutgoingPayment(r io.Reader) (*OutgoingPayment, error) {
p := &OutgoingPayment{}
// Deserialize invoice
inv, err := deserializeInvoice(r)
if err != nil {
return nil, err
}
p.Invoice = *inv
// Deserialize fee
feeBytes := make([]byte, 8)
_, err = r.Read(feeBytes)
@ -178,6 +203,7 @@ func deserializeOutgoingPayment(r io.Reader) (*OutgoingPayment, error) {
return nil, err
}
p.Fee = btcutil.Amount(byteOrder.Uint64(feeBytes))
// Deserialize path
pathLenBytes := make([]byte, 4)
_, err = r.Read(pathLenBytes)
@ -185,27 +211,38 @@ func deserializeOutgoingPayment(r io.Reader) (*OutgoingPayment, error) {
return nil, err
}
pathLen := byteOrder.Uint32(pathLenBytes)
path := make([][]byte, pathLen)
for i := uint32(0); i<pathLen; i++ {
// Each node in path have 33 bytes. It may be changed in future.
// So put 100 here.
path[i], err = wire.ReadVarBytes(r, 0, 100, "Node id")
path := make([][33]byte, pathLen)
for i := uint32(0); i < pathLen; i++ {
_, err := r.Read(path[i][:])
if err != nil {
return nil, err
}
}
p.Path = path
// Deserialize TimeLockLength
timeLockLengthBytes := make([]byte, 8)
timeLockLengthBytes := make([]byte, 4)
_, err = r.Read(timeLockLengthBytes)
if err != nil {
return nil, err
}
p.TimeLockLength = byteOrder.Uint64(timeLockLengthBytes)
p.TimeLockLength = byteOrder.Uint32(timeLockLengthBytes)
// Deserialize RHash
_, err = r.Read(p.RHash[:])
if err != nil {
return nil, err
}
// Deserialize Timestamp
tBytes, err := wire.ReadVarBytes(r, 0, 100,
"OutgoingPayment.Timestamp")
if err != nil {
return nil, err
}
err = p.Timestamp.UnmarshalBinary(tBytes)
if err != nil {
return nil, err
}
return p, nil
}
}

@ -1,20 +1,20 @@
package channeldb
import (
"testing"
"time"
"github.com/roasbeef/btcutil"
"bytes"
"reflect"
"github.com/davecgh/go-spew/spew"
"math/rand"
"fmt"
"github.com/btcsuite/fastsha256"
"github.com/davecgh/go-spew/spew"
"github.com/roasbeef/btcutil"
"math/rand"
"reflect"
"testing"
"time"
)
func makeFakePayment() *OutgoingPayment {
// Create a fake invoice which we'll use several times in the tests
// below.
// Create a fake invoice which
// we'll use several times in the tests below.
fakeInvoice := &Invoice{
CreationDate: time.Now(),
}
@ -23,63 +23,84 @@ func makeFakePayment() *OutgoingPayment {
copy(fakeInvoice.Terms.PaymentPreimage[:], rev[:])
fakeInvoice.Terms.Value = btcutil.Amount(10000)
// Make fake path
fakePath := make([][]byte, 3)
for i:=0; i<3; i++ {
fakePath[i] = make([]byte, 33)
for j:=0; j<33; j++ {
fakePath := make([][33]byte, 3)
for i := 0; i < 3; i++ {
for j := 0; j < 33; j++ {
fakePath[i][j] = byte(i)
}
}
var rHash [32]byte = fastsha256.Sum256(rev[:])
fakePayment := & OutgoingPayment{
Invoice: *fakeInvoice,
Fee: 101,
Path: fakePath,
fakePayment := &OutgoingPayment{
Invoice: *fakeInvoice,
Fee: 101,
Path: fakePath,
TimeLockLength: 1000,
RHash: rHash,
RHash: rHash,
Timestamp: time.Unix(100000, 0),
}
return fakePayment
}
// randomBytes creates random []byte with length
// in range [minLen, maxLen)
func randomBytes(minLen, maxLen int) []byte {
func randomBytes(minLen, maxLen int) ([]byte, error) {
l := minLen + rand.Intn(maxLen-minLen)
b := make([]byte, l)
_, err := rand.Read(b)
if err != nil {
panic(fmt.Sprintf("Internal error. Cannot generate random string: %v", err))
return nil, fmt.Errorf("Internal error. "+
"Cannot generate random string: %v", err)
}
return b
return b, nil
}
func makeRandomFakePayment() *OutgoingPayment {
// Create a fake invoice which we'll use several times in the tests
// below.
func makeRandomFakePayment() (*OutgoingPayment, error) {
var err error
fakeInvoice := &Invoice{
CreationDate: time.Now(),
}
fakeInvoice.Memo = randomBytes(1, 50)
fakeInvoice.Receipt = randomBytes(1, 50)
copy(fakeInvoice.Terms.PaymentPreimage[:], randomBytes(32, 33))
fakeInvoice.Memo, err = randomBytes(1, 50)
if err != nil {
return nil, err
}
fakeInvoice.Receipt, err = randomBytes(1, 50)
if err != nil {
return nil, err
}
preImg, err := randomBytes(32, 33)
if err != nil {
return nil, err
}
copy(fakeInvoice.Terms.PaymentPreimage[:], preImg)
fakeInvoice.Terms.Value = btcutil.Amount(rand.Intn(10000))
// Make fake path
fakePathLen := 1 + rand.Intn(5)
fakePath := make([][]byte, fakePathLen)
for i:=0; i<fakePathLen; i++ {
fakePath[i] = randomBytes(33, 34)
fakePath := make([][33]byte, fakePathLen)
for i := 0; i < fakePathLen; i++ {
b, err := randomBytes(33, 34)
if err != nil {
return nil, err
}
copy(fakePath[i][:], b)
}
var rHash [32]byte = fastsha256.Sum256(
fakeInvoice.Terms.PaymentPreimage[:],
)
fakePayment := & OutgoingPayment{
Invoice: *fakeInvoice,
Fee: btcutil.Amount(rand.Intn(1001)),
Path: fakePath,
TimeLockLength: uint64(rand.Intn(10000)),
RHash: rHash,
fakePayment := &OutgoingPayment{
Invoice: *fakeInvoice,
Fee: btcutil.Amount(rand.Intn(1001)),
Path: fakePath,
TimeLockLength: uint32(rand.Intn(10000)),
RHash: rHash,
Timestamp: time.Unix(rand.Int63n(10000), 0),
}
return fakePayment
return fakePayment, nil
}
func TestOutgoingPaymentSerialization(t *testing.T) {
@ -89,12 +110,15 @@ func TestOutgoingPaymentSerialization(t *testing.T) {
if err != nil {
t.Fatalf("Can't serialize outgoing payment: %v", err)
}
newPayment, err := deserializeOutgoingPayment(b)
if err != nil {
t.Fatalf("Can't deserialize outgoing payment: %v", err)
}
if !reflect.DeepEqual(fakePayment, newPayment) {
t.Fatalf("Payments do not match after serialization/deserialization %v vs %v",
t.Fatalf("Payments do not match after "+
"serialization/deserialization %v vs %v",
spew.Sdump(fakePayment),
spew.Sdump(newPayment),
)
@ -128,18 +152,23 @@ func TestOutgoingPaymentWorkflow(t *testing.T) {
}
// Make some random payments
for i:=0; i<5; i++ {
randomPayment := makeRandomFakePayment()
err := db.AddPayment(randomPayment)
for i := 0; i < 5; i++ {
randomPayment, err := makeRandomFakePayment()
if err != nil {
t.Fatalf("Internal error in tests: %v", err)
}
err = db.AddPayment(randomPayment)
if err != nil {
t.Fatalf("Can't put payment in DB: %v", err)
}
correctPayments = append(correctPayments, randomPayment)
}
payments, err = db.FetchAllPayments()
if err != nil {
t.Fatalf("Can't get payments from DB: %v", err)
}
if !reflect.DeepEqual(payments, correctPayments) {
t.Fatalf("Wrong payments after reading from DB."+
"Got %v, want %v",
@ -153,11 +182,14 @@ func TestOutgoingPaymentWorkflow(t *testing.T) {
if err != nil {
t.Fatalf("Can't delete payments from DB: %v", err)
}
// Check that there is no payments after deletion
paymentsAfterDeletion, err := db.FetchAllPayments()
if err != nil {
t.Fatalf("Can't get payments after deletion: %v", err)
}
if len(paymentsAfterDeletion) != 0 {
t.Fatalf("After deletion DB has %v payments, want %v", len(paymentsAfterDeletion), 0)
t.Fatalf("After deletion DB has %v payments, want %v",
len(paymentsAfterDeletion), 0)
}
}

@ -495,35 +495,41 @@ func testSingleHopInvoice(net *networkHarness, t *harnessTest) {
func testListPayments(net *networkHarness, t *harnessTest) {
ctxb := context.Background()
timeout := time.Duration(time.Second * 5)
ctxt, _ := context.WithTimeout(ctxb, timeout)
timeout := time.Duration(time.Second * 2)
// Delete all payments from Alice. DB should have no payments
deleteAllPaymentsInitialReq := &lnrpc.DeleteAllPaymentsRequest{}
_, err := net.Alice.DeleteAllPayments(ctxt, deleteAllPaymentsInitialReq)
deleteAllPaymentsInitialCtxt, _ := context.WithTimeout(ctxb, timeout)
_, err := net.Alice.DeleteAllPayments(deleteAllPaymentsInitialCtxt,
deleteAllPaymentsInitialReq)
if err != nil {
t.Fatalf("Can't delete payments at the begining: %v", err)
}
// Check that there are no payments before test.
reqInit := &lnrpc.ListPaymentsRequest{}
paymentsRespInit, err := net.Alice.ListPayments(ctxt, reqInit)
reqInitCtxt, _ := context.WithTimeout(ctxb, timeout)
paymentsRespInit, err := net.Alice.ListPayments(reqInitCtxt, reqInit)
if err != nil {
t.Fatalf("error when obtaining Alice payments: %v", err)
}
if len(paymentsRespInit.Payments) != 0 {
t.Fatalf("incorrect number of payments, got %v, want %v", len(paymentsRespInit.Payments), 0)
t.Fatalf("incorrect number of payments, got %v, want %v",
len(paymentsRespInit.Payments), 0)
}
// Open a channel with 100k satoshis between Alice and Bob with Alice being
// Open a channel with 100k satoshis
// between Alice and Bob with Alice being
// the sole funder of the channel.
chanAmt := btcutil.Amount(100000)
chanPoint := openChannelAndAssert(t, net, ctxt, net.Alice, net.Bob, chanAmt)
openChannelCtxt, _ := context.WithTimeout(ctxb, timeout)
chanPoint := openChannelAndAssert(t, net, openChannelCtxt,
net.Alice, net.Bob, chanAmt)
// Now that the channel is open, create an invoice for Bob which
// expects a payment of 1000 satoshis from Alice paid via a particular
// pre-image.
// expects a payment of 1000 satoshis from Alice
// paid via a particular pre-image.
const paymentAmt = 1000
preimage := bytes.Repeat([]byte("B"), 32)
invoice := &lnrpc.Invoice{
@ -531,14 +537,16 @@ func testListPayments(net *networkHarness, t *harnessTest) {
RPreimage: preimage,
Value: paymentAmt,
}
invoiceResp, err := net.Bob.AddInvoice(ctxt, invoice)
addInvoiceCtxt, _ := context.WithTimeout(ctxb, timeout)
invoiceResp, err := net.Bob.AddInvoice(addInvoiceCtxt, invoice)
if err != nil {
t.Fatalf("unable to add invoice: %v", err)
}
// With the invoice for Bob added, send a payment towards Alice paying
// to the above generated invoice.
sendStream, err := net.Alice.SendPayment(ctxt)
sendPaymentCtxt, _ := context.WithTimeout(ctxb, timeout)
sendStream, err := net.Alice.SendPayment(sendPaymentCtxt)
if err != nil {
t.Fatalf("unable to create alice payment stream: %v", err)
}
@ -558,33 +566,38 @@ func testListPayments(net *networkHarness, t *harnessTest) {
// like balance here because it is already checked in
// testSingleHopInvoice
req := &lnrpc.ListPaymentsRequest{}
paymentsResp, err := net.Alice.ListPayments(ctxt, req)
listPaymentsCtxt, _ := context.WithTimeout(ctxb, timeout)
paymentsResp, err := net.Alice.ListPayments(listPaymentsCtxt, req)
if err != nil {
t.Fatalf("error when obtaining Alice payments: %v", err)
}
if len(paymentsResp.Payments) != 1 {
t.Fatalf("incorrect number of payments, got %v, want %v", len(paymentsResp.Payments), 1)
t.Fatalf("incorrect number of payments, got %v, want %v",
len(paymentsResp.Payments), 1)
}
p := paymentsResp.Payments[0]
// Check path.
expectedPath := []string {
expectedPath := []string{
net.Alice.PubKeyStr,
net.Bob.PubKeyStr,
}
if !reflect.DeepEqual(p.Path, expectedPath) {
t.Fatalf("incorrect path, got %v, want %v", p.Path, expectedPath)
t.Fatalf("incorrect path, got %v, want %v",
p.Path, expectedPath)
}
// Check amount.
if p.Value != paymentAmt {
t.Fatalf("incorrect amount, got %v, want %v", p.Value, paymentAmt)
t.Fatalf("incorrect amount, got %v, want %v",
p.Value, paymentAmt)
}
// Check RHash.
correctRHash := hex.EncodeToString(invoiceResp.RHash)
if !reflect.DeepEqual(p.RHash, correctRHash){
t.Fatalf("incorrect RHash, got %v, want %v", p.RHash, correctRHash)
if !reflect.DeepEqual(p.RHash, correctRHash) {
t.Fatalf("incorrect RHash, got %v, want %v",
p.RHash, correctRHash)
}
// Check Fee.
@ -595,26 +608,31 @@ func testListPayments(net *networkHarness, t *harnessTest) {
// Delete all payments from Alice. DB should have no payments.
deleteAllPaymentsEndReq := &lnrpc.DeleteAllPaymentsRequest{}
_, err = net.Alice.DeleteAllPayments(ctxt, deleteAllPaymentsEndReq)
deleteAllPaymentsEndCtxt, _ := context.WithTimeout(ctxb, timeout)
_, err = net.Alice.DeleteAllPayments(deleteAllPaymentsEndCtxt,
deleteAllPaymentsEndReq)
if err != nil {
t.Fatalf("Can't delete payments at the end: %v", err)
}
// Check that there are no payments before test.
reqEnd := &lnrpc.ListPaymentsRequest{}
_, err = net.Alice.ListPayments(ctxt, reqEnd)
listPaymentsEndCtxt, _ := context.WithTimeout(ctxb, timeout)
_, err = net.Alice.ListPayments(listPaymentsEndCtxt, reqEnd)
if err != nil {
t.Fatalf("error when obtaining Alice payments: %v", err)
}
if len(paymentsRespInit.Payments) != 0 {
t.Fatalf("incorrect number of payments, got %v, want %v", len(paymentsRespInit.Payments), 0)
t.Fatalf("incorrect number of payments, got %v, want %v",
len(paymentsRespInit.Payments), 0)
}
ctxt, _ = context.WithTimeout(ctxb, timeout)
closeChannelAndAssert(t, net, ctxt, net.Alice, chanPoint, false)
closeChannelCtxt, _ := context.WithTimeout(ctxb, timeout)
closeChannelAndAssert(t, net, closeChannelCtxt,
net.Alice, chanPoint, false)
}
func testMultiHopPayments(net *networkHarness, t *harnessTest) {
const chanAmt = btcutil.Amount(100000)
ctxb := context.Background()

@ -598,15 +598,22 @@ func (r *rpcServer) ListChannels(ctx context.Context,
}
func constructPayment(path []graph.Vertex, amount btcutil.Amount, rHash []byte) *channeldb.OutgoingPayment {
payment := &channeldb.OutgoingPayment{}
// When we create payment we do not know preImage.
// So we need to save rHash
copy(payment.RHash[:], rHash)
payment.Invoice.Terms.Value = btcutil.Amount(amount)
payment.Invoice.CreationDate = time.Now()
pathBytes := make([][]byte, len(path))
payment.Timestamp = time.Now()
pathBytes33 := make([][33]byte, len(path))
for i:=0; i<len(path); i++ {
pathBytes[i] = path[i].ToByte()
pathBytes33[i] = path[i].ToByte33()
}
payment.Path = pathBytes
payment.Path = pathBytes33
return payment
}
@ -687,9 +694,12 @@ func (r *rpcServer) SendPayment(paymentStream lnrpc.Lightning_SendPaymentServer)
errChan <- err
return
}
// Save payment to DB.
payment := constructPayment(path, btcutil.Amount(nextPayment.Amt), rHash[:])
payment := constructPayment(path,
btcutil.Amount(nextPayment.Amt), rHash[:])
r.server.chanDB.AddPayment(payment)
// TODO(roasbeef): proper responses
resp := &lnrpc.SendResponse{}
if err := paymentStream.Send(resp); err != nil {
@ -1136,22 +1146,25 @@ func (r *rpcServer) ShowRoutingTable(ctx context.Context,
// ListPayments returns a list of all outgoing payments.
func (r *rpcServer) ListPayments(context.Context,
*lnrpc.ListPaymentsRequest) (*lnrpc.ListPaymentsResponse, error) {
rpcsLog.Debugf("[ListPayments]")
payments, err := r.server.chanDB.FetchAllPayments()
if err != nil {
return nil, err
}
paymentsResp := &lnrpc.ListPaymentsResponse{
Payments: make([]*lnrpc.Payment, len(payments)),
}
for i:=0; i<len(payments); i++{
for i:=0; i<len(payments); i++ {
p := &lnrpc.Payment{}
p.CreationDate = payments[i].CreationDate.Unix()
p.Value = int64(payments[i].Terms.Value)
p.RHash = hex.EncodeToString(payments[i].RHash[:])
path := make([]string, len(payments[i].Path))
for j:=0; j<len(path); j++ {
path[j] = hex.EncodeToString(payments[i].Path[j])
path[j] = hex.EncodeToString(payments[i].Path[j][:])
}
p.Path = path
paymentsResp.Payments[i] = p
@ -1163,7 +1176,9 @@ func (r *rpcServer) ListPayments(context.Context,
// DeleteAllPayments deletes all outgoing payments from DB.
func (r *rpcServer) DeleteAllPayments(context.Context,
*lnrpc.DeleteAllPaymentsRequest) (*lnrpc.DeleteAllPaymentsResponse, error) {
rpcsLog.Debugf("[DeleteAllPayments]")
err := r.server.chanDB.DeleteAllPayments()
resp := &lnrpc.DeleteAllPaymentsResponse{}
return resp, err