brontide+tor:add timeout value for network connections
This commit is contained in:
parent
9dcb522ebc
commit
fb67b58d3f
@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/btcsuite/btcd/btcec"
|
"github.com/btcsuite/btcd/btcec"
|
||||||
"github.com/lightningnetwork/lnd/keychain"
|
"github.com/lightningnetwork/lnd/keychain"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
|
"github.com/lightningnetwork/lnd/tor"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Conn is an implementation of net.Conn which enforces an authenticated key
|
// Conn is an implementation of net.Conn which enforces an authenticated key
|
||||||
@ -34,12 +35,12 @@ var _ net.Conn = (*Conn)(nil)
|
|||||||
// public key. In the case of a handshake failure, the connection is closed and
|
// public key. In the case of a handshake failure, the connection is closed and
|
||||||
// a non-nil error is returned.
|
// a non-nil error is returned.
|
||||||
func Dial(local keychain.SingleKeyECDH, netAddr *lnwire.NetAddress,
|
func Dial(local keychain.SingleKeyECDH, netAddr *lnwire.NetAddress,
|
||||||
dialer func(string, string) (net.Conn, error)) (*Conn, error) {
|
timeout time.Duration, dialer tor.DialFunc) (*Conn, error) {
|
||||||
|
|
||||||
ipAddr := netAddr.Address.String()
|
ipAddr := netAddr.Address.String()
|
||||||
var conn net.Conn
|
var conn net.Conn
|
||||||
var err error
|
var err error
|
||||||
conn, err = dialer("tcp", ipAddr)
|
conn, err = dialer("tcp", ipAddr, timeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/btcsuite/btcd/btcec"
|
"github.com/btcsuite/btcd/btcec"
|
||||||
"github.com/lightningnetwork/lnd/keychain"
|
"github.com/lightningnetwork/lnd/keychain"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
|
"github.com/lightningnetwork/lnd/tor"
|
||||||
)
|
)
|
||||||
|
|
||||||
type maybeNetConn struct {
|
type maybeNetConn struct {
|
||||||
@ -66,7 +67,10 @@ func establishTestConnection() (net.Conn, net.Conn, func(), error) {
|
|||||||
// successful.
|
// successful.
|
||||||
remoteConnChan := make(chan maybeNetConn, 1)
|
remoteConnChan := make(chan maybeNetConn, 1)
|
||||||
go func() {
|
go func() {
|
||||||
remoteConn, err := Dial(remoteKeyECDH, netAddr, net.Dial)
|
remoteConn, err := Dial(
|
||||||
|
remoteKeyECDH, netAddr,
|
||||||
|
tor.DefaultConnTimeout, net.DialTimeout,
|
||||||
|
)
|
||||||
remoteConnChan <- maybeNetConn{remoteConn, err}
|
remoteConnChan <- maybeNetConn{remoteConn, err}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@ -196,7 +200,10 @@ func TestConcurrentHandshakes(t *testing.T) {
|
|||||||
remoteKeyECDH := &keychain.PrivKeyECDH{PrivKey: remotePriv}
|
remoteKeyECDH := &keychain.PrivKeyECDH{PrivKey: remotePriv}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
remoteConn, err := Dial(remoteKeyECDH, netAddr, net.Dial)
|
remoteConn, err := Dial(
|
||||||
|
remoteKeyECDH, netAddr,
|
||||||
|
tor.DefaultConnTimeout, net.DialTimeout,
|
||||||
|
)
|
||||||
connChan <- maybeNetConn{remoteConn, err}
|
connChan <- maybeNetConn{remoteConn, err}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
48
tor/net.go
48
tor/net.go
@ -1,19 +1,31 @@
|
|||||||
package tor
|
package tor
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: this interface and its implementations should ideally be moved
|
// TODO: this interface and its implementations should ideally be moved
|
||||||
// elsewhere as they are not Tor-specific.
|
// elsewhere as they are not Tor-specific.
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultConnTimeout is the maximum amount of time a dial will wait for
|
||||||
|
// a connect to complete.
|
||||||
|
DefaultConnTimeout time.Duration = time.Second * 120
|
||||||
|
)
|
||||||
|
|
||||||
|
// DialFunc is a type defines the signature of a dialer used by our Net
|
||||||
|
// interface.
|
||||||
|
type DialFunc func(net, addr string, timeout time.Duration) (net.Conn, error)
|
||||||
|
|
||||||
// Net is an interface housing a Dial function and several DNS functions that
|
// Net is an interface housing a Dial function and several DNS functions that
|
||||||
// allows us to abstract the implementations of these functions over different
|
// allows us to abstract the implementations of these functions over different
|
||||||
// networks, e.g. clearnet, Tor net, etc.
|
// networks, e.g. clearnet, Tor net, etc.
|
||||||
type Net interface {
|
type Net interface {
|
||||||
// Dial connects to the address on the named network.
|
// Dial connects to the address on the named network.
|
||||||
Dial(network, address string) (net.Conn, error)
|
Dial(network, address string, timeout time.Duration) (net.Conn, error)
|
||||||
|
|
||||||
// LookupHost performs DNS resolution on a given host and returns its
|
// LookupHost performs DNS resolution on a given host and returns its
|
||||||
// addresses.
|
// addresses.
|
||||||
@ -21,7 +33,8 @@ type Net interface {
|
|||||||
|
|
||||||
// LookupSRV tries to resolve an SRV query of the given service,
|
// LookupSRV tries to resolve an SRV query of the given service,
|
||||||
// protocol, and domain name.
|
// protocol, and domain name.
|
||||||
LookupSRV(service, proto, name string) (string, []*net.SRV, error)
|
LookupSRV(service, proto, name string,
|
||||||
|
timeout time.Duration) (string, []*net.SRV, error)
|
||||||
|
|
||||||
// ResolveTCPAddr resolves TCP addresses.
|
// ResolveTCPAddr resolves TCP addresses.
|
||||||
ResolveTCPAddr(network, address string) (*net.TCPAddr, error)
|
ResolveTCPAddr(network, address string) (*net.TCPAddr, error)
|
||||||
@ -32,8 +45,10 @@ type Net interface {
|
|||||||
type ClearNet struct{}
|
type ClearNet struct{}
|
||||||
|
|
||||||
// Dial on the regular network uses net.Dial
|
// Dial on the regular network uses net.Dial
|
||||||
func (r *ClearNet) Dial(network, address string) (net.Conn, error) {
|
func (r *ClearNet) Dial(
|
||||||
return net.Dial(network, address)
|
network, address string, timeout time.Duration) (net.Conn, error) {
|
||||||
|
|
||||||
|
return net.DialTimeout(network, address, timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LookupHost for regular network uses the net.LookupHost function
|
// LookupHost for regular network uses the net.LookupHost function
|
||||||
@ -42,8 +57,14 @@ func (r *ClearNet) LookupHost(host string) ([]string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LookupSRV for regular network uses net.LookupSRV function
|
// LookupSRV for regular network uses net.LookupSRV function
|
||||||
func (r *ClearNet) LookupSRV(service, proto, name string) (string, []*net.SRV, error) {
|
func (r *ClearNet) LookupSRV(service, proto, name string,
|
||||||
return net.LookupSRV(service, proto, name)
|
timeout time.Duration) (string, []*net.SRV, error) {
|
||||||
|
|
||||||
|
// Create a context with a timeout value.
|
||||||
|
ctxt, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
return net.DefaultResolver.LookupSRV(ctxt, service, proto, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResolveTCPAddr for regular network uses net.ResolveTCPAddr function
|
// ResolveTCPAddr for regular network uses net.ResolveTCPAddr function
|
||||||
@ -71,13 +92,15 @@ type ProxyNet struct {
|
|||||||
|
|
||||||
// Dial uses the Tor Dial function in order to establish connections through
|
// Dial uses the Tor Dial function in order to establish connections through
|
||||||
// Tor. Since Tor only supports TCP connections, only TCP networks are allowed.
|
// Tor. Since Tor only supports TCP connections, only TCP networks are allowed.
|
||||||
func (p *ProxyNet) Dial(network, address string) (net.Conn, error) {
|
func (p *ProxyNet) Dial(network, address string,
|
||||||
|
timeout time.Duration) (net.Conn, error) {
|
||||||
|
|
||||||
switch network {
|
switch network {
|
||||||
case "tcp", "tcp4", "tcp6":
|
case "tcp", "tcp4", "tcp6":
|
||||||
default:
|
default:
|
||||||
return nil, errors.New("cannot dial non-tcp network via Tor")
|
return nil, errors.New("cannot dial non-tcp network via Tor")
|
||||||
}
|
}
|
||||||
return Dial(address, p.SOCKS, p.StreamIsolation)
|
return Dial(address, p.SOCKS, p.StreamIsolation, timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LookupHost uses the Tor LookupHost function in order to resolve hosts over
|
// LookupHost uses the Tor LookupHost function in order to resolve hosts over
|
||||||
@ -88,8 +111,13 @@ func (p *ProxyNet) LookupHost(host string) ([]string, error) {
|
|||||||
|
|
||||||
// LookupSRV uses the Tor LookupSRV function in order to resolve SRV DNS queries
|
// LookupSRV uses the Tor LookupSRV function in order to resolve SRV DNS queries
|
||||||
// over Tor.
|
// over Tor.
|
||||||
func (p *ProxyNet) LookupSRV(service, proto, name string) (string, []*net.SRV, error) {
|
func (p *ProxyNet) LookupSRV(service, proto,
|
||||||
return LookupSRV(service, proto, name, p.SOCKS, p.DNS, p.StreamIsolation)
|
name string, timeout time.Duration) (string, []*net.SRV, error) {
|
||||||
|
|
||||||
|
return LookupSRV(
|
||||||
|
service, proto, name, p.SOCKS, p.DNS,
|
||||||
|
p.StreamIsolation, timeout,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResolveTCPAddr uses the Tor ResolveTCPAddr function in order to resolve TCP
|
// ResolveTCPAddr uses the Tor ResolveTCPAddr function in order to resolve TCP
|
||||||
|
29
tor/tor.go
29
tor/tor.go
@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/connmgr"
|
"github.com/btcsuite/btcd/connmgr"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@ -54,8 +55,10 @@ func (c *proxyConn) RemoteAddr() net.Addr {
|
|||||||
// Dial is a wrapper over the non-exported dial function that returns a wrapper
|
// Dial is a wrapper over the non-exported dial function that returns a wrapper
|
||||||
// around net.Conn in order to expose the actual remote address we're dialing,
|
// around net.Conn in order to expose the actual remote address we're dialing,
|
||||||
// rather than the proxy's address.
|
// rather than the proxy's address.
|
||||||
func Dial(address, socksAddr string, streamIsolation bool) (net.Conn, error) {
|
func Dial(address, socksAddr string, streamIsolation bool,
|
||||||
conn, err := dial(address, socksAddr, streamIsolation)
|
timeout time.Duration) (net.Conn, error) {
|
||||||
|
|
||||||
|
conn, err := dial(address, socksAddr, streamIsolation, timeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -75,11 +78,13 @@ func Dial(address, socksAddr string, streamIsolation bool) (net.Conn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// dial establishes a connection to the address via Tor's SOCKS proxy. Only TCP
|
// dial establishes a connection to the address via Tor's SOCKS proxy. Only TCP
|
||||||
// is supported over Tor. The final argument determines if we should force
|
// is supported over Tor. The argument streamIsolation determines if we should
|
||||||
// stream isolation for this new connection. If we do, then this means this new
|
// force stream isolation for this new connection. If we do, then this means
|
||||||
// connection will use a fresh circuit, rather than possibly re-using an
|
// this new connection will use a fresh circuit, rather than possibly re-using
|
||||||
// existing circuit.
|
// an existing circuit.
|
||||||
func dial(address, socksAddr string, streamIsolation bool) (net.Conn, error) {
|
func dial(address, socksAddr string, streamIsolation bool,
|
||||||
|
timeout time.Duration) (net.Conn, error) {
|
||||||
|
|
||||||
// If we were requested to force stream isolation for this connection,
|
// If we were requested to force stream isolation for this connection,
|
||||||
// we'll populate the authentication credentials with random data as
|
// we'll populate the authentication credentials with random data as
|
||||||
// Tor will create a new circuit for each set of credentials.
|
// Tor will create a new circuit for each set of credentials.
|
||||||
@ -97,7 +102,8 @@ func dial(address, socksAddr string, streamIsolation bool) (net.Conn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Establish the connection through Tor's SOCKS proxy.
|
// Establish the connection through Tor's SOCKS proxy.
|
||||||
dialer, err := proxy.SOCKS5("tcp", socksAddr, auth, proxy.Direct)
|
proxyDialer := &net.Dialer{Timeout: timeout}
|
||||||
|
dialer, err := proxy.SOCKS5("tcp", socksAddr, auth, proxyDialer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -121,11 +127,12 @@ func LookupHost(host, socksAddr string) ([]string, error) {
|
|||||||
// natively support SRV queries so we must route all SRV queries through the
|
// natively support SRV queries so we must route all SRV queries through the
|
||||||
// proxy by connecting directly to a DNS server and querying it. The DNS server
|
// proxy by connecting directly to a DNS server and querying it. The DNS server
|
||||||
// must have TCP resolution enabled for the given port.
|
// must have TCP resolution enabled for the given port.
|
||||||
func LookupSRV(service, proto, name, socksAddr, dnsServer string,
|
func LookupSRV(service, proto, name, socksAddr,
|
||||||
streamIsolation bool) (string, []*net.SRV, error) {
|
dnsServer string, streamIsolation bool,
|
||||||
|
timeout time.Duration) (string, []*net.SRV, error) {
|
||||||
|
|
||||||
// Connect to the DNS server we'll be using to query SRV records.
|
// Connect to the DNS server we'll be using to query SRV records.
|
||||||
conn, err := dial(dnsServer, socksAddr, streamIsolation)
|
conn, err := dial(dnsServer, socksAddr, streamIsolation, timeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user