brontide+tor:add timeout value for network connections

This commit is contained in:
yyforyongyu 2020-08-25 12:48:32 +08:00
parent 9dcb522ebc
commit fb67b58d3f
No known key found for this signature in database
GPG Key ID: 9BCD95C4FF296868
4 changed files with 68 additions and 25 deletions

View File

@ -10,6 +10,7 @@ import (
"github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/btcec"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tor"
) )
// Conn is an implementation of net.Conn which enforces an authenticated key // Conn is an implementation of net.Conn which enforces an authenticated key
@ -34,12 +35,12 @@ var _ net.Conn = (*Conn)(nil)
// public key. In the case of a handshake failure, the connection is closed and // public key. In the case of a handshake failure, the connection is closed and
// a non-nil error is returned. // a non-nil error is returned.
func Dial(local keychain.SingleKeyECDH, netAddr *lnwire.NetAddress, func Dial(local keychain.SingleKeyECDH, netAddr *lnwire.NetAddress,
dialer func(string, string) (net.Conn, error)) (*Conn, error) { timeout time.Duration, dialer tor.DialFunc) (*Conn, error) {
ipAddr := netAddr.Address.String() ipAddr := netAddr.Address.String()
var conn net.Conn var conn net.Conn
var err error var err error
conn, err = dialer("tcp", ipAddr) conn, err = dialer("tcp", ipAddr, timeout)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -13,6 +13,7 @@ import (
"github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/btcec"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tor"
) )
type maybeNetConn struct { type maybeNetConn struct {
@ -66,7 +67,10 @@ func establishTestConnection() (net.Conn, net.Conn, func(), error) {
// successful. // successful.
remoteConnChan := make(chan maybeNetConn, 1) remoteConnChan := make(chan maybeNetConn, 1)
go func() { go func() {
remoteConn, err := Dial(remoteKeyECDH, netAddr, net.Dial) remoteConn, err := Dial(
remoteKeyECDH, netAddr,
tor.DefaultConnTimeout, net.DialTimeout,
)
remoteConnChan <- maybeNetConn{remoteConn, err} remoteConnChan <- maybeNetConn{remoteConn, err}
}() }()
@ -196,7 +200,10 @@ func TestConcurrentHandshakes(t *testing.T) {
remoteKeyECDH := &keychain.PrivKeyECDH{PrivKey: remotePriv} remoteKeyECDH := &keychain.PrivKeyECDH{PrivKey: remotePriv}
go func() { go func() {
remoteConn, err := Dial(remoteKeyECDH, netAddr, net.Dial) remoteConn, err := Dial(
remoteKeyECDH, netAddr,
tor.DefaultConnTimeout, net.DialTimeout,
)
connChan <- maybeNetConn{remoteConn, err} connChan <- maybeNetConn{remoteConn, err}
}() }()

View File

@ -1,19 +1,31 @@
package tor package tor
import ( import (
"context"
"errors" "errors"
"net" "net"
"time"
) )
// TODO: this interface and its implementations should ideally be moved // TODO: this interface and its implementations should ideally be moved
// elsewhere as they are not Tor-specific. // elsewhere as they are not Tor-specific.
const (
// DefaultConnTimeout is the maximum amount of time a dial will wait for
// a connect to complete.
DefaultConnTimeout time.Duration = time.Second * 120
)
// DialFunc is a type defines the signature of a dialer used by our Net
// interface.
type DialFunc func(net, addr string, timeout time.Duration) (net.Conn, error)
// Net is an interface housing a Dial function and several DNS functions that // Net is an interface housing a Dial function and several DNS functions that
// allows us to abstract the implementations of these functions over different // allows us to abstract the implementations of these functions over different
// networks, e.g. clearnet, Tor net, etc. // networks, e.g. clearnet, Tor net, etc.
type Net interface { type Net interface {
// Dial connects to the address on the named network. // Dial connects to the address on the named network.
Dial(network, address string) (net.Conn, error) Dial(network, address string, timeout time.Duration) (net.Conn, error)
// LookupHost performs DNS resolution on a given host and returns its // LookupHost performs DNS resolution on a given host and returns its
// addresses. // addresses.
@ -21,7 +33,8 @@ type Net interface {
// LookupSRV tries to resolve an SRV query of the given service, // LookupSRV tries to resolve an SRV query of the given service,
// protocol, and domain name. // protocol, and domain name.
LookupSRV(service, proto, name string) (string, []*net.SRV, error) LookupSRV(service, proto, name string,
timeout time.Duration) (string, []*net.SRV, error)
// ResolveTCPAddr resolves TCP addresses. // ResolveTCPAddr resolves TCP addresses.
ResolveTCPAddr(network, address string) (*net.TCPAddr, error) ResolveTCPAddr(network, address string) (*net.TCPAddr, error)
@ -32,8 +45,10 @@ type Net interface {
type ClearNet struct{} type ClearNet struct{}
// Dial on the regular network uses net.Dial // Dial on the regular network uses net.Dial
func (r *ClearNet) Dial(network, address string) (net.Conn, error) { func (r *ClearNet) Dial(
return net.Dial(network, address) network, address string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout(network, address, timeout)
} }
// LookupHost for regular network uses the net.LookupHost function // LookupHost for regular network uses the net.LookupHost function
@ -42,8 +57,14 @@ func (r *ClearNet) LookupHost(host string) ([]string, error) {
} }
// LookupSRV for regular network uses net.LookupSRV function // LookupSRV for regular network uses net.LookupSRV function
func (r *ClearNet) LookupSRV(service, proto, name string) (string, []*net.SRV, error) { func (r *ClearNet) LookupSRV(service, proto, name string,
return net.LookupSRV(service, proto, name) timeout time.Duration) (string, []*net.SRV, error) {
// Create a context with a timeout value.
ctxt, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return net.DefaultResolver.LookupSRV(ctxt, service, proto, name)
} }
// ResolveTCPAddr for regular network uses net.ResolveTCPAddr function // ResolveTCPAddr for regular network uses net.ResolveTCPAddr function
@ -71,13 +92,15 @@ type ProxyNet struct {
// Dial uses the Tor Dial function in order to establish connections through // Dial uses the Tor Dial function in order to establish connections through
// Tor. Since Tor only supports TCP connections, only TCP networks are allowed. // Tor. Since Tor only supports TCP connections, only TCP networks are allowed.
func (p *ProxyNet) Dial(network, address string) (net.Conn, error) { func (p *ProxyNet) Dial(network, address string,
timeout time.Duration) (net.Conn, error) {
switch network { switch network {
case "tcp", "tcp4", "tcp6": case "tcp", "tcp4", "tcp6":
default: default:
return nil, errors.New("cannot dial non-tcp network via Tor") return nil, errors.New("cannot dial non-tcp network via Tor")
} }
return Dial(address, p.SOCKS, p.StreamIsolation) return Dial(address, p.SOCKS, p.StreamIsolation, timeout)
} }
// LookupHost uses the Tor LookupHost function in order to resolve hosts over // LookupHost uses the Tor LookupHost function in order to resolve hosts over
@ -88,8 +111,13 @@ func (p *ProxyNet) LookupHost(host string) ([]string, error) {
// LookupSRV uses the Tor LookupSRV function in order to resolve SRV DNS queries // LookupSRV uses the Tor LookupSRV function in order to resolve SRV DNS queries
// over Tor. // over Tor.
func (p *ProxyNet) LookupSRV(service, proto, name string) (string, []*net.SRV, error) { func (p *ProxyNet) LookupSRV(service, proto,
return LookupSRV(service, proto, name, p.SOCKS, p.DNS, p.StreamIsolation) name string, timeout time.Duration) (string, []*net.SRV, error) {
return LookupSRV(
service, proto, name, p.SOCKS, p.DNS,
p.StreamIsolation, timeout,
)
} }
// ResolveTCPAddr uses the Tor ResolveTCPAddr function in order to resolve TCP // ResolveTCPAddr uses the Tor ResolveTCPAddr function in order to resolve TCP

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
"time"
"github.com/btcsuite/btcd/connmgr" "github.com/btcsuite/btcd/connmgr"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -54,8 +55,10 @@ func (c *proxyConn) RemoteAddr() net.Addr {
// Dial is a wrapper over the non-exported dial function that returns a wrapper // Dial is a wrapper over the non-exported dial function that returns a wrapper
// around net.Conn in order to expose the actual remote address we're dialing, // around net.Conn in order to expose the actual remote address we're dialing,
// rather than the proxy's address. // rather than the proxy's address.
func Dial(address, socksAddr string, streamIsolation bool) (net.Conn, error) { func Dial(address, socksAddr string, streamIsolation bool,
conn, err := dial(address, socksAddr, streamIsolation) timeout time.Duration) (net.Conn, error) {
conn, err := dial(address, socksAddr, streamIsolation, timeout)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -75,11 +78,13 @@ func Dial(address, socksAddr string, streamIsolation bool) (net.Conn, error) {
} }
// dial establishes a connection to the address via Tor's SOCKS proxy. Only TCP // dial establishes a connection to the address via Tor's SOCKS proxy. Only TCP
// is supported over Tor. The final argument determines if we should force // is supported over Tor. The argument streamIsolation determines if we should
// stream isolation for this new connection. If we do, then this means this new // force stream isolation for this new connection. If we do, then this means
// connection will use a fresh circuit, rather than possibly re-using an // this new connection will use a fresh circuit, rather than possibly re-using
// existing circuit. // an existing circuit.
func dial(address, socksAddr string, streamIsolation bool) (net.Conn, error) { func dial(address, socksAddr string, streamIsolation bool,
timeout time.Duration) (net.Conn, error) {
// If we were requested to force stream isolation for this connection, // If we were requested to force stream isolation for this connection,
// we'll populate the authentication credentials with random data as // we'll populate the authentication credentials with random data as
// Tor will create a new circuit for each set of credentials. // Tor will create a new circuit for each set of credentials.
@ -97,7 +102,8 @@ func dial(address, socksAddr string, streamIsolation bool) (net.Conn, error) {
} }
// Establish the connection through Tor's SOCKS proxy. // Establish the connection through Tor's SOCKS proxy.
dialer, err := proxy.SOCKS5("tcp", socksAddr, auth, proxy.Direct) proxyDialer := &net.Dialer{Timeout: timeout}
dialer, err := proxy.SOCKS5("tcp", socksAddr, auth, proxyDialer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -121,11 +127,12 @@ func LookupHost(host, socksAddr string) ([]string, error) {
// natively support SRV queries so we must route all SRV queries through the // natively support SRV queries so we must route all SRV queries through the
// proxy by connecting directly to a DNS server and querying it. The DNS server // proxy by connecting directly to a DNS server and querying it. The DNS server
// must have TCP resolution enabled for the given port. // must have TCP resolution enabled for the given port.
func LookupSRV(service, proto, name, socksAddr, dnsServer string, func LookupSRV(service, proto, name, socksAddr,
streamIsolation bool) (string, []*net.SRV, error) { dnsServer string, streamIsolation bool,
timeout time.Duration) (string, []*net.SRV, error) {
// Connect to the DNS server we'll be using to query SRV records. // Connect to the DNS server we'll be using to query SRV records.
conn, err := dial(dnsServer, socksAddr, streamIsolation) conn, err := dial(dnsServer, socksAddr, streamIsolation, timeout)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }