package brontide

import (
	"bytes"
	"encoding/hex"
	"fmt"
	"io"
	"math"
	"net"
	"testing"
	"testing/iotest"

	"github.com/btcsuite/btcd/btcec"
	"github.com/lightningnetwork/lnd/keychain"
	"github.com/lightningnetwork/lnd/lnwire"
	"github.com/lightningnetwork/lnd/tor"
)

type maybeNetConn struct {
	conn net.Conn
	err  error
}

func makeListener() (*Listener, *lnwire.NetAddress, error) {
	// First, generate the long-term private keys for the brontide listener.
	localPriv, err := btcec.NewPrivateKey(btcec.S256())
	if err != nil {
		return nil, nil, err
	}
	localKeyECDH := &keychain.PrivKeyECDH{PrivKey: localPriv}

	// Having a port of ":0" means a random port, and interface will be
	// chosen for our listener.
	addr := "localhost:0"

	// Our listener will be local, and the connection remote.
	listener, err := NewListener(localKeyECDH, addr)
	if err != nil {
		return nil, nil, err
	}

	netAddr := &lnwire.NetAddress{
		IdentityKey: localPriv.PubKey(),
		Address:     listener.Addr().(*net.TCPAddr),
	}

	return listener, netAddr, nil
}

func establishTestConnection() (net.Conn, net.Conn, func(), error) {
	listener, netAddr, err := makeListener()
	if err != nil {
		return nil, nil, nil, err
	}
	defer listener.Close()

	// Nos, generate the long-term private keys remote end of the connection
	// within our test.
	remotePriv, err := btcec.NewPrivateKey(btcec.S256())
	if err != nil {
		return nil, nil, nil, err
	}
	remoteKeyECDH := &keychain.PrivKeyECDH{PrivKey: remotePriv}

	// Initiate a connection with a separate goroutine, and listen with our
	// main one. If both errors are nil, then encryption+auth was
	// successful.
	remoteConnChan := make(chan maybeNetConn, 1)
	go func() {
		remoteConn, err := Dial(
			remoteKeyECDH, netAddr,
			tor.DefaultConnTimeout, net.DialTimeout,
		)
		remoteConnChan <- maybeNetConn{remoteConn, err}
	}()

	localConnChan := make(chan maybeNetConn, 1)
	go func() {
		localConn, err := listener.Accept()
		localConnChan <- maybeNetConn{localConn, err}
	}()

	remote := <-remoteConnChan
	if remote.err != nil {
		return nil, nil, nil, err
	}

	local := <-localConnChan
	if local.err != nil {
		return nil, nil, nil, err
	}

	cleanUp := func() {
		local.conn.Close()
		remote.conn.Close()
	}

	return local.conn, remote.conn, cleanUp, nil
}

func TestConnectionCorrectness(t *testing.T) {
	// Create a test connection, grabbing either side of the connection
	// into local variables. If the initial crypto handshake fails, then
	// we'll get a non-nil error here.
	localConn, remoteConn, cleanUp, err := establishTestConnection()
	if err != nil {
		t.Fatalf("unable to establish test connection: %v", err)
	}
	defer cleanUp()

	// Test out some message full-message reads.
	for i := 0; i < 10; i++ {
		msg := []byte(fmt.Sprintf("hello%d", i))

		if _, err := localConn.Write(msg); err != nil {
			t.Fatalf("remote conn failed to write: %v", err)
		}

		readBuf := make([]byte, len(msg))
		if _, err := remoteConn.Read(readBuf); err != nil {
			t.Fatalf("local conn failed to read: %v", err)
		}

		if !bytes.Equal(readBuf, msg) {
			t.Fatalf("messages don't match, %v vs %v",
				string(readBuf), string(msg))
		}
	}

	// Now try incremental message reads. This simulates first writing a
	// message header, then a message body.
	outMsg := []byte("hello world")
	if _, err := localConn.Write(outMsg); err != nil {
		t.Fatalf("remote conn failed to write: %v", err)
	}

	readBuf := make([]byte, len(outMsg))
	if _, err := remoteConn.Read(readBuf[:len(outMsg)/2]); err != nil {
		t.Fatalf("local conn failed to read: %v", err)
	}
	if _, err := remoteConn.Read(readBuf[len(outMsg)/2:]); err != nil {
		t.Fatalf("local conn failed to read: %v", err)
	}

	if !bytes.Equal(outMsg, readBuf) {
		t.Fatalf("messages don't match, %v vs %v",
			string(readBuf), string(outMsg))
	}
}

