From 6e7fcac1f5d21982b2bbf5abac40a195d11bd338 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 18 Sep 2017 17:27:37 -0700 Subject: [PATCH] lnwire: properly encode/decode addrs in NodeAnnouncement msg This commit fixes an existing deviation in the way we encode+decode the addresses within the NodeAnnouncement message with that of the specification. Prior to this commit, we would encode the _number_ of addresses, rather than the number of bytes it takes to encode all the addresses. In this commit, we fix this mistake by properly writing out the total number of bytes, modifying our parsing to take account of this new encoding. --- lnwire/lnwire.go | 154 ++++++++++++++++++++++++++++++++++-------- lnwire/lnwire_test.go | 1 + 2 files changed, 125 insertions(+), 30 deletions(-) diff --git a/lnwire/lnwire.go b/lnwire/lnwire.go index 1a9ea7f3..a62593f9 100644 --- a/lnwire/lnwire.go +++ b/lnwire/lnwire.go @@ -1,6 +1,7 @@ package lnwire import ( + "bytes" "encoding/binary" "fmt" "io" @@ -28,11 +29,47 @@ type PkScript []byte type addressType uint8 const ( - tcp4Addr addressType = 1 - tcp6Addr addressType = 2 - onionAddr addressType = 3 + // noAddr denotes a blank address. An address of this type indicates + // that a node doesn't have any advertise d addresses. + noAddr addressType = 0 + + // tcp4Addr denotes an IPv4 TCP address. + tcp4Addr addressType = 1 + + // tcp4Addr denotes an IPv6 TCP address. + tcp6Addr addressType = 2 + + // v2OnionAddr denotes a version 2 Tor onion service address. + v2OnionAddr addressType = 3 + + // v3OnionAddr denotes a version 3 Tor (prop224) onion service + // addresses + v3OnionAddr addressType = 4 ) +// AddrLen returns the number of bytes that it takes to encode the target +// address. +func (a addressType) AddrLen() uint16 { + switch a { + case noAddr: + return 0 + case tcp4Addr: + return 6 + + case tcp6Addr: + return 18 + + case v2OnionAddr: + return 12 + + case v3OnionAddr: + return 37 + + default: + return 0 + } +} + // writeElement is a one-stop shop to write the big endian representation of // any element which is to be serialized for the wire protocol. The passed // io.Writer should be backed by an appropriately sized byte slice, or be able @@ -93,7 +130,6 @@ func writeElement(w io.Writer, element interface{}) error { var b [33]byte serializedPubkey := e.SerializeCompressed() copy(b[:], serializedPubkey) - // TODO(roasbeef): use WriteVarBytes here? if _, err := w.Write(b[:]); err != nil { return err } @@ -251,6 +287,7 @@ func writeElement(w io.Writer, element interface{}) error { return fmt.Errorf("cannot write nil TCPAddr") } + // TODO(roasbeef): account for onion types too if e.IP.To4() != nil { var descriptor [1]byte descriptor[0] = uint8(tcp4Addr) @@ -280,27 +317,37 @@ func writeElement(w io.Writer, element interface{}) error { if _, err := w.Write(port[:]); err != nil { return err } + case []net.Addr: - // Write out the number of addresses. - if err := writeElement(w, uint16(len(e))); err != nil { + // First, we'll encode all the addresses into an intermediate + // buffer. We need to do this in order to compute the total + // length of the addresses. + var addrBuf bytes.Buffer + for _, address := range e { + if err := writeElement(&addrBuf, address); err != nil { + return err + } + } + + // With the addresses fully encoded, we can now write out the + // number of bytes needed to encode them. + addrLen := addrBuf.Len() + if err := writeElement(w, uint16(addrLen)); err != nil { return err } - // Append the actual addresses. - for _, address := range e { - if err := writeElement(w, address); err != nil { + // Finally, we'll write out the raw addresses themselves, but + // only if we have any bytes to write. + if addrLen > 0 { + if _, err := w.Write(addrBuf.Bytes()); err != nil { return err } } case RGB: - err := writeElements(w, - e.red, - e.green, - e.blue, - ) - if err != nil { + if err := writeElements(w, e.red, e.green, e.blue); err != nil { return err } + case DeliveryAddress: var length [2]byte binary.BigEndian.PutUint16(length[:], uint16(len(e))) @@ -531,44 +578,91 @@ func readElement(r io.Reader, element interface{}) error { } case *[]net.Addr: + // First, we'll read the number of total bytes that have been + // used to encode the set of addresses. var numAddrsBytes [2]byte if _, err = io.ReadFull(r, numAddrsBytes[:]); err != nil { return err } + addrsLen := binary.BigEndian.Uint16(numAddrsBytes[:]) - numAddrs := binary.BigEndian.Uint16(numAddrsBytes[:]) - addresses := make([]net.Addr, 0, numAddrs) + // With the number of addresses, read, we'll now pull in the + // buffer of the encoded addresses into memory. + addrs := make([]byte, addrsLen) + if _, err := io.ReadFull(r, addrs[:]); err != nil { + return err + } + addrBuf := bytes.NewReader(addrs) - for i := 0; i < int(numAddrs); i++ { + // Finally, we'll parse the remaining address payload in + // series, using the first byte to denote how to decode the + // address itself. + var ( + addresses []net.Addr + addrBytesRead uint16 + ) + for addrBytesRead < addrsLen { var descriptor [1]byte - if _, err = io.ReadFull(r, descriptor[:]); err != nil { + if _, err = io.ReadFull(addrBuf, descriptor[:]); err != nil { return err } + addrBytesRead += 1 + address := &net.TCPAddr{} - switch descriptor[0] { - case 1: + aType := addressType(descriptor[0]) + switch aType { + + case noAddr: + addrBytesRead += aType.AddrLen() + continue + + case tcp4Addr: var ip [4]byte - if _, err = io.ReadFull(r, ip[:]); err != nil { + if _, err = io.ReadFull(addrBuf, ip[:]); err != nil { return err } address.IP = (net.IP)(ip[:]) - case 2: + + var port [2]byte + if _, err = io.ReadFull(addrBuf, port[:]); err != nil { + return err + } + + address.Port = int(binary.BigEndian.Uint16(port[:])) + + addrBytesRead += aType.AddrLen() + + case tcp6Addr: var ip [16]byte - if _, err = io.ReadFull(r, ip[:]); err != nil { + if _, err = io.ReadFull(addrBuf, ip[:]); err != nil { return err } address.IP = (net.IP)(ip[:]) + + var port [2]byte + if _, err = io.ReadFull(addrBuf, port[:]); err != nil { + return err + } + address.Port = int(binary.BigEndian.Uint16(port[:])) + + addrBytesRead += aType.AddrLen() + + case v2OnionAddr: + addrBytesRead += aType.AddrLen() + continue + + case v3OnionAddr: + addrBytesRead += aType.AddrLen() + continue + + default: + return fmt.Errorf("unknown address type: %v", aType) } - var port [2]byte - if _, err = io.ReadFull(r, port[:]); err != nil { - return err - } - - address.Port = int(binary.BigEndian.Uint16(port[:])) addresses = append(addresses, address) } + *e = addresses case *RGB: err := readElements(r, diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index b937a9b6..d3b6bf2b 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -404,6 +404,7 @@ func TestLightningWireProtocol(t *testing.T) { green: uint8(r.Int31()), blue: uint8(r.Int31()), }, + // TODO(roasbeef): proper gen rand addrs Addresses: testAddrs, } req.Features.featuresMap = nil