package wtserver_test

import (
	"bytes"
	"reflect"
	"testing"
	"time"

	"github.com/btcsuite/btcd/btcec"
	"github.com/btcsuite/btcd/chaincfg"
	"github.com/btcsuite/btcd/txscript"
	"github.com/btcsuite/btcutil"
	"github.com/lightningnetwork/lnd/lnwire"
	"github.com/lightningnetwork/lnd/watchtower/blob"
	"github.com/lightningnetwork/lnd/watchtower/wtdb"
	"github.com/lightningnetwork/lnd/watchtower/wtmock"
	"github.com/lightningnetwork/lnd/watchtower/wtserver"
	"github.com/lightningnetwork/lnd/watchtower/wtwire"
)

var (
	// addr is the server's reward address given to watchtower clients.
	addr, _ = btcutil.DecodeAddress(
		"mrX9vMRYLfVy1BnZbc5gZjuyaqH3ZW2ZHz", &chaincfg.TestNet3Params,
	)

	addrScript, _ = txscript.PayToAddrScript(addr)

	testnetChainHash = *chaincfg.TestNet3Params.GenesisHash

	rewardType = (blob.FlagCommitOutputs | blob.FlagReward).Type()
)

// randPubKey generates a new secp keypair, and returns the public key.
func randPubKey(t *testing.T) *btcec.PublicKey {
	t.Helper()

	sk, err := btcec.NewPrivateKey(btcec.S256())
	if err != nil {
		t.Fatalf("unable to generate pubkey: %v", err)
	}

	return sk.PubKey()
}

// initServer creates and starts a new server using the server.DB and timeout.
// If the provided database is nil, a mock db will be used.
func initServer(t *testing.T, db wtserver.DB,
	timeout time.Duration) wtserver.Interface {

	t.Helper()

	if db == nil {
		db = wtmock.NewTowerDB()
	}

	s, err := wtserver.New(&wtserver.Config{
		DB:           db,
		ReadTimeout:  timeout,
		WriteTimeout: timeout,
		NewAddress: func() (btcutil.Address, error) {
			return addr, nil
		},
		ChainHash: testnetChainHash,
	})
	if err != nil {
		t.Fatalf("unable to create server: %v", err)
	}

	if err = s.Start(); err != nil {
		t.Fatalf("unable to start server: %v", err)
	}

	return s
}

// TestServerOnlyAcceptOnePeer checks that the server will reject duplicate
// peers with the same session id by disconnecting them. This is accomplished by
// connecting two distinct peers with the same session id, and trying to send
// messages on both connections. Since one should be rejected, we verify that
// only one of the connections is able to send messages.
func TestServerOnlyAcceptOnePeer(t *testing.T) {
	t.Parallel()

	const timeoutDuration = 500 * time.Millisecond

	s := initServer(t, nil, timeoutDuration)
	defer s.Stop()

	localPub := randPubKey(t)

	// Create two peers using the same session id.
	peerPub := randPubKey(t)
	peer1 := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
	peer2 := wtmock.NewMockPeer(localPub, peerPub, nil, 0)

	// Serialize a Init message to be sent by both peers.
	init := wtwire.NewInitMessage(
		lnwire.NewRawFeatureVector(), testnetChainHash,
	)

	var b bytes.Buffer
	_, err := wtwire.WriteMessage(&b, init, 0)
	if err != nil {
		t.Fatalf("unable to write message: %v", err)
	}

	msg := b.Bytes()

	// Connect both peers to the server simultaneously.
	s.InboundPeerConnected(peer1)
	s.InboundPeerConnected(peer2)

	// Use a timeout of twice the server's timeouts, to ensure the server
	// has time to process the messages.
	timeout := time.After(2 * timeoutDuration)

	// Try to send a message on either peer, and record the opposite peer as
	// the one we assume to be rejected.
	var (
		rejectedPeer *wtmock.MockPeer
		acceptedPeer *wtmock.MockPeer
	)
	select {
	case peer1.IncomingMsgs <- msg:
		acceptedPeer = peer1
		rejectedPeer = peer2
	case peer2.IncomingMsgs <- msg:
		acceptedPeer = peer2
		rejectedPeer = peer1
	case <-timeout:
		t.Fatalf("unable to send message via either peer")
	}

	// Try again to send a message, this time only via the assumed-rejected
	// peer. We expect our conservative timeout to expire, as the server
	// isn't reading from this peer. Before the timeout, the accepted peer
	// should also receive a reply to its Init message.
	select {
	case <-acceptedPeer.OutgoingMsgs:
		select {
		case rejectedPeer.IncomingMsgs <- msg:
			t.Fatalf("rejected peer should not have received message")
		case <-timeout:
			// Accepted peer got reply, rejected peer go nothing.
		}
	case rejectedPeer.IncomingMsgs <- msg:
		t.Fatalf("rejected peer should not have received message")
	case <-timeout:
		t.Fatalf("accepted peer should have received init message")
	}
}

