From 4c872c438b729d9d1f471dbba803543ec9b10498 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 22 Nov 2019 02:24:28 -0800 Subject: [PATCH] channeldb: complete migration 12 for TLV invoices --- channeldb/db.go | 7 + channeldb/invoice_test.go | 8 +- channeldb/invoices.go | 282 +++++++++++++++++------------- channeldb/log.go | 2 + channeldb/migration12/invoices.go | 4 +- htlcswitch/test_utils.go | 3 + invoices/invoiceregistry_test.go | 31 ++-- lnrpc/invoicesrpc/addinvoice.go | 8 + lnwire/features.go | 7 + zpay32/invoice.go | 8 + 10 files changed, 216 insertions(+), 144 deletions(-) diff --git a/channeldb/db.go b/channeldb/db.go index e8a8e5af..cf811aa6 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/coreos/bbolt" "github.com/go-errors/errors" + "github.com/lightningnetwork/lnd/channeldb/migration12" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" "github.com/lightningnetwork/lnd/lnwire" ) @@ -116,6 +117,12 @@ var ( number: 11, migration: migration_01_to_11.MigrateInvoices, }, + { + // Migrate to TLV invoice bodies, add payment address + // and features, remove receipt. + number: 12, + migration: migration12.MigrateInvoiceTLV, + }, } // Big endian is the preferred byte order, due to cursor scans over diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index 4b5dda87..e109d387 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -10,6 +10,10 @@ import ( "github.com/lightningnetwork/lnd/lnwire" ) +var ( + emptyFeatures = lnwire.NewFeatureVector(nil, lnwire.Features) +) + func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) { var pre [32]byte if _, err := rand.Read(pre[:]); err != nil { @@ -23,12 +27,12 @@ func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) { Terms: ContractTerm{ PaymentPreimage: pre, Value: value, + Features: emptyFeatures, }, Htlcs: map[CircuitKey]*InvoiceHTLC{}, Expiry: 4000, } i.Memo = []byte("memo") - i.Receipt = []byte("receipt") // Create a random byte slice of MaxPaymentRequestSize bytes to be used // as a dummy paymentrequest, and determine if it should be set based @@ -64,10 +68,10 @@ func TestInvoiceWorkflow(t *testing.T) { Htlcs: map[CircuitKey]*InvoiceHTLC{}, } fakeInvoice.Memo = []byte("memo") - fakeInvoice.Receipt = []byte("receipt") fakeInvoice.PaymentRequest = []byte("") copy(fakeInvoice.Terms.PaymentPreimage[:], rev[:]) fakeInvoice.Terms.Value = lnwire.NewMSatFromSatoshis(10000) + fakeInvoice.Terms.Features = emptyFeatures paymentHash := fakeInvoice.Terms.PaymentPreimage.Hash() diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 20f0c0a8..1f17d5f3 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -8,7 +8,6 @@ import ( "io" "time" - "github.com/btcsuite/btcd/wire" "github.com/coreos/bbolt" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" @@ -84,10 +83,6 @@ const ( // in the database. MaxMemoSize = 1024 - // MaxReceiptSize is the maximum size of the payment receipt stored - // within the database along side incoming/outgoing invoices. - MaxReceiptSize = 1024 - // MaxPaymentRequestSize is the max size of a payment request for // this invoice. // TODO(halseth): determine the max length payment request when field @@ -96,6 +91,11 @@ const ( // A set of tlv type definitions used to serialize invoice htlcs to the // database. + // + // NOTE: A migration should be added whenever this list changes. This + // prevents against the database being rolled back to an older + // format where the surrounding logic might assume a different set of + // fields are known. chanIDType tlv.Type = 1 htlcIDType tlv.Type = 3 amtType tlv.Type = 5 @@ -103,7 +103,28 @@ const ( acceptTimeType tlv.Type = 9 resolveTimeType tlv.Type = 11 expiryHeightType tlv.Type = 13 - stateType tlv.Type = 15 + htlcStateType tlv.Type = 15 + + // A set of tlv type definitions used to serialize invoice bodiees. + // + // NOTE: A migration should be added whenever this list changes. This + // prevents against the database being rolled back to an older + // format where the surrounding logic might assume a different set of + // fields are known. + memoType tlv.Type = 0 + payReqType tlv.Type = 1 + createTimeType tlv.Type = 2 + settleTimeType tlv.Type = 3 + addIndexType tlv.Type = 4 + settleIndexType tlv.Type = 5 + preimageType tlv.Type = 6 + valueType tlv.Type = 7 + cltvDeltaType tlv.Type = 8 + expiryType tlv.Type = 9 + paymentAddrType tlv.Type = 10 + featuresType tlv.Type = 11 + invStateType tlv.Type = 12 + amtPaidType tlv.Type = 13 ) // ContractState describes the state the invoice is in. @@ -156,6 +177,13 @@ type ContractTerm struct { // State describes the state the invoice is in. State ContractState + + // PaymentAddr is a randomly generated value include in the MPP record + // by the sender to prevent probing of the receiver. + PaymentAddr [32]byte + + // Features is the feature vectors advertised on the payment request. + Features *lnwire.FeatureVector } // Invoice is a payment invoice generated by a payee in order to request @@ -174,12 +202,6 @@ type Invoice struct { // or any other message which fits within the size constraints. Memo []byte - // Receipt is an optional field dedicated for storing a - // cryptographically binding receipt of payment. - // - // TODO(roasbeef): document scheme. - Receipt []byte - // PaymentRequest is an optional field where a payment request created // for this invoice can be stored. PaymentRequest []byte @@ -312,16 +334,14 @@ func validateInvoice(i *Invoice) error { return fmt.Errorf("max length a memo is %v, and invoice "+ "of length %v was provided", MaxMemoSize, len(i.Memo)) } - if len(i.Receipt) > MaxReceiptSize { - return fmt.Errorf("max length a receipt is %v, and invoice "+ - "of length %v was provided", MaxReceiptSize, - len(i.Receipt)) - } if len(i.PaymentRequest) > MaxPaymentRequestSize { return fmt.Errorf("max length of payment request is %v, length "+ "provided was %v", MaxPaymentRequestSize, len(i.PaymentRequest)) } + if i.Terms.Features == nil { + return errors.New("invoice must have a feature vector") + } return nil } @@ -892,71 +912,73 @@ func putInvoice(invoices, invoiceIndex, addIndex *bbolt.Bucket, // would modify the on disk format, make a copy of the original code and store // it with the migration. func serializeInvoice(w io.Writer, i *Invoice) error { - if err := wire.WriteVarBytes(w, 0, i.Memo[:]); err != nil { - return err - } - if err := wire.WriteVarBytes(w, 0, i.Receipt[:]); err != nil { - return err - } - if err := wire.WriteVarBytes(w, 0, i.PaymentRequest[:]); err != nil { - return err - } - - if err := binary.Write(w, byteOrder, i.FinalCltvDelta); err != nil { - return err - } - - if err := binary.Write(w, byteOrder, int64(i.Expiry)); err != nil { - return err - } - - birthBytes, err := i.CreationDate.MarshalBinary() + creationDateBytes, err := i.CreationDate.MarshalBinary() if err != nil { return err } - if err := wire.WriteVarBytes(w, 0, birthBytes); err != nil { - return err - } - - settleBytes, err := i.SettleDate.MarshalBinary() + settleDateBytes, err := i.SettleDate.MarshalBinary() if err != nil { return err } - if err := wire.WriteVarBytes(w, 0, settleBytes); err != nil { + var fb bytes.Buffer + err = i.Terms.Features.EncodeBase256(&fb) + if err != nil { + return err + } + featureBytes := fb.Bytes() + + preimage := [32]byte(i.Terms.PaymentPreimage) + value := uint64(i.Terms.Value) + cltvDelta := uint32(i.FinalCltvDelta) + expiry := uint64(i.Expiry) + + amtPaid := uint64(i.AmtPaid) + state := uint8(i.Terms.State) + + tlvStream, err := tlv.NewStream( + // Memo and payreq. + tlv.MakePrimitiveRecord(memoType, &i.Memo), + tlv.MakePrimitiveRecord(payReqType, &i.PaymentRequest), + + // Add/settle metadata. + tlv.MakePrimitiveRecord(createTimeType, &creationDateBytes), + tlv.MakePrimitiveRecord(settleTimeType, &settleDateBytes), + tlv.MakePrimitiveRecord(addIndexType, &i.AddIndex), + tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex), + + // Terms. + tlv.MakePrimitiveRecord(preimageType, &preimage), + tlv.MakePrimitiveRecord(valueType, &value), + tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta), + tlv.MakePrimitiveRecord(expiryType, &expiry), + tlv.MakePrimitiveRecord(paymentAddrType, &i.Terms.PaymentAddr), + tlv.MakePrimitiveRecord(featuresType, &featureBytes), + + // Invoice state. + tlv.MakePrimitiveRecord(invStateType, &state), + tlv.MakePrimitiveRecord(amtPaidType, &amtPaid), + ) + if err != nil { return err } - if _, err := w.Write(i.Terms.PaymentPreimage[:]); err != nil { + var b bytes.Buffer + if err = tlvStream.Encode(&b); err != nil { return err } - var scratch [8]byte - byteOrder.PutUint64(scratch[:], uint64(i.Terms.Value)) - if _, err := w.Write(scratch[:]); err != nil { + err = binary.Write(w, byteOrder, uint64(b.Len())) + if err != nil { return err } - if err := binary.Write(w, byteOrder, i.Terms.State); err != nil { + if _, err = w.Write(b.Bytes()); err != nil { return err } - if err := binary.Write(w, byteOrder, i.AddIndex); err != nil { - return err - } - if err := binary.Write(w, byteOrder, i.SettleIndex); err != nil { - return err - } - if err := binary.Write(w, byteOrder, int64(i.AmtPaid)); err != nil { - return err - } - - if err := serializeHtlcs(w, i.Htlcs); err != nil { - return err - } - - return nil + return serializeHtlcs(w, i.Htlcs) } // serializeHtlcs serializes a map containing circuit keys and invoice htlcs to @@ -980,7 +1002,7 @@ func serializeHtlcs(w io.Writer, htlcs map[CircuitKey]*InvoiceHTLC) error { tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime), tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime), tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry), - tlv.MakePrimitiveRecord(stateType, &state), + tlv.MakePrimitiveRecord(htlcStateType, &state), ) if err != nil { return err @@ -1018,79 +1040,89 @@ func fetchInvoice(invoiceNum []byte, invoices *bbolt.Bucket) (Invoice, error) { } func deserializeInvoice(r io.Reader) (Invoice, error) { - var err error - invoice := Invoice{} + var ( + preimage [32]byte + value uint64 + cltvDelta uint32 + expiry uint64 + amtPaid uint64 + state uint8 - // TODO(roasbeef): use read full everywhere - invoice.Memo, err = wire.ReadVarBytes(r, 0, MaxMemoSize, "") + creationDateBytes []byte + settleDateBytes []byte + featureBytes []byte + ) + + var i Invoice + tlvStream, err := tlv.NewStream( + // Memo and payreq. + tlv.MakePrimitiveRecord(memoType, &i.Memo), + tlv.MakePrimitiveRecord(payReqType, &i.PaymentRequest), + + // Add/settle metadata. + tlv.MakePrimitiveRecord(createTimeType, &creationDateBytes), + tlv.MakePrimitiveRecord(settleTimeType, &settleDateBytes), + tlv.MakePrimitiveRecord(addIndexType, &i.AddIndex), + tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex), + + // Terms. + tlv.MakePrimitiveRecord(preimageType, &preimage), + tlv.MakePrimitiveRecord(valueType, &value), + tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta), + tlv.MakePrimitiveRecord(expiryType, &expiry), + tlv.MakePrimitiveRecord(paymentAddrType, &i.Terms.PaymentAddr), + tlv.MakePrimitiveRecord(featuresType, &featureBytes), + + // Invoice state. + tlv.MakePrimitiveRecord(invStateType, &state), + tlv.MakePrimitiveRecord(amtPaidType, &amtPaid), + ) if err != nil { - return invoice, err + return i, err } - invoice.Receipt, err = wire.ReadVarBytes(r, 0, MaxReceiptSize, "") + + var bodyLen int64 + err = binary.Read(r, byteOrder, &bodyLen) if err != nil { - return invoice, err + return i, err } - invoice.PaymentRequest, err = wire.ReadVarBytes(r, 0, MaxPaymentRequestSize, "") + lr := io.LimitReader(r, bodyLen) + if err = tlvStream.Decode(lr); err != nil { + return i, err + } + + i.Terms.PaymentPreimage = lntypes.Preimage(preimage) + i.Terms.Value = lnwire.MilliSatoshi(value) + i.FinalCltvDelta = int32(cltvDelta) + i.Expiry = time.Duration(expiry) + i.AmtPaid = lnwire.MilliSatoshi(amtPaid) + i.Terms.State = ContractState(state) + + err = i.CreationDate.UnmarshalBinary(creationDateBytes) if err != nil { - return invoice, err + return i, err } - if err := binary.Read(r, byteOrder, &invoice.FinalCltvDelta); err != nil { - return invoice, err - } - - var expiry int64 - if err := binary.Read(r, byteOrder, &expiry); err != nil { - return invoice, err - } - invoice.Expiry = time.Duration(expiry) - - birthBytes, err := wire.ReadVarBytes(r, 0, 300, "birth") + err = i.SettleDate.UnmarshalBinary(settleDateBytes) if err != nil { - return invoice, err - } - if err := invoice.CreationDate.UnmarshalBinary(birthBytes); err != nil { - return invoice, err + return i, err } - settledBytes, err := wire.ReadVarBytes(r, 0, 300, "settled") + rawFeatures := lnwire.NewRawFeatureVector() + err = rawFeatures.DecodeBase256( + bytes.NewReader(featureBytes), len(featureBytes), + ) if err != nil { - return invoice, err - } - if err := invoice.SettleDate.UnmarshalBinary(settledBytes); err != nil { - return invoice, err + return i, err } - if _, err := io.ReadFull(r, invoice.Terms.PaymentPreimage[:]); err != nil { - return invoice, err - } - var scratch [8]byte - if _, err := io.ReadFull(r, scratch[:]); err != nil { - return invoice, err - } - invoice.Terms.Value = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:])) + i.Terms.Features = lnwire.NewFeatureVector( + rawFeatures, lnwire.Features, + ) - if err := binary.Read(r, byteOrder, &invoice.Terms.State); err != nil { - return invoice, err - } - - if err := binary.Read(r, byteOrder, &invoice.AddIndex); err != nil { - return invoice, err - } - if err := binary.Read(r, byteOrder, &invoice.SettleIndex); err != nil { - return invoice, err - } - if err := binary.Read(r, byteOrder, &invoice.AmtPaid); err != nil { - return invoice, err - } - - invoice.Htlcs, err = deserializeHtlcs(r) - if err != nil { - return Invoice{}, err - } - - return invoice, nil + i.Htlcs, err = deserializeHtlcs(r) + return i, err } // deserializeHtlcs reads a list of invoice htlcs from a reader and returns it @@ -1134,7 +1166,7 @@ func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) { tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime), tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime), tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry), - tlv.MakePrimitiveRecord(stateType, &state), + tlv.MakePrimitiveRecord(htlcStateType, &state), ) if err != nil { return nil, err @@ -1167,9 +1199,9 @@ func copySlice(src []byte) []byte { func copyInvoice(src *Invoice) *Invoice { dest := Invoice{ Memo: copySlice(src.Memo), - Receipt: copySlice(src.Receipt), PaymentRequest: copySlice(src.PaymentRequest), FinalCltvDelta: src.FinalCltvDelta, + Expiry: src.Expiry, CreationDate: src.CreationDate, SettleDate: src.SettleDate, Terms: src.Terms, @@ -1181,6 +1213,8 @@ func copyInvoice(src *Invoice) *Invoice { ), } + dest.Terms.Features = src.Terms.Features.Clone() + for k, v := range src.Htlcs { dest.Htlcs[k] = v } @@ -1202,10 +1236,10 @@ func (d *DB) updateInvoice(hash lntypes.Hash, invoices, settleIndex *bbolt.Bucke // Create deep copy to prevent any accidental modification in the // callback. - copy := copyInvoice(&invoice) + invoiceCopy := copyInvoice(&invoice) // Call the callback and obtain the update descriptor. - update, err := callback(copy) + update, err := callback(invoiceCopy) if err != nil { return &invoice, err } diff --git a/channeldb/log.go b/channeldb/log.go index 30ddff03..5229edbf 100644 --- a/channeldb/log.go +++ b/channeldb/log.go @@ -3,6 +3,7 @@ package channeldb import ( "github.com/btcsuite/btclog" "github.com/lightningnetwork/lnd/build" + "github.com/lightningnetwork/lnd/channeldb/migration12" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" ) @@ -27,4 +28,5 @@ func DisableLog() { func UseLogger(logger btclog.Logger) { log = logger migration_01_to_11.UseLogger(logger) + migration12.UseLogger(logger) } diff --git a/channeldb/migration12/invoices.go b/channeldb/migration12/invoices.go index e1c34c7b..0b83fe1f 100644 --- a/channeldb/migration12/invoices.go +++ b/channeldb/migration12/invoices.go @@ -154,8 +154,6 @@ type Invoice struct { // LegacyDeserializeInvoice decodes an invoice from the passed io.Reader using // the pre-TLV serialization. -// -// nolint: dupl func LegacyDeserializeInvoice(r io.Reader) (Invoice, error) { var err error invoice := Invoice{} @@ -241,6 +239,8 @@ func deserializeHtlcs(r io.Reader) ([]byte, error) { } // SerializeInvoice serializes an invoice to a writer. +// +// nolint: dupl func SerializeInvoice(w io.Writer, i *Invoice) error { creationDateBytes, err := i.CreationDate.MarshalBinary() if err != nil { diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index 44085f90..4f6d02eb 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -562,6 +562,9 @@ func generatePaymentWithPreimage(invoiceAmt, htlcAmt lnwire.MilliSatoshi, Terms: channeldb.ContractTerm{ Value: invoiceAmt, PaymentPreimage: preimage, + Features: lnwire.NewFeatureVector( + nil, lnwire.Features, + ), }, FinalCltvDelta: testInvoiceCltvExpiry, } diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index a137963e..51270354 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -28,6 +28,10 @@ var ( testFinalCltvRejectDelta = int32(4) testCurrentHeight = int32(1) + + testFeatures = lnwire.NewFeatureVector( + nil, lnwire.Features, + ) ) var ( @@ -35,6 +39,15 @@ var ( Terms: channeldb.ContractTerm{ PaymentPreimage: preimage, Value: lnwire.MilliSatoshi(100000), + Features: testFeatures, + }, + } + + testHodlInvoice = &channeldb.Invoice{ + Terms: channeldb.ContractTerm{ + PaymentPreimage: channeldb.UnknownPreimage, + Value: lnwire.MilliSatoshi(100000), + Features: testFeatures, }, } ) @@ -382,14 +395,7 @@ func TestSettleHoldInvoice(t *testing.T) { } // Add the invoice. - invoice := &channeldb.Invoice{ - Terms: channeldb.ContractTerm{ - PaymentPreimage: channeldb.UnknownPreimage, - Value: lnwire.MilliSatoshi(100000), - }, - } - - _, err = registry.AddInvoice(invoice, hash) + _, err = registry.AddInvoice(testHodlInvoice, hash) if err != nil { t.Fatal(err) } @@ -543,14 +549,7 @@ func TestCancelHoldInvoice(t *testing.T) { defer registry.Stop() // Add the invoice. - invoice := &channeldb.Invoice{ - Terms: channeldb.ContractTerm{ - PaymentPreimage: channeldb.UnknownPreimage, - Value: lnwire.MilliSatoshi(100000), - }, - } - - _, err = registry.AddInvoice(invoice, hash) + _, err = registry.AddInvoice(testHodlInvoice, hash) if err != nil { t.Fatal(err) } diff --git a/lnrpc/invoicesrpc/addinvoice.go b/lnrpc/invoicesrpc/addinvoice.go index da613808..70955e0f 100644 --- a/lnrpc/invoicesrpc/addinvoice.go +++ b/lnrpc/invoicesrpc/addinvoice.go @@ -363,6 +363,13 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig, } + // Set a blank feature vector, as our invoice generation forbids nil + // features. + invoiceFeatures := lnwire.NewFeatureVector( + lnwire.NewRawFeatureVector(), lnwire.Features, + ) + options = append(options, zpay32.Features(invoiceFeatures)) + // Create and encode the payment request as a bech32 (zpay32) string. creationDate := time.Now() payReq, err := zpay32.NewInvoice( @@ -390,6 +397,7 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig, Terms: channeldb.ContractTerm{ Value: amtMSat, PaymentPreimage: paymentPreimage, + Features: invoiceFeatures, }, } diff --git a/lnwire/features.go b/lnwire/features.go index 7900a9b3..c9a60b5b 100644 --- a/lnwire/features.go +++ b/lnwire/features.go @@ -371,3 +371,10 @@ func (fv *FeatureVector) isFeatureBitPair(bit FeatureBit) bool { name2, known2 := fv.featureNames[bit^1] return known1 && known2 && name1 == name2 } + +// Clone copies a feature vector, carrying over its feature bits. The feature +// names are not copied. +func (fv *FeatureVector) Clone() *FeatureVector { + features := fv.RawFeatureVector.Clone() + return NewFeatureVector(features, fv.featureNames) +} diff --git a/zpay32/invoice.go b/zpay32/invoice.go index 4e7210f7..d36b31ca 100644 --- a/zpay32/invoice.go +++ b/zpay32/invoice.go @@ -242,6 +242,14 @@ func RouteHint(routeHint []HopHint) func(*Invoice) { } } +// Features is a functional option that allows callers of NewInvoice to set the +// desired feature bits that are advertised on the invoice. +func Features(features *lnwire.FeatureVector) func(*Invoice) { + return func(i *Invoice) { + i.Features = features + } +} + // NewInvoice creates a new Invoice object. The last parameter is a set of // variadic arguments for setting optional fields of the invoice. //