// TestConecurrentHandshakes verifies the listener's ability to not be blocked
// by other pending handshakes. This is tested by opening multiple tcp
// connections with the listener, without completing any of the brontide acts.
// The test passes if real brontide dialer connects while the others are
// stalled.
func TestConcurrentHandshakes(t *testing.T) {
	listener, netAddr, err := makeListener()
	if err != nil {
		t.Fatalf("unable to create listener connection: %v", err)
	}
	defer listener.Close()

	const nblocking = 5

	// Open a handful of tcp connections, that do not complete any steps of
	// the brontide handshake.
	connChan := make(chan maybeNetConn)
	for i := 0; i < nblocking; i++ {
		go func() {
			conn, err := net.Dial("tcp", listener.Addr().String())
			connChan <- maybeNetConn{conn, err}
		}()
	}

	// Receive all connections/errors from our blocking tcp dials. We make a
	// pass to gather all connections and errors to make sure we defer the
	// calls to Close() on all successful connections.
	tcpErrs := make([]error, 0, nblocking)
	for i := 0; i < nblocking; i++ {
		result := <-connChan
		if result.conn != nil {
			defer result.conn.Close()
		}
		if result.err != nil {
			tcpErrs = append(tcpErrs, result.err)
		}
	}
	for _, tcpErr := range tcpErrs {
		if tcpErr != nil {
			t.Fatalf("unable to tcp dial listener: %v", tcpErr)
		}
	}

	// Now, construct a new private key and use the brontide dialer to
	// connect to the listener.
	remotePriv, err := btcec.NewPrivateKey(btcec.S256())
	if err != nil {
		t.Fatalf("unable to generate private key: %v", err)
	}
	remoteKeyECDH := &keychain.PrivKeyECDH{PrivKey: remotePriv}

	go func() {
		remoteConn, err := Dial(
			remoteKeyECDH, netAddr,
			tor.DefaultConnTimeout, net.DialTimeout,
		)
		connChan <- maybeNetConn{remoteConn, err}
	}()

	// This connection should be accepted without error, as the brontide
	// connection should bypass stalled tcp connections.
	conn, err := listener.Accept()
	if err != nil {
		t.Fatalf("unable to accept dial: %v", err)
	}
	defer conn.Close()

	result := <-connChan
	if result.err != nil {
		t.Fatalf("unable to dial %v: %v", netAddr, result.err)
	}
	result.conn.Close()
}

func TestMaxPayloadLength(t *testing.T) {
	t.Parallel()

	b := Machine{}
	b.split()

	// Create a payload that's only *slightly* above the maximum allotted
	// payload length.
	payloadToReject := make([]byte, math.MaxUint16+1)

	// A write of the payload generated above to the state machine should
	// be rejected as it's over the max payload length.
	err := b.WriteMessage(payloadToReject)
	if err != ErrMaxMessageLengthExceeded {
		t.Fatalf("payload is over the max allowed length, the write " +
			"should have been rejected")
	}

	// Generate another payload which should be accepted as a valid
	// payload.
	payloadToAccept := make([]byte, math.MaxUint16-1)
	if err := b.WriteMessage(payloadToAccept); err != nil {
		t.Fatalf("write for payload was rejected, should have been " +
			"accepted")
	}

	// Generate a final payload which is only *slightly* above the max payload length
	// when the MAC is accounted for.
	payloadToReject = make([]byte, math.MaxUint16+1)

	// This payload should be rejected.
	err = b.WriteMessage(payloadToReject)
	if err != ErrMaxMessageLengthExceeded {
		t.Fatalf("payload is over the max allowed length, the write " +
			"should have been rejected")
	}
}