type createSessionTestCase struct {
	name            string
	initMsg         *wtwire.Init
	createMsg       *wtwire.CreateSession
	expReply        *wtwire.CreateSessionReply
	expDupReply     *wtwire.CreateSessionReply
	sendStateUpdate bool
}

var createSessionTests = []createSessionTestCase{
	{
		name: "duplicate session create",
		initMsg: wtwire.NewInitMessage(
			lnwire.NewRawFeatureVector(),
			testnetChainHash,
		),
		createMsg: &wtwire.CreateSession{
			BlobType:     blob.TypeDefault,
			MaxUpdates:   1000,
			RewardBase:   0,
			RewardRate:   0,
			SweepFeeRate: 1,
		},
		expReply: &wtwire.CreateSessionReply{
			Code: wtwire.CodeOK,
			Data: []byte{},
		},
		expDupReply: &wtwire.CreateSessionReply{
			Code: wtwire.CodeOK,
			Data: []byte{},
		},
	},
	{
		name: "duplicate session create after use",
		initMsg: wtwire.NewInitMessage(
			lnwire.NewRawFeatureVector(),
			testnetChainHash,
		),
		createMsg: &wtwire.CreateSession{
			BlobType:     blob.TypeDefault,
			MaxUpdates:   1000,
			RewardBase:   0,
			RewardRate:   0,
			SweepFeeRate: 1,
		},
		expReply: &wtwire.CreateSessionReply{
			Code: wtwire.CodeOK,
			Data: []byte{},
		},
		expDupReply: &wtwire.CreateSessionReply{
			Code:        wtwire.CreateSessionCodeAlreadyExists,
			LastApplied: 1,
			Data:        []byte{},
		},
		sendStateUpdate: true,
	},
	{
		name: "duplicate session create reward",
		initMsg: wtwire.NewInitMessage(
			lnwire.NewRawFeatureVector(),
			testnetChainHash,
		),
		createMsg: &wtwire.CreateSession{
			BlobType:     rewardType,
			MaxUpdates:   1000,
			RewardBase:   0,
			RewardRate:   0,
			SweepFeeRate: 1,
		},
		expReply: &wtwire.CreateSessionReply{
			Code: wtwire.CodeOK,
			Data: addrScript,
		},
		expDupReply: &wtwire.CreateSessionReply{
			Code: wtwire.CodeOK,
			Data: addrScript,
		},
	},
	{
		name: "reject unsupported blob type",
		initMsg: wtwire.NewInitMessage(
			lnwire.NewRawFeatureVector(),
			testnetChainHash,
		),
		createMsg: &wtwire.CreateSession{
			BlobType:     0,
			MaxUpdates:   1000,
			RewardBase:   0,
			RewardRate:   0,
			SweepFeeRate: 1,
		},
		expReply: &wtwire.CreateSessionReply{
			Code: wtwire.CreateSessionCodeRejectBlobType,
			Data: []byte{},
		},
	},
	// TODO(conner): add policy rejection tests
}

// TestServerCreateSession checks the server's behavior in response to a
// table-driven set of CreateSession messages.
func TestServerCreateSession(t *testing.T) {
	t.Parallel()

	for i, test := range createSessionTests {
		t.Run(test.name, func(t *testing.T) {
			testServerCreateSession(t, i, test)
		})
	}
}

