package lnwire import ( "bytes" "encoding/binary" "fmt" "io" "math" "net" "github.com/go-errors/errors" "github.com/roasbeef/btcd/btcec" "github.com/roasbeef/btcd/chaincfg/chainhash" "github.com/roasbeef/btcd/wire" "github.com/roasbeef/btcutil" ) // MaxSliceLength is the maximum allowed lenth for any opaque byte slices in // the wire protocol. const MaxSliceLength = 65535 // PkScript is simple type definition which represents a raw serialized public // key script. type PkScript []byte // addressType specifies the network protocol and version that should be used // when connecting to a node at a particular address. type addressType uint8 const ( // noAddr denotes a blank address. An address of this type indicates // that a node doesn't have any advertise d addresses. noAddr addressType = 0 // tcp4Addr denotes an IPv4 TCP address. tcp4Addr addressType = 1 // tcp4Addr denotes an IPv6 TCP address. tcp6Addr addressType = 2 // v2OnionAddr denotes a version 2 Tor onion service address. v2OnionAddr addressType = 3 // v3OnionAddr denotes a version 3 Tor (prop224) onion service // addresses v3OnionAddr addressType = 4 ) // AddrLen returns the number of bytes that it takes to encode the target // address. func (a addressType) AddrLen() uint16 { switch a { case noAddr: return 0 case tcp4Addr: return 6 case tcp6Addr: return 18 case v2OnionAddr: return 12 case v3OnionAddr: return 37 default: return 0 } } // writeElement is a one-stop shop to write the big endian representation of // any element which is to be serialized for the wire protocol. The passed // io.Writer should be backed by an appropriately sized byte slice, or be able // to dynamically expand to accommodate additional data. // // TODO(roasbeef): this should eventually draw from a buffer pool for // serialization. // TODO(roasbeef): switch to var-ints for all? func writeElement(w io.Writer, element interface{}) error { switch e := element.(type) { case uint8: var b [1]byte b[0] = e if _, err := w.Write(b[:]); err != nil { return err } case uint16: var b [2]byte binary.BigEndian.PutUint16(b[:], e) if _, err := w.Write(b[:]); err != nil { return err } case ErrorCode: var b [2]byte binary.BigEndian.PutUint16(b[:], uint16(e)) if _, err := w.Write(b[:]); err != nil { return err } case MilliSatoshi: var b [8]byte binary.BigEndian.PutUint64(b[:], uint64(e)) if _, err := w.Write(b[:]); err != nil { return err } case btcutil.Amount: var b [8]byte binary.BigEndian.PutUint64(b[:], uint64(e)) if _, err := w.Write(b[:]); err != nil { return err } case uint32: var b [4]byte binary.BigEndian.PutUint32(b[:], e) if _, err := w.Write(b[:]); err != nil { return err } case uint64: var b [8]byte binary.BigEndian.PutUint64(b[:], e) if _, err := w.Write(b[:]); err != nil { return err } case *btcec.PublicKey: if e == nil { return fmt.Errorf("cannot write nil pubkey") } var b [33]byte serializedPubkey := e.SerializeCompressed() copy(b[:], serializedPubkey) if _, err := w.Write(b[:]); err != nil { return err } case []*btcec.Signature: var b [2]byte numSigs := uint16(len(e)) binary.BigEndian.PutUint16(b[:], numSigs) if _, err := w.Write(b[:]); err != nil { return err } for _, sig := range e { if err := writeElement(w, sig); err != nil { return err } } case *btcec.Signature: if e == nil { return fmt.Errorf("cannot write nil signature") } var b [64]byte err := SerializeSigToWire(&b, e) if err != nil { return err } // Write buffer if _, err = w.Write(b[:]); err != nil { return err } case PingPayload: var l [2]byte binary.BigEndian.PutUint16(l[:], uint16(len(e))) if _, err := w.Write(l[:]); err != nil { return err } if _, err := w.Write(e[:]); err != nil { return err } case PongPayload: var l [2]byte binary.BigEndian.PutUint16(l[:], uint16(len(e))) if _, err := w.Write(l[:]); err != nil { return err } if _, err := w.Write(e[:]); err != nil { return err } case ErrorData: var l [2]byte binary.BigEndian.PutUint16(l[:], uint16(len(e))) if _, err := w.Write(l[:]); err != nil { return err } if _, err := w.Write(e[:]); err != nil { return err } case OpaqueReason: var l [2]byte binary.BigEndian.PutUint16(l[:], uint16(len(e))) if _, err := w.Write(l[:]); err != nil { return err } if _, err := w.Write(e[:]); err != nil { return err } case []byte: if _, err := w.Write(e[:]); err != nil { return err } case PkScript: // The largest script we'll accept is a p2wsh which is exactly // 34 bytes long. scriptLength := len(e) if scriptLength > 34 { return fmt.Errorf("'PkScript' too long") } if err := wire.WriteVarBytes(w, 0, e); err != nil { return err } case *FeatureVector: if e == nil { return fmt.Errorf("cannot write nil feature vector") } if err := e.Encode(w); err != nil { return err } case wire.OutPoint: var h [32]byte copy(h[:], e.Hash[:]) if _, err := w.Write(h[:]); err != nil { return err } if e.Index > math.MaxUint16 { return fmt.Errorf("index for outpoint (%v) is "+ "greater than max index of %v", e.Index, math.MaxUint16) } var idx [2]byte binary.BigEndian.PutUint16(idx[:], uint16(e.Index)) if _, err := w.Write(idx[:]); err != nil { return err } case ChannelID: if _, err := w.Write(e[:]); err != nil { return err } case FailCode: if err := writeElement(w, uint16(e)); err != nil { return err } case ShortChannelID: // Check that field fit in 3 bytes and write the blockHeight if e.BlockHeight > ((1 << 24) - 1) { return errors.New("block height should fit in 3 bytes") } var blockHeight [4]byte binary.BigEndian.PutUint32(blockHeight[:], e.BlockHeight) if _, err := w.Write(blockHeight[1:]); err != nil { return err } // Check that field fit in 3 bytes and write the txIndex if e.TxIndex > ((1 << 24) - 1) { return errors.New("tx index should fit in 3 bytes") } var txIndex [4]byte binary.BigEndian.PutUint32(txIndex[:], e.TxIndex) if _, err := w.Write(txIndex[1:]); err != nil { return err } // Write the txPosition var txPosition [2]byte binary.BigEndian.PutUint16(txPosition[:], e.TxPosition) if _, err := w.Write(txPosition[:]); err != nil { return err } case *net.TCPAddr: if e == nil { return fmt.Errorf("cannot write nil TCPAddr") } // TODO(roasbeef): account for onion types too if e.IP.To4() != nil { var descriptor [1]byte descriptor[0] = uint8(tcp4Addr) if _, err := w.Write(descriptor[:]); err != nil { return err } var ip [4]byte copy(ip[:], e.IP.To4()) if _, err := w.Write(ip[:]); err != nil { return err } } else { var descriptor [1]byte descriptor[0] = uint8(tcp6Addr) if _, err := w.Write(descriptor[:]); err != nil { return err } var ip [16]byte copy(ip[:], e.IP.To16()) if _, err := w.Write(ip[:]); err != nil { return err } } var port [2]byte binary.BigEndian.PutUint16(port[:], uint16(e.Port)) if _, err := w.Write(port[:]); err != nil { return err } case []net.Addr: // First, we'll encode all the addresses into an intermediate // buffer. We need to do this in order to compute the total // length of the addresses. var addrBuf bytes.Buffer for _, address := range e { if err := writeElement(&addrBuf, address); err != nil { return err } } // With the addresses fully encoded, we can now write out the // number of bytes needed to encode them. addrLen := addrBuf.Len() if err := writeElement(w, uint16(addrLen)); err != nil { return err } // Finally, we'll write out the raw addresses themselves, but // only if we have any bytes to write. if addrLen > 0 { if _, err := w.Write(addrBuf.Bytes()); err != nil { return err } } case RGB: if err := writeElements(w, e.red, e.green, e.blue); err != nil { return err } case DeliveryAddress: var length [2]byte binary.BigEndian.PutUint16(length[:], uint16(len(e))) if _, err := w.Write(length[:]); err != nil { return err } if _, err := w.Write(e[:]); err != nil { return err } default: return fmt.Errorf("Unknown type in writeElement: %T", e) } return nil } // writeElements is writes each element in the elements slice to the passed // io.Writer using writeElement. func writeElements(w io.Writer, elements ...interface{}) error { for _, element := range elements { err := writeElement(w, element) if err != nil { return err } } return nil } // readElement is a one-stop utility function to deserialize any datastructure // encoded using the serialization format of lnwire. func readElement(r io.Reader, element interface{}) error { var err error switch e := element.(type) { case *uint8: var b [1]uint8 if _, err := r.Read(b[:]); err != nil { return err } *e = b[0] case *uint16: var b [2]byte if _, err := io.ReadFull(r, b[:]); err != nil { return err } *e = binary.BigEndian.Uint16(b[:]) case *ErrorCode: var b [2]byte if _, err := io.ReadFull(r, b[:]); err != nil { return err } *e = ErrorCode(binary.BigEndian.Uint16(b[:])) case *uint32: var b [4]byte if _, err := io.ReadFull(r, b[:]); err != nil { return err } *e = binary.BigEndian.Uint32(b[:]) case *uint64: var b [8]byte if _, err := io.ReadFull(r, b[:]); err != nil { return err } *e = binary.BigEndian.Uint64(b[:]) case *MilliSatoshi: var b [8]byte if _, err := io.ReadFull(r, b[:]); err != nil { return err } *e = MilliSatoshi(int64(binary.BigEndian.Uint64(b[:]))) case *btcutil.Amount: var b [8]byte if _, err := io.ReadFull(r, b[:]); err != nil { return err } *e = btcutil.Amount(int64(binary.BigEndian.Uint64(b[:]))) case **btcec.PublicKey: var b [btcec.PubKeyBytesLenCompressed]byte if _, err = io.ReadFull(r, b[:]); err != nil { return err } pubKey, err := btcec.ParsePubKey(b[:], btcec.S256()) if err != nil { return err } *e = pubKey case **FeatureVector: f, err := NewFeatureVectorFromReader(r) if err != nil { return err } *e = f case *[]*btcec.Signature: var l [2]byte if _, err := io.ReadFull(r, l[:]); err != nil { return err } numSigs := binary.BigEndian.Uint16(l[:]) var sigs []*btcec.Signature if numSigs > 0 { sigs = make([]*btcec.Signature, numSigs) for i := 0; i < int(numSigs); i++ { if err := readElement(r, &sigs[i]); err != nil { return err } } } *e = sigs case **btcec.Signature: var b [64]byte if _, err := io.ReadFull(r, b[:]); err != nil { return err } err = DeserializeSigFromWire(e, b) if err != nil { return err } case *OpaqueReason: var l [2]byte if _, err := io.ReadFull(r, l[:]); err != nil { return err } reasonLen := binary.BigEndian.Uint16(l[:]) *e = OpaqueReason(make([]byte, reasonLen)) if _, err := io.ReadFull(r, *e); err != nil { return err } case *ErrorData: var l [2]byte if _, err := io.ReadFull(r, l[:]); err != nil { return err } errorLen := binary.BigEndian.Uint16(l[:]) *e = ErrorData(make([]byte, errorLen)) if _, err := io.ReadFull(r, *e); err != nil { return err } case *PingPayload: var l [2]byte if _, err := io.ReadFull(r, l[:]); err != nil { return err } pingLen := binary.BigEndian.Uint16(l[:]) *e = PingPayload(make([]byte, pingLen)) if _, err := io.ReadFull(r, *e); err != nil { return err } case *PongPayload: var l [2]byte if _, err := io.ReadFull(r, l[:]); err != nil { return err } pongLen := binary.BigEndian.Uint16(l[:]) *e = PongPayload(make([]byte, pongLen)) if _, err := io.ReadFull(r, *e); err != nil { return err } case []byte: if _, err := io.ReadFull(r, e); err != nil { return err } case *PkScript: pkScript, err := wire.ReadVarBytes(r, 0, 34, "pkscript") if err != nil { return err } *e = pkScript case *wire.OutPoint: var h [32]byte if _, err = io.ReadFull(r, h[:]); err != nil { return err } hash, err := chainhash.NewHash(h[:]) if err != nil { return err } var idxBytes [2]byte _, err = io.ReadFull(r, idxBytes[:]) if err != nil { return err } index := binary.BigEndian.Uint16(idxBytes[:]) *e = wire.OutPoint{ Hash: *hash, Index: uint32(index), } case *FailCode: if err := readElement(r, (*uint16)(e)); err != nil { return err } case *ChannelID: if _, err := io.ReadFull(r, e[:]); err != nil { return err } case *ShortChannelID: var blockHeight [4]byte if _, err = io.ReadFull(r, blockHeight[1:]); err != nil { return err } var txIndex [4]byte if _, err = io.ReadFull(r, txIndex[1:]); err != nil { return err } var txPosition [2]byte if _, err = io.ReadFull(r, txPosition[:]); err != nil { return err } *e = ShortChannelID{ BlockHeight: binary.BigEndian.Uint32(blockHeight[:]), TxIndex: binary.BigEndian.Uint32(txIndex[:]), TxPosition: binary.BigEndian.Uint16(txPosition[:]), } case *[]net.Addr: // First, we'll read the number of total bytes that have been // used to encode the set of addresses. var numAddrsBytes [2]byte if _, err = io.ReadFull(r, numAddrsBytes[:]); err != nil { return err } addrsLen := binary.BigEndian.Uint16(numAddrsBytes[:]) // With the number of addresses, read, we'll now pull in the // buffer of the encoded addresses into memory. addrs := make([]byte, addrsLen) if _, err := io.ReadFull(r, addrs[:]); err != nil { return err } addrBuf := bytes.NewReader(addrs) // Finally, we'll parse the remaining address payload in // series, using the first byte to denote how to decode the // address itself. var ( addresses []net.Addr addrBytesRead uint16 ) for addrBytesRead < addrsLen { var descriptor [1]byte if _, err = io.ReadFull(addrBuf, descriptor[:]); err != nil { return err } addrBytesRead++ address := &net.TCPAddr{} aType := addressType(descriptor[0]) switch aType { case noAddr: addrBytesRead += aType.AddrLen() continue case tcp4Addr: var ip [4]byte if _, err = io.ReadFull(addrBuf, ip[:]); err != nil { return err } address.IP = (net.IP)(ip[:]) var port [2]byte if _, err = io.ReadFull(addrBuf, port[:]); err != nil { return err } address.Port = int(binary.BigEndian.Uint16(port[:])) addrBytesRead += aType.AddrLen() case tcp6Addr: var ip [16]byte if _, err = io.ReadFull(addrBuf, ip[:]); err != nil { return err } address.IP = (net.IP)(ip[:]) var port [2]byte if _, err = io.ReadFull(addrBuf, port[:]); err != nil { return err } address.Port = int(binary.BigEndian.Uint16(port[:])) addrBytesRead += aType.AddrLen() case v2OnionAddr: addrBytesRead += aType.AddrLen() continue case v3OnionAddr: addrBytesRead += aType.AddrLen() continue default: return fmt.Errorf("unknown address type: %v", aType) } addresses = append(addresses, address) } *e = addresses case *RGB: err := readElements(r, &e.red, &e.green, &e.blue, ) if err != nil { return err } case *DeliveryAddress: var addrLen [2]byte if _, err = io.ReadFull(r, addrLen[:]); err != nil { return err } length := binary.BigEndian.Uint16(addrLen[:]) var addrBytes [34]byte if length > 34 { return fmt.Errorf("Cannot read %d bytes into addrBytes", length) } if _, err = io.ReadFull(r, addrBytes[:length]); err != nil { return err } *e = addrBytes[:length] default: return fmt.Errorf("Unknown type in readElement: %T", e) } return nil } // readElements deserializes a variable number of elements into the passed // io.Reader, with each element being deserialized according to the readElement // function. func readElements(r io.Reader, elements ...interface{}) error { for _, element := range elements { err := readElement(r, element) if err != nil { return err } } return nil }