func TestWriteMessageChunking(t *testing.T) {
	// Create a test connection, grabbing either side of the connection
	// into local variables. If the initial crypto handshake fails, then
	// we'll get a non-nil error here.
	localConn, remoteConn, cleanUp, err := establishTestConnection()
	if err != nil {
		t.Fatalf("unable to establish test connection: %v", err)
	}
	defer cleanUp()

	// Attempt to write a message which is over 3x the max allowed payload
	// size.
	largeMessage := bytes.Repeat([]byte("kek"), math.MaxUint16*3)

	// Launch a new goroutine to write the large message generated above in
	// chunks. We spawn a new goroutine because otherwise, we may block as
	// the kernel waits for the buffer to flush.
	errCh := make(chan error)
	go func() {
		defer close(errCh)

		bytesWritten, err := localConn.Write(largeMessage)
		if err != nil {
			errCh <- fmt.Errorf("unable to write message: %v", err)
			return
		}

		// The entire message should have been written out to the remote
		// connection.
		if bytesWritten != len(largeMessage) {
			errCh <- fmt.Errorf("bytes not fully written")
			return
		}
	}()

	// Attempt to read the entirety of the message generated above.
	buf := make([]byte, len(largeMessage))
	if _, err := io.ReadFull(remoteConn, buf); err != nil {
		t.Fatalf("unable to read message: %v", err)
	}

	err = <-errCh
	if err != nil {
		t.Fatal(err)
	}

	// Finally, the message the remote end of the connection received
	// should be identical to what we sent from the local connection.
	if !bytes.Equal(buf, largeMessage) {
		t.Fatalf("bytes don't match")
	}
}