func testServerCreateSession(t *testing.T, i int, test createSessionTestCase) {
	const timeoutDuration = 500 * time.Millisecond

	s := initServer(t, nil, timeoutDuration)
	defer s.Stop()

	localPub := randPubKey(t)

	// Create a new client and connect to server.
	peerPub := randPubKey(t)
	peer := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
	connect(t, s, peer, test.initMsg, timeoutDuration)

	// Send the CreateSession message, and wait for a reply.
	sendMsg(t, test.createMsg, peer, timeoutDuration)

	reply := recvReply(
		t, "MsgCreateSessionReply", peer, timeoutDuration,
	).(*wtwire.CreateSessionReply)

	// Verify that the server's response matches our expectation.
	if !reflect.DeepEqual(reply, test.expReply) {
		t.Fatalf("[test %d] expected reply %v, got %d",
			i, test.expReply, reply)
	}

	// Assert that the server closes the connection after processing the
	// CreateSession.
	assertConnClosed(t, peer, 2*timeoutDuration)

	// If this test did not request sending a duplicate CreateSession, we can
	// continue to the next test.
	if test.expDupReply == nil {
		return
	}

	if test.sendStateUpdate {
		peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
		connect(t, s, peer, test.initMsg, timeoutDuration)
		update := &wtwire.StateUpdate{
			SeqNum:     1,
			IsComplete: 1,
		}
		sendMsg(t, update, peer, timeoutDuration)

		assertConnClosed(t, peer, 2*timeoutDuration)
	}

	// Simulate a peer with the same session id connection to the server
	// again.
	peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
	connect(t, s, peer, test.initMsg, timeoutDuration)

	// Send the _same_ CreateSession message as the first attempt.
	sendMsg(t, test.createMsg, peer, timeoutDuration)

	reply = recvReply(
		t, "MsgCreateSessionReply", peer, timeoutDuration,
	).(*wtwire.CreateSessionReply)

	// Ensure that the server's reply matches our expected response for a
	// duplicate send.
	if !reflect.DeepEqual(reply, test.expDupReply) {
		t.Fatalf("[test %d] expected reply %v, got %d",
			i, test.expReply, reply)
	}

	// Finally, check that the server tore down the connection.
	assertConnClosed(t, peer, 2*timeoutDuration)
}

type stateUpdateTestCase struct {
	name      string
	initMsg   *wtwire.Init
	createMsg *wtwire.CreateSession
	updates   []*wtwire.StateUpdate
	replies   []*wtwire.StateUpdateReply
}

