lnd.xprv/watchtower/wtserver/server_test.go

904 lines
24 KiB
Go

package wtserver_test
import (
"bytes"
"reflect"
"testing"
"time"
"github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcutil"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtmock"
"github.com/lightningnetwork/lnd/watchtower/wtserver"
"github.com/lightningnetwork/lnd/watchtower/wtwire"
)
var (
// addr is the server's reward address given to watchtower clients.
addr, _ = btcutil.DecodeAddress(
"mrX9vMRYLfVy1BnZbc5gZjuyaqH3ZW2ZHz", &chaincfg.TestNet3Params,
)
addrScript, _ = txscript.PayToAddrScript(addr)
testnetChainHash = *chaincfg.TestNet3Params.GenesisHash
)
// randPubKey generates a new secp keypair, and returns the public key.
func randPubKey(t *testing.T) *btcec.PublicKey {
t.Helper()
sk, err := btcec.NewPrivateKey(btcec.S256())
if err != nil {
t.Fatalf("unable to generate pubkey: %v", err)
}
return sk.PubKey()
}
// initServer creates and starts a new server using the server.DB and timeout.
// If the provided database is nil, a mock db will be used.
func initServer(t *testing.T, db wtserver.DB,
timeout time.Duration) wtserver.Interface {
t.Helper()
if db == nil {
db = wtmock.NewTowerDB()
}
s, err := wtserver.New(&wtserver.Config{
DB: db,
ReadTimeout: timeout,
WriteTimeout: timeout,
NewAddress: func() (btcutil.Address, error) {
return addr, nil
},
ChainHash: testnetChainHash,
})
if err != nil {
t.Fatalf("unable to create server: %v", err)
}
if err = s.Start(); err != nil {
t.Fatalf("unable to start server: %v", err)
}
return s
}
// TestServerOnlyAcceptOnePeer checks that the server will reject duplicate
// peers with the same session id by disconnecting them. This is accomplished by
// connecting two distinct peers with the same session id, and trying to send
// messages on both connections. Since one should be rejected, we verify that
// only one of the connections is able to send messages.
func TestServerOnlyAcceptOnePeer(t *testing.T) {
t.Parallel()
const timeoutDuration = 500 * time.Millisecond
s := initServer(t, nil, timeoutDuration)
defer s.Stop()
localPub := randPubKey(t)
// Create two peers using the same session id.
peerPub := randPubKey(t)
peer1 := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
peer2 := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
// Serialize a Init message to be sent by both peers.
init := wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(), testnetChainHash,
)
var b bytes.Buffer
_, err := wtwire.WriteMessage(&b, init, 0)
if err != nil {
t.Fatalf("unable to write message: %v", err)
}
msg := b.Bytes()
// Connect both peers to the server simultaneously.
s.InboundPeerConnected(peer1)
s.InboundPeerConnected(peer2)
// Use a timeout of twice the server's timeouts, to ensure the server
// has time to process the messages.
timeout := time.After(2 * timeoutDuration)
// Try to send a message on either peer, and record the opposite peer as
// the one we assume to be rejected.
var (
rejectedPeer *wtmock.MockPeer
acceptedPeer *wtmock.MockPeer
)
select {
case peer1.IncomingMsgs <- msg:
acceptedPeer = peer1
rejectedPeer = peer2
case peer2.IncomingMsgs <- msg:
acceptedPeer = peer2
rejectedPeer = peer1
case <-timeout:
t.Fatalf("unable to send message via either peer")
}
// Try again to send a message, this time only via the assumed-rejected
// peer. We expect our conservative timeout to expire, as the server
// isn't reading from this peer. Before the timeout, the accepted peer
// should also receive a reply to its Init message.
select {
case <-acceptedPeer.OutgoingMsgs:
select {
case rejectedPeer.IncomingMsgs <- msg:
t.Fatalf("rejected peer should not have received message")
case <-timeout:
// Accepted peer got reply, rejected peer go nothing.
}
case rejectedPeer.IncomingMsgs <- msg:
t.Fatalf("rejected peer should not have received message")
case <-timeout:
t.Fatalf("accepted peer should have received init message")
}
}
type createSessionTestCase struct {
name string
initMsg *wtwire.Init
createMsg *wtwire.CreateSession
expReply *wtwire.CreateSessionReply
expDupReply *wtwire.CreateSessionReply
sendStateUpdate bool
}
var createSessionTests = []createSessionTestCase{
{
name: "duplicate session create",
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: blob.TypeAltruistCommit,
MaxUpdates: 1000,
RewardBase: 0,
RewardRate: 0,
SweepFeeRate: 1,
},
expReply: &wtwire.CreateSessionReply{
Code: wtwire.CodeOK,
Data: []byte{},
},
expDupReply: &wtwire.CreateSessionReply{
Code: wtwire.CodeOK,
Data: []byte{},
},
},
{
name: "duplicate session create after use",
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: blob.TypeAltruistCommit,
MaxUpdates: 1000,
RewardBase: 0,
RewardRate: 0,
SweepFeeRate: 1,
},
expReply: &wtwire.CreateSessionReply{
Code: wtwire.CodeOK,
Data: []byte{},
},
expDupReply: &wtwire.CreateSessionReply{
Code: wtwire.CreateSessionCodeAlreadyExists,
LastApplied: 1,
Data: []byte{},
},
sendStateUpdate: true,
},
{
name: "duplicate session create reward",
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: blob.TypeRewardCommit,
MaxUpdates: 1000,
RewardBase: 0,
RewardRate: 0,
SweepFeeRate: 1,
},
expReply: &wtwire.CreateSessionReply{
Code: wtwire.CodeOK,
Data: addrScript,
},
expDupReply: &wtwire.CreateSessionReply{
Code: wtwire.CodeOK,
Data: addrScript,
},
},
{
name: "reject unsupported blob type",
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: 0,
MaxUpdates: 1000,
RewardBase: 0,
RewardRate: 0,
SweepFeeRate: 1,
},
expReply: &wtwire.CreateSessionReply{
Code: wtwire.CreateSessionCodeRejectBlobType,
Data: []byte{},
},
},
// TODO(conner): add policy rejection tests
}
// TestServerCreateSession checks the server's behavior in response to a
// table-driven set of CreateSession messages.
func TestServerCreateSession(t *testing.T) {
t.Parallel()
for i, test := range createSessionTests {
t.Run(test.name, func(t *testing.T) {
testServerCreateSession(t, i, test)
})
}
}
func testServerCreateSession(t *testing.T, i int, test createSessionTestCase) {
const timeoutDuration = 500 * time.Millisecond
s := initServer(t, nil, timeoutDuration)
defer s.Stop()
localPub := randPubKey(t)
// Create a new client and connect to server.
peerPub := randPubKey(t)
peer := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
connect(t, s, peer, test.initMsg, timeoutDuration)
// Send the CreateSession message, and wait for a reply.
sendMsg(t, test.createMsg, peer, timeoutDuration)
reply := recvReply(
t, "MsgCreateSessionReply", peer, timeoutDuration,
).(*wtwire.CreateSessionReply)
// Verify that the server's response matches our expectation.
if !reflect.DeepEqual(reply, test.expReply) {
t.Fatalf("[test %d] expected reply %v, got %d",
i, test.expReply, reply)
}
// Assert that the server closes the connection after processing the
// CreateSession.
assertConnClosed(t, peer, 2*timeoutDuration)
// If this test did not request sending a duplicate CreateSession, we can
// continue to the next test.
if test.expDupReply == nil {
return
}
if test.sendStateUpdate {
peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
connect(t, s, peer, test.initMsg, timeoutDuration)
update := &wtwire.StateUpdate{
SeqNum: 1,
IsComplete: 1,
}
sendMsg(t, update, peer, timeoutDuration)
assertConnClosed(t, peer, 2*timeoutDuration)
}
// Simulate a peer with the same session id connection to the server
// again.
peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
connect(t, s, peer, test.initMsg, timeoutDuration)
// Send the _same_ CreateSession message as the first attempt.
sendMsg(t, test.createMsg, peer, timeoutDuration)
reply = recvReply(
t, "MsgCreateSessionReply", peer, timeoutDuration,
).(*wtwire.CreateSessionReply)
// Ensure that the server's reply matches our expected response for a
// duplicate send.
if !reflect.DeepEqual(reply, test.expDupReply) {
t.Fatalf("[test %d] expected reply %v, got %d",
i, test.expReply, reply)
}
// Finally, check that the server tore down the connection.
assertConnClosed(t, peer, 2*timeoutDuration)
}
type stateUpdateTestCase struct {
name string
initMsg *wtwire.Init
createMsg *wtwire.CreateSession
updates []*wtwire.StateUpdate
replies []*wtwire.StateUpdateReply
}
var stateUpdateTests = []stateUpdateTestCase{
// Valid update sequence, send seqnum == lastapplied as last update.
{
name: "perm fail after sending seqnum equal lastapplied",
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: blob.TypeAltruistCommit,
MaxUpdates: 3,
RewardBase: 0,
RewardRate: 0,
SweepFeeRate: 1,
},
updates: []*wtwire.StateUpdate{
{SeqNum: 1, LastApplied: 0},
{SeqNum: 2, LastApplied: 1},
{SeqNum: 3, LastApplied: 2},
{SeqNum: 3, LastApplied: 3},
},
replies: []*wtwire.StateUpdateReply{
{Code: wtwire.CodeOK, LastApplied: 1},
{Code: wtwire.CodeOK, LastApplied: 2},
{Code: wtwire.CodeOK, LastApplied: 3},
{
Code: wtwire.CodePermanentFailure,
LastApplied: 3,
},
},
},
// Send update that skips next expected sequence number.
{
name: "skip sequence number",
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: blob.TypeAltruistCommit,
MaxUpdates: 4,
RewardBase: 0,
RewardRate: 0,
SweepFeeRate: 1,
},
updates: []*wtwire.StateUpdate{
{SeqNum: 2, LastApplied: 0},
},
replies: []*wtwire.StateUpdateReply{
{
Code: wtwire.StateUpdateCodeSeqNumOutOfOrder,
LastApplied: 0,
},
},
},
// Send update that reverts to older sequence number.
{
name: "revert to older seqnum",
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: blob.TypeAltruistCommit,
MaxUpdates: 4,
RewardBase: 0,
RewardRate: 0,
SweepFeeRate: 1,
},
updates: []*wtwire.StateUpdate{
{SeqNum: 1, LastApplied: 0},
{SeqNum: 2, LastApplied: 0},
{SeqNum: 1, LastApplied: 0},
},
replies: []*wtwire.StateUpdateReply{
{Code: wtwire.CodeOK, LastApplied: 1},
{Code: wtwire.CodeOK, LastApplied: 2},
{
Code: wtwire.StateUpdateCodeSeqNumOutOfOrder,
LastApplied: 2,
},
},
},
// Send update echoing a last applied that is lower than previous value.
{
name: "revert to older lastapplied",
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: blob.TypeAltruistCommit,
MaxUpdates: 4,
RewardBase: 0,
RewardRate: 0,
SweepFeeRate: 1,
},
updates: []*wtwire.StateUpdate{
{SeqNum: 1, LastApplied: 0},
{SeqNum: 2, LastApplied: 1},
{SeqNum: 3, LastApplied: 2},
{SeqNum: 4, LastApplied: 1},
},
replies: []*wtwire.StateUpdateReply{
{Code: wtwire.CodeOK, LastApplied: 1},
{Code: wtwire.CodeOK, LastApplied: 2},
{Code: wtwire.CodeOK, LastApplied: 3},
{Code: wtwire.StateUpdateCodeClientBehind, LastApplied: 3},
},
},
// Valid update sequence with disconnection, ensure resumes resume.
// Client echos last applied as they are received.
{
name: "resume after disconnect",
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: blob.TypeAltruistCommit,
MaxUpdates: 4,
RewardBase: 0,
RewardRate: 0,
SweepFeeRate: 1,
},
updates: []*wtwire.StateUpdate{
{SeqNum: 1, LastApplied: 0},
{SeqNum: 2, LastApplied: 1},
nil, // Wait for read timeout to drop conn, then reconnect.
{SeqNum: 3, LastApplied: 2},
{SeqNum: 4, LastApplied: 3},
},
replies: []*wtwire.StateUpdateReply{
{Code: wtwire.CodeOK, LastApplied: 1},
{Code: wtwire.CodeOK, LastApplied: 2},
nil,
{Code: wtwire.CodeOK, LastApplied: 3},
{Code: wtwire.CodeOK, LastApplied: 4},
},
},
// Valid update sequence with disconnection, resume next update. Client
// doesn't echo last applied until last message.
{
name: "resume after disconnect lagging lastapplied",
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: blob.TypeAltruistCommit,
MaxUpdates: 4,
RewardBase: 0,
RewardRate: 0,
SweepFeeRate: 1,
},
updates: []*wtwire.StateUpdate{
{SeqNum: 1, LastApplied: 0},
{SeqNum: 2, LastApplied: 0},
nil, // Wait for read timeout to drop conn, then reconnect.
{SeqNum: 3, LastApplied: 0},
{SeqNum: 4, LastApplied: 3},
},
replies: []*wtwire.StateUpdateReply{
{Code: wtwire.CodeOK, LastApplied: 1},
{Code: wtwire.CodeOK, LastApplied: 2},
nil,
{Code: wtwire.CodeOK, LastApplied: 3},
{Code: wtwire.CodeOK, LastApplied: 4},
},
},
// Valid update sequence with disconnection, resume last update. Client
// doesn't echo last applied until last message.
{
name: "resume after disconnect lagging lastapplied",
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: blob.TypeAltruistCommit,
MaxUpdates: 4,
RewardBase: 0,
RewardRate: 0,
SweepFeeRate: 1,
},
updates: []*wtwire.StateUpdate{
{SeqNum: 1, LastApplied: 0},
{SeqNum: 2, LastApplied: 0},
nil, // Wait for read timeout to drop conn, then reconnect.
{SeqNum: 2, LastApplied: 0},
{SeqNum: 3, LastApplied: 0},
{SeqNum: 4, LastApplied: 3},
},
replies: []*wtwire.StateUpdateReply{
{Code: wtwire.CodeOK, LastApplied: 1},
{Code: wtwire.CodeOK, LastApplied: 2},
nil,
{Code: wtwire.CodeOK, LastApplied: 2},
{Code: wtwire.CodeOK, LastApplied: 3},
{Code: wtwire.CodeOK, LastApplied: 4},
},
},
// Send update with sequence number that exceeds MaxUpdates.
{
name: "seqnum exceed maxupdates",
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: blob.TypeAltruistCommit,
MaxUpdates: 3,
RewardBase: 0,
RewardRate: 0,
SweepFeeRate: 1,
},
updates: []*wtwire.StateUpdate{
{SeqNum: 1, LastApplied: 0},
{SeqNum: 2, LastApplied: 1},
{SeqNum: 3, LastApplied: 2},
{SeqNum: 4, LastApplied: 3},
},
replies: []*wtwire.StateUpdateReply{
{Code: wtwire.CodeOK, LastApplied: 1},
{Code: wtwire.CodeOK, LastApplied: 2},
{Code: wtwire.CodeOK, LastApplied: 3},
{
Code: wtwire.StateUpdateCodeMaxUpdatesExceeded,
LastApplied: 3,
},
},
},
// Ensure sequence number 0 causes permanent failure.
{
name: "perm fail after seqnum 0",
initMsg: wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
),
createMsg: &wtwire.CreateSession{
BlobType: blob.TypeAltruistCommit,
MaxUpdates: 3,
RewardBase: 0,
RewardRate: 0,
SweepFeeRate: 1,
},
updates: []*wtwire.StateUpdate{
{SeqNum: 0, LastApplied: 0},
},
replies: []*wtwire.StateUpdateReply{
{
Code: wtwire.CodePermanentFailure,
LastApplied: 0,
},
},
},
}
// TestServerStateUpdates tests the behavior of the server in response to
// watchtower clients sending StateUpdate messages, after having already
// established an open session. The test asserts that the server responds
// with the appropriate failure codes in a number of failure conditions where
// the server and client desynchronize. It also checks the ability of the client
// to disconnect, connect, and continue updating from the last successful state
// update.
func TestServerStateUpdates(t *testing.T) {
t.Parallel()
for _, test := range stateUpdateTests {
t.Run(test.name, func(t *testing.T) {
testServerStateUpdates(t, test)
})
}
}
func testServerStateUpdates(t *testing.T, test stateUpdateTestCase) {
const timeoutDuration = 100 * time.Millisecond
s := initServer(t, nil, timeoutDuration)
defer s.Stop()
localPub := randPubKey(t)
// Create a new client and connect to the server.
peerPub := randPubKey(t)
peer := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
connect(t, s, peer, test.initMsg, timeoutDuration)
// Register a session for this client to use in the subsequent tests.
sendMsg(t, test.createMsg, peer, timeoutDuration)
initReply := recvReply(
t, "MsgCreateSessionReply", peer, timeoutDuration,
).(*wtwire.CreateSessionReply)
// Fail if the server rejected our proposed CreateSession message.
if initReply.Code != wtwire.CodeOK {
t.Fatalf("server rejected session init")
}
// Check that the server closed the connection used to register the
// session.
assertConnClosed(t, peer, 2*timeoutDuration)
// Now that the original connection has been closed, connect a new
// client with the same session id.
peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
connect(t, s, peer, test.initMsg, timeoutDuration)
// Send the intended StateUpdate messages in series.
for j, update := range test.updates {
// A nil update signals that we should wait for the prior
// connection to die, before re-register with the same session
// identifier.
if update == nil {
assertConnClosed(t, peer, 2*timeoutDuration)
peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
connect(t, s, peer, test.initMsg, timeoutDuration)
continue
}
// Send the state update and verify it against our expected
// response.
sendMsg(t, update, peer, timeoutDuration)
reply := recvReply(
t, "MsgStateUpdateReply", peer, timeoutDuration,
).(*wtwire.StateUpdateReply)
if !reflect.DeepEqual(reply, test.replies[j]) {
t.Fatalf("[update %d] expected reply "+
"%v, got %d", j,
test.replies[j], reply)
}
}
// Check that the final connection is properly cleaned up by the server.
assertConnClosed(t, peer, 2*timeoutDuration)
}
// TestServerDeleteSession asserts the response to a DeleteSession request, and
// checking that the proper error is returned when the session doesn't exist and
// that a successful deletion does not disrupt other sessions.
func TestServerDeleteSession(t *testing.T) {
db := wtmock.NewTowerDB()
localPub := randPubKey(t)
// Initialize two distinct peers with different session ids.
peerPub1 := randPubKey(t)
peerPub2 := randPubKey(t)
id1 := wtdb.NewSessionIDFromPubKey(peerPub1)
id2 := wtdb.NewSessionIDFromPubKey(peerPub2)
// Create closure to simplify assertions on session existence with the
// server's database.
hasSession := func(t *testing.T, id *wtdb.SessionID, shouldHave bool) {
t.Helper()
_, err := db.GetSessionInfo(id)
switch {
case shouldHave && err != nil:
t.Fatalf("expected server to have session %s, got: %v",
id, err)
case !shouldHave && err != wtdb.ErrSessionNotFound:
t.Fatalf("expected ErrSessionNotFound for session %s, "+
"got: %v", id, err)
}
}
initMsg := wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(),
testnetChainHash,
)
createSession := &wtwire.CreateSession{
BlobType: blob.TypeAltruistCommit,
MaxUpdates: 1000,
RewardBase: 0,
RewardRate: 0,
SweepFeeRate: 1,
}
const timeoutDuration = 100 * time.Millisecond
s := initServer(t, db, timeoutDuration)
defer s.Stop()
// Create a session for peer2 so that the server's db isn't completely
// empty.
peer2 := wtmock.NewMockPeer(localPub, peerPub2, nil, 0)
connect(t, s, peer2, initMsg, timeoutDuration)
sendMsg(t, createSession, peer2, timeoutDuration)
assertConnClosed(t, peer2, 2*timeoutDuration)
// Our initial assertions are that peer2 has a valid session, but peer1
// has not created one.
hasSession(t, &id1, false)
hasSession(t, &id2, true)
peer1Msgs := []struct {
send wtwire.Message
recv wtwire.Message
assert func(t *testing.T)
}{
{
// Deleting unknown session should fail.
send: &wtwire.DeleteSession{},
recv: &wtwire.DeleteSessionReply{
Code: wtwire.DeleteSessionCodeNotFound,
},
assert: func(t *testing.T) {
// Peer2 should still be only session.
hasSession(t, &id1, false)
hasSession(t, &id2, true)
},
},
{
// Create session for peer1.
send: createSession,
recv: &wtwire.CreateSessionReply{
Code: wtwire.CodeOK,
Data: []byte{},
},
assert: func(t *testing.T) {
// Both peers should have sessions.
hasSession(t, &id1, true)
hasSession(t, &id2, true)
},
},
{
// Delete peer1's session.
send: &wtwire.DeleteSession{},
recv: &wtwire.DeleteSessionReply{
Code: wtwire.CodeOK,
},
assert: func(t *testing.T) {
// Peer1's session should have been removed.
hasSession(t, &id1, false)
hasSession(t, &id2, true)
},
},
}
// Now as peer1, process the canned messages defined above. This will:
// 1. Try to delete an unknown session and get a not found error code.
// 2. Create a new session using the same parameters as peer2.
// 3. Delete the newly created session and get an OK.
for _, msg := range peer1Msgs {
peer1 := wtmock.NewMockPeer(localPub, peerPub1, nil, 0)
connect(t, s, peer1, initMsg, timeoutDuration)
sendMsg(t, msg.send, peer1, timeoutDuration)
reply := recvReply(
t, msg.recv.MsgType().String(), peer1, timeoutDuration,
)
if !reflect.DeepEqual(reply, msg.recv) {
t.Fatalf("expected reply: %v, got: %v", msg.recv, reply)
}
assertConnClosed(t, peer1, 2*timeoutDuration)
// Invoke assertions after completing the request/response
// dance.
msg.assert(t)
}
}
func connect(t *testing.T, s wtserver.Interface, peer *wtmock.MockPeer,
initMsg *wtwire.Init, timeout time.Duration) {
t.Helper()
s.InboundPeerConnected(peer)
sendMsg(t, initMsg, peer, timeout)
recvReply(t, "MsgInit", peer, timeout)
}
// sendMsg sends a wtwire.Message message via a wtmock.MockPeer.
func sendMsg(t *testing.T, msg wtwire.Message,
peer *wtmock.MockPeer, timeout time.Duration) {
t.Helper()
var b bytes.Buffer
_, err := wtwire.WriteMessage(&b, msg, 0)
if err != nil {
t.Fatalf("unable to encode %T message: %v",
msg, err)
}
select {
case peer.IncomingMsgs <- b.Bytes():
case <-time.After(2 * timeout):
t.Fatalf("unable to send %T message", msg)
}
}
// recvReply receives a message from the server, and parses it according to
// expected reply type. The supported replies are CreateSessionReply and
// StateUpdateReply.
func recvReply(t *testing.T, name string, peer *wtmock.MockPeer,
timeout time.Duration) wtwire.Message {
t.Helper()
var (
msg wtwire.Message
err error
)
select {
case b := <-peer.OutgoingMsgs:
msg, err = wtwire.ReadMessage(bytes.NewReader(b), 0)
if err != nil {
t.Fatalf("unable to decode server "+
"reply: %v", err)
}
case <-time.After(2 * timeout):
t.Fatalf("server did not reply")
}
switch name {
case "MsgInit":
if _, ok := msg.(*wtwire.Init); !ok {
t.Fatalf("expected %s reply message, "+
"got %T", name, msg)
}
case "MsgCreateSessionReply":
if _, ok := msg.(*wtwire.CreateSessionReply); !ok {
t.Fatalf("expected %s reply message, "+
"got %T", name, msg)
}
case "MsgStateUpdateReply":
if _, ok := msg.(*wtwire.StateUpdateReply); !ok {
t.Fatalf("expected %s reply message, "+
"got %T", name, msg)
}
case "MsgDeleteSessionReply":
if _, ok := msg.(*wtwire.DeleteSessionReply); !ok {
t.Fatalf("expected %s reply message, "+
"got %T", name, msg)
}
}
return msg
}
// assertConnClosed checks that the peer's connection is closed before the
// timeout expires.
func assertConnClosed(t *testing.T, peer *wtmock.MockPeer, duration time.Duration) {
t.Helper()
select {
case <-peer.Quit:
case <-time.After(duration):
t.Fatalf("expected connection to be closed")
}
}