zpay32: ensure feature vector is always populated

This commit is contained in:
Conner Fromknecht 2019-12-11 15:10:04 -08:00
parent 62dadff291
commit 840476996c
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7
2 changed files with 44 additions and 3 deletions

View File

@ -257,7 +257,8 @@ func RouteHint(routeHint []HopHint) func(*Invoice) {
}
// Features is a functional option that allows callers of NewInvoice to set the
// desired feature bits that are advertised on the invoice.
// desired feature bits that are advertised on the invoice. If this option is
// not used, an empty feature vector will automatically be populated.
func Features(features *lnwire.FeatureVector) func(*Invoice) {
return func(i *Invoice) {
i.Features = features
@ -290,6 +291,13 @@ func NewInvoice(net *chaincfg.Params, paymentHash [32]byte,
option(invoice)
}
// If no features were set, we'll populate an empty feature vector.
if invoice.Features == nil {
invoice.Features = lnwire.NewFeatureVector(
nil, lnwire.Features,
)
}
if err := validateInvoice(invoice); err != nil {
return nil, err
}
@ -398,6 +406,13 @@ func Decode(invoice string, net *chaincfg.Params) (*Invoice, error) {
decodedInvoice.Destination = pubkey
}
// If no feature vector was decoded, populate an empty one.
if decodedInvoice.Features == nil {
decodedInvoice.Features = lnwire.NewFeatureVector(
nil, lnwire.Features,
)
}
// Now that we have created the invoice, make sure it has the required
// fields set.
if err := validateInvoice(&decodedInvoice); err != nil {
@ -586,6 +601,11 @@ func validateInvoice(invoice *Invoice) error {
len(invoice.Destination.SerializeCompressed()))
}
// Ensure that all invoices have feature vectors.
if invoice.Features == nil {
return fmt.Errorf("missing feature vector")
}
return nil
}
@ -1070,7 +1090,7 @@ func writeTaggedFields(bufferBase32 *bytes.Buffer, invoice *Invoice) error {
return err
}
}
if invoice.Features != nil && invoice.Features.SerializeSize32() > 0 {
if invoice.Features.SerializeSize32() > 0 {
var b bytes.Buffer
err := invoice.Features.RawFeatureVector.EncodeBase32(&b)
if err != nil {

View File

@ -105,6 +105,8 @@ var (
},
}
emptyFeatures = lnwire.NewFeatureVector(nil, lnwire.Features)
// Must be initialized in init().
testDescriptionHash [32]byte
@ -190,6 +192,7 @@ func TestDecodeEncode(t *testing.T) {
Timestamp: time.Unix(1496314658, 0),
DescriptionHash: &testDescriptionHash,
Destination: testPubKey,
Features: emptyFeatures,
}
},
},
@ -206,6 +209,7 @@ func TestDecodeEncode(t *testing.T) {
Description: &testPleaseConsider,
DescriptionHash: &testDescriptionHash,
Destination: testPubKey,
Features: emptyFeatures,
}
},
},
@ -220,6 +224,7 @@ func TestDecodeEncode(t *testing.T) {
Timestamp: time.Unix(1496314658, 0),
PaymentHash: &testPaymentHash,
Destination: testPubKey,
Features: emptyFeatures,
}
},
},
@ -235,6 +240,7 @@ func TestDecodeEncode(t *testing.T) {
PaymentHash: &testPaymentHash,
Description: &testPleaseConsider,
Destination: testPubKey,
Features: emptyFeatures,
}
},
skipEncoding: true, // Skip encoding since we don't have the unknown fields to encode.
@ -251,6 +257,7 @@ func TestDecodeEncode(t *testing.T) {
PaymentHash: &testPaymentHash,
DescriptionHash: &testDescriptionHash,
Destination: testPubKey,
Features: emptyFeatures,
}
},
skipEncoding: true, // Skip encoding since we don't have the unknown fields to encode.
@ -267,6 +274,7 @@ func TestDecodeEncode(t *testing.T) {
PaymentHash: &testPaymentHash,
Destination: testPubKey,
DescriptionHash: &testDescriptionHash,
Features: emptyFeatures,
}
},
skipEncoding: true, // Skip encoding since we don't have the unknown fields to encode.
@ -282,6 +290,7 @@ func TestDecodeEncode(t *testing.T) {
PaymentHash: &testPaymentHash,
Description: &testCupOfCoffee,
Destination: testPubKey,
Features: emptyFeatures,
}
},
beforeEncoding: func(i *Invoice) {
@ -302,6 +311,7 @@ func TestDecodeEncode(t *testing.T) {
PaymentHash: &testPaymentHash,
Description: &testPleaseConsider,
Destination: testPubKey,
Features: emptyFeatures,
}
},
beforeEncoding: func(i *Invoice) {
@ -323,6 +333,7 @@ func TestDecodeEncode(t *testing.T) {
PaymentHash: &testPaymentHash,
Destination: testPubKey,
Description: &testEmptyString,
Features: emptyFeatures,
}
},
},
@ -382,6 +393,7 @@ func TestDecodeEncode(t *testing.T) {
PaymentHash: &testPaymentHash,
DescriptionHash: &testDescriptionHash,
Destination: testPubKey,
Features: emptyFeatures,
}
},
beforeEncoding: func(i *Invoice) {
@ -404,6 +416,7 @@ func TestDecodeEncode(t *testing.T) {
DescriptionHash: &testDescriptionHash,
Destination: testPubKey,
FallbackAddr: testAddrTestnet,
Features: emptyFeatures,
}
},
beforeEncoding: func(i *Invoice) {
@ -427,6 +440,7 @@ func TestDecodeEncode(t *testing.T) {
Destination: testPubKey,
FallbackAddr: testRustyAddr,
RouteHints: [][]HopHint{testSingleHop},
Features: emptyFeatures,
}
},
beforeEncoding: func(i *Invoice) {
@ -450,6 +464,7 @@ func TestDecodeEncode(t *testing.T) {
Destination: testPubKey,
FallbackAddr: testRustyAddr,
RouteHints: [][]HopHint{testDoubleHop},
Features: emptyFeatures,
}
},
beforeEncoding: func(i *Invoice) {
@ -472,6 +487,7 @@ func TestDecodeEncode(t *testing.T) {
DescriptionHash: &testDescriptionHash,
Destination: testPubKey,
FallbackAddr: testAddrMainnetP2SH,
Features: emptyFeatures,
}
},
beforeEncoding: func(i *Invoice) {
@ -547,6 +563,7 @@ func TestDecodeEncode(t *testing.T) {
DescriptionHash: &testDescriptionHash,
Destination: testPubKey,
FallbackAddr: testAddrMainnetP2WPKH,
Features: emptyFeatures,
}
},
beforeEncoding: func(i *Invoice) {
@ -569,6 +586,7 @@ func TestDecodeEncode(t *testing.T) {
DescriptionHash: &testDescriptionHash,
Destination: testPubKey,
FallbackAddr: testAddrMainnetP2WSH,
Features: emptyFeatures,
}
},
beforeEncoding: func(i *Invoice) {
@ -628,6 +646,7 @@ func TestDecodeEncode(t *testing.T) {
PaymentHash: &testPaymentHash,
Destination: testPubKey,
Description: &testEmptyString,
Features: emptyFeatures,
}
},
skipEncoding: true, // Skip encoding since we were given the wrong net
@ -645,6 +664,7 @@ func TestDecodeEncode(t *testing.T) {
PaymentHash: &testPaymentHash,
DescriptionHash: &testDescriptionHash,
Destination: testPubKey,
Features: emptyFeatures,
}
},
},
@ -660,6 +680,7 @@ func TestDecodeEncode(t *testing.T) {
PaymentHash: &testPaymentHash,
DescriptionHash: &testDescriptionHash,
Destination: testPubKey,
Features: emptyFeatures,
}
},
},
@ -983,7 +1004,7 @@ func compareInvoices(expected, actual *Invoice) error {
if !reflect.DeepEqual(expected.Features, actual.Features) {
return fmt.Errorf("expected features %v, got %v",
expected.Features.RawFeatureVector, actual.Features.RawFeatureVector)
expected.Features, actual.Features)
}
return nil