var stateUpdateTests = []stateUpdateTestCase{
	// Valid update sequence, send seqnum == lastapplied as last update.
	{
		name: "perm fail after sending seqnum equal lastapplied",
		initMsg: wtwire.NewInitMessage(
			lnwire.NewRawFeatureVector(),
			testnetChainHash,
		),
		createMsg: &wtwire.CreateSession{
			BlobType:     blob.TypeDefault,
			MaxUpdates:   3,
			RewardBase:   0,
			RewardRate:   0,
			SweepFeeRate: 1,
		},
		updates: []*wtwire.StateUpdate{
			{SeqNum: 1, LastApplied: 0},
			{SeqNum: 2, LastApplied: 1},
			{SeqNum: 3, LastApplied: 2},
			{SeqNum: 3, LastApplied: 3},
		},
		replies: []*wtwire.StateUpdateReply{
			{Code: wtwire.CodeOK, LastApplied: 1},
			{Code: wtwire.CodeOK, LastApplied: 2},
			{Code: wtwire.CodeOK, LastApplied: 3},
			{
				Code:        wtwire.CodePermanentFailure,
				LastApplied: 3,
			},
		},
	},
	// Send update that skips next expected sequence number.
	{
		name: "skip sequence number",
		initMsg: wtwire.NewInitMessage(
			lnwire.NewRawFeatureVector(),
			testnetChainHash,
		),
		createMsg: &wtwire.CreateSession{
			BlobType:     blob.TypeDefault,
			MaxUpdates:   4,
			RewardBase:   0,
			RewardRate:   0,
			SweepFeeRate: 1,
		},
		updates: []*wtwire.StateUpdate{
			{SeqNum: 2, LastApplied: 0},
		},
		replies: []*wtwire.StateUpdateReply{
			{
				Code:        wtwire.StateUpdateCodeSeqNumOutOfOrder,
				LastApplied: 0,
			},
		},
	},
	// Send update that reverts to older sequence number.
	{
		name: "revert to older seqnum",
		initMsg: wtwire.NewInitMessage(
			lnwire.NewRawFeatureVector(),
			testnetChainHash,
		),
		createMsg: &wtwire.CreateSession{
			BlobType:     blob.TypeDefault,
			MaxUpdates:   4,
			RewardBase:   0,
			RewardRate:   0,
			SweepFeeRate: 1,
		},
		updates: []*wtwire.StateUpdate{
			{SeqNum: 1, LastApplied: 0},
			{SeqNum: 2, LastApplied: 0},
			{SeqNum: 1, LastApplied: 0},
		},
		replies: []*wtwire.StateUpdateReply{
			{Code: wtwire.CodeOK, LastApplied: 1},
			{Code: wtwire.CodeOK, LastApplied: 2},
			{
				Code:        wtwire.StateUpdateCodeSeqNumOutOfOrder,
				LastApplied: 2,
			},
		},
	},
	// Send update echoing a last applied that is lower than previous value.
	{
		name: "revert to older lastapplied",
		initMsg: wtwire.NewInitMessage(
			lnwire.NewRawFeatureVector(),
			testnetChainHash,
		),
		createMsg: &wtwire.CreateSession{
			BlobType:     blob.TypeDefault,
			MaxUpdates:   4,
			RewardBase:   0,
			RewardRate:   0,
			SweepFeeRate: 1,
		},
		updates: []*wtwire.StateUpdate{
			{SeqNum: 1, LastApplied: 0},
			{SeqNum: 2, LastApplied: 1},
			{SeqNum: 3, LastApplied: 2},
			{SeqNum: 4, LastApplied: 1},
		},
		replies: []*wtwire.StateUpdateReply{
			{Code: wtwire.CodeOK, LastApplied: 1},
			{Code: wtwire.CodeOK, LastApplied: 2},
			{Code: wtwire.CodeOK, LastApplied: 3},
			{Code: wtwire.StateUpdateCodeClientBehind, LastApplied: 3},
		},
	},
	// Valid update sequence with disconnection, ensure resumes resume.
	// Client echos last applied as they are received.
	{
		name: "resume after disconnect",
		initMsg: wtwire.NewInitMessage(
			lnwire.NewRawFeatureVector(),
			testnetChainHash,
		),
		createMsg: &wtwire.CreateSession{
			BlobType:     blob.TypeDefault,
			MaxUpdates:   4,
			RewardBase:   0,
			RewardRate:   0,
			SweepFeeRate: 1,
		},
		updates: []*wtwire.StateUpdate{
			{SeqNum: 1, LastApplied: 0},
			{SeqNum: 2, LastApplied: 1},
			nil, // Wait for read timeout to drop conn, then reconnect.
			{SeqNum: 3, LastApplied: 2},
			{SeqNum: 4, LastApplied: 3},
		},
		replies: []*wtwire.StateUpdateReply{
			{Code: wtwire.CodeOK, LastApplied: 1},
			{Code: wtwire.CodeOK, LastApplied: 2},
			nil,
			{Code: wtwire.CodeOK, LastApplied: 3},
			{Code: wtwire.CodeOK, LastApplied: 4},
		},
	},
	// Valid update sequence with disconnection, resume next update. Client
	// doesn't echo last applied until last message.
	{
		name: "resume after disconnect lagging lastapplied",
		initMsg: wtwire.NewInitMessage(
			lnwire.NewRawFeatureVector(),
			testnetChainHash,
		),
		createMsg: &wtwire.CreateSession{
			BlobType:     blob.TypeDefault,
			MaxUpdates:   4,
			RewardBase:   0,
			RewardRate:   0,
			SweepFeeRate: 1,
		},
		updates: []*wtwire.StateUpdate{
			{SeqNum: 1, LastApplied: 0},
			{SeqNum: 2, LastApplied: 0},
			nil, // Wait for read timeout to drop conn, then reconnect.
			{SeqNum: 3, LastApplied: 0},
			{SeqNum: 4, LastApplied: 3},
		},
		replies: []*wtwire.StateUpdateReply{
			{Code: wtwire.CodeOK, LastApplied: 1},
			{Code: wtwire.CodeOK, LastApplied: 2},
			nil,
			{Code: wtwire.CodeOK, LastApplied: 3},
			{Code: wtwire.CodeOK, LastApplied: 4},
		},
	},
	// Valid update sequence with disconnection, resume last update.  Client
	// doesn't echo last applied until last message.
	{
		name: "resume after disconnect lagging lastapplied",
		initMsg: wtwire.NewInitMessage(
			lnwire.NewRawFeatureVector(),
			testnetChainHash,
		),
		createMsg: &wtwire.CreateSession{
			BlobType:     blob.TypeDefault,
			MaxUpdates:   4,
			RewardBase:   0,
			RewardRate:   0,
			SweepFeeRate: 1,
		},
		updates: []*wtwire.StateUpdate{
			{SeqNum: 1, LastApplied: 0},
			{SeqNum: 2, LastApplied: 0},
			nil, // Wait for read timeout to drop conn, then reconnect.
			{SeqNum: 2, LastApplied: 0},
			{SeqNum: 3, LastApplied: 0},
			{SeqNum: 4, LastApplied: 3},
		},
		replies: []*wtwire.StateUpdateReply{
			{Code: wtwire.CodeOK, LastApplied: 1},
			{Code: wtwire.CodeOK, LastApplied: 2},
			nil,
			{Code: wtwire.CodeOK, LastApplied: 2},
			{Code: wtwire.CodeOK, LastApplied: 3},
			{Code: wtwire.CodeOK, LastApplied: 4},
		},
	},
	// Send update with sequence number that exceeds MaxUpdates.
	{
		name: "seqnum exceed maxupdates",
		initMsg: wtwire.NewInitMessage(
			lnwire.NewRawFeatureVector(),
			testnetChainHash,
		),
		createMsg: &wtwire.CreateSession{
			BlobType:     blob.TypeDefault,
			MaxUpdates:   3,
			RewardBase:   0,
			RewardRate:   0,
			SweepFeeRate: 1,
		},
		updates: []*wtwire.StateUpdate{
			{SeqNum: 1, LastApplied: 0},
			{SeqNum: 2, LastApplied: 1},
			{SeqNum: 3, LastApplied: 2},
			{SeqNum: 4, LastApplied: 3},
		},
		replies: []*wtwire.StateUpdateReply{
			{Code: wtwire.CodeOK, LastApplied: 1},
			{Code: wtwire.CodeOK, LastApplied: 2},
			{Code: wtwire.CodeOK, LastApplied: 3},
			{
				Code:        wtwire.StateUpdateCodeMaxUpdatesExceeded,
				LastApplied: 3,
			},
		},
	},
	// Ensure sequence number 0 causes permanent failure.
	{
		name: "perm fail after seqnum 0",
		initMsg: wtwire.NewInitMessage(
			lnwire.NewRawFeatureVector(),
			testnetChainHash,
		),
		createMsg: &wtwire.CreateSession{
			BlobType:     blob.TypeDefault,
			MaxUpdates:   3,
			RewardBase:   0,
			RewardRate:   0,
			SweepFeeRate: 1,
		},
		updates: []*wtwire.StateUpdate{
			{SeqNum: 0, LastApplied: 0},
		},
		replies: []*wtwire.StateUpdateReply{
			{
				Code:        wtwire.CodePermanentFailure,
				LastApplied: 0,
			},
		},
	},
}

