From e3cdbcac0c38ee12bfe39b2490f44751af89eb09 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Thu, 9 Nov 2017 19:44:31 -0800 Subject: [PATCH] channeldb: add new codec.go file which houses common serialization funcs In this commit, we add a new files to the channeldb package: codec.go. This file is similar to the way that serialization is done within the lnwire package. The goal of this file is to reduce all the duplication within the common serialization methods across the package. --- channeldb/codec.go | 309 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 309 insertions(+) create mode 100644 channeldb/codec.go diff --git a/channeldb/codec.go b/channeldb/codec.go new file mode 100644 index 00000000..77b0236c --- /dev/null +++ b/channeldb/codec.go @@ -0,0 +1,309 @@ +package channeldb + +import ( + "encoding/binary" + "fmt" + "io" + + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/shachain" + "github.com/roasbeef/btcd/btcec" + "github.com/roasbeef/btcd/chaincfg/chainhash" + "github.com/roasbeef/btcd/wire" + "github.com/roasbeef/btcutil" +) + +// outPointSize is the size of a serialized outpoint on disk. +const outPointSize = 36 + +// writeOutpoint writes an outpoint to the passed writer using the minimal +// amount of bytes possible. +func writeOutpoint(w io.Writer, o *wire.OutPoint) error { + if _, err := w.Write(o.Hash[:]); err != nil { + return err + } + if err := binary.Write(w, byteOrder, o.Index); err != nil { + return err + } + + return nil +} + +// readOutpoint reads an outpoint from the passed reader that was previously +// written using the writeOutpoint struct. +func readOutpoint(r io.Reader, o *wire.OutPoint) error { + if _, err := io.ReadFull(r, o.Hash[:]); err != nil { + return err + } + if err := binary.Read(r, byteOrder, &o.Index); err != nil { + return err + } + + return nil +} + +// writeElement is a one-stop shop to write the big endian representation of +// any element which is to be serialized for storage on disk. The passed +// io.Writer should be backed by an appropriately sized byte slice, or be able +// to dynamically expand to accommodate additional data. +func writeElement(w io.Writer, element interface{}) error { + switch e := element.(type) { + case ChannelType: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case chainhash.Hash: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case wire.OutPoint: + return writeOutpoint(w, &e) + + case lnwire.ShortChannelID: + if err := binary.Write(w, byteOrder, e.ToUint64()); err != nil { + return err + } + + case uint64: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case uint32: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case int32: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case uint16: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case bool: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case btcutil.Amount: + if err := binary.Write(w, byteOrder, uint64(e)); err != nil { + return err + } + + case lnwire.MilliSatoshi: + if err := binary.Write(w, byteOrder, uint64(e)); err != nil { + return err + } + + case *btcec.PublicKey: + b := e.SerializeCompressed() + if _, err := w.Write(b); err != nil { + return err + } + + case shachain.Producer: + return e.Encode(w) + + case shachain.Store: + return e.Encode(w) + + case *wire.MsgTx: + return e.Serialize(w) + + case [32]byte: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case []byte: + if err := wire.WriteVarBytes(w, 0, e); err != nil { + return err + } + + case lnwire.Message: + if _, err := lnwire.WriteMessage(w, e, 0); err != nil { + return err + } + + case ClosureType: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + default: + return fmt.Errorf("Unknown type in writeElement: %T", e) + } + + return nil +} + +// writeElements is writes each element in the elements slice to the passed +// io.Writer using writeElement. +func writeElements(w io.Writer, elements ...interface{}) error { + for _, element := range elements { + err := writeElement(w, element) + if err != nil { + return err + } + } + return nil +} + +// readElement is a one-stop utility function to deserialize any datastructure +// encoded using the serialization format of the database. +func readElement(r io.Reader, element interface{}) error { + switch e := element.(type) { + case *ChannelType: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *chainhash.Hash: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *wire.OutPoint: + return readOutpoint(r, e) + + case *lnwire.ShortChannelID: + var a uint64 + if err := binary.Read(r, byteOrder, &a); err != nil { + return err + } + *e = lnwire.NewShortChanIDFromInt(a) + + case *uint64: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *uint32: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *int32: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *uint16: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *bool: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *btcutil.Amount: + var a uint64 + if err := binary.Read(r, byteOrder, &a); err != nil { + return err + } + + *e = btcutil.Amount(a) + + case *lnwire.MilliSatoshi: + var a uint64 + if err := binary.Read(r, byteOrder, &a); err != nil { + return err + } + + *e = lnwire.MilliSatoshi(a) + + case **btcec.PublicKey: + var b [btcec.PubKeyBytesLenCompressed]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + + pubKey, err := btcec.ParsePubKey(b[:], btcec.S256()) + if err != nil { + return err + } + *e = pubKey + + case *shachain.Producer: + var root [32]byte + if _, err := io.ReadFull(r, root[:]); err != nil { + return err + } + + // TODO(roasbeef): remove + producer, err := shachain.NewRevocationProducerFromBytes(root[:]) + if err != nil { + return err + } + + *e = producer + + case *shachain.Store: + store, err := shachain.NewRevocationStoreFromBytes(r) + if err != nil { + return err + } + + *e = store + + case **wire.MsgTx: + tx := wire.NewMsgTx(2) + if err := tx.Deserialize(r); err != nil { + return err + } + + *e = tx + + case *[32]byte: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *[]byte: + bytes, err := wire.ReadVarBytes(r, 0, 66000, "[]byte") + if err != nil { + return err + } + + *e = bytes + + case *lnwire.Message: + msg, err := lnwire.ReadMessage(r, 0) + if err != nil { + return err + } + + *e = msg + + case *ClosureType: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + default: + return fmt.Errorf("Unknown type in readElement: %T", e) + } + + return nil +} + +// readElements deserializes a variable number of elements into the passed +// io.Reader, with each element being deserialized according to the readElement +// function. +func readElements(r io.Reader, elements ...interface{}) error { + for _, element := range elements { + err := readElement(r, element) + if err != nil { + return err + } + } + return nil +}