diff --git a/brontide/conn.go b/brontide/conn.go index 33b550b8..c64d8a64 100644 --- a/brontide/conn.go +++ b/brontide/conn.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tor" ) // Conn is an implementation of net.Conn which enforces an authenticated key @@ -34,12 +35,12 @@ var _ net.Conn = (*Conn)(nil) // public key. In the case of a handshake failure, the connection is closed and // a non-nil error is returned. func Dial(local keychain.SingleKeyECDH, netAddr *lnwire.NetAddress, - dialer func(string, string) (net.Conn, error)) (*Conn, error) { + timeout time.Duration, dialer tor.DialFunc) (*Conn, error) { ipAddr := netAddr.Address.String() var conn net.Conn var err error - conn, err = dialer("tcp", ipAddr) + conn, err = dialer("tcp", ipAddr, timeout) if err != nil { return nil, err } diff --git a/brontide/noise_test.go b/brontide/noise_test.go index ed2229c1..dd0882ce 100644 --- a/brontide/noise_test.go +++ b/brontide/noise_test.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tor" ) type maybeNetConn struct { @@ -66,7 +67,10 @@ func establishTestConnection() (net.Conn, net.Conn, func(), error) { // successful. remoteConnChan := make(chan maybeNetConn, 1) go func() { - remoteConn, err := Dial(remoteKeyECDH, netAddr, net.Dial) + remoteConn, err := Dial( + remoteKeyECDH, netAddr, + tor.DefaultConnTimeout, net.DialTimeout, + ) remoteConnChan <- maybeNetConn{remoteConn, err} }() @@ -196,7 +200,10 @@ func TestConcurrentHandshakes(t *testing.T) { remoteKeyECDH := &keychain.PrivKeyECDH{PrivKey: remotePriv} go func() { - remoteConn, err := Dial(remoteKeyECDH, netAddr, net.Dial) + remoteConn, err := Dial( + remoteKeyECDH, netAddr, + tor.DefaultConnTimeout, net.DialTimeout, + ) connChan <- maybeNetConn{remoteConn, err} }() diff --git a/tor/net.go b/tor/net.go index febf7227..d389cc73 100644 --- a/tor/net.go +++ b/tor/net.go @@ -1,19 +1,31 @@ package tor import ( + "context" "errors" "net" + "time" ) // TODO: this interface and its implementations should ideally be moved // elsewhere as they are not Tor-specific. +const ( + // DefaultConnTimeout is the maximum amount of time a dial will wait for + // a connect to complete. + DefaultConnTimeout time.Duration = time.Second * 120 +) + +// DialFunc is a type defines the signature of a dialer used by our Net +// interface. +type DialFunc func(net, addr string, timeout time.Duration) (net.Conn, error) + // Net is an interface housing a Dial function and several DNS functions that // allows us to abstract the implementations of these functions over different // networks, e.g. clearnet, Tor net, etc. type Net interface { // Dial connects to the address on the named network. - Dial(network, address string) (net.Conn, error) + Dial(network, address string, timeout time.Duration) (net.Conn, error) // LookupHost performs DNS resolution on a given host and returns its // addresses. @@ -21,7 +33,8 @@ type Net interface { // LookupSRV tries to resolve an SRV query of the given service, // protocol, and domain name. - LookupSRV(service, proto, name string) (string, []*net.SRV, error) + LookupSRV(service, proto, name string, + timeout time.Duration) (string, []*net.SRV, error) // ResolveTCPAddr resolves TCP addresses. ResolveTCPAddr(network, address string) (*net.TCPAddr, error) @@ -32,8 +45,10 @@ type Net interface { type ClearNet struct{} // Dial on the regular network uses net.Dial -func (r *ClearNet) Dial(network, address string) (net.Conn, error) { - return net.Dial(network, address) +func (r *ClearNet) Dial( + network, address string, timeout time.Duration) (net.Conn, error) { + + return net.DialTimeout(network, address, timeout) } // LookupHost for regular network uses the net.LookupHost function @@ -42,8 +57,14 @@ func (r *ClearNet) LookupHost(host string) ([]string, error) { } // LookupSRV for regular network uses net.LookupSRV function -func (r *ClearNet) LookupSRV(service, proto, name string) (string, []*net.SRV, error) { - return net.LookupSRV(service, proto, name) +func (r *ClearNet) LookupSRV(service, proto, name string, + timeout time.Duration) (string, []*net.SRV, error) { + + // Create a context with a timeout value. + ctxt, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + return net.DefaultResolver.LookupSRV(ctxt, service, proto, name) } // ResolveTCPAddr for regular network uses net.ResolveTCPAddr function @@ -71,13 +92,15 @@ type ProxyNet struct { // Dial uses the Tor Dial function in order to establish connections through // Tor. Since Tor only supports TCP connections, only TCP networks are allowed. -func (p *ProxyNet) Dial(network, address string) (net.Conn, error) { +func (p *ProxyNet) Dial(network, address string, + timeout time.Duration) (net.Conn, error) { + switch network { case "tcp", "tcp4", "tcp6": default: return nil, errors.New("cannot dial non-tcp network via Tor") } - return Dial(address, p.SOCKS, p.StreamIsolation) + return Dial(address, p.SOCKS, p.StreamIsolation, timeout) } // LookupHost uses the Tor LookupHost function in order to resolve hosts over @@ -88,8 +111,13 @@ func (p *ProxyNet) LookupHost(host string) ([]string, error) { // LookupSRV uses the Tor LookupSRV function in order to resolve SRV DNS queries // over Tor. -func (p *ProxyNet) LookupSRV(service, proto, name string) (string, []*net.SRV, error) { - return LookupSRV(service, proto, name, p.SOCKS, p.DNS, p.StreamIsolation) +func (p *ProxyNet) LookupSRV(service, proto, + name string, timeout time.Duration) (string, []*net.SRV, error) { + + return LookupSRV( + service, proto, name, p.SOCKS, p.DNS, + p.StreamIsolation, timeout, + ) } // ResolveTCPAddr uses the Tor ResolveTCPAddr function in order to resolve TCP diff --git a/tor/tor.go b/tor/tor.go index 3b5d46a0..e5430f57 100644 --- a/tor/tor.go +++ b/tor/tor.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "strconv" + "time" "github.com/btcsuite/btcd/connmgr" "github.com/miekg/dns" @@ -54,8 +55,10 @@ func (c *proxyConn) RemoteAddr() net.Addr { // Dial is a wrapper over the non-exported dial function that returns a wrapper // around net.Conn in order to expose the actual remote address we're dialing, // rather than the proxy's address. -func Dial(address, socksAddr string, streamIsolation bool) (net.Conn, error) { - conn, err := dial(address, socksAddr, streamIsolation) +func Dial(address, socksAddr string, streamIsolation bool, + timeout time.Duration) (net.Conn, error) { + + conn, err := dial(address, socksAddr, streamIsolation, timeout) if err != nil { return nil, err } @@ -75,11 +78,13 @@ func Dial(address, socksAddr string, streamIsolation bool) (net.Conn, error) { } // dial establishes a connection to the address via Tor's SOCKS proxy. Only TCP -// is supported over Tor. The final argument determines if we should force -// stream isolation for this new connection. If we do, then this means this new -// connection will use a fresh circuit, rather than possibly re-using an -// existing circuit. -func dial(address, socksAddr string, streamIsolation bool) (net.Conn, error) { +// is supported over Tor. The argument streamIsolation determines if we should +// force stream isolation for this new connection. If we do, then this means +// this new connection will use a fresh circuit, rather than possibly re-using +// an existing circuit. +func dial(address, socksAddr string, streamIsolation bool, + timeout time.Duration) (net.Conn, error) { + // If we were requested to force stream isolation for this connection, // we'll populate the authentication credentials with random data as // Tor will create a new circuit for each set of credentials. @@ -97,7 +102,8 @@ func dial(address, socksAddr string, streamIsolation bool) (net.Conn, error) { } // Establish the connection through Tor's SOCKS proxy. - dialer, err := proxy.SOCKS5("tcp", socksAddr, auth, proxy.Direct) + proxyDialer := &net.Dialer{Timeout: timeout} + dialer, err := proxy.SOCKS5("tcp", socksAddr, auth, proxyDialer) if err != nil { return nil, err } @@ -121,11 +127,12 @@ func LookupHost(host, socksAddr string) ([]string, error) { // natively support SRV queries so we must route all SRV queries through the // proxy by connecting directly to a DNS server and querying it. The DNS server // must have TCP resolution enabled for the given port. -func LookupSRV(service, proto, name, socksAddr, dnsServer string, - streamIsolation bool) (string, []*net.SRV, error) { +func LookupSRV(service, proto, name, socksAddr, + dnsServer string, streamIsolation bool, + timeout time.Duration) (string, []*net.SRV, error) { // Connect to the DNS server we'll be using to query SRV records. - conn, err := dial(dnsServer, socksAddr, streamIsolation) + conn, err := dial(dnsServer, socksAddr, streamIsolation, timeout) if err != nil { return "", nil, err }