// TestServerStateUpdates tests the behavior of the server in response to
// watchtower clients sending StateUpdate messages, after having already
// established an open session. The test asserts that the server responds
// with the appropriate failure codes in a number of failure conditions where
// the server and client desynchronize. It also checks the ability of the client
// to disconnect, connect, and continue updating from the last successful state
// update.
func TestServerStateUpdates(t *testing.T) {
	t.Parallel()

	for _, test := range stateUpdateTests {
		t.Run(test.name, func(t *testing.T) {
			testServerStateUpdates(t, test)
		})
	}
}

func testServerStateUpdates(t *testing.T, test stateUpdateTestCase) {
	const timeoutDuration = 100 * time.Millisecond

	s := initServer(t, nil, timeoutDuration)
	defer s.Stop()

	localPub := randPubKey(t)

	// Create a new client and connect to the server.
	peerPub := randPubKey(t)
	peer := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
	connect(t, s, peer, test.initMsg, timeoutDuration)

	// Register a session for this client to use in the subsequent tests.
	sendMsg(t, test.createMsg, peer, timeoutDuration)
	initReply := recvReply(
		t, "MsgCreateSessionReply", peer, timeoutDuration,
	).(*wtwire.CreateSessionReply)

	// Fail if the server rejected our proposed CreateSession message.
	if initReply.Code != wtwire.CodeOK {
		t.Fatalf("server rejected session init")
	}

	// Check that the server closed the connection used to register the
	// session.
	assertConnClosed(t, peer, 2*timeoutDuration)

	// Now that the original connection has been closed, connect a new
	// client with the same session id.
	peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
	connect(t, s, peer, test.initMsg, timeoutDuration)

	// Send the intended StateUpdate messages in series.
	for j, update := range test.updates {
		// A nil update signals that we should wait for the prior
		// connection to die, before re-register with the same session
		// identifier.
		if update == nil {
			assertConnClosed(t, peer, 2*timeoutDuration)

			peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
			connect(t, s, peer, test.initMsg, timeoutDuration)

			continue
		}

		// Send the state update and verify it against our expected
		// response.
		sendMsg(t, update, peer, timeoutDuration)
		reply := recvReply(
			t, "MsgStateUpdateReply", peer, timeoutDuration,
		).(*wtwire.StateUpdateReply)

		if !reflect.DeepEqual(reply, test.replies[j]) {
			t.Fatalf("[update %d] expected reply "+
				"%v, got %d", j,
				test.replies[j], reply)
		}
	}

	// Check that the final connection is properly cleaned up by the server.
	assertConnClosed(t, peer, 2*timeoutDuration)
}

