diff --git a/lnwire/lnwire.go b/lnwire/lnwire.go index 1a9ea7f3..a62593f9 100644 --- a/lnwire/lnwire.go +++ b/lnwire/lnwire.go @@ -1,6 +1,7 @@ package lnwire import ( + "bytes" "encoding/binary" "fmt" "io" @@ -28,11 +29,47 @@ type PkScript []byte type addressType uint8 const ( - tcp4Addr addressType = 1 - tcp6Addr addressType = 2 - onionAddr addressType = 3 + // noAddr denotes a blank address. An address of this type indicates + // that a node doesn't have any advertise d addresses. + noAddr addressType = 0 + + // tcp4Addr denotes an IPv4 TCP address. + tcp4Addr addressType = 1 + + // tcp4Addr denotes an IPv6 TCP address. + tcp6Addr addressType = 2 + + // v2OnionAddr denotes a version 2 Tor onion service address. + v2OnionAddr addressType = 3 + + // v3OnionAddr denotes a version 3 Tor (prop224) onion service + // addresses + v3OnionAddr addressType = 4 ) +// AddrLen returns the number of bytes that it takes to encode the target +// address. +func (a addressType) AddrLen() uint16 { + switch a { + case noAddr: + return 0 + case tcp4Addr: + return 6 + + case tcp6Addr: + return 18 + + case v2OnionAddr: + return 12 + + case v3OnionAddr: + return 37 + + default: + return 0 + } +} + // writeElement is a one-stop shop to write the big endian representation of // any element which is to be serialized for the wire protocol. The passed // io.Writer should be backed by an appropriately sized byte slice, or be able @@ -93,7 +130,6 @@ func writeElement(w io.Writer, element interface{}) error { var b [33]byte serializedPubkey := e.SerializeCompressed() copy(b[:], serializedPubkey) - // TODO(roasbeef): use WriteVarBytes here? if _, err := w.Write(b[:]); err != nil { return err } @@ -251,6 +287,7 @@ func writeElement(w io.Writer, element interface{}) error { return fmt.Errorf("cannot write nil TCPAddr") } + // TODO(roasbeef): account for onion types too if e.IP.To4() != nil { var descriptor [1]byte descriptor[0] = uint8(tcp4Addr) @@ -280,27 +317,37 @@ func writeElement(w io.Writer, element interface{}) error { if _, err := w.Write(port[:]); err != nil { return err } + case []net.Addr: - // Write out the number of addresses. - if err := writeElement(w, uint16(len(e))); err != nil { + // First, we'll encode all the addresses into an intermediate + // buffer. We need to do this in order to compute the total + // length of the addresses. + var addrBuf bytes.Buffer + for _, address := range e { + if err := writeElement(&addrBuf, address); err != nil { + return err + } + } + + // With the addresses fully encoded, we can now write out the + // number of bytes needed to encode them. + addrLen := addrBuf.Len() + if err := writeElement(w, uint16(addrLen)); err != nil { return err } - // Append the actual addresses. - for _, address := range e { - if err := writeElement(w, address); err != nil { + // Finally, we'll write out the raw addresses themselves, but + // only if we have any bytes to write. + if addrLen > 0 { + if _, err := w.Write(addrBuf.Bytes()); err != nil { return err } } case RGB: - err := writeElements(w, - e.red, - e.green, - e.blue, - ) - if err != nil { + if err := writeElements(w, e.red, e.green, e.blue); err != nil { return err } + case DeliveryAddress: var length [2]byte binary.BigEndian.PutUint16(length[:], uint16(len(e))) @@ -531,44 +578,91 @@ func readElement(r io.Reader, element interface{}) error { } case *[]net.Addr: + // First, we'll read the number of total bytes that have been + // used to encode the set of addresses. var numAddrsBytes [2]byte if _, err = io.ReadFull(r, numAddrsBytes[:]); err != nil { return err } + addrsLen := binary.BigEndian.Uint16(numAddrsBytes[:]) - numAddrs := binary.BigEndian.Uint16(numAddrsBytes[:]) - addresses := make([]net.Addr, 0, numAddrs) + // With the number of addresses, read, we'll now pull in the + // buffer of the encoded addresses into memory. + addrs := make([]byte, addrsLen) + if _, err := io.ReadFull(r, addrs[:]); err != nil { + return err + } + addrBuf := bytes.NewReader(addrs) - for i := 0; i < int(numAddrs); i++ { + // Finally, we'll parse the remaining address payload in + // series, using the first byte to denote how to decode the + // address itself. + var ( + addresses []net.Addr + addrBytesRead uint16 + ) + for addrBytesRead < addrsLen { var descriptor [1]byte - if _, err = io.ReadFull(r, descriptor[:]); err != nil { + if _, err = io.ReadFull(addrBuf, descriptor[:]); err != nil { return err } + addrBytesRead += 1 + address := &net.TCPAddr{} - switch descriptor[0] { - case 1: + aType := addressType(descriptor[0]) + switch aType { + + case noAddr: + addrBytesRead += aType.AddrLen() + continue + + case tcp4Addr: var ip [4]byte - if _, err = io.ReadFull(r, ip[:]); err != nil { + if _, err = io.ReadFull(addrBuf, ip[:]); err != nil { return err } address.IP = (net.IP)(ip[:]) - case 2: + + var port [2]byte + if _, err = io.ReadFull(addrBuf, port[:]); err != nil { + return err + } + + address.Port = int(binary.BigEndian.Uint16(port[:])) + + addrBytesRead += aType.AddrLen() + + case tcp6Addr: var ip [16]byte - if _, err = io.ReadFull(r, ip[:]); err != nil { + if _, err = io.ReadFull(addrBuf, ip[:]); err != nil { return err } address.IP = (net.IP)(ip[:]) + + var port [2]byte + if _, err = io.ReadFull(addrBuf, port[:]); err != nil { + return err + } + address.Port = int(binary.BigEndian.Uint16(port[:])) + + addrBytesRead += aType.AddrLen() + + case v2OnionAddr: + addrBytesRead += aType.AddrLen() + continue + + case v3OnionAddr: + addrBytesRead += aType.AddrLen() + continue + + default: + return fmt.Errorf("unknown address type: %v", aType) } - var port [2]byte - if _, err = io.ReadFull(r, port[:]); err != nil { - return err - } - - address.Port = int(binary.BigEndian.Uint16(port[:])) addresses = append(addresses, address) } + *e = addresses case *RGB: err := readElements(r, diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index b937a9b6..d3b6bf2b 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -404,6 +404,7 @@ func TestLightningWireProtocol(t *testing.T) { green: uint8(r.Int31()), blue: uint8(r.Int31()), }, + // TODO(roasbeef): proper gen rand addrs Addresses: testAddrs, } req.Features.featuresMap = nil