channeldb: refactor payments code

Go-fmt files. Refactored code according to the guidelines.
Enhanced payment test: add error checking
and individual context for each API call.
Add Timestamp field to payment struct.
This commit is contained in:
BitfuryLightning 2016-12-21 04:19:01 -05:00 committed by Olaoluwa Osuntokun
parent eb4d0e035e
commit 1c7f87c3f1
No known key found for this signature in database
GPG Key ID: 9CC5B105D03521A2
5 changed files with 215 additions and 114 deletions

@ -116,8 +116,7 @@ func validateInvoice(i *Invoice) error {
// insertion will be aborted and rejected due to the strict policy banning any // insertion will be aborted and rejected due to the strict policy banning any
// duplicate payment hashes. // duplicate payment hashes.
func (d *DB) AddInvoice(i *Invoice) error { func (d *DB) AddInvoice(i *Invoice) error {
err := validateInvoice(i) if err := validateInvoice(i); err != nil {
if err != nil {
return err return err
} }
return d.Update(func(tx *bolt.Tx) error { return d.Update(func(tx *bolt.Tx) error {

@ -1,51 +1,55 @@
package channeldb package channeldb
import ( import (
"bytes"
"encoding/binary"
"github.com/boltdb/bolt"
"github.com/roasbeef/btcd/wire"
"github.com/roasbeef/btcutil" "github.com/roasbeef/btcutil"
"io" "io"
"github.com/roasbeef/btcd/wire" "time"
"github.com/boltdb/bolt"
"encoding/binary"
"bytes"
) )
var ( var (
// invoiceBucket is the name of the bucket within the database that // invoiceBucket is the name of the bucket within
// stores all data related to payments. // the database that stores all data related to payments.
// Within the payments bucket, each invoice is keyed by its invoice ID // Within the payments bucket, each invoice is keyed
// by its invoice ID
// which is a monotonically increasing uint64. // which is a monotonically increasing uint64.
// BoltDB sequence feature is used for generating monotonically increasing // BoltDB sequence feature is used for generating
// id. // monotonically increasing id.
paymentBucket = []byte("payments") paymentBucket = []byte("payments")
) )
// OutgoingPayment represents payment from given node.
type OutgoingPayment struct { type OutgoingPayment struct {
Invoice Invoice
// Total fee paid
Fee btcutil.Amount
// Path including starting and ending nodes
Path [][]byte
TimeLockLength uint64
// We probably need both RHash and Preimage
// because we start knowing only RHash
RHash [32]byte
}
func validatePayment(p *OutgoingPayment) error { // Total fee paid.
err := validateInvoice(&p.Invoice) Fee btcutil.Amount
if err != nil {
return err // Path including starting and ending nodes.
} Path [][33]byte
return nil
// Timelock length.
TimeLockLength uint32
// RHash value used for payment.
// We need RHash because we start payment knowing only RHash
RHash [32]byte
// Timestamp is time when payment was created.
Timestamp time.Time
} }
// AddPayment adds payment to DB. // AddPayment adds payment to DB.
// There is no checking that payment with the same hash already exist. // There is no checking that payment with the same hash already exist.
func (db *DB) AddPayment(p *OutgoingPayment) error { func (db *DB) AddPayment(p *OutgoingPayment) error {
err := validatePayment(p) err := validateInvoice(&p.Invoice)
if err != nil { if err != nil {
return err return err
} }
// We serialize before writing to database // We serialize before writing to database
// so no db access in the case of serialization errors // so no db access in the case of serialization errors
b := new(bytes.Buffer) b := new(bytes.Buffer)
@ -54,15 +58,17 @@ func (db *DB) AddPayment(p *OutgoingPayment) error {
return err return err
} }
paymentBytes := b.Bytes() paymentBytes := b.Bytes()
return db.Update(func (tx *bolt.Tx) error { return db.Update(func(tx *bolt.Tx) error {
payments, err := tx.CreateBucketIfNotExists(paymentBucket) payments, err := tx.CreateBucketIfNotExists(paymentBucket)
if err != nil { if err != nil {
return err return err
} }
paymentId, err := payments.NextSequence() paymentId, err := payments.NextSequence()
if err != nil { if err != nil {
return err return err
} }
// We use BigEndian for keys because // We use BigEndian for keys because
// it orders keys in ascending order // it orders keys in ascending order
paymentIdBytes := make([]byte, 8) paymentIdBytes := make([]byte, 8)
@ -78,12 +84,12 @@ func (db *DB) AddPayment(p *OutgoingPayment) error {
// FetchAllPayments returns all outgoing payments in DB. // FetchAllPayments returns all outgoing payments in DB.
func (db *DB) FetchAllPayments() ([]*OutgoingPayment, error) { func (db *DB) FetchAllPayments() ([]*OutgoingPayment, error) {
var payments []*OutgoingPayment var payments []*OutgoingPayment
err := db.View(func (tx *bolt.Tx) error { err := db.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket(paymentBucket) bucket := tx.Bucket(paymentBucket)
if bucket == nil { if bucket == nil {
return nil return ErrNoPaymentsCreated
} }
err := bucket.ForEach(func (k, v []byte) error { err := bucket.ForEach(func(k, v []byte) error {
// Value can be nil if it is a sub-backet // Value can be nil if it is a sub-backet
// so simply ignore it. // so simply ignore it.
if v == nil { if v == nil {
@ -109,11 +115,12 @@ func (db *DB) FetchAllPayments() ([]*OutgoingPayment, error) {
// If payments bucket does not exist it will create // If payments bucket does not exist it will create
// new bucket without error. // new bucket without error.
func (db *DB) DeleteAllPayments() error { func (db *DB) DeleteAllPayments() error {
return db.Update(func (tx *bolt.Tx) error { return db.Update(func(tx *bolt.Tx) error {
err := tx.DeleteBucket(paymentBucket) err := tx.DeleteBucket(paymentBucket)
if err != nil && err != bolt.ErrBucketNotFound { if err != nil && err != bolt.ErrBucketNotFound {
return err return err
} }
_, err = tx.CreateBucket(paymentBucket) _, err = tx.CreateBucket(paymentBucket)
if err != nil { if err != nil {
return err return err
@ -127,6 +134,7 @@ func serializeOutgoingPayment(w io.Writer, p *OutgoingPayment) error {
if err != nil { if err != nil {
return err return err
} }
// Serialize fee. // Serialize fee.
feeBytes := make([]byte, 8) feeBytes := make([]byte, 8)
byteOrder.PutUint64(feeBytes, uint64(p.Fee)) byteOrder.PutUint64(feeBytes, uint64(p.Fee))
@ -134,43 +142,60 @@ func serializeOutgoingPayment(w io.Writer, p *OutgoingPayment) error {
if err != nil { if err != nil {
return err return err
} }
// Serialize path. // Serialize path.
pathLen := uint32(len(p.Path)) pathLen := uint32(len(p.Path))
pathLenBytes := make([]byte, 4) pathLenBytes := make([]byte, 4)
// Write length of the path
byteOrder.PutUint32(pathLenBytes, pathLen) byteOrder.PutUint32(pathLenBytes, pathLen)
_, err = w.Write(pathLenBytes) _, err = w.Write(pathLenBytes)
if err != nil { if err != nil {
return err return err
} }
// Serialize each element of the path
for i := uint32(0); i < pathLen; i++ { for i := uint32(0); i < pathLen; i++ {
err := wire.WriteVarBytes(w, 0, p.Path[i]) _, err := w.Write(p.Path[i][:])
if err != nil { if err != nil {
return err return err
} }
} }
// Serialize TimeLockLength // Serialize TimeLockLength
timeLockLengthBytes := make([]byte, 8) timeLockLengthBytes := make([]byte, 4)
byteOrder.PutUint64(timeLockLengthBytes, p.TimeLockLength) byteOrder.PutUint32(timeLockLengthBytes, p.TimeLockLength)
_, err = w.Write(timeLockLengthBytes) _, err = w.Write(timeLockLengthBytes)
if err != nil { if err != nil {
return err return err
} }
// Serialize RHash // Serialize RHash
_, err = w.Write(p.RHash[:]) _, err = w.Write(p.RHash[:])
if err != nil { if err != nil {
return err return err
} }
// Serialize Timestamp.
tBytes, err := p.Timestamp.MarshalBinary()
if err != nil {
return err
}
err = wire.WriteVarBytes(w, 0, tBytes)
if err != nil {
return err
}
return nil return nil
} }
func deserializeOutgoingPayment(r io.Reader) (*OutgoingPayment, error) { func deserializeOutgoingPayment(r io.Reader) (*OutgoingPayment, error) {
p := &OutgoingPayment{} p := &OutgoingPayment{}
// Deserialize invoice // Deserialize invoice
inv, err := deserializeInvoice(r) inv, err := deserializeInvoice(r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
p.Invoice = *inv p.Invoice = *inv
// Deserialize fee // Deserialize fee
feeBytes := make([]byte, 8) feeBytes := make([]byte, 8)
_, err = r.Read(feeBytes) _, err = r.Read(feeBytes)
@ -178,6 +203,7 @@ func deserializeOutgoingPayment(r io.Reader) (*OutgoingPayment, error) {
return nil, err return nil, err
} }
p.Fee = btcutil.Amount(byteOrder.Uint64(feeBytes)) p.Fee = btcutil.Amount(byteOrder.Uint64(feeBytes))
// Deserialize path // Deserialize path
pathLenBytes := make([]byte, 4) pathLenBytes := make([]byte, 4)
_, err = r.Read(pathLenBytes) _, err = r.Read(pathLenBytes)
@ -185,27 +211,38 @@ func deserializeOutgoingPayment(r io.Reader) (*OutgoingPayment, error) {
return nil, err return nil, err
} }
pathLen := byteOrder.Uint32(pathLenBytes) pathLen := byteOrder.Uint32(pathLenBytes)
path := make([][]byte, pathLen) path := make([][33]byte, pathLen)
for i := uint32(0); i<pathLen; i++ { for i := uint32(0); i < pathLen; i++ {
// Each node in path have 33 bytes. It may be changed in future. _, err := r.Read(path[i][:])
// So put 100 here.
path[i], err = wire.ReadVarBytes(r, 0, 100, "Node id")
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
p.Path = path p.Path = path
// Deserialize TimeLockLength // Deserialize TimeLockLength
timeLockLengthBytes := make([]byte, 8) timeLockLengthBytes := make([]byte, 4)
_, err = r.Read(timeLockLengthBytes) _, err = r.Read(timeLockLengthBytes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
p.TimeLockLength = byteOrder.Uint64(timeLockLengthBytes) p.TimeLockLength = byteOrder.Uint32(timeLockLengthBytes)
// Deserialize RHash // Deserialize RHash
_, err = r.Read(p.RHash[:]) _, err = r.Read(p.RHash[:])
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Deserialize Timestamp
tBytes, err := wire.ReadVarBytes(r, 0, 100,
"OutgoingPayment.Timestamp")
if err != nil {
return nil, err
}
err = p.Timestamp.UnmarshalBinary(tBytes)
if err != nil {
return nil, err
}
return p, nil return p, nil
} }

@ -1,20 +1,20 @@
package channeldb package channeldb
import ( import (
"testing"
"time"
"github.com/roasbeef/btcutil"
"bytes" "bytes"
"reflect"
"github.com/davecgh/go-spew/spew"
"math/rand"
"fmt" "fmt"
"github.com/btcsuite/fastsha256" "github.com/btcsuite/fastsha256"
"github.com/davecgh/go-spew/spew"
"github.com/roasbeef/btcutil"
"math/rand"
"reflect"
"testing"
"time"
) )
func makeFakePayment() *OutgoingPayment { func makeFakePayment() *OutgoingPayment {
// Create a fake invoice which we'll use several times in the tests // Create a fake invoice which
// below. // we'll use several times in the tests below.
fakeInvoice := &Invoice{ fakeInvoice := &Invoice{
CreationDate: time.Now(), CreationDate: time.Now(),
} }
@ -23,63 +23,84 @@ func makeFakePayment() *OutgoingPayment {
copy(fakeInvoice.Terms.PaymentPreimage[:], rev[:]) copy(fakeInvoice.Terms.PaymentPreimage[:], rev[:])
fakeInvoice.Terms.Value = btcutil.Amount(10000) fakeInvoice.Terms.Value = btcutil.Amount(10000)
// Make fake path // Make fake path
fakePath := make([][]byte, 3) fakePath := make([][33]byte, 3)
for i:=0; i<3; i++ { for i := 0; i < 3; i++ {
fakePath[i] = make([]byte, 33) for j := 0; j < 33; j++ {
for j:=0; j<33; j++ {
fakePath[i][j] = byte(i) fakePath[i][j] = byte(i)
} }
} }
var rHash [32]byte = fastsha256.Sum256(rev[:]) var rHash [32]byte = fastsha256.Sum256(rev[:])
fakePayment := & OutgoingPayment{ fakePayment := &OutgoingPayment{
Invoice: *fakeInvoice, Invoice: *fakeInvoice,
Fee: 101, Fee: 101,
Path: fakePath, Path: fakePath,
TimeLockLength: 1000, TimeLockLength: 1000,
RHash: rHash, RHash: rHash,
Timestamp: time.Unix(100000, 0),
} }
return fakePayment return fakePayment
} }
// randomBytes creates random []byte with length // randomBytes creates random []byte with length
// in range [minLen, maxLen) // in range [minLen, maxLen)
func randomBytes(minLen, maxLen int) []byte { func randomBytes(minLen, maxLen int) ([]byte, error) {
l := minLen + rand.Intn(maxLen-minLen) l := minLen + rand.Intn(maxLen-minLen)
b := make([]byte, l) b := make([]byte, l)
_, err := rand.Read(b) _, err := rand.Read(b)
if err != nil { if err != nil {
panic(fmt.Sprintf("Internal error. Cannot generate random string: %v", err)) return nil, fmt.Errorf("Internal error. "+
"Cannot generate random string: %v", err)
} }
return b return b, nil
} }
func makeRandomFakePayment() *OutgoingPayment { func makeRandomFakePayment() (*OutgoingPayment, error) {
// Create a fake invoice which we'll use several times in the tests var err error
// below.
fakeInvoice := &Invoice{ fakeInvoice := &Invoice{
CreationDate: time.Now(), CreationDate: time.Now(),
} }
fakeInvoice.Memo = randomBytes(1, 50)
fakeInvoice.Receipt = randomBytes(1, 50) fakeInvoice.Memo, err = randomBytes(1, 50)
copy(fakeInvoice.Terms.PaymentPreimage[:], randomBytes(32, 33)) if err != nil {
return nil, err
}
fakeInvoice.Receipt, err = randomBytes(1, 50)
if err != nil {
return nil, err
}
preImg, err := randomBytes(32, 33)
if err != nil {
return nil, err
}
copy(fakeInvoice.Terms.PaymentPreimage[:], preImg)
fakeInvoice.Terms.Value = btcutil.Amount(rand.Intn(10000)) fakeInvoice.Terms.Value = btcutil.Amount(rand.Intn(10000))
// Make fake path // Make fake path
fakePathLen := 1 + rand.Intn(5) fakePathLen := 1 + rand.Intn(5)
fakePath := make([][]byte, fakePathLen) fakePath := make([][33]byte, fakePathLen)
for i:=0; i<fakePathLen; i++ { for i := 0; i < fakePathLen; i++ {
fakePath[i] = randomBytes(33, 34) b, err := randomBytes(33, 34)
if err != nil {
return nil, err
}
copy(fakePath[i][:], b)
} }
var rHash [32]byte = fastsha256.Sum256( var rHash [32]byte = fastsha256.Sum256(
fakeInvoice.Terms.PaymentPreimage[:], fakeInvoice.Terms.PaymentPreimage[:],
) )
fakePayment := & OutgoingPayment{ fakePayment := &OutgoingPayment{
Invoice: *fakeInvoice, Invoice: *fakeInvoice,
Fee: btcutil.Amount(rand.Intn(1001)), Fee: btcutil.Amount(rand.Intn(1001)),
Path: fakePath, Path: fakePath,
TimeLockLength: uint64(rand.Intn(10000)), TimeLockLength: uint32(rand.Intn(10000)),
RHash: rHash, RHash: rHash,
Timestamp: time.Unix(rand.Int63n(10000), 0),
} }
return fakePayment return fakePayment, nil
} }
func TestOutgoingPaymentSerialization(t *testing.T) { func TestOutgoingPaymentSerialization(t *testing.T) {
@ -89,12 +110,15 @@ func TestOutgoingPaymentSerialization(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Can't serialize outgoing payment: %v", err) t.Fatalf("Can't serialize outgoing payment: %v", err)
} }
newPayment, err := deserializeOutgoingPayment(b) newPayment, err := deserializeOutgoingPayment(b)
if err != nil { if err != nil {
t.Fatalf("Can't deserialize outgoing payment: %v", err) t.Fatalf("Can't deserialize outgoing payment: %v", err)
} }
if !reflect.DeepEqual(fakePayment, newPayment) { if !reflect.DeepEqual(fakePayment, newPayment) {
t.Fatalf("Payments do not match after serialization/deserialization %v vs %v", t.Fatalf("Payments do not match after "+
"serialization/deserialization %v vs %v",
spew.Sdump(fakePayment), spew.Sdump(fakePayment),
spew.Sdump(newPayment), spew.Sdump(newPayment),
) )
@ -128,18 +152,23 @@ func TestOutgoingPaymentWorkflow(t *testing.T) {
} }
// Make some random payments // Make some random payments
for i:=0; i<5; i++ { for i := 0; i < 5; i++ {
randomPayment := makeRandomFakePayment() randomPayment, err := makeRandomFakePayment()
err := db.AddPayment(randomPayment) if err != nil {
t.Fatalf("Internal error in tests: %v", err)
}
err = db.AddPayment(randomPayment)
if err != nil { if err != nil {
t.Fatalf("Can't put payment in DB: %v", err) t.Fatalf("Can't put payment in DB: %v", err)
} }
correctPayments = append(correctPayments, randomPayment) correctPayments = append(correctPayments, randomPayment)
} }
payments, err = db.FetchAllPayments() payments, err = db.FetchAllPayments()
if err != nil { if err != nil {
t.Fatalf("Can't get payments from DB: %v", err) t.Fatalf("Can't get payments from DB: %v", err)
} }
if !reflect.DeepEqual(payments, correctPayments) { if !reflect.DeepEqual(payments, correctPayments) {
t.Fatalf("Wrong payments after reading from DB."+ t.Fatalf("Wrong payments after reading from DB."+
"Got %v, want %v", "Got %v, want %v",
@ -153,11 +182,14 @@ func TestOutgoingPaymentWorkflow(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Can't delete payments from DB: %v", err) t.Fatalf("Can't delete payments from DB: %v", err)
} }
// Check that there is no payments after deletion
paymentsAfterDeletion, err := db.FetchAllPayments() paymentsAfterDeletion, err := db.FetchAllPayments()
if err != nil { if err != nil {
t.Fatalf("Can't get payments after deletion: %v", err) t.Fatalf("Can't get payments after deletion: %v", err)
} }
if len(paymentsAfterDeletion) != 0 { if len(paymentsAfterDeletion) != 0 {
t.Fatalf("After deletion DB has %v payments, want %v", len(paymentsAfterDeletion), 0) t.Fatalf("After deletion DB has %v payments, want %v",
len(paymentsAfterDeletion), 0)
} }
} }

@ -495,35 +495,41 @@ func testSingleHopInvoice(net *networkHarness, t *harnessTest) {
func testListPayments(net *networkHarness, t *harnessTest) { func testListPayments(net *networkHarness, t *harnessTest) {
ctxb := context.Background() ctxb := context.Background()
timeout := time.Duration(time.Second * 5) timeout := time.Duration(time.Second * 2)
ctxt, _ := context.WithTimeout(ctxb, timeout)
// Delete all payments from Alice. DB should have no payments // Delete all payments from Alice. DB should have no payments
deleteAllPaymentsInitialReq := &lnrpc.DeleteAllPaymentsRequest{} deleteAllPaymentsInitialReq := &lnrpc.DeleteAllPaymentsRequest{}
_, err := net.Alice.DeleteAllPayments(ctxt, deleteAllPaymentsInitialReq) deleteAllPaymentsInitialCtxt, _ := context.WithTimeout(ctxb, timeout)
_, err := net.Alice.DeleteAllPayments(deleteAllPaymentsInitialCtxt,
deleteAllPaymentsInitialReq)
if err != nil { if err != nil {
t.Fatalf("Can't delete payments at the begining: %v", err) t.Fatalf("Can't delete payments at the begining: %v", err)
} }
// Check that there are no payments before test. // Check that there are no payments before test.
reqInit := &lnrpc.ListPaymentsRequest{} reqInit := &lnrpc.ListPaymentsRequest{}
paymentsRespInit, err := net.Alice.ListPayments(ctxt, reqInit) reqInitCtxt, _ := context.WithTimeout(ctxb, timeout)
paymentsRespInit, err := net.Alice.ListPayments(reqInitCtxt, reqInit)
if err != nil { if err != nil {
t.Fatalf("error when obtaining Alice payments: %v", err) t.Fatalf("error when obtaining Alice payments: %v", err)
} }
if len(paymentsRespInit.Payments) != 0 { if len(paymentsRespInit.Payments) != 0 {
t.Fatalf("incorrect number of payments, got %v, want %v", len(paymentsRespInit.Payments), 0) t.Fatalf("incorrect number of payments, got %v, want %v",
len(paymentsRespInit.Payments), 0)
} }
// Open a channel with 100k satoshis between Alice and Bob with Alice being // Open a channel with 100k satoshis
// between Alice and Bob with Alice being
// the sole funder of the channel. // the sole funder of the channel.
chanAmt := btcutil.Amount(100000) chanAmt := btcutil.Amount(100000)
chanPoint := openChannelAndAssert(t, net, ctxt, net.Alice, net.Bob, chanAmt) openChannelCtxt, _ := context.WithTimeout(ctxb, timeout)
chanPoint := openChannelAndAssert(t, net, openChannelCtxt,
net.Alice, net.Bob, chanAmt)
// Now that the channel is open, create an invoice for Bob which // Now that the channel is open, create an invoice for Bob which
// expects a payment of 1000 satoshis from Alice paid via a particular // expects a payment of 1000 satoshis from Alice
// pre-image. // paid via a particular pre-image.
const paymentAmt = 1000 const paymentAmt = 1000
preimage := bytes.Repeat([]byte("B"), 32) preimage := bytes.Repeat([]byte("B"), 32)
invoice := &lnrpc.Invoice{ invoice := &lnrpc.Invoice{
@ -531,14 +537,16 @@ func testListPayments(net *networkHarness, t *harnessTest) {
RPreimage: preimage, RPreimage: preimage,
Value: paymentAmt, Value: paymentAmt,
} }
invoiceResp, err := net.Bob.AddInvoice(ctxt, invoice) addInvoiceCtxt, _ := context.WithTimeout(ctxb, timeout)
invoiceResp, err := net.Bob.AddInvoice(addInvoiceCtxt, invoice)
if err != nil { if err != nil {
t.Fatalf("unable to add invoice: %v", err) t.Fatalf("unable to add invoice: %v", err)
} }
// With the invoice for Bob added, send a payment towards Alice paying // With the invoice for Bob added, send a payment towards Alice paying
// to the above generated invoice. // to the above generated invoice.
sendStream, err := net.Alice.SendPayment(ctxt) sendPaymentCtxt, _ := context.WithTimeout(ctxb, timeout)
sendStream, err := net.Alice.SendPayment(sendPaymentCtxt)
if err != nil { if err != nil {
t.Fatalf("unable to create alice payment stream: %v", err) t.Fatalf("unable to create alice payment stream: %v", err)
} }
@ -558,33 +566,38 @@ func testListPayments(net *networkHarness, t *harnessTest) {
// like balance here because it is already checked in // like balance here because it is already checked in
// testSingleHopInvoice // testSingleHopInvoice
req := &lnrpc.ListPaymentsRequest{} req := &lnrpc.ListPaymentsRequest{}
paymentsResp, err := net.Alice.ListPayments(ctxt, req) listPaymentsCtxt, _ := context.WithTimeout(ctxb, timeout)
paymentsResp, err := net.Alice.ListPayments(listPaymentsCtxt, req)
if err != nil { if err != nil {
t.Fatalf("error when obtaining Alice payments: %v", err) t.Fatalf("error when obtaining Alice payments: %v", err)
} }
if len(paymentsResp.Payments) != 1 { if len(paymentsResp.Payments) != 1 {
t.Fatalf("incorrect number of payments, got %v, want %v", len(paymentsResp.Payments), 1) t.Fatalf("incorrect number of payments, got %v, want %v",
len(paymentsResp.Payments), 1)
} }
p := paymentsResp.Payments[0] p := paymentsResp.Payments[0]
// Check path. // Check path.
expectedPath := []string { expectedPath := []string{
net.Alice.PubKeyStr, net.Alice.PubKeyStr,
net.Bob.PubKeyStr, net.Bob.PubKeyStr,
} }
if !reflect.DeepEqual(p.Path, expectedPath) { if !reflect.DeepEqual(p.Path, expectedPath) {
t.Fatalf("incorrect path, got %v, want %v", p.Path, expectedPath) t.Fatalf("incorrect path, got %v, want %v",
p.Path, expectedPath)
} }
// Check amount. // Check amount.
if p.Value != paymentAmt { if p.Value != paymentAmt {
t.Fatalf("incorrect amount, got %v, want %v", p.Value, paymentAmt) t.Fatalf("incorrect amount, got %v, want %v",
p.Value, paymentAmt)
} }
// Check RHash. // Check RHash.
correctRHash := hex.EncodeToString(invoiceResp.RHash) correctRHash := hex.EncodeToString(invoiceResp.RHash)
if !reflect.DeepEqual(p.RHash, correctRHash){ if !reflect.DeepEqual(p.RHash, correctRHash) {
t.Fatalf("incorrect RHash, got %v, want %v", p.RHash, correctRHash) t.Fatalf("incorrect RHash, got %v, want %v",
p.RHash, correctRHash)
} }
// Check Fee. // Check Fee.
@ -595,26 +608,31 @@ func testListPayments(net *networkHarness, t *harnessTest) {
// Delete all payments from Alice. DB should have no payments. // Delete all payments from Alice. DB should have no payments.
deleteAllPaymentsEndReq := &lnrpc.DeleteAllPaymentsRequest{} deleteAllPaymentsEndReq := &lnrpc.DeleteAllPaymentsRequest{}
_, err = net.Alice.DeleteAllPayments(ctxt, deleteAllPaymentsEndReq) deleteAllPaymentsEndCtxt, _ := context.WithTimeout(ctxb, timeout)
_, err = net.Alice.DeleteAllPayments(deleteAllPaymentsEndCtxt,
deleteAllPaymentsEndReq)
if err != nil { if err != nil {
t.Fatalf("Can't delete payments at the end: %v", err) t.Fatalf("Can't delete payments at the end: %v", err)
} }
// Check that there are no payments before test. // Check that there are no payments before test.
reqEnd := &lnrpc.ListPaymentsRequest{} reqEnd := &lnrpc.ListPaymentsRequest{}
_, err = net.Alice.ListPayments(ctxt, reqEnd) listPaymentsEndCtxt, _ := context.WithTimeout(ctxb, timeout)
_, err = net.Alice.ListPayments(listPaymentsEndCtxt, reqEnd)
if err != nil { if err != nil {
t.Fatalf("error when obtaining Alice payments: %v", err) t.Fatalf("error when obtaining Alice payments: %v", err)
} }
if len(paymentsRespInit.Payments) != 0 { if len(paymentsRespInit.Payments) != 0 {
t.Fatalf("incorrect number of payments, got %v, want %v", len(paymentsRespInit.Payments), 0) t.Fatalf("incorrect number of payments, got %v, want %v",
len(paymentsRespInit.Payments), 0)
} }
closeChannelCtxt, _ := context.WithTimeout(ctxb, timeout)
ctxt, _ = context.WithTimeout(ctxb, timeout) closeChannelAndAssert(t, net, closeChannelCtxt,
closeChannelAndAssert(t, net, ctxt, net.Alice, chanPoint, false) net.Alice, chanPoint, false)
} }
func testMultiHopPayments(net *networkHarness, t *harnessTest) { func testMultiHopPayments(net *networkHarness, t *harnessTest) {
const chanAmt = btcutil.Amount(100000) const chanAmt = btcutil.Amount(100000)
ctxb := context.Background() ctxb := context.Background()

@ -598,15 +598,22 @@ func (r *rpcServer) ListChannels(ctx context.Context,
} }
func constructPayment(path []graph.Vertex, amount btcutil.Amount, rHash []byte) *channeldb.OutgoingPayment { func constructPayment(path []graph.Vertex, amount btcutil.Amount, rHash []byte) *channeldb.OutgoingPayment {
payment := &channeldb.OutgoingPayment{} payment := &channeldb.OutgoingPayment{}
// When we create payment we do not know preImage.
// So we need to save rHash
copy(payment.RHash[:], rHash) copy(payment.RHash[:], rHash)
payment.Invoice.Terms.Value = btcutil.Amount(amount) payment.Invoice.Terms.Value = btcutil.Amount(amount)
payment.Invoice.CreationDate = time.Now() payment.Invoice.CreationDate = time.Now()
pathBytes := make([][]byte, len(path)) payment.Timestamp = time.Now()
pathBytes33 := make([][33]byte, len(path))
for i:=0; i<len(path); i++ { for i:=0; i<len(path); i++ {
pathBytes[i] = path[i].ToByte() pathBytes33[i] = path[i].ToByte33()
} }
payment.Path = pathBytes payment.Path = pathBytes33
return payment return payment
} }
@ -687,9 +694,12 @@ func (r *rpcServer) SendPayment(paymentStream lnrpc.Lightning_SendPaymentServer)
errChan <- err errChan <- err
return return
} }
// Save payment to DB. // Save payment to DB.
payment := constructPayment(path, btcutil.Amount(nextPayment.Amt), rHash[:]) payment := constructPayment(path,
btcutil.Amount(nextPayment.Amt), rHash[:])
r.server.chanDB.AddPayment(payment) r.server.chanDB.AddPayment(payment)
// TODO(roasbeef): proper responses // TODO(roasbeef): proper responses
resp := &lnrpc.SendResponse{} resp := &lnrpc.SendResponse{}
if err := paymentStream.Send(resp); err != nil { if err := paymentStream.Send(resp); err != nil {
@ -1136,22 +1146,25 @@ func (r *rpcServer) ShowRoutingTable(ctx context.Context,
// ListPayments returns a list of all outgoing payments. // ListPayments returns a list of all outgoing payments.
func (r *rpcServer) ListPayments(context.Context, func (r *rpcServer) ListPayments(context.Context,
*lnrpc.ListPaymentsRequest) (*lnrpc.ListPaymentsResponse, error) { *lnrpc.ListPaymentsRequest) (*lnrpc.ListPaymentsResponse, error) {
rpcsLog.Debugf("[ListPayments]") rpcsLog.Debugf("[ListPayments]")
payments, err := r.server.chanDB.FetchAllPayments() payments, err := r.server.chanDB.FetchAllPayments()
if err != nil { if err != nil {
return nil, err return nil, err
} }
paymentsResp := &lnrpc.ListPaymentsResponse{ paymentsResp := &lnrpc.ListPaymentsResponse{
Payments: make([]*lnrpc.Payment, len(payments)), Payments: make([]*lnrpc.Payment, len(payments)),
} }
for i:=0; i<len(payments); i++{ for i:=0; i<len(payments); i++ {
p := &lnrpc.Payment{} p := &lnrpc.Payment{}
p.CreationDate = payments[i].CreationDate.Unix() p.CreationDate = payments[i].CreationDate.Unix()
p.Value = int64(payments[i].Terms.Value) p.Value = int64(payments[i].Terms.Value)
p.RHash = hex.EncodeToString(payments[i].RHash[:]) p.RHash = hex.EncodeToString(payments[i].RHash[:])
path := make([]string, len(payments[i].Path)) path := make([]string, len(payments[i].Path))
for j:=0; j<len(path); j++ { for j:=0; j<len(path); j++ {
path[j] = hex.EncodeToString(payments[i].Path[j]) path[j] = hex.EncodeToString(payments[i].Path[j][:])
} }
p.Path = path p.Path = path
paymentsResp.Payments[i] = p paymentsResp.Payments[i] = p
@ -1163,7 +1176,9 @@ func (r *rpcServer) ListPayments(context.Context,
// DeleteAllPayments deletes all outgoing payments from DB. // DeleteAllPayments deletes all outgoing payments from DB.
func (r *rpcServer) DeleteAllPayments(context.Context, func (r *rpcServer) DeleteAllPayments(context.Context,
*lnrpc.DeleteAllPaymentsRequest) (*lnrpc.DeleteAllPaymentsResponse, error) { *lnrpc.DeleteAllPaymentsRequest) (*lnrpc.DeleteAllPaymentsResponse, error) {
rpcsLog.Debugf("[DeleteAllPayments]") rpcsLog.Debugf("[DeleteAllPayments]")
err := r.server.chanDB.DeleteAllPayments() err := r.server.chanDB.DeleteAllPayments()
resp := &lnrpc.DeleteAllPaymentsResponse{} resp := &lnrpc.DeleteAllPaymentsResponse{}
return resp, err return resp, err