// TestServerDeleteSession asserts the response to a DeleteSession request, and
// checking that the proper error is returned when the session doesn't exist and
// that a successful deletion does not disrupt other sessions.
func TestServerDeleteSession(t *testing.T) {
	db := wtmock.NewTowerDB()

	localPub := randPubKey(t)

	// Initialize two distinct peers with different session ids.
	peerPub1 := randPubKey(t)
	peerPub2 := randPubKey(t)

	id1 := wtdb.NewSessionIDFromPubKey(peerPub1)
	id2 := wtdb.NewSessionIDFromPubKey(peerPub2)

	// Create closure to simplify assertions on session existence with the
	// server's database.
	hasSession := func(t *testing.T, id *wtdb.SessionID, shouldHave bool) {
		t.Helper()

		_, err := db.GetSessionInfo(id)
		switch {
		case shouldHave && err != nil:
			t.Fatalf("expected server to have session %s, got: %v",
				id, err)
		case !shouldHave && err != wtdb.ErrSessionNotFound:
			t.Fatalf("expected ErrSessionNotFound for session %s, "+
				"got: %v", id, err)
		}
	}

	initMsg := wtwire.NewInitMessage(
		lnwire.NewRawFeatureVector(),
		testnetChainHash,
	)

	createSession := &wtwire.CreateSession{
		BlobType:     blob.TypeDefault,
		MaxUpdates:   1000,
		RewardBase:   0,
		RewardRate:   0,
		SweepFeeRate: 1,
	}

	const timeoutDuration = 100 * time.Millisecond

	s := initServer(t, db, timeoutDuration)
	defer s.Stop()

	// Create a session for peer2 so that the server's db isn't completely
	// empty.
	peer2 := wtmock.NewMockPeer(localPub, peerPub2, nil, 0)
	connect(t, s, peer2, initMsg, timeoutDuration)
	sendMsg(t, createSession, peer2, timeoutDuration)
	assertConnClosed(t, peer2, 2*timeoutDuration)

	// Our initial assertions are that peer2 has a valid session, but peer1
	// has not created one.
	hasSession(t, &id1, false)
	hasSession(t, &id2, true)

	peer1Msgs := []struct {
		send   wtwire.Message
		recv   wtwire.Message
		assert func(t *testing.T)
	}{
		{
			// Deleting unknown session should fail.
			send: &wtwire.DeleteSession{},
			recv: &wtwire.DeleteSessionReply{
				Code: wtwire.DeleteSessionCodeNotFound,
			},
			assert: func(t *testing.T) {
				// Peer2 should still be only session.
				hasSession(t, &id1, false)
				hasSession(t, &id2, true)
			},
		},
		{
			// Create session for peer1.
			send: createSession,
			recv: &wtwire.CreateSessionReply{
				Code: wtwire.CodeOK,
				Data: []byte{},
			},
			assert: func(t *testing.T) {
				// Both peers should have sessions.
				hasSession(t, &id1, true)
				hasSession(t, &id2, true)
			},
		},

		{
			// Delete peer1's session.
			send: &wtwire.DeleteSession{},
			recv: &wtwire.DeleteSessionReply{
				Code: wtwire.CodeOK,
			},
			assert: func(t *testing.T) {
				// Peer1's session should have been removed.
				hasSession(t, &id1, false)
				hasSession(t, &id2, true)
			},
		},
	}

	// Now as peer1, process the canned messages defined above. This will:
	// 1. Try to delete an unknown session and get a not found error code.
	// 2. Create a new session using the same parameters as peer2.
	// 3. Delete the newly created session and get an OK.
	for _, msg := range peer1Msgs {
		peer1 := wtmock.NewMockPeer(localPub, peerPub1, nil, 0)
		connect(t, s, peer1, initMsg, timeoutDuration)
		sendMsg(t, msg.send, peer1, timeoutDuration)
		reply := recvReply(
			t, msg.recv.MsgType().String(), peer1, timeoutDuration,
		)

		if !reflect.DeepEqual(reply, msg.recv) {
			t.Fatalf("expected reply: %v, got: %v", msg.recv, reply)
		}

		assertConnClosed(t, peer1, 2*timeoutDuration)

		// Invoke assertions after completing the request/response
		// dance.
		msg.assert(t)
	}
}