// TestBolt0008TestVectors ensures that our implementation of brontide exactly
// matches the test vectors within the specification.
func TestBolt0008TestVectors(t *testing.T) {
	t.Parallel()

	// First, we'll generate the state of the initiator from the test
	// vectors at the appendix of BOLT-0008
	initiatorKeyBytes, err := hex.DecodeString("1111111111111111111111" +
		"111111111111111111111111111111111111111111")
	if err != nil {
		t.Fatalf("unable to decode hex: %v", err)
	}
	initiatorPriv, _ := btcec.PrivKeyFromBytes(
		btcec.S256(), initiatorKeyBytes,
	)
	initiatorKeyECDH := &keychain.PrivKeyECDH{PrivKey: initiatorPriv}

	// We'll then do the same for the responder.
	responderKeyBytes, err := hex.DecodeString("212121212121212121212121" +
		"2121212121212121212121212121212121212121")
	if err != nil {
		t.Fatalf("unable to decode hex: %v", err)
	}
	responderPriv, responderPub := btcec.PrivKeyFromBytes(
		btcec.S256(), responderKeyBytes,
	)
	responderKeyECDH := &keychain.PrivKeyECDH{PrivKey: responderPriv}

	// With the initiator's key data parsed, we'll now define a custom
	// EphemeralGenerator function for the state machine to ensure that the
	// initiator and responder both generate the ephemeral public key
	// defined within the test vectors.
	initiatorEphemeral := EphemeralGenerator(func() (*btcec.PrivateKey, error) {
		e := "121212121212121212121212121212121212121212121212121212" +
			"1212121212"
		eBytes, err := hex.DecodeString(e)
		if err != nil {
			return nil, err
		}

		priv, _ := btcec.PrivKeyFromBytes(btcec.S256(), eBytes)
		return priv, nil
	})
	responderEphemeral := EphemeralGenerator(func() (*btcec.PrivateKey, error) {
		e := "222222222222222222222222222222222222222222222222222" +
			"2222222222222"
		eBytes, err := hex.DecodeString(e)
		if err != nil {
			return nil, err
		}

		priv, _ := btcec.PrivKeyFromBytes(btcec.S256(), eBytes)
		return priv, nil
	})

	// Finally, we'll create both brontide state machines, so we can begin
	// our test.
	initiator := NewBrontideMachine(
		true, initiatorKeyECDH, responderPub, initiatorEphemeral,
	)
	responder := NewBrontideMachine(
		false, responderKeyECDH, nil, responderEphemeral,
	)

	// We'll start with the initiator generating the initial payload for
	// act one. This should consist of exactly 50 bytes. We'll assert that
	// the payload return is _exactly_ the same as what's specified within
	// the test vectors.
	actOne, err := initiator.GenActOne()
	if err != nil {
		t.Fatalf("unable to generate act one: %v", err)
	}
	expectedActOne, err := hex.DecodeString("00036360e856310ce5d294e" +
		"8be33fc807077dc56ac80d95d9cd4ddbd21325eff73f70df608655115" +
		"1f58b8afe6c195782c6a")
	if err != nil {
		t.Fatalf("unable to parse expected act one: %v", err)
	}
	if !bytes.Equal(expectedActOne, actOne[:]) {
		t.Fatalf("act one mismatch: expected %x, got %x",
			expectedActOne, actOne)
	}

	// With the assertion above passed, we'll now process the act one
	// payload with the responder of the crypto handshake.
	if err := responder.RecvActOne(actOne); err != nil {
		t.Fatalf("responder unable to process act one: %v", err)
	}

	// Next, we'll start the second act by having the responder generate
	// its contribution to the crypto handshake. We'll also verify that we
	// produce the _exact_ same byte stream as advertised within the spec's
	// test vectors.
	actTwo, err := responder.GenActTwo()
	if err != nil {
		t.Fatalf("unable to generate act two: %v", err)
	}
	expectedActTwo, err := hex.DecodeString("0002466d7fcae563e5cb09a0" +
		"d1870bb580344804617879a14949cf22285f1bae3f276e2470b93aac58" +
		"3c9ef6eafca3f730ae")
	if err != nil {
		t.Fatalf("unable to parse expected act two: %v", err)
	}
	if !bytes.Equal(expectedActTwo, actTwo[:]) {
		t.Fatalf("act two mismatch: expected %x, got %x",
			expectedActTwo, actTwo)
	}

	// Moving the handshake along, we'll also ensure that the initiator
	// accepts the act two payload.
	if err := initiator.RecvActTwo(actTwo); err != nil {
		t.Fatalf("initiator unable to process act two: %v", err)
	}

	// At the final step, we'll generate the last act from the initiator
	// and once again verify that it properly matches the test vectors.
	actThree, err := initiator.GenActThree()
	if err != nil {
		t.Fatalf("unable to generate act three: %v", err)
	}
	expectedActThree, err := hex.DecodeString("00b9e3a702e93e3a9948c2e" +
		"d6e5fd7590a6e1c3a0344cfc9d5b57357049aa22355361aa02e55a8f" +
		"c28fef5bd6d71ad0c38228dc68b1c466263b47fdf31e560e139ba")
	if err != nil {
		t.Fatalf("unable to parse expected act three: %v", err)
	}
	if !bytes.Equal(expectedActThree, actThree[:]) {
		t.Fatalf("act three mismatch: expected %x, got %x",
			expectedActThree, actThree)
	}

	// Finally, we'll ensure that the responder itself also properly parses
	// the last payload in the crypto handshake.
	if err := responder.RecvActThree(actThree); err != nil {
		t.Fatalf("responder unable to process act three: %v", err)
	}

	// As a final assertion, we'll ensure that both sides have derived the
	// proper symmetric encryption keys.
	sendingKey, err := hex.DecodeString("969ab31b4d288cedf6218839b27a3e2" +
		"140827047f2c0f01bf5c04435d43511a9")
	if err != nil {
		t.Fatalf("unable to parse sending key: %v", err)
	}
	recvKey, err := hex.DecodeString("bb9020b8965f4df047e07f955f3c4b884" +
		"18984aadc5cdb35096b9ea8fa5c3442")
	if err != nil {
		t.Fatalf("unable to parse receiving key: %v", err)
	}

	chainKey, err := hex.DecodeString("919219dbb2920afa8db80f9a51787a840" +
		"bcf111ed8d588caf9ab4be716e42b01")
	if err != nil {
		t.Fatalf("unable to parse chaining key: %v", err)
	}

	if !bytes.Equal(initiator.sendCipher.secretKey[:], sendingKey) {
		t.Fatalf("sending key mismatch: expected %x, got %x",
			initiator.sendCipher.secretKey[:], sendingKey)
	}
	if !bytes.Equal(initiator.recvCipher.secretKey[:], recvKey) {
		t.Fatalf("receiving key mismatch: expected %x, got %x",
			initiator.recvCipher.secretKey[:], recvKey)
	}
	if !bytes.Equal(initiator.chainingKey[:], chainKey) {
		t.Fatalf("chaining key mismatch: expected %x, got %x",
			initiator.chainingKey[:], chainKey)
	}

	if !bytes.Equal(responder.sendCipher.secretKey[:], recvKey) {
		t.Fatalf("sending key mismatch: expected %x, got %x",
			responder.sendCipher.secretKey[:], recvKey)
	}
	if !bytes.Equal(responder.recvCipher.secretKey[:], sendingKey) {
		t.Fatalf("receiving key mismatch: expected %x, got %x",
			responder.recvCipher.secretKey[:], sendingKey)
	}
	if !bytes.Equal(responder.chainingKey[:], chainKey) {
		t.Fatalf("chaining key mismatch: expected %x, got %x",
			responder.chainingKey[:], chainKey)
	}

	// Now test as per section "transport-message test" in Test Vectors
	// (the transportMessageVectors ciphertexts are from this section of BOLT 8);
	// we do slightly greater than 1000 encryption/decryption operations
	// to ensure that the key rotation algorithm is operating as expected.
	// The starting point for enc/decr is already guaranteed correct from the
	// above tests of sendingKey, receivingKey, chainingKey.
	transportMessageVectors := map[int]string{
		0: "cf2b30ddf0cf3f80e7c35a6e6730b59fe802473180f396d88a8fb0db8cb" +
			"cf25d2f214cf9ea1d95",
		1: "72887022101f0b6753e0c7de21657d35a4cb2a1f5cde2650528bbc8f837" +
			"d0f0d7ad833b1a256a1",
		500: "178cb9d7387190fa34db9c2d50027d21793c9bc2d40b1e14dcf30ebeeeb2" +
			"20f48364f7a4c68bf8",
		501: "1b186c57d44eb6de4c057c49940d79bb838a145cb528d6e8fd26dbe50a6" +
			"0ca2c104b56b60e45bd",
		1000: "4a2f3cc3b5e78ddb83dcb426d9863d9d9a723b0337c89dd0b005d89f8d3" +
			"c05c52b76b29b740f09",
		1001: "2ecd8c8a5629d0d02ab457a0fdd0f7b90a192cd46be5ecb6ca570bfc5e2" +
			"68338b1a16cf4ef2d36",
	}

	// Payload for every message is the string "hello".
	payload := []byte("hello")

	var buf bytes.Buffer

	for i := 0; i < 1002; i++ {
		err = initiator.WriteMessage(payload)
		if err != nil {
			t.Fatalf("could not write message %s", payload)
		}
		_, err = initiator.Flush(&buf)
		if err != nil {
			t.Fatalf("could not flush message: %v", err)
		}
		if val, ok := transportMessageVectors[i]; ok {
			binaryVal, err := hex.DecodeString(val)
			if err != nil {
				t.Fatalf("Failed to decode hex string %s", val)
			}
			if !bytes.Equal(buf.Bytes(), binaryVal) {
				t.Fatalf("Ciphertext %x was not equal to expected %s",
					buf.String()[:], val)
			}
		}

		// Responder decrypts the bytes, in every iteration, and
		// should always be able to decrypt the same payload message.
		plaintext, err := responder.ReadMessage(&buf)
		if err != nil {
			t.Fatalf("failed to read message in responder: %v", err)
		}

		// Ensure decryption succeeded
		if !bytes.Equal(plaintext, payload) {
			t.Fatalf("Decryption failed to receive plaintext: %s, got %s",
				payload, plaintext)
		}

		// Clear out the buffer for the next iteration
		buf.Reset()
	}
}

