From c1b1dd96efcdd50f6e228021fd5846d673c6ebd6 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Wed, 23 May 2018 15:41:16 +0200 Subject: [PATCH] lncfg: move configuration helper methods to new package --- config.go | 170 ++++++------------------------------------ lncfg/address.go | 161 +++++++++++++++++++++++++++++++++++++++ lncfg/address_test.go | 99 ++++++++++++++++++++++++ lnd.go | 9 ++- server.go | 3 +- 5 files changed, 290 insertions(+), 152 deletions(-) create mode 100644 lncfg/address.go create mode 100644 lncfg/address_test.go diff --git a/config.go b/config.go index 10e5184e..5b0779c3 100644 --- a/config.go +++ b/config.go @@ -18,11 +18,11 @@ import ( "strconv" "strings" "time" - "crypto/tls" flags "github.com/jessevdk/go-flags" "github.com/lightningnetwork/lnd/brontide" "github.com/lightningnetwork/lnd/htlcswitch/hodl" + "github.com/lightningnetwork/lnd/lncfg" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tor" "github.com/roasbeef/btcd/btcec" @@ -90,8 +90,6 @@ var ( defaultBitcoindDir = btcutil.AppDataDir("bitcoin", false) defaultLitecoindDir = btcutil.AppDataDir("litecoin", false) - loopBackAddrs = []string{"localhost", "127.0.0.1", "[::1]"} - defaultTorSOCKS = net.JoinHostPort("localhost", strconv.Itoa(defaultTorSOCKSPort)) defaultTorDNS = net.JoinHostPort(defaultTorDNSHost, strconv.Itoa(defaultTorDNSPort)) defaultTorControl = net.JoinHostPort("localhost", strconv.Itoa(defaultTorControlPort)) @@ -431,15 +429,27 @@ func loadConfig() (*config, error) { } // Validate the Tor config parameters. - cfg.Tor.SOCKS = normalizeAddress( + socks, err := lncfg.ParseAddressString( cfg.Tor.SOCKS, strconv.Itoa(defaultTorSOCKSPort), ) - cfg.Tor.DNS = normalizeAddress( + if err != nil { + return nil, err + } + cfg.Tor.SOCKS = socks.String() + dns, err := lncfg.ParseAddressString( cfg.Tor.DNS, strconv.Itoa(defaultTorDNSPort), ) - cfg.Tor.Control = normalizeAddress( + if err != nil { + return nil, err + } + cfg.Tor.DNS = dns.String() + control, err := lncfg.ParseAddressString( cfg.Tor.Control, strconv.Itoa(defaultTorControlPort), ) + if err != nil { + return nil, err + } + cfg.Tor.Control = control.String() switch { case cfg.Tor.V2 && cfg.Tor.V3: return nil, errors.New("either tor.v2 or tor.v3 can be set, " + @@ -783,13 +793,13 @@ func loadConfig() (*config, error) { // For each of the RPC listeners (REST+gRPC), we'll ensure that users // have specified a safe combo for authentication. If not, we'll bail // out with an error. - err := enforceSafeAuthentication( + err = lncfg.EnforceSafeAuthentication( cfg.RPCListeners, !cfg.NoMacaroons, ) if err != nil { return nil, err } - err = enforceSafeAuthentication( + err = lncfg.EnforceSafeAuthentication( cfg.RESTListeners, !cfg.NoMacaroons, ) if err != nil { @@ -805,7 +815,7 @@ func loadConfig() (*config, error) { // Add default port to all RPC listener addresses if needed and remove // duplicate addresses. - cfg.RPCListeners, err = normalizeAddresses( + cfg.RPCListeners, err = lncfg.NormalizeAddresses( cfg.RawRPCListeners, strconv.Itoa(defaultRPCPort), ) if err != nil { @@ -814,7 +824,7 @@ func loadConfig() (*config, error) { // Add default port to all REST listener addresses if needed and remove // duplicate addresses. - cfg.RESTListeners, err = normalizeAddresses( + cfg.RESTListeners, err = lncfg.NormalizeAddresses( cfg.RawRESTListeners, strconv.Itoa(defaultRESTPort), ) if err != nil { @@ -823,7 +833,7 @@ func loadConfig() (*config, error) { // Add default port to all listener addresses if needed and remove // duplicate addresses. - cfg.Listeners, err = normalizeAddresses( + cfg.Listeners, err = lncfg.NormalizeAddresses( cfg.RawListeners, strconv.Itoa(defaultPeerPort), ) if err != nil { @@ -832,7 +842,7 @@ func loadConfig() (*config, error) { // Add default port to all external IP addresses if needed and remove // duplicate addresses. - cfg.ExternalIPs, err = normalizeAddresses( + cfg.ExternalIPs, err = lncfg.NormalizeAddresses( cfg.RawExternalIPs, strconv.Itoa(defaultPeerPort), ) if err != nil { @@ -843,7 +853,7 @@ func loadConfig() (*config, error) { // Also, we would need to refactor the brontide listener to support // that. for _, p2pListener := range cfg.Listeners { - if isUnix(p2pListener) { + if lncfg.IsUnix(p2pListener) { err := fmt.Errorf("unix socket addresses cannot be " + "used for the p2p connection listener: %s", p2pListener) @@ -1241,140 +1251,6 @@ func extractBitcoindRPCParams(bitcoindConfigPath string) (string, string, string string(zmqPathSubmatches[1]), nil } -// normalizeAddresses returns a new slice with all the passed addresses -// normalized with the given default port and all duplicates removed. -func normalizeAddresses(addrs []string, - defaultPort string) ([]net.Addr, error) { - result := make([]net.Addr, 0, len(addrs)) - seen := map[string]struct{}{} - for _, strAddr := range addrs { - addr, err := parseAddressString(strAddr, defaultPort) - if err != nil { - return nil, err - } - - if _, ok := seen[addr.String()]; !ok { - result = append(result, addr) - seen[addr.String()] = struct{}{} - } - } - return result, nil -} - -// enforceSafeAuthentication enforces "safe" authentication taking into account -// the interfaces that the RPC servers are listening on, and if macaroons are -// activated or not. To protect users from using dangerous config combinations, -// we'll prevent disabling authentication if the sever is listening on a public -// interface. -func enforceSafeAuthentication(addrs []net.Addr, macaroonsActive bool) error { - // We'll now examine all addresses that this RPC server is listening - // on. If it's a localhost address, we'll skip it, otherwise, we'll - // return an error if macaroons are inactive. - for _, addr := range addrs { - if isLoopback(addr) || isUnix(addr) { - continue - } - - if !macaroonsActive { - return fmt.Errorf("Detected RPC server listening on "+ - "publicly reachable interface %v with "+ - "authentication disabled! Refusing to start "+ - "with --no-macaroons specified.", addr) - } - } - - return nil -} - -// listenOnAddress creates a listener that listens on the given -// address. -func listenOnAddress(addr net.Addr) (net.Listener, error) { - return net.Listen(addr.Network(), addr.String()) -} - -// tlsListenOnAddress creates a TLS listener that listens on the given -// address. -func tlsListenOnAddress(addr net.Addr, - config *tls.Config) (net.Listener, error) { - return tls.Listen(addr.Network(), addr.String(), config) -} - -// isLoopback returns true if an address describes a loopback interface. -func isLoopback(addr net.Addr) bool { - for _, loopback := range loopBackAddrs { - if strings.Contains(addr.String(), loopback) { - return true - } - } - - return false -} - -// isUnix returns true if an address describes an Unix socket address. -func isUnix(addr net.Addr) bool { - return strings.HasPrefix(addr.Network(), "unix") -} - -// parseAddressString converts an address in string format to a net.Addr that is -// compatible with lnd. UDP is not supported because lnd needs reliable -// connections. -func parseAddressString(strAddress string, - defaultPort string) (net.Addr, error) { - var parsedNetwork, parsedAddr string - - // Addresses can either be in network://address:port format or only - // address:port. We want to support both. - if strings.Contains(strAddress, "://") { - parts := strings.Split(strAddress, "://") - parsedNetwork, parsedAddr = parts[0], parts[1] - } else if strings.Contains(strAddress, ":") { - parts := strings.Split(strAddress, ":") - parsedNetwork = parts[0] - parsedAddr = strings.Join(parts[1:], ":") - } - - // Only TCP and Unix socket addresses are valid. We can't use IP or - // UDP only connections for anything we do in lnd. - switch parsedNetwork { - case "unix", "unixpacket": - return net.ResolveUnixAddr(parsedNetwork, parsedAddr) - case "tcp", "tcp4", "tcp6": - return net.ResolveTCPAddr(parsedNetwork, - verifyPort(parsedAddr, defaultPort)) - case "ip", "ip4", "ip6", "udp", "udp4", "udp6", "unixgram": - return nil, fmt.Errorf("only TCP or unix socket "+ - "addresses are supported: %s", parsedAddr) - default: - // There was no network specified, just try to parse as host - // and port. - return net.ResolveTCPAddr( - "tcp", verifyPort(strAddress, defaultPort), - ) - } -} - -// verifyPort makes sure that an address string has both a host and a port. -// If there is no port found, the default port is appended. -func verifyPort(strAddress string, defaultPort string) string { - host, port, err := net.SplitHostPort(strAddress) - if err != nil { - // If we already have an IPv6 address with brackets, don't use - // the JoinHostPort function, since it will always add a pair - // of brackets too. - if strings.HasPrefix(strAddress, "[") { - strAddress = strAddress + ":" + defaultPort - } else { - strAddress = net.JoinHostPort(strAddress, defaultPort) - } - } else if host == "" && port == "" { - // The string ':' is parsed as valid empty host and empty port. - // But in that case, we want the default port to be applied too. - strAddress = ":" + defaultPort - } - - return strAddress -} - // normalizeNetwork returns the common name of a network type used to create // file paths. This allows differently versioned networks to use the same path. func normalizeNetwork(network string) string { diff --git a/lncfg/address.go b/lncfg/address.go new file mode 100644 index 00000000..c3c418e4 --- /dev/null +++ b/lncfg/address.go @@ -0,0 +1,161 @@ +package lncfg + +import ( + "time" + "net" + "fmt" + "crypto/tls" + "strings" +) + +var ( + loopBackAddrs = []string{"localhost", "127.0.0.1", "[::1]"} +) + +// NormalizeAddresses returns a new slice with all the passed addresses +// normalized with the given default port and all duplicates removed. +func NormalizeAddresses(addrs []string, + defaultPort string) ([]net.Addr, error) { + result := make([]net.Addr, 0, len(addrs)) + seen := map[string]struct{}{} + for _, strAddr := range addrs { + addr, err := ParseAddressString(strAddr, defaultPort) + if err != nil { + return nil, err + } + + if _, ok := seen[addr.String()]; !ok { + result = append(result, addr) + seen[addr.String()] = struct{}{} + } + } + return result, nil +} + +// EnforceSafeAuthentication enforces "safe" authentication taking into account +// the interfaces that the RPC servers are listening on, and if macaroons are +// activated or not. To protect users from using dangerous config combinations, +// we'll prevent disabling authentication if the sever is listening on a public +// interface. +func EnforceSafeAuthentication(addrs []net.Addr, macaroonsActive bool) error { + // We'll now examine all addresses that this RPC server is listening + // on. If it's a localhost address, we'll skip it, otherwise, we'll + // return an error if macaroons are inactive. + for _, addr := range addrs { + if IsLoopback(addr) || IsUnix(addr) { + continue + } + + if !macaroonsActive { + return fmt.Errorf("Detected RPC server listening on "+ + "publicly reachable interface %v with "+ + "authentication disabled! Refusing to start "+ + "with --no-macaroons specified.", addr) + } + } + + return nil +} + +// ListenOnAddress creates a listener that listens on the given +// address. +func ListenOnAddress(addr net.Addr) (net.Listener, error) { + return net.Listen(addr.Network(), addr.String()) +} + +// TlsListenOnAddress creates a TLS listener that listens on the given +// address. +func TlsListenOnAddress(addr net.Addr, + config *tls.Config) (net.Listener, error) { + return tls.Listen(addr.Network(), addr.String(), config) +} + +// IsLoopback returns true if an address describes a loopback interface. +func IsLoopback(addr net.Addr) bool { + for _, loopback := range loopBackAddrs { + if strings.Contains(addr.String(), loopback) { + return true + } + } + + return false +} + +// isUnix returns true if an address describes an Unix socket address. +func IsUnix(addr net.Addr) bool { + return strings.HasPrefix(addr.Network(), "unix") +} + +// ParseAddressString converts an address in string format to a net.Addr that is +// compatible with lnd. UDP is not supported because lnd needs reliable +// connections. +func ParseAddressString(strAddress string, + defaultPort string) (net.Addr, error) { + var parsedNetwork, parsedAddr string + + // Addresses can either be in network://address:port format or only + // address:port. We want to support both. + if strings.Contains(strAddress, "://") { + parts := strings.Split(strAddress, "://") + parsedNetwork, parsedAddr = parts[0], parts[1] + } else if strings.Contains(strAddress, ":") { + parts := strings.Split(strAddress, ":") + parsedNetwork = parts[0] + parsedAddr = strings.Join(parts[1:], ":") + } + + // Only TCP and Unix socket addresses are valid. We can't use IP or + // UDP only connections for anything we do in lnd. + switch parsedNetwork { + case "unix", "unixpacket": + return net.ResolveUnixAddr(parsedNetwork, parsedAddr) + case "tcp", "tcp4", "tcp6": + return net.ResolveTCPAddr(parsedNetwork, + verifyPort(parsedAddr, defaultPort)) + case "ip", "ip4", "ip6", "udp", "udp4", "udp6", "unixgram": + return nil, fmt.Errorf("only TCP or unix socket "+ + "addresses are supported: %s", parsedAddr) + default: + // There was no network specified, just try to parse as host + // and port. + return net.ResolveTCPAddr( + "tcp", verifyPort(strAddress, defaultPort), + ) + } +} + +// verifyPort makes sure that an address string has both a host and a port. +// If there is no port found, the default port is appended. +func verifyPort(strAddress string, defaultPort string) string { + host, port, err := net.SplitHostPort(strAddress) + if err != nil { + // If we already have an IPv6 address with brackets, don't use + // the JoinHostPort function, since it will always add a pair + // of brackets too. + if strings.HasPrefix(strAddress, "[") { + strAddress = strAddress + ":" + defaultPort + } else { + strAddress = net.JoinHostPort(strAddress, defaultPort) + } + } else if host == "" && port == "" { + // The string ':' is parsed as valid empty host and empty port. + // But in that case, we want the default port to be applied too. + strAddress = ":" + defaultPort + } + + return strAddress +} + +// ClientAddressDialer creates a gRPC dialer that can also dial unix socket +// addresses instead of just TCP addresses. +func ClientAddressDialer(defaultPort string) func(string, + time.Duration) (net.Conn, error) { + return func(addr string, timeout time.Duration) (net.Conn, error) { + parsedAddr, err := ParseAddressString(addr, defaultPort) + if err != nil { + return nil, err + } + return net.DialTimeout(parsedAddr.Network(), + parsedAddr.String(), timeout) + } +} diff --git a/lncfg/address_test.go b/lncfg/address_test.go new file mode 100644 index 00000000..8453a46a --- /dev/null +++ b/lncfg/address_test.go @@ -0,0 +1,99 @@ +// +build !rpctest + +package lncfg + +import "testing" + +// addressTest defines a test vector for an address that contains the non- +// normalized input and the expected normalized output. +type addressTest struct { + address string + expectedNetwork string + expectedAddress string + isLoopback bool + isUnix bool +} + +var ( + defaultTestPort = "1234" + addressTestVectors = []addressTest{ + {"tcp://127.0.0.1:9735", "tcp", "127.0.0.1:9735", true, false}, + {"tcp:127.0.0.1:9735", "tcp", "127.0.0.1:9735", true, false}, + {"127.0.0.1:9735", "tcp", "127.0.0.1:9735", true, false}, + {":9735", "tcp", ":9735", false, false}, + {"", "tcp", ":1234", false, false}, + {":", "tcp", ":1234", false, false}, + {"tcp4://127.0.0.1:9735", "tcp", "127.0.0.1:9735", true, false}, + {"tcp4:127.0.0.1:9735", "tcp", "127.0.0.1:9735", true, false}, + {"127.0.0.1", "tcp", "127.0.0.1:1234", true, false}, + {"[::1]", "tcp", "[::1]:1234", true, false}, + {"::1", "tcp", "[::1]:1234", true, false}, + {"tcp6://::1", "tcp", "[::1]:1234", true, false}, + {"tcp6:::1", "tcp", "[::1]:1234", true, false}, + {"localhost:9735", "tcp", "127.0.0.1:9735", true, false}, + {"localhost", "tcp", "127.0.0.1:1234", true, false}, + {"unix:///tmp/lnd.sock", "unix", "/tmp/lnd.sock", false, true}, + {"unix:/tmp/lnd.sock", "unix", "/tmp/lnd.sock", false, true}, + } + invalidTestVectors = []string{ + "some string", + "://", + "12.12.12", + "123", + } +) + +// TestAddresses ensures that all supported address formats can be parsed and +// normalized correctly. +func TestAddresses(t *testing.T) { + // First, test all correct addresses. + for _, testVector := range addressTestVectors { + addr := []string{testVector.address} + normalized, err := NormalizeAddresses(addr, defaultTestPort) + if err != nil { + t.Fatalf("unable to normalize address %s: %v", + testVector.address, err) + } + netAddr := normalized[0] + if err != nil { + t.Fatalf("unable to split normalized address: %v", err) + } + if netAddr.Network() != testVector.expectedNetwork || + netAddr.String() != testVector.expectedAddress { + t.Fatalf( + "mismatched address: expected %s://%s, got "+ + "%s://%s", + testVector.expectedNetwork, + testVector.expectedAddress, + netAddr.Network(), netAddr.String(), + ) + } + isAddrLoopback := IsLoopback(normalized[0]) + if testVector.isLoopback != isAddrLoopback { + t.Fatalf( + "mismatched loopback detection: expected "+ + "%v, got %v for addr %s", + testVector.isLoopback, isAddrLoopback, + testVector.address, + ) + } + isAddrUnix := IsUnix(normalized[0]) + if testVector.isUnix != isAddrUnix { + t.Fatalf( + "mismatched unix detection: expected "+ + "%v, got %v for addr %s", + testVector.isUnix, isAddrUnix, + testVector.address, + ) + } + } + + // Finally, test invalid inputs to see if they are handled correctly. + for _, testVector := range invalidTestVectors { + addr := []string{testVector} + _, err := NormalizeAddresses(addr, defaultTestPort) + if err == nil { + t.Fatalf("expected error when parsing %v", testVector) + } + } +} diff --git a/lnd.go b/lnd.go index 2cc5d5b2..b05bf6e4 100644 --- a/lnd.go +++ b/lnd.go @@ -38,6 +38,7 @@ import ( "github.com/lightningnetwork/lnd/autopilot" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lncfg" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/btcwallet" @@ -522,7 +523,7 @@ func lndMain() error { // Next, Start the gRPC server listening for HTTP/2 connections. for _, listener := range cfg.RPCListeners { - lis, err := listenOnAddress(listener) + lis, err := lncfg.ListenOnAddress(listener) if err != nil { ltndLog.Errorf( "RPC server unable to listen on %s", listener, @@ -545,7 +546,7 @@ func lndMain() error { return err } for _, restEndpoint := range cfg.RESTListeners { - lis, err := tlsListenOnAddress(restEndpoint, tlsConf) + lis, err := lncfg.TlsListenOnAddress(restEndpoint, tlsConf) if err != nil { ltndLog.Errorf( "gRPC proxy unable to listen on %s", @@ -926,7 +927,7 @@ func waitForWalletPassword(grpcEndpoints, restEndpoints []net.Addr, for _, grpcEndpoint := range grpcEndpoints { // Start a gRPC server listening for HTTP/2 connections, solely // used for getting the encryption password from the client. - lis, err := listenOnAddress(grpcEndpoint) + lis, err := lncfg.ListenOnAddress(grpcEndpoint) if err != nil { ltndLog.Errorf( "password RPC server unable to listen on %s", @@ -964,7 +965,7 @@ func waitForWalletPassword(grpcEndpoints, restEndpoints []net.Addr, srv := &http.Server{Handler: mux} for _, restEndpoint := range restEndpoints { - lis, err := tlsListenOnAddress(restEndpoint, tlsConf) + lis, err := lncfg.TlsListenOnAddress(restEndpoint, tlsConf) if err != nil { ltndLog.Errorf( "password gRPC proxy unable to listen on %s", diff --git a/server.go b/server.go index decd9fdb..ebdd6e27 100644 --- a/server.go +++ b/server.go @@ -24,6 +24,7 @@ import ( "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/discovery" "github.com/lightningnetwork/lnd/htlcswitch" + "github.com/lightningnetwork/lnd/lncfg" "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnwallet" @@ -396,7 +397,7 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl, // If external IP addresses have been specified, add those to the list // of this server's addresses. - externalIPs, err := normalizeAddresses( + externalIPs, err := lncfg.NormalizeAddresses( externalIpStrings, strconv.Itoa(defaultPeerPort), ) if err != nil {