From c22b46d4624ef610ff09a2174b203506006ea9c5 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Sun, 1 Apr 2018 07:00:57 -0700 Subject: [PATCH] brontide/noise_test: test parallel handshakes --- brontide/noise_test.go | 152 ++++++++++++++++++++++++++++++----------- 1 file changed, 114 insertions(+), 38 deletions(-) diff --git a/brontide/noise_test.go b/brontide/noise_test.go index 5b0562c2..415659ca 100644 --- a/brontide/noise_test.go +++ b/brontide/noise_test.go @@ -13,16 +13,16 @@ import ( "github.com/roasbeef/btcd/btcec" ) -func establishTestConnection() (net.Conn, net.Conn, func(), error) { - // First, generate the long-term private keys both ends of the - // connection within our test. +type maybeNetConn struct { + conn net.Conn + err error +} + +func makeListener() (*Listener, *lnwire.NetAddress, error) { + // First, generate the long-term private keys for the brontide listener. localPriv, err := btcec.NewPrivateKey(btcec.S256()) if err != nil { - return nil, nil, nil, err - } - remotePriv, err := btcec.NewPrivateKey(btcec.S256()) - if err != nil { - return nil, nil, nil, err + return nil, nil, err } // Having a port of ":0" means a random port, and interface will be @@ -32,56 +32,62 @@ func establishTestConnection() (net.Conn, net.Conn, func(), error) { // Our listener will be local, and the connection remote. listener, err := NewListener(localPriv, addr) if err != nil { - return nil, nil, nil, err + return nil, nil, err } - defer listener.Close() netAddr := &lnwire.NetAddress{ IdentityKey: localPriv.PubKey(), Address: listener.Addr().(*net.TCPAddr), } + return listener, netAddr, nil +} + +func establishTestConnection() (net.Conn, net.Conn, func(), error) { + listener, netAddr, err := makeListener() + if err != nil { + return nil, nil, nil, err + } + defer listener.Close() + + // Nos, generate the long-term private keys remote end of the connection + // within our test. + remotePriv, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + return nil, nil, nil, err + } + // Initiate a connection with a separate goroutine, and listen with our // main one. If both errors are nil, then encryption+auth was // successful. - conErrChan := make(chan error, 1) - connChan := make(chan net.Conn, 1) + remoteConnChan := make(chan maybeNetConn, 1) go func() { - conn, err := Dial(remotePriv, netAddr, net.Dial) - - conErrChan <- err - connChan <- conn + remoteConn, err := Dial(remotePriv, netAddr, net.Dial) + remoteConnChan <- maybeNetConn{remoteConn, err} }() - lisErrChan := make(chan error, 1) - lisChan := make(chan net.Conn, 1) + localConnChan := make(chan maybeNetConn, 1) go func() { - localConn, listenErr := listener.Accept() - - lisErrChan <- listenErr - lisChan <- localConn + localConn, err := listener.Accept() + localConnChan <- maybeNetConn{localConn, err} }() - select { - case err := <-conErrChan: - if err != nil { - return nil, nil, nil, err - } - case err := <-lisErrChan: - if err != nil { - return nil, nil, nil, err - } + remote := <-remoteConnChan + if remote.err != nil { + return nil, nil, nil, err } - localConn := <-lisChan - remoteConn := <-connChan + local := <-localConnChan + if local.err != nil { + return nil, nil, nil, err + } cleanUp := func() { - localConn.Close() - remoteConn.Close() + local.conn.Close() + remote.conn.Close() } - return localConn, remoteConn, cleanUp, nil + return local.conn, remote.conn, cleanUp, nil } func TestConnectionCorrectness(t *testing.T) { @@ -134,14 +140,84 @@ func TestConnectionCorrectness(t *testing.T) { } } +// TestConecurrentHandshakes verifies the listener's ability to not be blocked +// by other pending handshakes. This is tested by opening multiple tcp +// connections with the listener, without completing any of the brontide acts. +// The test passes if real brontide dialer connects while the others are +// stalled. +func TestConcurrentHandshakes(t *testing.T) { + listener, netAddr, err := makeListener() + if err != nil { + t.Fatalf("unable to create listener connection: %v", err) + } + defer listener.Close() + + const nblocking = 5 + + // Open a handful of tcp connections, that do not complete any steps of + // the brontide handshake. + connChan := make(chan maybeNetConn) + for i := 0; i < nblocking; i++ { + go func() { + conn, err := net.Dial("tcp", listener.Addr().String()) + connChan <- maybeNetConn{conn, err} + }() + } + + // Receive all connections/errors from our blocking tcp dials. We make a + // pass to gather all connections and errors to make sure we defer the + // calls to Close() on all successful connections. + tcpErrs := make([]error, 0, nblocking) + for i := 0; i < nblocking; i++ { + result := <-connChan + if result.conn != nil { + defer result.conn.Close() + } + if result.err != nil { + tcpErrs = append(tcpErrs, result.err) + } + } + for _, tcpErr := range tcpErrs { + if tcpErr != nil { + t.Fatalf("unable to tcp dial listener: %v", tcpErr) + } + } + + // Now, construct a new private key and use the brontide dialer to + // connect to the listener. + remotePriv, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + t.Fatalf("unable to generate private key: %v", err) + } + + go func() { + remoteConn, err := Dial(remotePriv, netAddr, net.Dial) + connChan <- maybeNetConn{remoteConn, err} + }() + + // This connection should be accepted without error, as the brontide + // connection should bypass stalled tcp connections. + conn, err := listener.Accept() + if err != nil { + t.Fatalf("unable to accept dial: %v", err) + } + defer conn.Close() + + result := <-connChan + if result.err != nil { + t.Fatalf("unable to dial %v: %v", netAddr, result.err) + } + result.conn.Close() +} + func TestMaxPayloadLength(t *testing.T) { t.Parallel() b := Machine{} b.split() - // Create a payload that's only *slightly* above the maximum allotted payload - // length. + // Create a payload that's only *slightly* above the maximum allotted + // payload length. payloadToReject := make([]byte, math.MaxUint16+1) var buf bytes.Buffer