// timeoutWriter wraps an io.Writer and throws an iotest.ErrTimeout after
// writing n bytes.
type timeoutWriter struct {
	w io.Writer
	n int64
}

func NewTimeoutWriter(w io.Writer, n int64) io.Writer {
	return &timeoutWriter{w, n}
}

func (t *timeoutWriter) Write(p []byte) (int, error) {
	n := len(p)
	if int64(n) > t.n {
		n = int(t.n)
	}
	n, err := t.w.Write(p[:n])
	t.n -= int64(n)
	if err == nil && t.n == 0 {
		return n, iotest.ErrTimeout
	}
	return n, err
}

const payloadSize = 10

type flushChunk struct {
	errAfter int64
	expN     int
	expErr   error
}

type flushTest struct {
	name   string
	chunks []flushChunk
}

var flushTests = []flushTest{
	{
		name: "partial header write",
		chunks: []flushChunk{
			// Write 18-byte header in two parts, 16 then 2.
			{
				errAfter: encHeaderSize - 2,
				expN:     0,
				expErr:   iotest.ErrTimeout,
			},
			{
				errAfter: 2,
				expN:     0,
				expErr:   iotest.ErrTimeout,
			},
			// Write payload and MAC in one go.
			{
				errAfter: -1,
				expN:     payloadSize,
			},
		},
	},
	{
		name: "full payload then full mac",
		chunks: []flushChunk{
			// Write entire header and entire payload w/o MAC.
			{
				errAfter: encHeaderSize + payloadSize,
				expN:     payloadSize,
				expErr:   iotest.ErrTimeout,
			},
			// Write the entire MAC.
			{
				errAfter: -1,
				expN:     0,
			},
		},
	},
	{
		name: "payload-only, straddle, mac-only",
		chunks: []flushChunk{
			// Write header and all but last byte of payload.
			{
				errAfter: encHeaderSize + payloadSize - 1,
				expN:     payloadSize - 1,
				expErr:   iotest.ErrTimeout,
			},
			// Write last byte of payload and first byte of MAC.
			{
				errAfter: 2,
				expN:     1,
				expErr:   iotest.ErrTimeout,
			},
			// Write 10 bytes of the MAC.
			{
				errAfter: 10,
				expN:     0,
				expErr:   iotest.ErrTimeout,
			},
			// Write the remaining 5 MAC bytes.
			{
				errAfter: -1,
				expN:     0,
			},
		},
	},
}