func connect(t *testing.T, s wtserver.Interface, peer *wtmock.MockPeer,
	initMsg *wtwire.Init, timeout time.Duration) {

	t.Helper()

	s.InboundPeerConnected(peer)
	sendMsg(t, initMsg, peer, timeout)
	recvReply(t, "MsgInit", peer, timeout)
}

// sendMsg sends a wtwire.Message message via a wtmock.MockPeer.
func sendMsg(t *testing.T, msg wtwire.Message,
	peer *wtmock.MockPeer, timeout time.Duration) {

	t.Helper()

	var b bytes.Buffer
	_, err := wtwire.WriteMessage(&b, msg, 0)
	if err != nil {
		t.Fatalf("unable to encode %T message: %v",
			msg, err)
	}

	select {
	case peer.IncomingMsgs <- b.Bytes():
	case <-time.After(2 * timeout):
		t.Fatalf("unable to send %T message", msg)
	}
}

// recvReply receives a message from the server, and parses it according to
// expected reply type. The supported replies are CreateSessionReply and
// StateUpdateReply.
func recvReply(t *testing.T, name string, peer *wtmock.MockPeer,
	timeout time.Duration) wtwire.Message {

	t.Helper()

	var (
		msg wtwire.Message
		err error
	)

	select {
	case b := <-peer.OutgoingMsgs:
		msg, err = wtwire.ReadMessage(bytes.NewReader(b), 0)
		if err != nil {
			t.Fatalf("unable to decode server "+
				"reply: %v", err)
		}

	case <-time.After(2 * timeout):
		t.Fatalf("server did not reply")
	}

	switch name {
	case "MsgInit":
		if _, ok := msg.(*wtwire.Init); !ok {
			t.Fatalf("expected %s reply message, "+
				"got %T", name, msg)
		}
	case "MsgCreateSessionReply":
		if _, ok := msg.(*wtwire.CreateSessionReply); !ok {
			t.Fatalf("expected %s reply message, "+
				"got %T", name, msg)
		}
	case "MsgStateUpdateReply":
		if _, ok := msg.(*wtwire.StateUpdateReply); !ok {
			t.Fatalf("expected %s reply message, "+
				"got %T", name, msg)
		}
	case "MsgDeleteSessionReply":
		if _, ok := msg.(*wtwire.DeleteSessionReply); !ok {
			t.Fatalf("expected %s reply message, "+
				"got %T", name, msg)
		}
	}

	return msg
}

// assertConnClosed checks that the peer's connection is closed before the
// timeout expires.
func assertConnClosed(t *testing.T, peer *wtmock.MockPeer, duration time.Duration) {
	t.Helper()

	select {
	case <-peer.Quit:
	case <-time.After(duration):
		t.Fatalf("expected connection to be closed")
	}
}