From 840476996c84601747792184f304963b9dc449c0 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Wed, 11 Dec 2019 15:10:04 -0800 Subject: [PATCH] zpay32: ensure feature vector is always populated --- zpay32/invoice.go | 24 ++++++++++++++++++++++-- zpay32/invoice_test.go | 23 ++++++++++++++++++++++- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/zpay32/invoice.go b/zpay32/invoice.go index 68b0b41b..ab871125 100644 --- a/zpay32/invoice.go +++ b/zpay32/invoice.go @@ -257,7 +257,8 @@ func RouteHint(routeHint []HopHint) func(*Invoice) { } // Features is a functional option that allows callers of NewInvoice to set the -// desired feature bits that are advertised on the invoice. +// desired feature bits that are advertised on the invoice. If this option is +// not used, an empty feature vector will automatically be populated. func Features(features *lnwire.FeatureVector) func(*Invoice) { return func(i *Invoice) { i.Features = features @@ -290,6 +291,13 @@ func NewInvoice(net *chaincfg.Params, paymentHash [32]byte, option(invoice) } + // If no features were set, we'll populate an empty feature vector. + if invoice.Features == nil { + invoice.Features = lnwire.NewFeatureVector( + nil, lnwire.Features, + ) + } + if err := validateInvoice(invoice); err != nil { return nil, err } @@ -398,6 +406,13 @@ func Decode(invoice string, net *chaincfg.Params) (*Invoice, error) { decodedInvoice.Destination = pubkey } + // If no feature vector was decoded, populate an empty one. + if decodedInvoice.Features == nil { + decodedInvoice.Features = lnwire.NewFeatureVector( + nil, lnwire.Features, + ) + } + // Now that we have created the invoice, make sure it has the required // fields set. if err := validateInvoice(&decodedInvoice); err != nil { @@ -586,6 +601,11 @@ func validateInvoice(invoice *Invoice) error { len(invoice.Destination.SerializeCompressed())) } + // Ensure that all invoices have feature vectors. + if invoice.Features == nil { + return fmt.Errorf("missing feature vector") + } + return nil } @@ -1070,7 +1090,7 @@ func writeTaggedFields(bufferBase32 *bytes.Buffer, invoice *Invoice) error { return err } } - if invoice.Features != nil && invoice.Features.SerializeSize32() > 0 { + if invoice.Features.SerializeSize32() > 0 { var b bytes.Buffer err := invoice.Features.RawFeatureVector.EncodeBase32(&b) if err != nil { diff --git a/zpay32/invoice_test.go b/zpay32/invoice_test.go index abd418d2..9a34b590 100644 --- a/zpay32/invoice_test.go +++ b/zpay32/invoice_test.go @@ -105,6 +105,8 @@ var ( }, } + emptyFeatures = lnwire.NewFeatureVector(nil, lnwire.Features) + // Must be initialized in init(). testDescriptionHash [32]byte @@ -190,6 +192,7 @@ func TestDecodeEncode(t *testing.T) { Timestamp: time.Unix(1496314658, 0), DescriptionHash: &testDescriptionHash, Destination: testPubKey, + Features: emptyFeatures, } }, }, @@ -206,6 +209,7 @@ func TestDecodeEncode(t *testing.T) { Description: &testPleaseConsider, DescriptionHash: &testDescriptionHash, Destination: testPubKey, + Features: emptyFeatures, } }, }, @@ -220,6 +224,7 @@ func TestDecodeEncode(t *testing.T) { Timestamp: time.Unix(1496314658, 0), PaymentHash: &testPaymentHash, Destination: testPubKey, + Features: emptyFeatures, } }, }, @@ -235,6 +240,7 @@ func TestDecodeEncode(t *testing.T) { PaymentHash: &testPaymentHash, Description: &testPleaseConsider, Destination: testPubKey, + Features: emptyFeatures, } }, skipEncoding: true, // Skip encoding since we don't have the unknown fields to encode. @@ -251,6 +257,7 @@ func TestDecodeEncode(t *testing.T) { PaymentHash: &testPaymentHash, DescriptionHash: &testDescriptionHash, Destination: testPubKey, + Features: emptyFeatures, } }, skipEncoding: true, // Skip encoding since we don't have the unknown fields to encode. @@ -267,6 +274,7 @@ func TestDecodeEncode(t *testing.T) { PaymentHash: &testPaymentHash, Destination: testPubKey, DescriptionHash: &testDescriptionHash, + Features: emptyFeatures, } }, skipEncoding: true, // Skip encoding since we don't have the unknown fields to encode. @@ -282,6 +290,7 @@ func TestDecodeEncode(t *testing.T) { PaymentHash: &testPaymentHash, Description: &testCupOfCoffee, Destination: testPubKey, + Features: emptyFeatures, } }, beforeEncoding: func(i *Invoice) { @@ -302,6 +311,7 @@ func TestDecodeEncode(t *testing.T) { PaymentHash: &testPaymentHash, Description: &testPleaseConsider, Destination: testPubKey, + Features: emptyFeatures, } }, beforeEncoding: func(i *Invoice) { @@ -323,6 +333,7 @@ func TestDecodeEncode(t *testing.T) { PaymentHash: &testPaymentHash, Destination: testPubKey, Description: &testEmptyString, + Features: emptyFeatures, } }, }, @@ -382,6 +393,7 @@ func TestDecodeEncode(t *testing.T) { PaymentHash: &testPaymentHash, DescriptionHash: &testDescriptionHash, Destination: testPubKey, + Features: emptyFeatures, } }, beforeEncoding: func(i *Invoice) { @@ -404,6 +416,7 @@ func TestDecodeEncode(t *testing.T) { DescriptionHash: &testDescriptionHash, Destination: testPubKey, FallbackAddr: testAddrTestnet, + Features: emptyFeatures, } }, beforeEncoding: func(i *Invoice) { @@ -427,6 +440,7 @@ func TestDecodeEncode(t *testing.T) { Destination: testPubKey, FallbackAddr: testRustyAddr, RouteHints: [][]HopHint{testSingleHop}, + Features: emptyFeatures, } }, beforeEncoding: func(i *Invoice) { @@ -450,6 +464,7 @@ func TestDecodeEncode(t *testing.T) { Destination: testPubKey, FallbackAddr: testRustyAddr, RouteHints: [][]HopHint{testDoubleHop}, + Features: emptyFeatures, } }, beforeEncoding: func(i *Invoice) { @@ -472,6 +487,7 @@ func TestDecodeEncode(t *testing.T) { DescriptionHash: &testDescriptionHash, Destination: testPubKey, FallbackAddr: testAddrMainnetP2SH, + Features: emptyFeatures, } }, beforeEncoding: func(i *Invoice) { @@ -547,6 +563,7 @@ func TestDecodeEncode(t *testing.T) { DescriptionHash: &testDescriptionHash, Destination: testPubKey, FallbackAddr: testAddrMainnetP2WPKH, + Features: emptyFeatures, } }, beforeEncoding: func(i *Invoice) { @@ -569,6 +586,7 @@ func TestDecodeEncode(t *testing.T) { DescriptionHash: &testDescriptionHash, Destination: testPubKey, FallbackAddr: testAddrMainnetP2WSH, + Features: emptyFeatures, } }, beforeEncoding: func(i *Invoice) { @@ -628,6 +646,7 @@ func TestDecodeEncode(t *testing.T) { PaymentHash: &testPaymentHash, Destination: testPubKey, Description: &testEmptyString, + Features: emptyFeatures, } }, skipEncoding: true, // Skip encoding since we were given the wrong net @@ -645,6 +664,7 @@ func TestDecodeEncode(t *testing.T) { PaymentHash: &testPaymentHash, DescriptionHash: &testDescriptionHash, Destination: testPubKey, + Features: emptyFeatures, } }, }, @@ -660,6 +680,7 @@ func TestDecodeEncode(t *testing.T) { PaymentHash: &testPaymentHash, DescriptionHash: &testDescriptionHash, Destination: testPubKey, + Features: emptyFeatures, } }, }, @@ -983,7 +1004,7 @@ func compareInvoices(expected, actual *Invoice) error { if !reflect.DeepEqual(expected.Features, actual.Features) { return fmt.Errorf("expected features %v, got %v", - expected.Features.RawFeatureVector, actual.Features.RawFeatureVector) + expected.Features, actual.Features) } return nil