// TestFlush asserts a Machine's ability to handle timeouts during Flush that
// cause partial writes, and that the machine can properly resume writes on
// subsequent calls to Flush.
func TestFlush(t *testing.T) {
	// Run each test individually, to assert that they pass in isolation.
	for _, test := range flushTests {
		t.Run(test.name, func(t *testing.T) {
			var (
				w bytes.Buffer
				b Machine
			)
			b.split()
			testFlush(t, test, &b, &w)
		})
	}

	// Finally, run the tests serially as if all on one connection.
	t.Run("flush serial", func(t *testing.T) {
		var (
			w bytes.Buffer
			b Machine
		)
		b.split()
		for _, test := range flushTests {
			testFlush(t, test, &b, &w)
		}
	})
}

// testFlush buffers a message on the Machine, then flushes it to the io.Writer
// in chunks. Once complete, a final call to flush is made to assert that Write
// is not called again.
func testFlush(t *testing.T, test flushTest, b *Machine, w io.Writer) {
	payload := make([]byte, payloadSize)
	if err := b.WriteMessage(payload); err != nil {
		t.Fatalf("unable to write message: %v", err)
	}

	for _, chunk := range test.chunks {
		assertFlush(t, b, w, chunk.errAfter, chunk.expN, chunk.expErr)
	}

	// We should always be able to call Flush after a message has been
	// successfully written, and it should result in a NOP.
	assertFlush(t, b, w, 0, 0, nil)
}

// assertFlush flushes a chunk to the passed io.Writer. If n >= 0, a
// timeoutWriter will be used the flush should stop with iotest.ErrTimeout after
// n bytes. The method asserts that the returned error matches expErr and that
// the number of bytes written by Flush matches expN.
func assertFlush(t *testing.T, b *Machine, w io.Writer, n int64, expN int,
	expErr error) {

	t.Helper()

	if n >= 0 {
		w = NewTimeoutWriter(w, n)
	}
	nn, err := b.Flush(w)
	if err != expErr {
		t.Fatalf("expected flush err: %v, got: %v", expErr, err)
	}
	if nn != expN {
		t.Fatalf("expected n: %d, got: %d", expN, nn)
	}
}