1/*
2 * Copyright (c) 2005, 2011, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24/*
25 * @test
26 * @bug 4836493
27 * @summary Socket timeouts for SSLSockets causes data corruption.
28 * @run main/othervm ClientTimeout
29 *
30 *     SunJSSE does not support dynamic system properties, no way to re-use
31 *     system properties in samevm/agentvm mode.
32 */
33
34import java.io.*;
35import java.net.*;
36import java.util.*;
37import java.security.*;
38import javax.net.ssl.*;
39
40public class ClientTimeout {
41
42    /*
43     * =============================================================
44     * Set the various variables needed for the tests, then
45     * specify what tests to run on each side.
46     */
47
48    /*
49     * Should we run the client or server in a separate thread?
50     * Both sides can throw exceptions, but do you have a preference
51     * as to which side should be the main thread.
52     */
53    static boolean separateServerThread = true;
54
55    /*
56     * Where do we find the keystores?
57     */
58    static String pathToStores = "../../../../javax/net/ssl/etc";
59    static String keyStoreFile = "keystore";
60    static String trustStoreFile = "truststore";
61    static String passwd = "passphrase";
62
63    /*
64     * Is the server ready to serve?
65     */
66    volatile static boolean serverReady = false;
67
68    /*
69     * Turn on SSL debugging?
70     */
71    static boolean debug = false;
72
73
74    /*
75     * define the rhythm of timeout exception
76     */
77    static boolean rhythm = true;
78
79    /*
80     * If the client or server is doing some kind of object creation
81     * that the other side depends on, and that thread prematurely
82     * exits, you may experience a hang.  The test harness will
83     * terminate all hung threads after its timeout has expired,
84     * currently 3 minutes by default, but you might try to be
85     * smart about it....
86     */
87
88    /*
89     * Define the server side of the test.
90     *
91     * If the server prematurely exits, serverReady will be set to true
92     * to avoid infinite hangs.
93     */
94    void doServerSide() throws Exception {
95        SSLServerSocketFactory sslssf =
96            (SSLServerSocketFactory) SSLServerSocketFactory.getDefault();
97        SSLServerSocket sslServerSocket =
98            (SSLServerSocket) sslssf.createServerSocket(serverPort);
99
100        serverPort = sslServerSocket.getLocalPort();
101
102        /*
103         * Signal Client, we're ready for his connect.
104         */
105        serverReady = true;
106
107        SSLSocket sslSocket = (SSLSocket) sslServerSocket.accept();
108        InputStream sslIS = sslSocket.getInputStream();
109        OutputStream sslOS = sslSocket.getOutputStream();
110        sslSocket.startHandshake();
111
112        // transfer a file to client.
113        String transFilename =
114            System.getProperty("test.src", "./") + "/" +
115                        this.getClass().getName() + ".java";
116        MessageDigest md = MessageDigest.getInstance("SHA");
117        DigestInputStream transIns = new DigestInputStream(
118                new FileInputStream(transFilename), md);
119
120        byte[] bytes = new byte[2000];
121        int i = 0;
122        while (true) {
123            // reset the cycle
124            if (i >= bytes.length) {
125                i = 0;
126            }
127
128            int length = 0;
129            if ((length = transIns.read(bytes, 0, i++)) == -1) {
130                break;
131            } else {
132                sslOS.write(bytes, 0, length);
133                sslOS.flush();
134
135                if (i % 3 == 0) {
136                    Thread.sleep(300);  // Stall past the timeout...
137                }
138            }
139        }
140        serverDigest = md.digest();
141        transIns.close();
142        sslSocket.close();
143    }
144
145    /*
146     * Define the client side of the test.
147     *
148     * If the server prematurely exits, serverReady will be set to true
149     * to avoid infinite hangs.
150     */
151    void doClientSide() throws Exception {
152        boolean caught = false;
153
154        /*
155         * Wait for server to get started.
156         */
157        while (!serverReady) {
158            Thread.sleep(50);
159        }
160
161        Socket baseSocket = new Socket("localhost", serverPort) {
162            MyInputStream ins = null;
163
164            public InputStream getInputStream() throws IOException {
165                if (ins != null) {
166                    return ins;
167                } else {
168                    ins = new MyInputStream(super.getInputStream());
169                    return ins;
170                }
171            }
172        };
173
174        SSLSocketFactory sslsf =
175            (SSLSocketFactory) SSLSocketFactory.getDefault();
176        SSLSocket sslSocket = (SSLSocket)
177            sslsf.createSocket(baseSocket, "localhost", serverPort, true);
178
179        InputStream sslIS = sslSocket.getInputStream();
180        OutputStream sslOS = sslSocket.getOutputStream();
181
182        // handshaking
183        sslSocket.setSoTimeout(100); // The stall timeout.
184        while (true) {
185            try {
186                rhythm = true;
187                sslSocket.startHandshake();
188                break;
189            } catch (SocketTimeoutException e) {
190                System.out.println("Handshaker exception: "
191                        + e.getMessage());
192            }
193        }
194
195        // read application data from server
196        MessageDigest md = MessageDigest.getInstance("SHA");
197        DigestInputStream transIns = new DigestInputStream(sslIS, md);
198        byte[] bytes = new byte[2000];
199        while (true) {
200            try {
201                rhythm = true;
202
203                while (transIns.read(bytes, 0, 17) != -1) {
204                    rhythm = true;
205                }
206                break;
207            } catch (SocketTimeoutException e) {
208                System.out.println("InputStream Exception: "
209                        + e.getMessage());
210            }
211        }
212        // Wait for server to get ready.
213        while (serverDigest == null) {
214            Thread.sleep(20);
215        }
216
217        byte[] cliDigest = md.digest();
218        if (!Arrays.equals(cliDigest, serverDigest)) {
219            throw new Exception("Application data trans error");
220        }
221
222        transIns.close();
223        sslSocket.close();
224    }
225
226    /*
227     * =============================================================
228     * The remainder is just support stuff
229     */
230
231    static class MyInputStream extends InputStream {
232        InputStream ins = null;
233
234        public MyInputStream(InputStream ins) {
235            this.ins = ins;
236        }
237
238        public int read() throws IOException {
239            return read(new byte[1], 0, 1);
240        }
241
242        public int read(byte[] data, int offset, int len) throws IOException {
243            if (!ClientTimeout.rhythm) {
244                throw new SocketTimeoutException(
245                        "Throwing a timeout exception");
246            }
247            ClientTimeout.rhythm = false;
248            return ins.read(data, offset, len);
249        }
250    }
251
252    // use any free port by default
253    volatile int serverPort = 0;
254
255    volatile Exception serverException = null;
256    volatile Exception clientException = null;
257
258    volatile byte[] serverDigest = null;
259
260    public static void main(String[] args) throws Exception {
261        String keyFilename =
262            System.getProperty("test.src", "./") + "/" + pathToStores +
263                "/" + keyStoreFile;
264        String trustFilename =
265            System.getProperty("test.src", "./") + "/" + pathToStores +
266                "/" + trustStoreFile;
267
268        System.setProperty("javax.net.ssl.keyStore", keyFilename);
269        System.setProperty("javax.net.ssl.keyStorePassword", passwd);
270        System.setProperty("javax.net.ssl.trustStore", trustFilename);
271        System.setProperty("javax.net.ssl.trustStorePassword", passwd);
272
273        if (debug)
274            System.setProperty("javax.net.debug", "all");
275
276        /*
277         * Start the tests.
278         */
279        new ClientTimeout();
280    }
281
282    Thread clientThread = null;
283    Thread serverThread = null;
284
285    /*
286     * Primary constructor, used to drive remainder of the test.
287     *
288     * Fork off the other side, then do your work.
289     */
290    ClientTimeout() throws Exception {
291        if (separateServerThread) {
292            startServer(true);
293            startClient(false);
294        } else {
295            startClient(true);
296            startServer(false);
297        }
298
299        /*
300         * Wait for other side to close down.
301         */
302        if (separateServerThread) {
303            serverThread.join();
304        } else {
305            clientThread.join();
306        }
307
308        /*
309         * When we get here, the test is pretty much over.
310         *
311         * If the main thread excepted, that propagates back
312         * immediately.  If the other thread threw an exception, we
313         * should report back.
314         */
315        if (serverException != null) {
316            System.out.print("Server Exception:");
317            throw serverException;
318        }
319        if (clientException != null) {
320            System.out.print("Client Exception:");
321            throw clientException;
322        }
323    }
324
325    void startServer(boolean newThread) throws Exception {
326        if (newThread) {
327            serverThread = new Thread() {
328                public void run() {
329                    try {
330                        doServerSide();
331                    } catch (Exception e) {
332                        /*
333                         * Our server thread just died.
334                         *
335                         * Release the client, if not active already...
336                         */
337                        System.err.println("Server died...");
338                        System.err.println(e);
339                        serverReady = true;
340                        serverException = e;
341                    }
342                }
343            };
344            serverThread.start();
345        } else {
346            doServerSide();
347        }
348    }
349
350    void startClient(boolean newThread) throws Exception {
351        if (newThread) {
352            clientThread = new Thread() {
353                public void run() {
354                    try {
355                        doClientSide();
356                    } catch (Exception e) {
357                        /*
358                         * Our client thread just died.
359                         */
360                        System.err.println("Client died...");
361                        clientException = e;
362                    }
363                }
364            };
365            clientThread.start();
366        } else {
367            doClientSide();
368        }
369    }
370}
371