diff --git a/lnwire/lnwire.go b/lnwire/lnwire.go index 9709f491..d62d9a43 100644 --- a/lnwire/lnwire.go +++ b/lnwire/lnwire.go @@ -62,252 +62,197 @@ func (c CreditsAmount) ToSatoshi() int64 { return int64(c / 1000) } -// Writes the big endian representation of element -// Unified function to call when writing different types -// Pre-allocate a byte-array of the correct size for cargo-cult security -// More copies but whatever... +// writeElement is a one-stop shop to write the big endian representation of +// any element which is to be serialized for the wire protocol. The passed +// io.Writer should be backed by an appropriatly sized byte slice, or be able +// to dynamically expand to accomdate additional data. +// +// TODO(roasbeef): this should eventually draw from a buffer pool for +// serialization. +// TODO(roasbeef): switch to var-ints for all? func writeElement(w io.Writer, element interface{}) error { - var err error switch e := element.(type) { case uint8: var b [1]byte b[0] = byte(e) - _, err = w.Write(b[:]) - if err != nil { + if _, err := w.Write(b[:]); err != nil { return err } - return nil case uint16: var b [2]byte binary.BigEndian.PutUint16(b[:], uint16(e)) - _, err = w.Write(b[:]) - if err != nil { + if _, err := w.Write(b[:]); err != nil { return err } - return nil case CreditsAmount: - err = binary.Write(w, binary.BigEndian, int64(e)) - if err != nil { + if err := binary.Write(w, binary.BigEndian, int64(e)); err != nil { return err } - return nil case uint32: var b [4]byte binary.BigEndian.PutUint32(b[:], uint32(e)) - _, err = w.Write(b[:]) - if err != nil { + if _, err := w.Write(b[:]); err != nil { return err } - return nil case uint64: var b [8]byte binary.BigEndian.PutUint64(b[:], uint64(e)) - _, err = w.Write(b[:]) - if err != nil { + if _, err := w.Write(b[:]); err != nil { return err } - return nil case HTLCKey: - err = binary.Write(w, binary.BigEndian, int64(e)) - if err != nil { + if err := binary.Write(w, binary.BigEndian, int64(e)); err != nil { return err } - return nil case btcutil.Amount: - err = binary.Write(w, binary.BigEndian, int64(e)) - if err != nil { + if err := binary.Write(w, binary.BigEndian, int64(e)); err != nil { return err } - return nil case *btcec.PublicKey: var b [33]byte serializedPubkey := e.SerializeCompressed() - if len(serializedPubkey) != 33 { - return fmt.Errorf("Wrong size pubkey") - } copy(b[:], serializedPubkey) - _, err = w.Write(b[:]) - if err != nil { + // TODO(roasbeef): use WriteVarBytes here? + if _, err := w.Write(b[:]); err != nil { return err } - return nil case []uint64: + // Enforce a max number of elements in a uint64 slice. numItems := len(e) if numItems > 65535 { return fmt.Errorf("Too many []uint64s") } - // Write the size - err = writeElement(w, uint16(numItems)) - if err != nil { + + // First write out the the number of elements in the slice as a + // length prefix. + if err := writeElement(w, uint16(numItems)); err != nil { return err } - // Write the data + + // After the prefix detailing the number of elements, write out + // each uint64 in series. for i := 0; i < numItems; i++ { - err = writeElement(w, e[i]) - if err != nil { + if err := writeElement(w, e[i]); err != nil { return err } } - return nil case []*btcec.Signature: + // Enforce a sane number for the maximum number of signatures. numSigs := len(e) if numSigs > 127 { return fmt.Errorf("Too many signatures!") } - // Write the size - err = writeElement(w, uint8(numSigs)) - if err != nil { + + // First write out the the number of elements in the slice as a + // length prefix. + if err := writeElement(w, uint8(numSigs)); err != nil { return err } - // Write the data + + // After the prefix detailing the number of elements, write out + // each signature in series. for i := 0; i < numSigs; i++ { - err = writeElement(w, e[i]) - if err != nil { + if err := writeElement(w, e[i]); err != nil { return err } } - return nil case *btcec.Signature: sig := e.Serialize() - sigLength := len(sig) - if sigLength > 73 { + if len(sig) > 73 { return fmt.Errorf("Signature too long!") } - // Write the size - err = writeElement(w, uint8(sigLength)) - if err != nil { + + if err := wire.WriteVarBytes(w, 0, sig); err != nil { return err } - // Write the data - _, err = w.Write(sig) - if err != nil { - return err - } - return nil case *wire.ShaHash: - _, err = w.Write(e[:]) - if err != nil { + if _, err := w.Write(e[:]); err != nil { return err } - return nil - case []*[20]byte: - // Get size of slice and dump in slice + case [][20]byte: + // First write out the number of elements in the slice. sliceSize := len(e) - err = writeElement(w, uint16(sliceSize)) - if err != nil { + if err := writeElement(w, uint16(sliceSize)); err != nil { return err } - // Write in each sequentially + + // Then write each out sequentially. for _, element := range e { - err = writeElement(w, &element) - if err != nil { + if err := writeElement(w, &element); err != nil { return err } } - return nil - case **[20]byte: - _, err = w.Write((*e)[:]) - if err != nil { - return err - } case [20]byte: - _, err = w.Write(e[:]) - if err != nil { + // TODO(roasbeef): should be factor out to caller logic... + if _, err := w.Write(e[:]); err != nil { return err } - return nil case wire.BitcoinNet: var b [4]byte binary.BigEndian.PutUint32(b[:], uint32(e)) - _, err := w.Write(b[:]) - if err != nil { + if _, err := w.Write(b[:]); err != nil { return err } - return nil case []byte: + // Enforce the maxmium length of all slices used in the wire + // protocol. sliceLength := len(e) - if sliceLength > MAX_SLICE_LENGTH { + if sliceLength > MaxSliceLength { return fmt.Errorf("Slice length too long!") } - // Write the size - err = writeElement(w, uint16(sliceLength)) - if err != nil { - return err - } - // Write the data - _, err = w.Write(e) - if err != nil { - return err - } - return nil case PkScript: + // Make sure it's P2PKH or P2SH size or less. scriptLength := len(e) - // Make sure it's P2PKH or P2SH size or less if scriptLength > 25 { return fmt.Errorf("PkScript too long!") } - // Write the size (1-byte) - err = writeElement(w, uint8(scriptLength)) - if err != nil { + + if err := wire.WriteVarBytes(w, 0, e); err != nil { return err } - // Write the data - _, err = w.Write(e) - if err != nil { - return err - } - return nil case string: strlen := len(e) - if strlen > 65535 { + if strlen > MaxSliceLength { return fmt.Errorf("String too long!") } - // Write the size (2-bytes) - err = writeElement(w, uint16(strlen)) - if err != nil { - return err - } - // Write the data - _, err = w.Write([]byte(e)) - if err != nil { + + if err := wire.WriteVarString(w, 0, e); err != nil { return err } case []*wire.TxIn: - // Append the unsigned(!!!) txins // Write the size (1-byte) if len(e) > 127 { return fmt.Errorf("Too many txins") } - err = writeElement(w, uint8(len(e))) - if err != nil { + + // Write out the number of txins. + if err := writeElement(w, uint8(len(e))); err != nil { return err } + // Append the actual TxIns (Size: NumOfTxins * 36) - // Do not include the sequence number to eliminate funny business + // During serialization we leave out the sequence number to + // eliminate any funny business. for _, in := range e { - err = writeElement(w, in) - if err != nil { + if err := writeElement(w, in); err != nil { return err } } - return nil case *wire.TxIn: - // Hash + // First write out the previous txid. var h [32]byte - copy(h[:], e.PreviousOutPoint.Hash.Bytes()) - _, err = w.Write(h[:]) - if err != nil { + copy(h[:], e.PreviousOutPoint.Hash[:]) + if _, err := w.Write(h[:]); err != nil { return err } - // Index + + // Then the exact index of the previous out point. var idx [4]byte binary.BigEndian.PutUint32(idx[:], e.PreviousOutPoint.Index) - _, err = w.Write(idx[:]) - if err != nil { + if _, err := w.Write(idx[:]); err != nil { return err } - return nil - + // TODO(roasbeef): *MsgTx default: return fmt.Errorf("Unknown type in writeElement: %T", e) } @@ -315,6 +260,8 @@ func writeElement(w io.Writer, element interface{}) error { return nil } +// writeElements is writes each element in the elements slice to the passed +// io.Writer using writeElement. func writeElements(w io.Writer, elements ...interface{}) error { for _, element := range elements { err := writeElement(w, element) @@ -325,89 +272,73 @@ func writeElements(w io.Writer, elements ...interface{}) error { return nil } +// readElement is a one-stop utility function to deserialize any datastructure +// encoded using the serialization format of lnwire. func readElement(r io.Reader, element interface{}) error { var err error switch e := element.(type) { case *uint8: var b [1]uint8 - _, err = r.Read(b[:]) - if err != nil { + if _, err := r.Read(b[:]); err != nil { return err } *e = b[0] - return nil case *uint16: var b [2]byte - _, err = io.ReadFull(r, b[:]) - if err != nil { + if _, err := io.ReadFull(r, b[:]); err != nil { return err } *e = binary.BigEndian.Uint16(b[:]) - return nil case *CreditsAmount: var b [8]byte - _, err = io.ReadFull(r, b[:]) - if err != nil { + if _, err := io.ReadFull(r, b[:]); err != nil { return err } *e = CreditsAmount(int64(binary.BigEndian.Uint64(b[:]))) - return nil case *uint32: var b [4]byte - _, err = io.ReadFull(r, b[:]) - if err != nil { + if _, err := io.ReadFull(r, b[:]); err != nil { return err } *e = binary.BigEndian.Uint32(b[:]) - return nil case *uint64: var b [8]byte - _, err = io.ReadFull(r, b[:]) - if err != nil { + if _, err := io.ReadFull(r, b[:]); err != nil { return err } *e = binary.BigEndian.Uint64(b[:]) - return nil case *HTLCKey: var b [8]byte - _, err = io.ReadFull(r, b[:]) - if err != nil { + if _, err := io.ReadFull(r, b[:]); err != nil { return err } *e = HTLCKey(int64(binary.BigEndian.Uint64(b[:]))) - return nil case *btcutil.Amount: var b [8]byte - _, err = io.ReadFull(r, b[:]) - if err != nil { + if _, err := io.ReadFull(r, b[:]); err != nil { return err } *e = btcutil.Amount(int64(binary.BigEndian.Uint64(b[:]))) - return nil case **wire.ShaHash: var b wire.ShaHash - _, err = io.ReadFull(r, b[:]) - if err != nil { + if _, err := io.ReadFull(r, b[:]); err != nil { return err } *e = &b - return nil case **btcec.PublicKey: var b [33]byte - _, err = io.ReadFull(r, b[:]) + if _, err = io.ReadFull(r, b[:]); err != nil { + return err + } + + pubKey, err := btcec.ParsePubKey(b[:], btcec.S256()) if err != nil { return err } - x, err := btcec.ParsePubKey(b[:], btcec.S256()) - if err != nil { - return err - } - *e = x - return nil + *e = pubKey case *[]uint64: var numItems uint16 - err = readElement(r, &numItems) - if err != nil { + if err := readElement(r, &numItems); err != nil { return err } // if numItems > 65535 { @@ -425,7 +356,6 @@ func readElement(r io.Reader, element interface{}) error { items = append(items, item) } *e = items - return nil case *[]*btcec.Signature: var numSigs uint8 err = readElement(r, &numSigs) @@ -449,39 +379,25 @@ func readElement(r io.Reader, element interface{}) error { *e = sigs return nil case **btcec.Signature: - var sigLength uint8 - err = readElement(r, &sigLength) + sigBytes, err := wire.ReadVarBytes(r, 0, 73, "signature") if err != nil { return err } - if sigLength > 73 { - return fmt.Errorf("Signature too long!") - } - - // Read the sig length - l := io.LimitReader(r, int64(sigLength)) - sig, err := ioutil.ReadAll(l) + sig, err := btcec.ParseSignature(sigBytes, btcec.S256()) if err != nil { return err } - if len(sig) != int(sigLength) { - return fmt.Errorf("EOF: Signature length mismatch.") - } - btcecSig, err := btcec.ParseSignature(sig, btcec.S256()) - if err != nil { - return err - } - *e = btcecSig - return nil - case *[]*[20]byte: + *e = sig + case *[][20]byte: // How many to read var sliceSize uint16 err = readElement(r, &sliceSize) if err != nil { return err } - var data []*[20]byte + + data := make([][20]byte, 0, sliceSize) // Append the actual for i := uint16(0); i < sliceSize; i++ { var element [20]byte @@ -489,93 +405,42 @@ func readElement(r io.Reader, element interface{}) error { if err != nil { return err } - data = append(data, &element) + data = append(data, element) } *e = data - return nil case *[20]byte: - _, err = io.ReadFull(r, e[:]) - if err != nil { + if _, err = io.ReadFull(r, e[:]); err != nil { return err } - return nil case *wire.BitcoinNet: var b [4]byte - _, err := io.ReadFull(r, b[:]) - if err != nil { + if _, err := io.ReadFull(r, b[:]); err != nil { return err } *e = wire.BitcoinNet(binary.BigEndian.Uint32(b[:])) return nil case *[]byte: - // Get the blob length first - var blobLength uint16 - err = readElement(r, &blobLength) + bytes, err := wire.ReadVarBytes(r, 0, MaxSliceLength, "byte slice") if err != nil { return err } - - // Shouldn't need to do this, since it's uint16, but we - // might have a different value for MAX_SLICE_LENGTH... - if int(blobLength) > MAX_SLICE_LENGTH { - return fmt.Errorf("Slice length too long!") - } - - // Read the slice length - l := io.LimitReader(r, int64(blobLength)) - *e, err = ioutil.ReadAll(l) - if err != nil { - return err - } - if len(*e) != int(blobLength) { - return fmt.Errorf("EOF: Slice length mismatch.") - } - return nil + *e = bytes case *PkScript: - // Get the script length first - var scriptLength uint8 - err = readElement(r, &scriptLength) + pkScript, err := wire.ReadVarBytes(r, 0, 25, "pkscript") if err != nil { return err } - - if scriptLength > 25 { - return fmt.Errorf("PkScript too long!") - } - - // Read the script length - l := io.LimitReader(r, int64(scriptLength)) - *e, err = ioutil.ReadAll(l) - if err != nil { - return err - } - if len(*e) != int(scriptLength) { - return fmt.Errorf("EOF: Signature length mismatch.") - } - return nil + *e = pkScript case *string: - // Get the string length first - var strlen uint16 - err = readElement(r, &strlen) + str, err := wire.ReadVarString(r, 0) if err != nil { return err } - // Read the string for the length - l := io.LimitReader(r, int64(strlen)) - b, err := ioutil.ReadAll(l) - if len(b) != int(strlen) { - return fmt.Errorf("EOF: String length mismatch.") - } - *e = string(b) - if err != nil { - return err - } - return nil + *e = str case *[]*wire.TxIn: // Read the size (1-byte number of txins) var numScripts uint8 - err = readElement(r, &numScripts) - if err != nil { + if err := readElement(r, &numScripts); err != nil { return err } if numScripts > 127 { @@ -583,23 +448,20 @@ func readElement(r io.Reader, element interface{}) error { } // Append the actual TxIns - var txins []*wire.TxIn + txins := make([]*wire.TxIn, 0, numScripts) for i := uint8(0); i < numScripts; i++ { outpoint := new(wire.OutPoint) txin := wire.NewTxIn(outpoint, nil, nil) - err = readElement(r, &txin) - if err != nil { + if err := readElement(r, &txin); err != nil { return err } txins = append(txins, txin) } *e = txins - return nil case **wire.TxIn: // Hash var h [32]byte - _, err = io.ReadFull(r, h[:]) - if err != nil { + if _, err = io.ReadFull(r, h[:]); err != nil { return err } hash, err := wire.NewShaHash(h[:]) @@ -607,6 +469,7 @@ func readElement(r io.Reader, element interface{}) error { return err } (*e).PreviousOutPoint.Hash = *hash + // Index var idxBytes [4]byte _, err = io.ReadFull(r, idxBytes[:]) @@ -622,6 +485,9 @@ func readElement(r io.Reader, element interface{}) error { return nil } +// readElements deserializes a variable number of elements into the passed +// io.Reader, with each element being deserialized according to the readElement +// function. func readElements(r io.Reader, elements ...interface{}) error { for _, element := range elements { err := readElement(r, element)