You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
733 lines
21 KiB
733 lines
21 KiB
package brontide |
|
|
|
import ( |
|
"bytes" |
|
"encoding/hex" |
|
"fmt" |
|
"io" |
|
"math" |
|
"net" |
|
"testing" |
|
"testing/iotest" |
|
|
|
"github.com/btcsuite/btcd/btcec" |
|
"github.com/lightningnetwork/lnd/keychain" |
|
"github.com/lightningnetwork/lnd/lnwire" |
|
"github.com/lightningnetwork/lnd/tor" |
|
) |
|
|
|
type maybeNetConn struct { |
|
conn net.Conn |
|
err error |
|
} |
|
|
|
func makeListener() (*Listener, *lnwire.NetAddress, error) { |
|
// First, generate the long-term private keys for the brontide listener. |
|
localPriv, err := btcec.NewPrivateKey(btcec.S256()) |
|
if err != nil { |
|
return nil, nil, err |
|
} |
|
localKeyECDH := &keychain.PrivKeyECDH{PrivKey: localPriv} |
|
|
|
// Having a port of ":0" means a random port, and interface will be |
|
// chosen for our listener. |
|
addr := "localhost:0" |
|
|
|
// Our listener will be local, and the connection remote. |
|
listener, err := NewListener(localKeyECDH, addr) |
|
if err != nil { |
|
return nil, nil, err |
|
} |
|
|
|
netAddr := &lnwire.NetAddress{ |
|
IdentityKey: localPriv.PubKey(), |
|
Address: listener.Addr().(*net.TCPAddr), |
|
} |
|
|
|
return listener, netAddr, nil |
|
} |
|
|
|
func establishTestConnection() (net.Conn, net.Conn, func(), error) { |
|
listener, netAddr, err := makeListener() |
|
if err != nil { |
|
return nil, nil, nil, err |
|
} |
|
defer listener.Close() |
|
|
|
// Nos, generate the long-term private keys remote end of the connection |
|
// within our test. |
|
remotePriv, err := btcec.NewPrivateKey(btcec.S256()) |
|
if err != nil { |
|
return nil, nil, nil, err |
|
} |
|
remoteKeyECDH := &keychain.PrivKeyECDH{PrivKey: remotePriv} |
|
|
|
// Initiate a connection with a separate goroutine, and listen with our |
|
// main one. If both errors are nil, then encryption+auth was |
|
// successful. |
|
remoteConnChan := make(chan maybeNetConn, 1) |
|
go func() { |
|
remoteConn, err := Dial( |
|
remoteKeyECDH, netAddr, |
|
tor.DefaultConnTimeout, net.DialTimeout, |
|
) |
|
remoteConnChan <- maybeNetConn{remoteConn, err} |
|
}() |
|
|
|
localConnChan := make(chan maybeNetConn, 1) |
|
go func() { |
|
localConn, err := listener.Accept() |
|
localConnChan <- maybeNetConn{localConn, err} |
|
}() |
|
|
|
remote := <-remoteConnChan |
|
if remote.err != nil { |
|
return nil, nil, nil, err |
|
} |
|
|
|
local := <-localConnChan |
|
if local.err != nil { |
|
return nil, nil, nil, err |
|
} |
|
|
|
cleanUp := func() { |
|
local.conn.Close() |
|
remote.conn.Close() |
|
} |
|
|
|
return local.conn, remote.conn, cleanUp, nil |
|
} |
|
|
|
func TestConnectionCorrectness(t *testing.T) { |
|
// Create a test connection, grabbing either side of the connection |
|
// into local variables. If the initial crypto handshake fails, then |
|
// we'll get a non-nil error here. |
|
localConn, remoteConn, cleanUp, err := establishTestConnection() |
|
if err != nil { |
|
t.Fatalf("unable to establish test connection: %v", err) |
|
} |
|
defer cleanUp() |
|
|
|
// Test out some message full-message reads. |
|
for i := 0; i < 10; i++ { |
|
msg := []byte(fmt.Sprintf("hello%d", i)) |
|
|
|
if _, err := localConn.Write(msg); err != nil { |
|
t.Fatalf("remote conn failed to write: %v", err) |
|
} |
|
|
|
readBuf := make([]byte, len(msg)) |
|
if _, err := remoteConn.Read(readBuf); err != nil { |
|
t.Fatalf("local conn failed to read: %v", err) |
|
} |
|
|
|
if !bytes.Equal(readBuf, msg) { |
|
t.Fatalf("messages don't match, %v vs %v", |
|
string(readBuf), string(msg)) |
|
} |
|
} |
|
|
|
// Now try incremental message reads. This simulates first writing a |
|
// message header, then a message body. |
|
outMsg := []byte("hello world") |
|
if _, err := localConn.Write(outMsg); err != nil { |
|
t.Fatalf("remote conn failed to write: %v", err) |
|
} |
|
|
|
readBuf := make([]byte, len(outMsg)) |
|
if _, err := remoteConn.Read(readBuf[:len(outMsg)/2]); err != nil { |
|
t.Fatalf("local conn failed to read: %v", err) |
|
} |
|
if _, err := remoteConn.Read(readBuf[len(outMsg)/2:]); err != nil { |
|
t.Fatalf("local conn failed to read: %v", err) |
|
} |
|
|
|
if !bytes.Equal(outMsg, readBuf) { |
|
t.Fatalf("messages don't match, %v vs %v", |
|
string(readBuf), string(outMsg)) |
|
} |
|
} |
|
|
|
// TestConecurrentHandshakes verifies the listener's ability to not be blocked |
|
// by other pending handshakes. This is tested by opening multiple tcp |
|
// connections with the listener, without completing any of the brontide acts. |
|
// The test passes if real brontide dialer connects while the others are |
|
// stalled. |
|
func TestConcurrentHandshakes(t *testing.T) { |
|
listener, netAddr, err := makeListener() |
|
if err != nil { |
|
t.Fatalf("unable to create listener connection: %v", err) |
|
} |
|
defer listener.Close() |
|
|
|
const nblocking = 5 |
|
|
|
// Open a handful of tcp connections, that do not complete any steps of |
|
// the brontide handshake. |
|
connChan := make(chan maybeNetConn) |
|
for i := 0; i < nblocking; i++ { |
|
go func() { |
|
conn, err := net.Dial("tcp", listener.Addr().String()) |
|
connChan <- maybeNetConn{conn, err} |
|
}() |
|
} |
|
|
|
// Receive all connections/errors from our blocking tcp dials. We make a |
|
// pass to gather all connections and errors to make sure we defer the |
|
// calls to Close() on all successful connections. |
|
tcpErrs := make([]error, 0, nblocking) |
|
for i := 0; i < nblocking; i++ { |
|
result := <-connChan |
|
if result.conn != nil { |
|
defer result.conn.Close() |
|
} |
|
if result.err != nil { |
|
tcpErrs = append(tcpErrs, result.err) |
|
} |
|
} |
|
for _, tcpErr := range tcpErrs { |
|
if tcpErr != nil { |
|
t.Fatalf("unable to tcp dial listener: %v", tcpErr) |
|
} |
|
} |
|
|
|
// Now, construct a new private key and use the brontide dialer to |
|
// connect to the listener. |
|
remotePriv, err := btcec.NewPrivateKey(btcec.S256()) |
|
if err != nil { |
|
t.Fatalf("unable to generate private key: %v", err) |
|
} |
|
remoteKeyECDH := &keychain.PrivKeyECDH{PrivKey: remotePriv} |
|
|
|
go func() { |
|
remoteConn, err := Dial( |
|
remoteKeyECDH, netAddr, |
|
tor.DefaultConnTimeout, net.DialTimeout, |
|
) |
|
connChan <- maybeNetConn{remoteConn, err} |
|
}() |
|
|
|
// This connection should be accepted without error, as the brontide |
|
// connection should bypass stalled tcp connections. |
|
conn, err := listener.Accept() |
|
if err != nil { |
|
t.Fatalf("unable to accept dial: %v", err) |
|
} |
|
defer conn.Close() |
|
|
|
result := <-connChan |
|
if result.err != nil { |
|
t.Fatalf("unable to dial %v: %v", netAddr, result.err) |
|
} |
|
result.conn.Close() |
|
} |
|
|
|
func TestMaxPayloadLength(t *testing.T) { |
|
t.Parallel() |
|
|
|
b := Machine{} |
|
b.split() |
|
|
|
// Create a payload that's only *slightly* above the maximum allotted |
|
// payload length. |
|
payloadToReject := make([]byte, math.MaxUint16+1) |
|
|
|
// A write of the payload generated above to the state machine should |
|
// be rejected as it's over the max payload length. |
|
err := b.WriteMessage(payloadToReject) |
|
if err != ErrMaxMessageLengthExceeded { |
|
t.Fatalf("payload is over the max allowed length, the write " + |
|
"should have been rejected") |
|
} |
|
|
|
// Generate another payload which should be accepted as a valid |
|
// payload. |
|
payloadToAccept := make([]byte, math.MaxUint16-1) |
|
if err := b.WriteMessage(payloadToAccept); err != nil { |
|
t.Fatalf("write for payload was rejected, should have been " + |
|
"accepted") |
|
} |
|
|
|
// Generate a final payload which is only *slightly* above the max payload length |
|
// when the MAC is accounted for. |
|
payloadToReject = make([]byte, math.MaxUint16+1) |
|
|
|
// This payload should be rejected. |
|
err = b.WriteMessage(payloadToReject) |
|
if err != ErrMaxMessageLengthExceeded { |
|
t.Fatalf("payload is over the max allowed length, the write " + |
|
"should have been rejected") |
|
} |
|
} |
|
|
|
func TestWriteMessageChunking(t *testing.T) { |
|
// Create a test connection, grabbing either side of the connection |
|
// into local variables. If the initial crypto handshake fails, then |
|
// we'll get a non-nil error here. |
|
localConn, remoteConn, cleanUp, err := establishTestConnection() |
|
if err != nil { |
|
t.Fatalf("unable to establish test connection: %v", err) |
|
} |
|
defer cleanUp() |
|
|
|
// Attempt to write a message which is over 3x the max allowed payload |
|
// size. |
|
largeMessage := bytes.Repeat([]byte("kek"), math.MaxUint16*3) |
|
|
|
// Launch a new goroutine to write the large message generated above in |
|
// chunks. We spawn a new goroutine because otherwise, we may block as |
|
// the kernel waits for the buffer to flush. |
|
errCh := make(chan error) |
|
go func() { |
|
defer close(errCh) |
|
|
|
bytesWritten, err := localConn.Write(largeMessage) |
|
if err != nil { |
|
errCh <- fmt.Errorf("unable to write message: %v", err) |
|
return |
|
} |
|
|
|
// The entire message should have been written out to the remote |
|
// connection. |
|
if bytesWritten != len(largeMessage) { |
|
errCh <- fmt.Errorf("bytes not fully written") |
|
return |
|
} |
|
}() |
|
|
|
// Attempt to read the entirety of the message generated above. |
|
buf := make([]byte, len(largeMessage)) |
|
if _, err := io.ReadFull(remoteConn, buf); err != nil { |
|
t.Fatalf("unable to read message: %v", err) |
|
} |
|
|
|
err = <-errCh |
|
if err != nil { |
|
t.Fatal(err) |
|
} |
|
|
|
// Finally, the message the remote end of the connection received |
|
// should be identical to what we sent from the local connection. |
|
if !bytes.Equal(buf, largeMessage) { |
|
t.Fatalf("bytes don't match") |
|
} |
|
} |
|
|
|
// TestBolt0008TestVectors ensures that our implementation of brontide exactly |
|
// matches the test vectors within the specification. |
|
func TestBolt0008TestVectors(t *testing.T) { |
|
t.Parallel() |
|
|
|
// First, we'll generate the state of the initiator from the test |
|
// vectors at the appendix of BOLT-0008 |
|
initiatorKeyBytes, err := hex.DecodeString("1111111111111111111111" + |
|
"111111111111111111111111111111111111111111") |
|
if err != nil { |
|
t.Fatalf("unable to decode hex: %v", err) |
|
} |
|
initiatorPriv, _ := btcec.PrivKeyFromBytes( |
|
btcec.S256(), initiatorKeyBytes, |
|
) |
|
initiatorKeyECDH := &keychain.PrivKeyECDH{PrivKey: initiatorPriv} |
|
|
|
// We'll then do the same for the responder. |
|
responderKeyBytes, err := hex.DecodeString("212121212121212121212121" + |
|
"2121212121212121212121212121212121212121") |
|
if err != nil { |
|
t.Fatalf("unable to decode hex: %v", err) |
|
} |
|
responderPriv, responderPub := btcec.PrivKeyFromBytes( |
|
btcec.S256(), responderKeyBytes, |
|
) |
|
responderKeyECDH := &keychain.PrivKeyECDH{PrivKey: responderPriv} |
|
|
|
// With the initiator's key data parsed, we'll now define a custom |
|
// EphemeralGenerator function for the state machine to ensure that the |
|
// initiator and responder both generate the ephemeral public key |
|
// defined within the test vectors. |
|
initiatorEphemeral := EphemeralGenerator(func() (*btcec.PrivateKey, error) { |
|
e := "121212121212121212121212121212121212121212121212121212" + |
|
"1212121212" |
|
eBytes, err := hex.DecodeString(e) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
priv, _ := btcec.PrivKeyFromBytes(btcec.S256(), eBytes) |
|
return priv, nil |
|
}) |
|
responderEphemeral := EphemeralGenerator(func() (*btcec.PrivateKey, error) { |
|
e := "222222222222222222222222222222222222222222222222222" + |
|
"2222222222222" |
|
eBytes, err := hex.DecodeString(e) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
priv, _ := btcec.PrivKeyFromBytes(btcec.S256(), eBytes) |
|
return priv, nil |
|
}) |
|
|
|
// Finally, we'll create both brontide state machines, so we can begin |
|
// our test. |
|
initiator := NewBrontideMachine( |
|
true, initiatorKeyECDH, responderPub, initiatorEphemeral, |
|
) |
|
responder := NewBrontideMachine( |
|
false, responderKeyECDH, nil, responderEphemeral, |
|
) |
|
|
|
// We'll start with the initiator generating the initial payload for |
|
// act one. This should consist of exactly 50 bytes. We'll assert that |
|
// the payload return is _exactly_ the same as what's specified within |
|
// the test vectors. |
|
actOne, err := initiator.GenActOne() |
|
if err != nil { |
|
t.Fatalf("unable to generate act one: %v", err) |
|
} |
|
expectedActOne, err := hex.DecodeString("00036360e856310ce5d294e" + |
|
"8be33fc807077dc56ac80d95d9cd4ddbd21325eff73f70df608655115" + |
|
"1f58b8afe6c195782c6a") |
|
if err != nil { |
|
t.Fatalf("unable to parse expected act one: %v", err) |
|
} |
|
if !bytes.Equal(expectedActOne, actOne[:]) { |
|
t.Fatalf("act one mismatch: expected %x, got %x", |
|
expectedActOne, actOne) |
|
} |
|
|
|
// With the assertion above passed, we'll now process the act one |
|
// payload with the responder of the crypto handshake. |
|
if err := responder.RecvActOne(actOne); err != nil { |
|
t.Fatalf("responder unable to process act one: %v", err) |
|
} |
|
|
|
// Next, we'll start the second act by having the responder generate |
|
// its contribution to the crypto handshake. We'll also verify that we |
|
// produce the _exact_ same byte stream as advertised within the spec's |
|
// test vectors. |
|
actTwo, err := responder.GenActTwo() |
|
if err != nil { |
|
t.Fatalf("unable to generate act two: %v", err) |
|
} |
|
expectedActTwo, err := hex.DecodeString("0002466d7fcae563e5cb09a0" + |
|
"d1870bb580344804617879a14949cf22285f1bae3f276e2470b93aac58" + |
|
"3c9ef6eafca3f730ae") |
|
if err != nil { |
|
t.Fatalf("unable to parse expected act two: %v", err) |
|
} |
|
if !bytes.Equal(expectedActTwo, actTwo[:]) { |
|
t.Fatalf("act two mismatch: expected %x, got %x", |
|
expectedActTwo, actTwo) |
|
} |
|
|
|
// Moving the handshake along, we'll also ensure that the initiator |
|
// accepts the act two payload. |
|
if err := initiator.RecvActTwo(actTwo); err != nil { |
|
t.Fatalf("initiator unable to process act two: %v", err) |
|
} |
|
|
|
// At the final step, we'll generate the last act from the initiator |
|
// and once again verify that it properly matches the test vectors. |
|
actThree, err := initiator.GenActThree() |
|
if err != nil { |
|
t.Fatalf("unable to generate act three: %v", err) |
|
} |
|
expectedActThree, err := hex.DecodeString("00b9e3a702e93e3a9948c2e" + |
|
"d6e5fd7590a6e1c3a0344cfc9d5b57357049aa22355361aa02e55a8f" + |
|
"c28fef5bd6d71ad0c38228dc68b1c466263b47fdf31e560e139ba") |
|
if err != nil { |
|
t.Fatalf("unable to parse expected act three: %v", err) |
|
} |
|
if !bytes.Equal(expectedActThree, actThree[:]) { |
|
t.Fatalf("act three mismatch: expected %x, got %x", |
|
expectedActThree, actThree) |
|
} |
|
|
|
// Finally, we'll ensure that the responder itself also properly parses |
|
// the last payload in the crypto handshake. |
|
if err := responder.RecvActThree(actThree); err != nil { |
|
t.Fatalf("responder unable to process act three: %v", err) |
|
} |
|
|
|
// As a final assertion, we'll ensure that both sides have derived the |
|
// proper symmetric encryption keys. |
|
sendingKey, err := hex.DecodeString("969ab31b4d288cedf6218839b27a3e2" + |
|
"140827047f2c0f01bf5c04435d43511a9") |
|
if err != nil { |
|
t.Fatalf("unable to parse sending key: %v", err) |
|
} |
|
recvKey, err := hex.DecodeString("bb9020b8965f4df047e07f955f3c4b884" + |
|
"18984aadc5cdb35096b9ea8fa5c3442") |
|
if err != nil { |
|
t.Fatalf("unable to parse receiving key: %v", err) |
|
} |
|
|
|
chainKey, err := hex.DecodeString("919219dbb2920afa8db80f9a51787a840" + |
|
"bcf111ed8d588caf9ab4be716e42b01") |
|
if err != nil { |
|
t.Fatalf("unable to parse chaining key: %v", err) |
|
} |
|
|
|
if !bytes.Equal(initiator.sendCipher.secretKey[:], sendingKey) { |
|
t.Fatalf("sending key mismatch: expected %x, got %x", |
|
initiator.sendCipher.secretKey[:], sendingKey) |
|
} |
|
if !bytes.Equal(initiator.recvCipher.secretKey[:], recvKey) { |
|
t.Fatalf("receiving key mismatch: expected %x, got %x", |
|
initiator.recvCipher.secretKey[:], recvKey) |
|
} |
|
if !bytes.Equal(initiator.chainingKey[:], chainKey) { |
|
t.Fatalf("chaining key mismatch: expected %x, got %x", |
|
initiator.chainingKey[:], chainKey) |
|
} |
|
|
|
if !bytes.Equal(responder.sendCipher.secretKey[:], recvKey) { |
|
t.Fatalf("sending key mismatch: expected %x, got %x", |
|
responder.sendCipher.secretKey[:], recvKey) |
|
} |
|
if !bytes.Equal(responder.recvCipher.secretKey[:], sendingKey) { |
|
t.Fatalf("receiving key mismatch: expected %x, got %x", |
|
responder.recvCipher.secretKey[:], sendingKey) |
|
} |
|
if !bytes.Equal(responder.chainingKey[:], chainKey) { |
|
t.Fatalf("chaining key mismatch: expected %x, got %x", |
|
responder.chainingKey[:], chainKey) |
|
} |
|
|
|
// Now test as per section "transport-message test" in Test Vectors |
|
// (the transportMessageVectors ciphertexts are from this section of BOLT 8); |
|
// we do slightly greater than 1000 encryption/decryption operations |
|
// to ensure that the key rotation algorithm is operating as expected. |
|
// The starting point for enc/decr is already guaranteed correct from the |
|
// above tests of sendingKey, receivingKey, chainingKey. |
|
transportMessageVectors := map[int]string{ |
|
0: "cf2b30ddf0cf3f80e7c35a6e6730b59fe802473180f396d88a8fb0db8cb" + |
|
"cf25d2f214cf9ea1d95", |
|
1: "72887022101f0b6753e0c7de21657d35a4cb2a1f5cde2650528bbc8f837" + |
|
"d0f0d7ad833b1a256a1", |
|
500: "178cb9d7387190fa34db9c2d50027d21793c9bc2d40b1e14dcf30ebeeeb2" + |
|
"20f48364f7a4c68bf8", |
|
501: "1b186c57d44eb6de4c057c49940d79bb838a145cb528d6e8fd26dbe50a6" + |
|
"0ca2c104b56b60e45bd", |
|
1000: "4a2f3cc3b5e78ddb83dcb426d9863d9d9a723b0337c89dd0b005d89f8d3" + |
|
"c05c52b76b29b740f09", |
|
1001: "2ecd8c8a5629d0d02ab457a0fdd0f7b90a192cd46be5ecb6ca570bfc5e2" + |
|
"68338b1a16cf4ef2d36", |
|
} |
|
|
|
// Payload for every message is the string "hello". |
|
payload := []byte("hello") |
|
|
|
var buf bytes.Buffer |
|
|
|
for i := 0; i < 1002; i++ { |
|
err = initiator.WriteMessage(payload) |
|
if err != nil { |
|
t.Fatalf("could not write message %s", payload) |
|
} |
|
_, err = initiator.Flush(&buf) |
|
if err != nil { |
|
t.Fatalf("could not flush message: %v", err) |
|
} |
|
if val, ok := transportMessageVectors[i]; ok { |
|
binaryVal, err := hex.DecodeString(val) |
|
if err != nil { |
|
t.Fatalf("Failed to decode hex string %s", val) |
|
} |
|
if !bytes.Equal(buf.Bytes(), binaryVal) { |
|
t.Fatalf("Ciphertext %x was not equal to expected %s", |
|
buf.String()[:], val) |
|
} |
|
} |
|
|
|
// Responder decrypts the bytes, in every iteration, and |
|
// should always be able to decrypt the same payload message. |
|
plaintext, err := responder.ReadMessage(&buf) |
|
if err != nil { |
|
t.Fatalf("failed to read message in responder: %v", err) |
|
} |
|
|
|
// Ensure decryption succeeded |
|
if !bytes.Equal(plaintext, payload) { |
|
t.Fatalf("Decryption failed to receive plaintext: %s, got %s", |
|
payload, plaintext) |
|
} |
|
|
|
// Clear out the buffer for the next iteration |
|
buf.Reset() |
|
} |
|
} |
|
|
|
// timeoutWriter wraps an io.Writer and throws an iotest.ErrTimeout after |
|
// writing n bytes. |
|
type timeoutWriter struct { |
|
w io.Writer |
|
n int64 |
|
} |
|
|
|
func NewTimeoutWriter(w io.Writer, n int64) io.Writer { |
|
return &timeoutWriter{w, n} |
|
} |
|
|
|
func (t *timeoutWriter) Write(p []byte) (int, error) { |
|
n := len(p) |
|
if int64(n) > t.n { |
|
n = int(t.n) |
|
} |
|
n, err := t.w.Write(p[:n]) |
|
t.n -= int64(n) |
|
if err == nil && t.n == 0 { |
|
return n, iotest.ErrTimeout |
|
} |
|
return n, err |
|
} |
|
|
|
const payloadSize = 10 |
|
|
|
type flushChunk struct { |
|
errAfter int64 |
|
expN int |
|
expErr error |
|
} |
|
|
|
type flushTest struct { |
|
name string |
|
chunks []flushChunk |
|
} |
|
|
|
var flushTests = []flushTest{ |
|
{ |
|
name: "partial header write", |
|
chunks: []flushChunk{ |
|
// Write 18-byte header in two parts, 16 then 2. |
|
{ |
|
errAfter: encHeaderSize - 2, |
|
expN: 0, |
|
expErr: iotest.ErrTimeout, |
|
}, |
|
{ |
|
errAfter: 2, |
|
expN: 0, |
|
expErr: iotest.ErrTimeout, |
|
}, |
|
// Write payload and MAC in one go. |
|
{ |
|
errAfter: -1, |
|
expN: payloadSize, |
|
}, |
|
}, |
|
}, |
|
{ |
|
name: "full payload then full mac", |
|
chunks: []flushChunk{ |
|
// Write entire header and entire payload w/o MAC. |
|
{ |
|
errAfter: encHeaderSize + payloadSize, |
|
expN: payloadSize, |
|
expErr: iotest.ErrTimeout, |
|
}, |
|
// Write the entire MAC. |
|
{ |
|
errAfter: -1, |
|
expN: 0, |
|
}, |
|
}, |
|
}, |
|
{ |
|
name: "payload-only, straddle, mac-only", |
|
chunks: []flushChunk{ |
|
// Write header and all but last byte of payload. |
|
{ |
|
errAfter: encHeaderSize + payloadSize - 1, |
|
expN: payloadSize - 1, |
|
expErr: iotest.ErrTimeout, |
|
}, |
|
// Write last byte of payload and first byte of MAC. |
|
{ |
|
errAfter: 2, |
|
expN: 1, |
|
expErr: iotest.ErrTimeout, |
|
}, |
|
// Write 10 bytes of the MAC. |
|
{ |
|
errAfter: 10, |
|
expN: 0, |
|
expErr: iotest.ErrTimeout, |
|
}, |
|
// Write the remaining 5 MAC bytes. |
|
{ |
|
errAfter: -1, |
|
expN: 0, |
|
}, |
|
}, |
|
}, |
|
} |
|
|
|
// TestFlush asserts a Machine's ability to handle timeouts during Flush that |
|
// cause partial writes, and that the machine can properly resume writes on |
|
// subsequent calls to Flush. |
|
func TestFlush(t *testing.T) { |
|
// Run each test individually, to assert that they pass in isolation. |
|
for _, test := range flushTests { |
|
t.Run(test.name, func(t *testing.T) { |
|
var ( |
|
w bytes.Buffer |
|
b Machine |
|
) |
|
b.split() |
|
testFlush(t, test, &b, &w) |
|
}) |
|
} |
|
|
|
// Finally, run the tests serially as if all on one connection. |
|
t.Run("flush serial", func(t *testing.T) { |
|
var ( |
|
w bytes.Buffer |
|
b Machine |
|
) |
|
b.split() |
|
for _, test := range flushTests { |
|
testFlush(t, test, &b, &w) |
|
} |
|
}) |
|
} |
|
|
|
// testFlush buffers a message on the Machine, then flushes it to the io.Writer |
|
// in chunks. Once complete, a final call to flush is made to assert that Write |
|
// is not called again. |
|
func testFlush(t *testing.T, test flushTest, b *Machine, w io.Writer) { |
|
payload := make([]byte, payloadSize) |
|
if err := b.WriteMessage(payload); err != nil { |
|
t.Fatalf("unable to write message: %v", err) |
|
} |
|
|
|
for _, chunk := range test.chunks { |
|
assertFlush(t, b, w, chunk.errAfter, chunk.expN, chunk.expErr) |
|
} |
|
|
|
// We should always be able to call Flush after a message has been |
|
// successfully written, and it should result in a NOP. |
|
assertFlush(t, b, w, 0, 0, nil) |
|
} |
|
|
|
// assertFlush flushes a chunk to the passed io.Writer. If n >= 0, a |
|
// timeoutWriter will be used the flush should stop with iotest.ErrTimeout after |
|
// n bytes. The method asserts that the returned error matches expErr and that |
|
// the number of bytes written by Flush matches expN. |
|
func assertFlush(t *testing.T, b *Machine, w io.Writer, n int64, expN int, |
|
expErr error) { |
|
|
|
t.Helper() |
|
|
|
if n >= 0 { |
|
w = NewTimeoutWriter(w, n) |
|
} |
|
nn, err := b.Flush(w) |
|
if err != expErr { |
|
t.Fatalf("expected flush err: %v, got: %v", expErr, err) |
|
} |
|
if nn != expN { |
|
t.Fatalf("expected n: %d, got: %d", expN, nn) |
|
} |
|
}
|
|
|