From 6f5d673679025b6c97a2837c74b456378875001f Mon Sep 17 00:00:00 2001 From: Wilmer Paulino Date: Tue, 9 Jan 2018 00:46:29 -0500 Subject: [PATCH] invoice: refactor parsing tagged fields This commit refactors parsing each of the tagged fields of an invoice into their own method. This makes the code easier to read and will allow us to introduce unit tests for each parsing method. --- zpay32/invoice.go | 334 ++++++++++++++++++++++++++++------------------ 1 file changed, 202 insertions(+), 132 deletions(-) diff --git a/zpay32/invoice.go b/zpay32/invoice.go index 33752d34..c27ded22 100644 --- a/zpay32/invoice.go +++ b/zpay32/invoice.go @@ -588,7 +588,7 @@ func parseTimestamp(data []byte) (uint64, error) { func parseTaggedFields(invoice *Invoice, fields []byte, net *chaincfg.Params) error { index := 0 for { - // If less than 3 groups less, it cannot possibly contain more + // If there are less than 3 groups to read, there cannot be more // interesting information, as we need the type (1 group) and // length (2 groups). if len(fields)-index < 3 { @@ -596,7 +596,10 @@ func parseTaggedFields(invoice *Invoice, fields []byte, net *chaincfg.Params) er } typ := fields[index] - dataLength := uint16(fields[index+1]<<5) | uint16(fields[index+2]) + dataLength, err := parseFieldDataLength(fields[index+1 : index+3]) + if err != nil { + return err + } // If we don't have enough field data left to read this length, // return error. @@ -616,17 +619,7 @@ func parseTaggedFields(invoice *Invoice, fields []byte, net *chaincfg.Params) er continue } - if dataLength != hashBase32Len { - // Skipping unknown field length. - continue - } - hash, err := bech32.ConvertBits(base32Data, 5, 8, false) - if err != nil { - return err - } - var pHash [32]byte - copy(pHash[:], hash[:]) - invoice.PaymentHash = &pHash + invoice.PaymentHash, err = parsePaymentHash(base32Data) case fieldTypeD: if invoice.Description != nil { // We skip the field if we have already seen a @@ -634,13 +627,7 @@ func parseTaggedFields(invoice *Invoice, fields []byte, net *chaincfg.Params) er continue } - base256Data, err := bech32.ConvertBits(base32Data, 5, 8, - false) - if err != nil { - return err - } - desc := string(base256Data) - invoice.Description = &desc + invoice.Description, err = parseDescription(base32Data) case fieldTypeN: if invoice.Destination != nil { // We skip the field if we have already seen a @@ -648,21 +635,7 @@ func parseTaggedFields(invoice *Invoice, fields []byte, net *chaincfg.Params) er continue } - if len(base32Data) != pubKeyBase32Len { - // Skip unknown length. - continue - } - - base256Data, err := bech32.ConvertBits(base32Data, 5, 8, - false) - if err != nil { - return err - } - invoice.Destination, err = btcec.ParsePubKey(base256Data, - btcec.S256()) - if err != nil { - return err - } + invoice.Destination, err = parseDestination(base32Data) case fieldTypeH: if invoice.DescriptionHash != nil { // We skip the field if we have already seen a @@ -670,17 +643,7 @@ func parseTaggedFields(invoice *Invoice, fields []byte, net *chaincfg.Params) er continue } - if len(base32Data) != hashBase32Len { - // Skip unknown length. - continue - } - hash, err := bech32.ConvertBits(base32Data, 5, 8, false) - if err != nil { - return err - } - var dHash [32]byte - copy(dHash[:], hash[:]) - invoice.DescriptionHash = &dHash + invoice.DescriptionHash, err = parseDescriptionHash(base32Data) case fieldTypeX: if invoice.expiry != nil { // We skip the field if we have already seen a @@ -688,12 +651,7 @@ func parseTaggedFields(invoice *Invoice, fields []byte, net *chaincfg.Params) er continue } - exp, err := base32ToUint64(base32Data) - if err != nil { - return err - } - dur := time.Duration(exp) * time.Second - invoice.expiry = &dur + invoice.expiry, err = parseExpiry(base32Data) case fieldTypeC: if invoice.minFinalCLTVExpiry != nil { // We skip the field if we have already seen a @@ -701,11 +659,7 @@ func parseTaggedFields(invoice *Invoice, fields []byte, net *chaincfg.Params) er continue } - expiry, err := base32ToUint64(base32Data) - if err != nil { - return err - } - invoice.minFinalCLTVExpiry = &expiry + invoice.minFinalCLTVExpiry, err = parseMinFinalCLTVExpiry(base32Data) case fieldTypeF: if invoice.FallbackAddr != nil { // We skip the field if we have already seen a @@ -713,56 +667,7 @@ func parseTaggedFields(invoice *Invoice, fields []byte, net *chaincfg.Params) er continue } - var addr btcutil.Address - version := base32Data[0] - switch version { - case 0: - witness, err := bech32.ConvertBits( - base32Data[1:], 5, 8, false) - if err != nil { - return err - } - switch len(witness) { - case 20: - addr, err = btcutil.NewAddressWitnessPubKeyHash( - witness, net) - case 32: - addr, err = btcutil.NewAddressWitnessScriptHash( - witness, net) - default: - return fmt.Errorf("unknown witness "+ - "program length: %d", len(witness)) - } - if err != nil { - return err - } - case 17: - pkHash, err := bech32.ConvertBits(base32Data[1:], - 5, 8, false) - if err != nil { - return err - } - addr, err = btcutil.NewAddressPubKeyHash(pkHash, - net) - if err != nil { - return err - } - case 18: - scriptHash, err := bech32.ConvertBits( - base32Data[1:], 5, 8, false) - if err != nil { - return err - } - addr, err = btcutil.NewAddressScriptHashFromHash( - scriptHash, net) - if err != nil { - return err - } - default: - // Skipping unknown witness version. - continue - } - invoice.FallbackAddr = addr + invoice.FallbackAddr, err = parseFallbackAddr(base32Data, net) case fieldTypeR: if invoice.RoutingInfo != nil { // We skip the field if we have already seen a @@ -770,39 +675,204 @@ func parseTaggedFields(invoice *Invoice, fields []byte, net *chaincfg.Params) er continue } - base256Data, err := bech32.ConvertBits(base32Data, 5, 8, - false) - if err != nil { - return err - } - - for len(base256Data) > 0 { - info := ExtraRoutingInfo{} - info.PubKey, err = btcec.ParsePubKey( - base256Data[:33], btcec.S256()) - if err != nil { - return err - } - info.ShortChanID = binary.BigEndian.Uint64( - base256Data[33:41]) - info.FeeBaseMsat = binary.BigEndian.Uint32( - base256Data[41:45]) - info.FeeProportionalMillionths = binary.BigEndian.Uint32( - base256Data[45:49]) - info.CltvExpDelta = binary.BigEndian.Uint16( - base256Data[49:51]) - invoice.RoutingInfo = append( - invoice.RoutingInfo, info) - base256Data = base256Data[51:] - } + invoice.RoutingInfo, err = parseRoutingInfo(base32Data) default: // Ignore unknown type. } + + // Check if there was an error from parsing any of the tagged + // fields and return it. + if err != nil { + return err + } } return nil } +// parseFieldDataLength converts the two byte slice into a uint16. +func parseFieldDataLength(data []byte) (uint16, error) { + if len(data) != 2 { + return 0, fmt.Errorf("data length must be 2 bytes, was %d", + len(data)) + } + + return uint16(data[0]<<5) | uint16(data[1]), nil +} + +// parsePaymentHash converts a 256-bit payment hash (encoded in base32) +// to *[32]byte. +func parsePaymentHash(data []byte) (*[32]byte, error) { + var paymentHash [32]byte + + // As BOLT-11 states, a reader must skip over the payment hash field if + // it does not have a length of 52, so avoid returning an error. + if len(data) != hashBase32Len { + return nil, nil + } + + hash, err := bech32.ConvertBits(data, 5, 8, false) + if err != nil { + return nil, err + } + + copy(paymentHash[:], hash[:]) + + return &paymentHash, nil +} + +// parseDescription converts the data (encoded in base32) into a string to use +// as the description. +func parseDescription(data []byte) (*string, error) { + base256Data, err := bech32.ConvertBits(data, 5, 8, false) + if err != nil { + return nil, err + } + + description := string(base256Data) + + return &description, nil +} + +// parseDestination converts the data (encoded in base32) into a 33-byte public +// key of the payee node. +func parseDestination(data []byte) (*btcec.PublicKey, error) { + // As BOLT-11 states, a reader must skip over the destination field + // if it does not have a length of 53, so avoid returning an error. + if len(data) != pubKeyBase32Len { + return nil, nil + } + + base256Data, err := bech32.ConvertBits(data, 5, 8, false) + if err != nil { + return nil, err + } + + return btcec.ParsePubKey(base256Data, btcec.S256()) +} + +// parseDescriptionHash converts a 256-bit description hash (encoded in base32) +// to *[32]byte. +func parseDescriptionHash(data []byte) (*[32]byte, error) { + var descriptionHash [32]byte + + // As BOLT-11 states, a reader must skip over the description hash field + // if it does not have a length of 52, so avoid returning an error. + if len(data) != hashBase32Len { + return nil, nil + } + + hash, err := bech32.ConvertBits(data, 5, 8, false) + if err != nil { + return nil, err + } + + copy(descriptionHash[:], hash[:]) + + return &descriptionHash, nil +} + +// parseExpiry converts the data (encoded in base32) into the expiry time. +func parseExpiry(data []byte) (*time.Duration, error) { + expiry, err := base32ToUint64(data) + if err != nil { + return nil, err + } + + duration := time.Duration(expiry) * time.Second + + return &duration, nil +} + +// parseMinFinalCLTVExpiry converts the data (encoded in base32) into a uint64 +// to use as the minFinalCLTVExpiry. +func parseMinFinalCLTVExpiry(data []byte) (*uint64, error) { + expiry, err := base32ToUint64(data) + if err != nil { + return nil, err + } + + return &expiry, nil +} + +// parseFallbackAddr converts the data (encoded in base32) into a fallback +// on-chain address. +func parseFallbackAddr(data []byte, net *chaincfg.Params) (btcutil.Address, error) { + var addr btcutil.Address + version := data[0] + switch version { + case 0: + witness, err := bech32.ConvertBits(data[1:], 5, 8, false) + if err != nil { + return nil, err + } + + switch len(witness) { + case 20: + addr, err = btcutil.NewAddressWitnessPubKeyHash(witness, net) + case 32: + addr, err = btcutil.NewAddressWitnessScriptHash(witness, net) + default: + return nil, fmt.Errorf("unknown witness program length %d", + len(witness)) + } + + if err != nil { + return nil, err + } + case 17: + pubKeyHash, err := bech32.ConvertBits(data[1:], 5, 8, false) + if err != nil { + return nil, err + } + + addr, err = btcutil.NewAddressPubKeyHash(pubKeyHash, net) + if err != nil { + return nil, err + } + case 18: + scriptHash, err := bech32.ConvertBits(data[1:], 5, 8, false) + if err != nil { + return nil, err + } + + addr, err = btcutil.NewAddressScriptHashFromHash(scriptHash, net) + if err != nil { + return nil, err + } + default: + // Ignore unknown version. + } + + return addr, nil +} + +// parseRoutingInfo converts the data (encoded in base32) into an array +// containing one or more entries of extra routing info. +func parseRoutingInfo(data []byte) ([]ExtraRoutingInfo, error) { + base256Data, err := bech32.ConvertBits(data, 5, 8, false) + if err != nil { + return nil, err + } + + var routingInfo []ExtraRoutingInfo + info := ExtraRoutingInfo{} + for len(base256Data) > 0 { + info.PubKey, err = btcec.ParsePubKey(base256Data[:33], btcec.S256()) + if err != nil { + return nil, err + } + info.ShortChanID = binary.BigEndian.Uint64(base256Data[33:41]) + info.FeeBaseMsat = binary.BigEndian.Uint32(base256Data[41:45]) + info.FeeProportionalMillionths = binary.BigEndian.Uint32(base256Data[45:49]) + info.CltvExpDelta = binary.BigEndian.Uint16(base256Data[49:51]) + routingInfo = append(routingInfo, info) + base256Data = base256Data[51:] + } + + return routingInfo, nil +} + // writeTaggedFields writes the non-nil tagged fields of the Invoice to the // base32 buffer. func writeTaggedFields(bufferBase32 *bytes.Buffer, invoice *Invoice) error {