diff --git a/zpay32/invoice.go b/zpay32/invoice.go index 83370737..0df3b1d9 100644 --- a/zpay32/invoice.go +++ b/zpay32/invoice.go @@ -642,7 +642,7 @@ func parseTaggedFields(invoice *Invoice, fields []byte, net *chaincfg.Params) er continue } - invoice.PaymentHash, err = parsePaymentHash(base32Data) + invoice.PaymentHash, err = parse32Bytes(base32Data) case fieldTypeD: if invoice.Description != nil { // We skip the field if we have already seen a @@ -666,7 +666,7 @@ func parseTaggedFields(invoice *Invoice, fields []byte, net *chaincfg.Params) er continue } - invoice.DescriptionHash, err = parseDescriptionHash(base32Data) + invoice.DescriptionHash, err = parse32Bytes(base32Data) case fieldTypeX: if invoice.expiry != nil { // We skip the field if we have already seen a @@ -733,12 +733,12 @@ func parseFieldDataLength(data []byte) (uint16, error) { return uint16(data[0])<<5 | uint16(data[1]), nil } -// parsePaymentHash converts a 256-bit payment hash (encoded in base32) -// to *[32]byte. -func parsePaymentHash(data []byte) (*[32]byte, error) { +// parse32Bytes converts a 256-bit value (encoded in base32) to *[32]byte. This +// can be used for payment hashes, description hashes, payment addresses, etc. +func parse32Bytes(data []byte) (*[32]byte, error) { var paymentHash [32]byte - // As BOLT-11 states, a reader must skip over the payment hash field if + // As BOLT-11 states, a reader must skip over the 32-byte fields if // it does not have a length of 52, so avoid returning an error. if len(data) != hashBase32Len { return nil, nil @@ -784,27 +784,6 @@ func parseDestination(data []byte) (*btcec.PublicKey, error) { return btcec.ParsePubKey(base256Data, btcec.S256()) } -// parseDescriptionHash converts a 256-bit description hash (encoded in base32) -// to *[32]byte. -func parseDescriptionHash(data []byte) (*[32]byte, error) { - var descriptionHash [32]byte - - // As BOLT-11 states, a reader must skip over the description hash field - // if it does not have a length of 52, so avoid returning an error. - if len(data) != hashBase32Len { - return nil, nil - } - - hash, err := bech32.ConvertBits(data, 5, 8, false) - if err != nil { - return nil, err - } - - copy(descriptionHash[:], hash[:]) - - return &descriptionHash, nil -} - // parseExpiry converts the data (encoded in base32) into the expiry time. func parseExpiry(data []byte) (*time.Duration, error) { expiry, err := base32ToUint64(data) @@ -944,18 +923,7 @@ func parseFeatures(data []byte) (*lnwire.FeatureVector, error) { // base32 buffer. func writeTaggedFields(bufferBase32 *bytes.Buffer, invoice *Invoice) error { if invoice.PaymentHash != nil { - // Convert 32 byte hash to 52 5-bit groups. - base32, err := bech32.ConvertBits(invoice.PaymentHash[:], 8, 5, - true) - if err != nil { - return err - } - if len(base32) != hashBase32Len { - return fmt.Errorf("invalid payment hash length: %d", - len(invoice.PaymentHash)) - } - - err = writeTaggedField(bufferBase32, fieldTypeP, base32) + err := writeBytes32(bufferBase32, fieldTypeP, *invoice.PaymentHash) if err != nil { return err } @@ -974,19 +942,9 @@ func writeTaggedFields(bufferBase32 *bytes.Buffer, invoice *Invoice) error { } if invoice.DescriptionHash != nil { - // Convert 32 byte hash to 52 5-bit groups. - descBase32, err := bech32.ConvertBits( - invoice.DescriptionHash[:], 8, 5, true) - if err != nil { - return err - } - - if len(descBase32) != hashBase32Len { - return fmt.Errorf("invalid description hash length: %d", - len(invoice.DescriptionHash)) - } - - err = writeTaggedField(bufferBase32, fieldTypeH, descBase32) + err := writeBytes32( + bufferBase32, fieldTypeH, *invoice.DescriptionHash, + ) if err != nil { return err } @@ -1106,6 +1064,18 @@ func writeTaggedFields(bufferBase32 *bytes.Buffer, invoice *Invoice) error { return nil } +// writeBytes32 encodes a 32-byte array as base32 and writes it to bufferBase32 +// under the passed fieldType. +func writeBytes32(bufferBase32 *bytes.Buffer, fieldType byte, b [32]byte) error { + // Convert 32 byte hash to 52 5-bit groups. + base32, err := bech32.ConvertBits(b[:], 8, 5, true) + if err != nil { + return err + } + + return writeTaggedField(bufferBase32, fieldType, base32) +} + // writeTaggedField takes the type of a tagged data field, and the data of // the tagged field (encoded in base32), and writes the type, length and data // to the buffer. diff --git a/zpay32/invoice_internal_test.go b/zpay32/invoice_internal_test.go index 8089c69d..b72c72e0 100644 --- a/zpay32/invoice_internal_test.go +++ b/zpay32/invoice_internal_test.go @@ -314,10 +314,10 @@ func TestParseFieldDataLength(t *testing.T) { } } -// TestParsePaymentHash checks that the payment hash is properly parsed. +// TestParse32Bytes checks that the payment hash is properly parsed. // If the data does not have a length of 52 bytes, we skip over parsing the // field and do not return an error. -func TestParsePaymentHash(t *testing.T) { +func TestParse32Bytes(t *testing.T) { t.Parallel() testPaymentHashData, _ := bech32.ConvertBits(testPaymentHash[:], 8, 5, true) @@ -350,7 +350,7 @@ func TestParsePaymentHash(t *testing.T) { } for i, test := range tests { - paymentHash, err := parsePaymentHash(test.data) + paymentHash, err := parse32Bytes(test.data) if (err == nil) != test.valid { t.Errorf("payment hash decoding test %d failed: %v", i, err) return @@ -458,56 +458,6 @@ func TestParseDestination(t *testing.T) { } } -// TestParseDescriptionHash checks that the description hash is properly parsed. -// If the data does not have a length of 52 bytes, we skip over parsing the -// field and do not return an error. -func TestParseDescriptionHash(t *testing.T) { - t.Parallel() - - testDescriptionHashData, _ := bech32.ConvertBits(testDescriptionHash[:], 8, 5, true) - - tests := []struct { - data []byte - valid bool - result *[32]byte - }{ - { - data: []byte{}, - valid: true, - result: nil, // skip unknown length, not 52 bytes - }, - { - data: []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, - valid: true, - result: nil, // skip unknown length, not 52 bytes - }, - { - data: testDescriptionHashData, - valid: true, - result: &testDescriptionHash, - }, - { - data: append(testDescriptionHashData, 0x0), - valid: true, - result: nil, // skip unknown length, not 52 bytes - }, - } - - for i, test := range tests { - descriptionHash, err := parseDescriptionHash(test.data) - if (err == nil) != test.valid { - t.Errorf("description hash decoding test %d failed: %v", i, err) - return - } - if test.valid && !compareHashes(descriptionHash, test.result) { - t.Fatalf("test %d failed decoding description hash: "+ - "expected %x, got %x", - i, *test.result, *descriptionHash) - return - } - } -} - // TestParseExpiry checks that the expiry is properly parsed. func TestParseExpiry(t *testing.T) { t.Parallel()