diff --git a/channeldb/addr.go b/channeldb/addr.go index 0d40099e..dd057265 100644 --- a/channeldb/addr.go +++ b/channeldb/addr.go @@ -3,6 +3,7 @@ package channeldb import ( "encoding/binary" "errors" + "fmt" "io" "net" @@ -43,6 +44,10 @@ func encodeTCPAddr(w io.Writer, addr *net.TCPAddr) error { ip = addr.IP.To16() } + if ip == nil { + return fmt.Errorf("unable to encode IP %v", addr.IP) + } + if _, err := w.Write([]byte{addrType}); err != nil { return err } @@ -64,7 +69,8 @@ func encodeTCPAddr(w io.Writer, addr *net.TCPAddr) error { // representation. func encodeOnionAddr(w io.Writer, addr *tor.OnionAddr) error { var suffixIndex int - switch len(addr.OnionService) { + hostLen := len(addr.OnionService) + switch hostLen { case tor.V2Len: if _, err := w.Write([]byte{byte(v2OnionAddr)}); err != nil { return err @@ -79,12 +85,29 @@ func encodeOnionAddr(w io.Writer, addr *tor.OnionAddr) error { return errors.New("unknown onion service length") } + suffix := addr.OnionService[suffixIndex:] + if suffix != tor.OnionSuffix { + return fmt.Errorf("invalid suffix \"%v\"", suffix) + } + host, err := tor.Base32Encoding.DecodeString( addr.OnionService[:suffixIndex], ) if err != nil { return err } + + // Sanity check the decoded length. + switch { + case hostLen == tor.V2Len && len(host) != tor.V2DecodedLen: + return fmt.Errorf("onion service %v decoded to invalid host %x", + addr.OnionService, host) + + case hostLen == tor.V3Len && len(host) != tor.V3DecodedLen: + return fmt.Errorf("onion service %v decoded to invalid host %x", + addr.OnionService, host) + } + if _, err := w.Write(host); err != nil { return err } diff --git a/channeldb/addr_test.go b/channeldb/addr_test.go index c4bc4e8e..c761989c 100644 --- a/channeldb/addr_test.go +++ b/channeldb/addr_test.go @@ -3,6 +3,7 @@ package channeldb import ( "bytes" "net" + "strings" "testing" "github.com/lightningnetwork/lnd/tor" @@ -13,19 +14,23 @@ type unknownAddrType struct{} func (t unknownAddrType) Network() string { return "unknown" } func (t unknownAddrType) String() string { return "unknown" } +var testIP4 = net.ParseIP("192.168.1.1") +var testIP6 = net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329") + var addrTests = []struct { expAddr net.Addr - serErr error + serErr string }{ + // Valid addresses. { expAddr: &net.TCPAddr{ - IP: net.ParseIP("192.168.1.1"), + IP: testIP4, Port: 12345, }, }, { expAddr: &net.TCPAddr{ - IP: net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329"), + IP: testIP6, Port: 65535, }, }, @@ -41,9 +46,67 @@ var addrTests = []struct { Port: 80, }, }, + + // Invalid addresses. { expAddr: unknownAddrType{}, - serErr: ErrUnknownAddressType, + serErr: ErrUnknownAddressType.Error(), + }, + { + expAddr: &net.TCPAddr{ + // Remove last byte of IPv4 address. + IP: testIP4[:len(testIP4)-1], + Port: 12345, + }, + serErr: "unable to encode", + }, + { + expAddr: &net.TCPAddr{ + // Add an extra byte of IPv4 address. + IP: append(testIP4, 0xff), + Port: 12345, + }, + serErr: "unable to encode", + }, + { + expAddr: &net.TCPAddr{ + // Remove last byte of IPv6 address. + IP: testIP6[:len(testIP6)-1], + Port: 65535, + }, + serErr: "unable to encode", + }, + { + expAddr: &net.TCPAddr{ + // Add an extra byte to the IPv6 address. + IP: append(testIP6, 0xff), + Port: 65535, + }, + serErr: "unable to encode", + }, + { + expAddr: &tor.OnionAddr{ + // Invalid suffix. + OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd.inion", + Port: 80, + }, + serErr: "invalid suffix", + }, + { + expAddr: &tor.OnionAddr{ + // Invalid length. + OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyy.onion", + Port: 80, + }, + serErr: "unknown onion service length", + }, + { + expAddr: &tor.OnionAddr{ + // Invalid encoding. + OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyA.onion", + Port: 80, + }, + serErr: "illegal base32", }, } @@ -55,11 +118,21 @@ func TestAddrSerialization(t *testing.T) { var b bytes.Buffer for _, test := range addrTests { err := serializeAddr(&b, test.expAddr) - if err != test.serErr { + switch { + case err == nil && test.serErr != "": + t.Fatalf("expected serialization err for addr %v", + test.expAddr) + + case err != nil && test.serErr == "": + t.Fatalf("unexpected serialization err for addr %v: %v", + test.expAddr, err) + + case err != nil && !strings.Contains(err.Error(), test.serErr): t.Fatalf("unexpected serialization err for addr %v, "+ - "want: %v, got %v", - test.expAddr, test.serErr, err) - } else if test.serErr != nil { + "want: %v, got %v", test.expAddr, test.serErr, + err) + + case err != nil: continue }