diff --git a/channeldb/payments.go b/channeldb/payments.go index b10749c2..774eea57 100644 --- a/channeldb/payments.go +++ b/channeldb/payments.go @@ -3,117 +3,122 @@ package channeldb import ( "bytes" "encoding/binary" - "github.com/boltdb/bolt" - "github.com/roasbeef/btcd/wire" - "github.com/roasbeef/btcutil" "io" - "time" + + "github.com/boltdb/bolt" + "github.com/roasbeef/btcutil" ) var ( - // invoiceBucket is the name of the bucket within - // the database that stores all data related to payments. - // Within the payments bucket, each invoice is keyed - // by its invoice ID - // which is a monotonically increasing uint64. - // BoltDB sequence feature is used for generating - // monotonically increasing id. + // paymentBucket is the name of the bucket within the database that + // stores all data related to payments. + // + // Within the payments bucket, each invoice is keyed by its invoice ID + // which is a monotonically increasing uint64. BoltDB's sequence + // feature is used for generating monotonically increasing id. paymentBucket = []byte("payments") ) -// OutgoingPayment represents payment from given node. +// OutgoingPayment represents a successful payment between the daemon and a +// remote node. Details such as the total fee paid, and the time of the payment +// are stored. type OutgoingPayment struct { Invoice - // Total fee paid. + // Fee is the total fee paid for the payment in satoshis. Fee btcutil.Amount - // Path including starting and ending nodes. - Path [][33]byte - - // Timelock length. + // TotalTimeLock is the total cumulative time-lock in the HTLC extended + // from the second-to-last hop to the destination. TimeLockLength uint32 - // RHash value used for payment. - // We need RHash because we start payment knowing only RHash - RHash [32]byte + // Path encodes the path the payment took throuhg the network. The path + // excludes the outgoing node and consists of the hex-encoded + // compressed public key of each of the nodes involved in the payment. + Path [][33]byte - // Timestamp is time when payment was created. - Timestamp time.Time + // PaymentHash is the payment hash (r-hash) used to send the payment. + // + // TODO(roasbeef): weave through preimage on payment success to can + // store only supplemental info the embedded Invoice + PaymentHash [32]byte } -// AddPayment adds payment to DB. -// There is no checking that payment with the same hash already exist. -func (db *DB) AddPayment(p *OutgoingPayment) error { - err := validateInvoice(&p.Invoice) - if err != nil { +// AddPayment saves a successful payment to the database. It is assumed that +// all payment are sent using unique payment hashes. +func (db *DB) AddPayment(payment *OutgoingPayment) error { + // Validate the field of the inner voice within the outgoing payment, + // these must also adhere to the same constraints as regular invoices. + if err := validateInvoice(&payment.Invoice); err != nil { return err } - // We serialize before writing to database - // so no db access in the case of serialization errors - b := new(bytes.Buffer) - err = serializeOutgoingPayment(b, p) - if err != nil { + // We first serialize the payment before starting the database + // transaction so we can avoid creating a DB payment in the case of a + // serialization error. + var b bytes.Buffer + if err := serializeOutgoingPayment(&b, payment); err != nil { return err } paymentBytes := b.Bytes() + return db.Update(func(tx *bolt.Tx) error { payments, err := tx.CreateBucketIfNotExists(paymentBucket) if err != nil { return err } + // Obtain the new unique sequence number for this payment. paymentId, err := payments.NextSequence() if err != nil { return err } - // We use BigEndian for keys because - // it orders keys in ascending order + // We use BigEndian for keys as it orders keys in + // ascending order. This allows bucket scans to order payments + // in the order in which they were created. paymentIdBytes := make([]byte, 8) binary.BigEndian.PutUint64(paymentIdBytes, paymentId) - err = payments.Put(paymentIdBytes, paymentBytes) - if err != nil { - return err - } - return nil + + return payments.Put(paymentIdBytes, paymentBytes) }) } // FetchAllPayments returns all outgoing payments in DB. func (db *DB) FetchAllPayments() ([]*OutgoingPayment, error) { var payments []*OutgoingPayment + err := db.View(func(tx *bolt.Tx) error { bucket := tx.Bucket(paymentBucket) if bucket == nil { return ErrNoPaymentsCreated } - err := bucket.ForEach(func(k, v []byte) error { - // Value can be nil if it is a sub-backet - // so simply ignore it. + + return bucket.ForEach(func(k, v []byte) error { + // If the value is nil, then we ignore it as it may be + // a sub-bucket. if v == nil { return nil } + r := bytes.NewReader(v) payment, err := deserializeOutgoingPayment(r) if err != nil { return err } + payments = append(payments, payment) return nil }) - return err }) if err != nil { return nil, err } + return payments, nil } // DeleteAllPayments deletes all payments from DB. -// If payments bucket does not exist it will create -// new bucket without error. func (db *DB) DeleteAllPayments() error { return db.Update(func(tx *bolt.Tx) error { err := tx.DeleteBucket(paymentBucket) @@ -125,124 +130,87 @@ func (db *DB) DeleteAllPayments() error { if err != nil { return err } - return err + + return nil }) } func serializeOutgoingPayment(w io.Writer, p *OutgoingPayment) error { - err := serializeInvoice(w, &p.Invoice) - if err != nil { + var scratch [8]byte + + if err := serializeInvoice(w, &p.Invoice); err != nil { return err } - // Serialize fee. - feeBytes := make([]byte, 8) - byteOrder.PutUint64(feeBytes, uint64(p.Fee)) - _, err = w.Write(feeBytes) - if err != nil { + byteOrder.PutUint64(scratch[:], uint64(p.Fee)) + if _, err := w.Write(scratch[:]); err != nil { return err } - // Serialize path. + // First write out the length of the bytes to prefix the value. pathLen := uint32(len(p.Path)) - pathLenBytes := make([]byte, 4) - // Write length of the path - byteOrder.PutUint32(pathLenBytes, pathLen) - _, err = w.Write(pathLenBytes) - if err != nil { + byteOrder.PutUint32(scratch[:4], pathLen) + if _, err := w.Write(scratch[:4]); err != nil { return err } - // Serialize each element of the path - for i := uint32(0); i < pathLen; i++ { - _, err := w.Write(p.Path[i][:]) - if err != nil { + + // Then with the path written, we write out the series of public keys + // involved in the path. + for _, hop := range p.Path { + if _, err := w.Write(hop[:]); err != nil { return err } } - // Serialize TimeLockLength - timeLockLengthBytes := make([]byte, 4) - byteOrder.PutUint32(timeLockLengthBytes, p.TimeLockLength) - _, err = w.Write(timeLockLengthBytes) - if err != nil { + byteOrder.PutUint32(scratch[:4], p.TimeLockLength) + if _, err := w.Write(scratch[:4]); err != nil { return err } - // Serialize RHash - _, err = w.Write(p.RHash[:]) - if err != nil { + if _, err := w.Write(p.PaymentHash[:]); err != nil { return err } - // Serialize Timestamp. - tBytes, err := p.Timestamp.MarshalBinary() - if err != nil { - return err - } - err = wire.WriteVarBytes(w, 0, tBytes) - if err != nil { - return err - } return nil } func deserializeOutgoingPayment(r io.Reader) (*OutgoingPayment, error) { + var scratch [8]byte + p := &OutgoingPayment{} - // Deserialize invoice inv, err := deserializeInvoice(r) if err != nil { return nil, err } p.Invoice = *inv - // Deserialize fee - feeBytes := make([]byte, 8) - _, err = r.Read(feeBytes) - if err != nil { + if _, err := r.Read(scratch[:]); err != nil { return nil, err } - p.Fee = btcutil.Amount(byteOrder.Uint64(feeBytes)) + p.Fee = btcutil.Amount(byteOrder.Uint64(scratch[:])) - // Deserialize path - pathLenBytes := make([]byte, 4) - _, err = r.Read(pathLenBytes) - if err != nil { + if _, err = r.Read(scratch[:4]); err != nil { return nil, err } - pathLen := byteOrder.Uint32(pathLenBytes) + pathLen := byteOrder.Uint32(scratch[:4]) + path := make([][33]byte, pathLen) for i := uint32(0); i < pathLen; i++ { - _, err := r.Read(path[i][:]) - if err != nil { + if _, err := r.Read(path[i][:]); err != nil { return nil, err } } p.Path = path - // Deserialize TimeLockLength - timeLockLengthBytes := make([]byte, 4) - _, err = r.Read(timeLockLengthBytes) - if err != nil { + if _, err = r.Read(scratch[:4]); err != nil { return nil, err } - p.TimeLockLength = byteOrder.Uint32(timeLockLengthBytes) + p.TimeLockLength = byteOrder.Uint32(scratch[:4]) - // Deserialize RHash - _, err = r.Read(p.RHash[:]) - if err != nil { + if _, err := r.Read(p.PaymentHash[:]); err != nil { return nil, err } - // Deserialize Timestamp - tBytes, err := wire.ReadVarBytes(r, 0, 100, - "OutgoingPayment.Timestamp") - if err != nil { - return nil, err - } - err = p.Timestamp.UnmarshalBinary(tBytes) - if err != nil { - return nil, err - } return p, nil } diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index edd64370..b887e370 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -3,55 +3,50 @@ package channeldb import ( "bytes" "fmt" - "github.com/btcsuite/fastsha256" - "github.com/davecgh/go-spew/spew" - "github.com/roasbeef/btcutil" "math/rand" "reflect" "testing" "time" + + "github.com/btcsuite/fastsha256" + "github.com/davecgh/go-spew/spew" + "github.com/roasbeef/btcutil" ) func makeFakePayment() *OutgoingPayment { - // Create a fake invoice which - // we'll use several times in the tests below. fakeInvoice := &Invoice{ CreationDate: time.Now(), + Memo: []byte("fake memo"), + Receipt: []byte("fake receipt"), } - fakeInvoice.Memo = []byte("memo") - fakeInvoice.Receipt = []byte("recipt") + copy(fakeInvoice.Terms.PaymentPreimage[:], rev[:]) fakeInvoice.Terms.Value = btcutil.Amount(10000) - // Make fake path + fakePath := make([][33]byte, 3) for i := 0; i < 3; i++ { - for j := 0; j < 33; j++ { - fakePath[i][j] = byte(i) - } + copy(fakePath[i][:], bytes.Repeat([]byte{byte(i)}, 33)) } - var rHash [32]byte = fastsha256.Sum256(rev[:]) - fakePayment := &OutgoingPayment{ + + return &OutgoingPayment{ Invoice: *fakeInvoice, Fee: 101, Path: fakePath, TimeLockLength: 1000, - RHash: rHash, - Timestamp: time.Unix(100000, 0), + PaymentHash: fastsha256.Sum256(rev[:]), } - return fakePayment } -// randomBytes creates random []byte with length -// in range [minLen, maxLen) +// randomBytes creates random []byte with length in range [minLen, maxLen) func randomBytes(minLen, maxLen int) ([]byte, error) { - l := minLen + rand.Intn(maxLen-minLen) - b := make([]byte, l) - _, err := rand.Read(b) - if err != nil { + randBuf := make([]byte, minLen+rand.Intn(maxLen-minLen)) + + if _, err := rand.Read(randBuf); err != nil { return nil, fmt.Errorf("Internal error. "+ "Cannot generate random string: %v", err) } - return b, nil + + return randBuf, nil } func makeRandomFakePayment() (*OutgoingPayment, error) { @@ -78,7 +73,6 @@ func makeRandomFakePayment() (*OutgoingPayment, error) { fakeInvoice.Terms.Value = btcutil.Amount(rand.Intn(10000)) - // Make fake path fakePathLen := 1 + rand.Intn(5) fakePath := make([][33]byte, fakePathLen) for i := 0; i < fakePathLen; i++ { @@ -89,31 +83,29 @@ func makeRandomFakePayment() (*OutgoingPayment, error) { copy(fakePath[i][:], b) } - var rHash [32]byte = fastsha256.Sum256( - fakeInvoice.Terms.PaymentPreimage[:], - ) + rHash := fastsha256.Sum256(fakeInvoice.Terms.PaymentPreimage[:]) fakePayment := &OutgoingPayment{ Invoice: *fakeInvoice, Fee: btcutil.Amount(rand.Intn(1001)), Path: fakePath, TimeLockLength: uint32(rand.Intn(10000)), - RHash: rHash, - Timestamp: time.Unix(rand.Int63n(10000), 0), + PaymentHash: rHash, } + return fakePayment, nil } func TestOutgoingPaymentSerialization(t *testing.T) { fakePayment := makeFakePayment() - b := new(bytes.Buffer) - err := serializeOutgoingPayment(b, fakePayment) - if err != nil { - t.Fatalf("Can't serialize outgoing payment: %v", err) + + var b bytes.Buffer + if err := serializeOutgoingPayment(&b, fakePayment); err != nil { + t.Fatalf("unable to serialize outgoing payment: %v", err) } - newPayment, err := deserializeOutgoingPayment(b) + newPayment, err := deserializeOutgoingPayment(&b) if err != nil { - t.Fatalf("Can't deserialize outgoing payment: %v", err) + t.Fatalf("unable to deserialize outgoing payment: %v", err) } if !reflect.DeepEqual(fakePayment, newPayment) { @@ -127,27 +119,27 @@ func TestOutgoingPaymentSerialization(t *testing.T) { func TestOutgoingPaymentWorkflow(t *testing.T) { db, cleanUp, err := makeTestDB() + defer cleanUp() if err != nil { t.Fatalf("unable to make test db: %v", err) } - defer cleanUp() fakePayment := makeFakePayment() - err = db.AddPayment(fakePayment) - if err != nil { - t.Fatalf("Can't put payment in DB: %v", err) + if err = db.AddPayment(fakePayment); err != nil { + t.Fatalf("unable to put payment in DB: %v", err) } payments, err := db.FetchAllPayments() if err != nil { - t.Fatalf("Can't get payments from DB: %v", err) + t.Fatalf("unable to fetch payments from DB: %v", err) } - correctPayments := []*OutgoingPayment{fakePayment} - if !reflect.DeepEqual(payments, correctPayments) { + + expectedPayments := []*OutgoingPayment{fakePayment} + if !reflect.DeepEqual(payments, expectedPayments) { t.Fatalf("Wrong payments after reading from DB."+ "Got %v, want %v", spew.Sdump(payments), - spew.Sdump(correctPayments), + spew.Sdump(expectedPayments), ) } @@ -157,11 +149,12 @@ func TestOutgoingPaymentWorkflow(t *testing.T) { if err != nil { t.Fatalf("Internal error in tests: %v", err) } - err = db.AddPayment(randomPayment) - if err != nil { - t.Fatalf("Can't put payment in DB: %v", err) + + if err = db.AddPayment(randomPayment); err != nil { + t.Fatalf("unable to put payment in DB: %v", err) } - correctPayments = append(correctPayments, randomPayment) + + expectedPayments = append(expectedPayments, randomPayment) } payments, err = db.FetchAllPayments() @@ -169,18 +162,17 @@ func TestOutgoingPaymentWorkflow(t *testing.T) { t.Fatalf("Can't get payments from DB: %v", err) } - if !reflect.DeepEqual(payments, correctPayments) { + if !reflect.DeepEqual(payments, expectedPayments) { t.Fatalf("Wrong payments after reading from DB."+ "Got %v, want %v", spew.Sdump(payments), - spew.Sdump(correctPayments), + spew.Sdump(expectedPayments), ) } // Delete all payments. - err = db.DeleteAllPayments() - if err != nil { - t.Fatalf("Can't delete payments from DB: %v", err) + if err = db.DeleteAllPayments(); err != nil { + t.Fatalf("unable to delete payments from DB: %v", err) } // Check that there is no payments after deletion