channeldb: complete migration 12 for TLV invoices

This commit is contained in:
Conner Fromknecht 2019-11-22 02:24:28 -08:00
parent 76682ad820
commit 4c872c438b
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7
10 changed files with 216 additions and 144 deletions

@ -13,6 +13,7 @@ import (
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/coreos/bbolt" "github.com/coreos/bbolt"
"github.com/go-errors/errors" "github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/channeldb/migration12"
"github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
) )
@ -116,6 +117,12 @@ var (
number: 11, number: 11,
migration: migration_01_to_11.MigrateInvoices, migration: migration_01_to_11.MigrateInvoices,
}, },
{
// Migrate to TLV invoice bodies, add payment address
// and features, remove receipt.
number: 12,
migration: migration12.MigrateInvoiceTLV,
},
} }
// Big endian is the preferred byte order, due to cursor scans over // Big endian is the preferred byte order, due to cursor scans over

@ -10,6 +10,10 @@ import (
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
) )
var (
emptyFeatures = lnwire.NewFeatureVector(nil, lnwire.Features)
)
func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) { func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) {
var pre [32]byte var pre [32]byte
if _, err := rand.Read(pre[:]); err != nil { if _, err := rand.Read(pre[:]); err != nil {
@ -23,12 +27,12 @@ func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) {
Terms: ContractTerm{ Terms: ContractTerm{
PaymentPreimage: pre, PaymentPreimage: pre,
Value: value, Value: value,
Features: emptyFeatures,
}, },
Htlcs: map[CircuitKey]*InvoiceHTLC{}, Htlcs: map[CircuitKey]*InvoiceHTLC{},
Expiry: 4000, Expiry: 4000,
} }
i.Memo = []byte("memo") i.Memo = []byte("memo")
i.Receipt = []byte("receipt")
// Create a random byte slice of MaxPaymentRequestSize bytes to be used // Create a random byte slice of MaxPaymentRequestSize bytes to be used
// as a dummy paymentrequest, and determine if it should be set based // as a dummy paymentrequest, and determine if it should be set based
@ -64,10 +68,10 @@ func TestInvoiceWorkflow(t *testing.T) {
Htlcs: map[CircuitKey]*InvoiceHTLC{}, Htlcs: map[CircuitKey]*InvoiceHTLC{},
} }
fakeInvoice.Memo = []byte("memo") fakeInvoice.Memo = []byte("memo")
fakeInvoice.Receipt = []byte("receipt")
fakeInvoice.PaymentRequest = []byte("") fakeInvoice.PaymentRequest = []byte("")
copy(fakeInvoice.Terms.PaymentPreimage[:], rev[:]) copy(fakeInvoice.Terms.PaymentPreimage[:], rev[:])
fakeInvoice.Terms.Value = lnwire.NewMSatFromSatoshis(10000) fakeInvoice.Terms.Value = lnwire.NewMSatFromSatoshis(10000)
fakeInvoice.Terms.Features = emptyFeatures
paymentHash := fakeInvoice.Terms.PaymentPreimage.Hash() paymentHash := fakeInvoice.Terms.PaymentPreimage.Hash()

@ -8,7 +8,6 @@ import (
"io" "io"
"time" "time"
"github.com/btcsuite/btcd/wire"
"github.com/coreos/bbolt" "github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
@ -84,10 +83,6 @@ const (
// in the database. // in the database.
MaxMemoSize = 1024 MaxMemoSize = 1024
// MaxReceiptSize is the maximum size of the payment receipt stored
// within the database along side incoming/outgoing invoices.
MaxReceiptSize = 1024
// MaxPaymentRequestSize is the max size of a payment request for // MaxPaymentRequestSize is the max size of a payment request for
// this invoice. // this invoice.
// TODO(halseth): determine the max length payment request when field // TODO(halseth): determine the max length payment request when field
@ -96,6 +91,11 @@ const (
// A set of tlv type definitions used to serialize invoice htlcs to the // A set of tlv type definitions used to serialize invoice htlcs to the
// database. // database.
//
// NOTE: A migration should be added whenever this list changes. This
// prevents against the database being rolled back to an older
// format where the surrounding logic might assume a different set of
// fields are known.
chanIDType tlv.Type = 1 chanIDType tlv.Type = 1
htlcIDType tlv.Type = 3 htlcIDType tlv.Type = 3
amtType tlv.Type = 5 amtType tlv.Type = 5
@ -103,7 +103,28 @@ const (
acceptTimeType tlv.Type = 9 acceptTimeType tlv.Type = 9
resolveTimeType tlv.Type = 11 resolveTimeType tlv.Type = 11
expiryHeightType tlv.Type = 13 expiryHeightType tlv.Type = 13
stateType tlv.Type = 15 htlcStateType tlv.Type = 15
// A set of tlv type definitions used to serialize invoice bodiees.
//
// NOTE: A migration should be added whenever this list changes. This
// prevents against the database being rolled back to an older
// format where the surrounding logic might assume a different set of
// fields are known.
memoType tlv.Type = 0
payReqType tlv.Type = 1
createTimeType tlv.Type = 2
settleTimeType tlv.Type = 3
addIndexType tlv.Type = 4
settleIndexType tlv.Type = 5
preimageType tlv.Type = 6
valueType tlv.Type = 7
cltvDeltaType tlv.Type = 8
expiryType tlv.Type = 9
paymentAddrType tlv.Type = 10
featuresType tlv.Type = 11
invStateType tlv.Type = 12
amtPaidType tlv.Type = 13
) )
// ContractState describes the state the invoice is in. // ContractState describes the state the invoice is in.
@ -156,6 +177,13 @@ type ContractTerm struct {
// State describes the state the invoice is in. // State describes the state the invoice is in.
State ContractState State ContractState
// PaymentAddr is a randomly generated value include in the MPP record
// by the sender to prevent probing of the receiver.
PaymentAddr [32]byte
// Features is the feature vectors advertised on the payment request.
Features *lnwire.FeatureVector
} }
// Invoice is a payment invoice generated by a payee in order to request // Invoice is a payment invoice generated by a payee in order to request
@ -174,12 +202,6 @@ type Invoice struct {
// or any other message which fits within the size constraints. // or any other message which fits within the size constraints.
Memo []byte Memo []byte
// Receipt is an optional field dedicated for storing a
// cryptographically binding receipt of payment.
//
// TODO(roasbeef): document scheme.
Receipt []byte
// PaymentRequest is an optional field where a payment request created // PaymentRequest is an optional field where a payment request created
// for this invoice can be stored. // for this invoice can be stored.
PaymentRequest []byte PaymentRequest []byte
@ -312,16 +334,14 @@ func validateInvoice(i *Invoice) error {
return fmt.Errorf("max length a memo is %v, and invoice "+ return fmt.Errorf("max length a memo is %v, and invoice "+
"of length %v was provided", MaxMemoSize, len(i.Memo)) "of length %v was provided", MaxMemoSize, len(i.Memo))
} }
if len(i.Receipt) > MaxReceiptSize {
return fmt.Errorf("max length a receipt is %v, and invoice "+
"of length %v was provided", MaxReceiptSize,
len(i.Receipt))
}
if len(i.PaymentRequest) > MaxPaymentRequestSize { if len(i.PaymentRequest) > MaxPaymentRequestSize {
return fmt.Errorf("max length of payment request is %v, length "+ return fmt.Errorf("max length of payment request is %v, length "+
"provided was %v", MaxPaymentRequestSize, "provided was %v", MaxPaymentRequestSize,
len(i.PaymentRequest)) len(i.PaymentRequest))
} }
if i.Terms.Features == nil {
return errors.New("invoice must have a feature vector")
}
return nil return nil
} }
@ -892,71 +912,73 @@ func putInvoice(invoices, invoiceIndex, addIndex *bbolt.Bucket,
// would modify the on disk format, make a copy of the original code and store // would modify the on disk format, make a copy of the original code and store
// it with the migration. // it with the migration.
func serializeInvoice(w io.Writer, i *Invoice) error { func serializeInvoice(w io.Writer, i *Invoice) error {
if err := wire.WriteVarBytes(w, 0, i.Memo[:]); err != nil { creationDateBytes, err := i.CreationDate.MarshalBinary()
return err
}
if err := wire.WriteVarBytes(w, 0, i.Receipt[:]); err != nil {
return err
}
if err := wire.WriteVarBytes(w, 0, i.PaymentRequest[:]); err != nil {
return err
}
if err := binary.Write(w, byteOrder, i.FinalCltvDelta); err != nil {
return err
}
if err := binary.Write(w, byteOrder, int64(i.Expiry)); err != nil {
return err
}
birthBytes, err := i.CreationDate.MarshalBinary()
if err != nil { if err != nil {
return err return err
} }
if err := wire.WriteVarBytes(w, 0, birthBytes); err != nil { settleDateBytes, err := i.SettleDate.MarshalBinary()
return err
}
settleBytes, err := i.SettleDate.MarshalBinary()
if err != nil { if err != nil {
return err return err
} }
if err := wire.WriteVarBytes(w, 0, settleBytes); err != nil { var fb bytes.Buffer
err = i.Terms.Features.EncodeBase256(&fb)
if err != nil {
return err
}
featureBytes := fb.Bytes()
preimage := [32]byte(i.Terms.PaymentPreimage)
value := uint64(i.Terms.Value)
cltvDelta := uint32(i.FinalCltvDelta)
expiry := uint64(i.Expiry)
amtPaid := uint64(i.AmtPaid)
state := uint8(i.Terms.State)
tlvStream, err := tlv.NewStream(
// Memo and payreq.
tlv.MakePrimitiveRecord(memoType, &i.Memo),
tlv.MakePrimitiveRecord(payReqType, &i.PaymentRequest),
// Add/settle metadata.
tlv.MakePrimitiveRecord(createTimeType, &creationDateBytes),
tlv.MakePrimitiveRecord(settleTimeType, &settleDateBytes),
tlv.MakePrimitiveRecord(addIndexType, &i.AddIndex),
tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex),
// Terms.
tlv.MakePrimitiveRecord(preimageType, &preimage),
tlv.MakePrimitiveRecord(valueType, &value),
tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta),
tlv.MakePrimitiveRecord(expiryType, &expiry),
tlv.MakePrimitiveRecord(paymentAddrType, &i.Terms.PaymentAddr),
tlv.MakePrimitiveRecord(featuresType, &featureBytes),
// Invoice state.
tlv.MakePrimitiveRecord(invStateType, &state),
tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
)
if err != nil {
return err return err
} }
if _, err := w.Write(i.Terms.PaymentPreimage[:]); err != nil { var b bytes.Buffer
if err = tlvStream.Encode(&b); err != nil {
return err return err
} }
var scratch [8]byte err = binary.Write(w, byteOrder, uint64(b.Len()))
byteOrder.PutUint64(scratch[:], uint64(i.Terms.Value)) if err != nil {
if _, err := w.Write(scratch[:]); err != nil {
return err return err
} }
if err := binary.Write(w, byteOrder, i.Terms.State); err != nil { if _, err = w.Write(b.Bytes()); err != nil {
return err return err
} }
if err := binary.Write(w, byteOrder, i.AddIndex); err != nil { return serializeHtlcs(w, i.Htlcs)
return err
}
if err := binary.Write(w, byteOrder, i.SettleIndex); err != nil {
return err
}
if err := binary.Write(w, byteOrder, int64(i.AmtPaid)); err != nil {
return err
}
if err := serializeHtlcs(w, i.Htlcs); err != nil {
return err
}
return nil
} }
// serializeHtlcs serializes a map containing circuit keys and invoice htlcs to // serializeHtlcs serializes a map containing circuit keys and invoice htlcs to
@ -980,7 +1002,7 @@ func serializeHtlcs(w io.Writer, htlcs map[CircuitKey]*InvoiceHTLC) error {
tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime), tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime),
tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime), tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry), tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
tlv.MakePrimitiveRecord(stateType, &state), tlv.MakePrimitiveRecord(htlcStateType, &state),
) )
if err != nil { if err != nil {
return err return err
@ -1018,79 +1040,89 @@ func fetchInvoice(invoiceNum []byte, invoices *bbolt.Bucket) (Invoice, error) {
} }
func deserializeInvoice(r io.Reader) (Invoice, error) { func deserializeInvoice(r io.Reader) (Invoice, error) {
var err error var (
invoice := Invoice{} preimage [32]byte
value uint64
cltvDelta uint32
expiry uint64
amtPaid uint64
state uint8
// TODO(roasbeef): use read full everywhere creationDateBytes []byte
invoice.Memo, err = wire.ReadVarBytes(r, 0, MaxMemoSize, "") settleDateBytes []byte
featureBytes []byte
)
var i Invoice
tlvStream, err := tlv.NewStream(
// Memo and payreq.
tlv.MakePrimitiveRecord(memoType, &i.Memo),
tlv.MakePrimitiveRecord(payReqType, &i.PaymentRequest),
// Add/settle metadata.
tlv.MakePrimitiveRecord(createTimeType, &creationDateBytes),
tlv.MakePrimitiveRecord(settleTimeType, &settleDateBytes),
tlv.MakePrimitiveRecord(addIndexType, &i.AddIndex),
tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex),
// Terms.
tlv.MakePrimitiveRecord(preimageType, &preimage),
tlv.MakePrimitiveRecord(valueType, &value),
tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta),
tlv.MakePrimitiveRecord(expiryType, &expiry),
tlv.MakePrimitiveRecord(paymentAddrType, &i.Terms.PaymentAddr),
tlv.MakePrimitiveRecord(featuresType, &featureBytes),
// Invoice state.
tlv.MakePrimitiveRecord(invStateType, &state),
tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
)
if err != nil { if err != nil {
return invoice, err return i, err
} }
invoice.Receipt, err = wire.ReadVarBytes(r, 0, MaxReceiptSize, "")
var bodyLen int64
err = binary.Read(r, byteOrder, &bodyLen)
if err != nil { if err != nil {
return invoice, err return i, err
} }
invoice.PaymentRequest, err = wire.ReadVarBytes(r, 0, MaxPaymentRequestSize, "") lr := io.LimitReader(r, bodyLen)
if err = tlvStream.Decode(lr); err != nil {
return i, err
}
i.Terms.PaymentPreimage = lntypes.Preimage(preimage)
i.Terms.Value = lnwire.MilliSatoshi(value)
i.FinalCltvDelta = int32(cltvDelta)
i.Expiry = time.Duration(expiry)
i.AmtPaid = lnwire.MilliSatoshi(amtPaid)
i.Terms.State = ContractState(state)
err = i.CreationDate.UnmarshalBinary(creationDateBytes)
if err != nil { if err != nil {
return invoice, err return i, err
} }
if err := binary.Read(r, byteOrder, &invoice.FinalCltvDelta); err != nil { err = i.SettleDate.UnmarshalBinary(settleDateBytes)
return invoice, err
}
var expiry int64
if err := binary.Read(r, byteOrder, &expiry); err != nil {
return invoice, err
}
invoice.Expiry = time.Duration(expiry)
birthBytes, err := wire.ReadVarBytes(r, 0, 300, "birth")
if err != nil { if err != nil {
return invoice, err return i, err
}
if err := invoice.CreationDate.UnmarshalBinary(birthBytes); err != nil {
return invoice, err
} }
settledBytes, err := wire.ReadVarBytes(r, 0, 300, "settled") rawFeatures := lnwire.NewRawFeatureVector()
err = rawFeatures.DecodeBase256(
bytes.NewReader(featureBytes), len(featureBytes),
)
if err != nil { if err != nil {
return invoice, err return i, err
}
if err := invoice.SettleDate.UnmarshalBinary(settledBytes); err != nil {
return invoice, err
} }
if _, err := io.ReadFull(r, invoice.Terms.PaymentPreimage[:]); err != nil { i.Terms.Features = lnwire.NewFeatureVector(
return invoice, err rawFeatures, lnwire.Features,
} )
var scratch [8]byte
if _, err := io.ReadFull(r, scratch[:]); err != nil {
return invoice, err
}
invoice.Terms.Value = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:]))
if err := binary.Read(r, byteOrder, &invoice.Terms.State); err != nil { i.Htlcs, err = deserializeHtlcs(r)
return invoice, err return i, err
}
if err := binary.Read(r, byteOrder, &invoice.AddIndex); err != nil {
return invoice, err
}
if err := binary.Read(r, byteOrder, &invoice.SettleIndex); err != nil {
return invoice, err
}
if err := binary.Read(r, byteOrder, &invoice.AmtPaid); err != nil {
return invoice, err
}
invoice.Htlcs, err = deserializeHtlcs(r)
if err != nil {
return Invoice{}, err
}
return invoice, nil
} }
// deserializeHtlcs reads a list of invoice htlcs from a reader and returns it // deserializeHtlcs reads a list of invoice htlcs from a reader and returns it
@ -1134,7 +1166,7 @@ func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) {
tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime), tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime),
tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime), tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry), tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
tlv.MakePrimitiveRecord(stateType, &state), tlv.MakePrimitiveRecord(htlcStateType, &state),
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -1167,9 +1199,9 @@ func copySlice(src []byte) []byte {
func copyInvoice(src *Invoice) *Invoice { func copyInvoice(src *Invoice) *Invoice {
dest := Invoice{ dest := Invoice{
Memo: copySlice(src.Memo), Memo: copySlice(src.Memo),
Receipt: copySlice(src.Receipt),
PaymentRequest: copySlice(src.PaymentRequest), PaymentRequest: copySlice(src.PaymentRequest),
FinalCltvDelta: src.FinalCltvDelta, FinalCltvDelta: src.FinalCltvDelta,
Expiry: src.Expiry,
CreationDate: src.CreationDate, CreationDate: src.CreationDate,
SettleDate: src.SettleDate, SettleDate: src.SettleDate,
Terms: src.Terms, Terms: src.Terms,
@ -1181,6 +1213,8 @@ func copyInvoice(src *Invoice) *Invoice {
), ),
} }
dest.Terms.Features = src.Terms.Features.Clone()
for k, v := range src.Htlcs { for k, v := range src.Htlcs {
dest.Htlcs[k] = v dest.Htlcs[k] = v
} }
@ -1202,10 +1236,10 @@ func (d *DB) updateInvoice(hash lntypes.Hash, invoices, settleIndex *bbolt.Bucke
// Create deep copy to prevent any accidental modification in the // Create deep copy to prevent any accidental modification in the
// callback. // callback.
copy := copyInvoice(&invoice) invoiceCopy := copyInvoice(&invoice)
// Call the callback and obtain the update descriptor. // Call the callback and obtain the update descriptor.
update, err := callback(copy) update, err := callback(invoiceCopy)
if err != nil { if err != nil {
return &invoice, err return &invoice, err
} }

@ -3,6 +3,7 @@ package channeldb
import ( import (
"github.com/btcsuite/btclog" "github.com/btcsuite/btclog"
"github.com/lightningnetwork/lnd/build" "github.com/lightningnetwork/lnd/build"
"github.com/lightningnetwork/lnd/channeldb/migration12"
"github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11"
) )
@ -27,4 +28,5 @@ func DisableLog() {
func UseLogger(logger btclog.Logger) { func UseLogger(logger btclog.Logger) {
log = logger log = logger
migration_01_to_11.UseLogger(logger) migration_01_to_11.UseLogger(logger)
migration12.UseLogger(logger)
} }

@ -154,8 +154,6 @@ type Invoice struct {
// LegacyDeserializeInvoice decodes an invoice from the passed io.Reader using // LegacyDeserializeInvoice decodes an invoice from the passed io.Reader using
// the pre-TLV serialization. // the pre-TLV serialization.
//
// nolint: dupl
func LegacyDeserializeInvoice(r io.Reader) (Invoice, error) { func LegacyDeserializeInvoice(r io.Reader) (Invoice, error) {
var err error var err error
invoice := Invoice{} invoice := Invoice{}
@ -241,6 +239,8 @@ func deserializeHtlcs(r io.Reader) ([]byte, error) {
} }
// SerializeInvoice serializes an invoice to a writer. // SerializeInvoice serializes an invoice to a writer.
//
// nolint: dupl
func SerializeInvoice(w io.Writer, i *Invoice) error { func SerializeInvoice(w io.Writer, i *Invoice) error {
creationDateBytes, err := i.CreationDate.MarshalBinary() creationDateBytes, err := i.CreationDate.MarshalBinary()
if err != nil { if err != nil {

@ -562,6 +562,9 @@ func generatePaymentWithPreimage(invoiceAmt, htlcAmt lnwire.MilliSatoshi,
Terms: channeldb.ContractTerm{ Terms: channeldb.ContractTerm{
Value: invoiceAmt, Value: invoiceAmt,
PaymentPreimage: preimage, PaymentPreimage: preimage,
Features: lnwire.NewFeatureVector(
nil, lnwire.Features,
),
}, },
FinalCltvDelta: testInvoiceCltvExpiry, FinalCltvDelta: testInvoiceCltvExpiry,
} }

@ -28,6 +28,10 @@ var (
testFinalCltvRejectDelta = int32(4) testFinalCltvRejectDelta = int32(4)
testCurrentHeight = int32(1) testCurrentHeight = int32(1)
testFeatures = lnwire.NewFeatureVector(
nil, lnwire.Features,
)
) )
var ( var (
@ -35,6 +39,15 @@ var (
Terms: channeldb.ContractTerm{ Terms: channeldb.ContractTerm{
PaymentPreimage: preimage, PaymentPreimage: preimage,
Value: lnwire.MilliSatoshi(100000), Value: lnwire.MilliSatoshi(100000),
Features: testFeatures,
},
}
testHodlInvoice = &channeldb.Invoice{
Terms: channeldb.ContractTerm{
PaymentPreimage: channeldb.UnknownPreimage,
Value: lnwire.MilliSatoshi(100000),
Features: testFeatures,
}, },
} }
) )
@ -382,14 +395,7 @@ func TestSettleHoldInvoice(t *testing.T) {
} }
// Add the invoice. // Add the invoice.
invoice := &channeldb.Invoice{ _, err = registry.AddInvoice(testHodlInvoice, hash)
Terms: channeldb.ContractTerm{
PaymentPreimage: channeldb.UnknownPreimage,
Value: lnwire.MilliSatoshi(100000),
},
}
_, err = registry.AddInvoice(invoice, hash)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -543,14 +549,7 @@ func TestCancelHoldInvoice(t *testing.T) {
defer registry.Stop() defer registry.Stop()
// Add the invoice. // Add the invoice.
invoice := &channeldb.Invoice{ _, err = registry.AddInvoice(testHodlInvoice, hash)
Terms: channeldb.ContractTerm{
PaymentPreimage: channeldb.UnknownPreimage,
Value: lnwire.MilliSatoshi(100000),
},
}
_, err = registry.AddInvoice(invoice, hash)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

@ -363,6 +363,13 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig,
} }
// Set a blank feature vector, as our invoice generation forbids nil
// features.
invoiceFeatures := lnwire.NewFeatureVector(
lnwire.NewRawFeatureVector(), lnwire.Features,
)
options = append(options, zpay32.Features(invoiceFeatures))
// Create and encode the payment request as a bech32 (zpay32) string. // Create and encode the payment request as a bech32 (zpay32) string.
creationDate := time.Now() creationDate := time.Now()
payReq, err := zpay32.NewInvoice( payReq, err := zpay32.NewInvoice(
@ -390,6 +397,7 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig,
Terms: channeldb.ContractTerm{ Terms: channeldb.ContractTerm{
Value: amtMSat, Value: amtMSat,
PaymentPreimage: paymentPreimage, PaymentPreimage: paymentPreimage,
Features: invoiceFeatures,
}, },
} }

@ -371,3 +371,10 @@ func (fv *FeatureVector) isFeatureBitPair(bit FeatureBit) bool {
name2, known2 := fv.featureNames[bit^1] name2, known2 := fv.featureNames[bit^1]
return known1 && known2 && name1 == name2 return known1 && known2 && name1 == name2
} }
// Clone copies a feature vector, carrying over its feature bits. The feature
// names are not copied.
func (fv *FeatureVector) Clone() *FeatureVector {
features := fv.RawFeatureVector.Clone()
return NewFeatureVector(features, fv.featureNames)
}

@ -242,6 +242,14 @@ func RouteHint(routeHint []HopHint) func(*Invoice) {
} }
} }
// Features is a functional option that allows callers of NewInvoice to set the
// desired feature bits that are advertised on the invoice.
func Features(features *lnwire.FeatureVector) func(*Invoice) {
return func(i *Invoice) {
i.Features = features
}
}
// NewInvoice creates a new Invoice object. The last parameter is a set of // NewInvoice creates a new Invoice object. The last parameter is a set of
// variadic arguments for setting optional fields of the invoice. // variadic arguments for setting optional fields of the invoice.
// //