1/*
2 * Copyright (c) 2005, 2011, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24/*
25 * @test
26 * @bug 4836493
27 * @summary Socket timeouts for SSLSockets causes data corruption.
28 * @run main/othervm ServerTimeout
29 *
30 *     SunJSSE does not support dynamic system properties, no way to re-use
31 *     system properties in samevm/agentvm mode.
32 */
33
34import java.io.*;
35import java.net.*;
36import java.util.*;
37import java.security.*;
38import javax.net.ssl.*;
39
40public class ServerTimeout {
41
42    /*
43     * =============================================================
44     * Set the various variables needed for the tests, then
45     * specify what tests to run on each side.
46     */
47
48    /*
49     * Should we run the client or server in a separate thread?
50     * Both sides can throw exceptions, but do you have a preference
51     * as to which side should be the main thread.
52     */
53    static boolean separateServerThread = true;
54
55    /*
56     * Where do we find the keystores?
57     */
58    static String pathToStores = "../../../../javax/net/ssl/etc";
59    static String keyStoreFile = "keystore";
60    static String trustStoreFile = "truststore";
61    static String passwd = "passphrase";
62
63    /*
64     * Is the server ready to serve?
65     */
66    volatile static boolean serverReady = false;
67
68    /*
69     * Turn on SSL debugging?
70     */
71    static boolean debug = false;
72
73    /*
74     * If the client or server is doing some kind of object creation
75     * that the other side depends on, and that thread prematurely
76     * exits, you may experience a hang.  The test harness will
77     * terminate all hung threads after its timeout has expired,
78     * currently 3 minutes by default, but you might try to be
79     * smart about it....
80     */
81
82    /*
83     * Define the server side of the test.
84     *
85     * If the server prematurely exits, serverReady will be set to true
86     * to avoid infinite hangs.
87     */
88    void doServerSide() throws Exception {
89        SSLServerSocketFactory sslssf =
90            (SSLServerSocketFactory) SSLServerSocketFactory.getDefault();
91        SSLServerSocket sslServerSocket =
92            (SSLServerSocket) sslssf.createServerSocket(serverPort);
93
94        serverPort = sslServerSocket.getLocalPort();
95
96        /*
97         * Signal Client, we're ready for his connect.
98         */
99        serverReady = true;
100
101        SSLSocket sslSocket = (SSLSocket) sslServerSocket.accept();
102        InputStream sslIS = sslSocket.getInputStream();
103        OutputStream sslOS = sslSocket.getOutputStream();
104        sslSocket.startHandshake();
105
106        // read application data from client
107        MessageDigest md = MessageDigest.getInstance("SHA");
108        DigestInputStream transIns = new DigestInputStream(sslIS, md);
109        byte[] bytes = new byte[2000];
110        sslSocket.setSoTimeout(100); // The stall timeout
111        while (true) {
112            try {
113                while (transIns.read(bytes, 0, 17) != -1);
114                break;
115            } catch (SocketTimeoutException e) {
116                System.out.println("Server inputStream Exception: "
117                        + e.getMessage());
118            }
119        }
120        // wait for client to get ready
121        while (clientDigest == null) {
122            Thread.sleep(20);
123        }
124
125        byte[] srvDigest = md.digest();
126        if (!Arrays.equals(clientDigest, srvDigest)) {
127            throw new Exception("Application data trans error");
128        }
129
130        transIns.close();
131        sslSocket.close();
132    }
133
134    /*
135     * Define the client side of the test.
136     *
137     * If the server prematurely exits, serverReady will be set to true
138     * to avoid infinite hangs.
139     */
140    void doClientSide() throws Exception {
141        boolean caught = false;
142
143        /*
144         * Wait for server to get started.
145         */
146        while (!serverReady) {
147            Thread.sleep(50);
148        }
149
150        SSLSocketFactory sslsf =
151            (SSLSocketFactory) SSLSocketFactory.getDefault();
152        SSLSocket sslSocket = (SSLSocket)
153            sslsf.createSocket("localhost", serverPort);
154
155        InputStream sslIS = sslSocket.getInputStream();
156        OutputStream sslOS = sslSocket.getOutputStream();
157        sslSocket.startHandshake();
158
159        // transfer a file to server
160        String transFilename =
161                System.getProperty("test.src", "./") + "/" +
162                        this.getClass().getName() + ".java";
163        MessageDigest md = MessageDigest.getInstance("SHA");
164        DigestInputStream transIns = new DigestInputStream(
165                new FileInputStream(transFilename), md);
166        byte[] bytes = new byte[2000];
167        int i = 0;
168        while (true) {
169            // reset the cycle
170            if (i >= bytes.length) {
171                i = 0;
172            }
173
174            int length = transIns.read(bytes, 0, i++);
175            if (length == -1) {
176                break;
177            } else {
178                sslOS.write(bytes, 0, length);
179                sslOS.flush();
180
181                if (i % 3 == 0) {
182                    Thread.sleep(300);  // Stall past the timeout...
183                }
184            }
185        }
186        clientDigest = md.digest();
187        transIns.close();
188        sslSocket.close();
189    }
190
191    /*
192     * =============================================================
193     * The remainder is just support stuff
194     */
195
196    // use any free port by default
197    volatile int serverPort = 0;
198
199    volatile Exception serverException = null;
200    volatile Exception clientException = null;
201
202    volatile byte[] clientDigest = null;
203
204    public static void main(String[] args) throws Exception {
205        String keyFilename =
206            System.getProperty("test.src", "./") + "/" + pathToStores +
207                "/" + keyStoreFile;
208        String trustFilename =
209            System.getProperty("test.src", "./") + "/" + pathToStores +
210                "/" + trustStoreFile;
211
212        System.setProperty("javax.net.ssl.keyStore", keyFilename);
213        System.setProperty("javax.net.ssl.keyStorePassword", passwd);
214        System.setProperty("javax.net.ssl.trustStore", trustFilename);
215        System.setProperty("javax.net.ssl.trustStorePassword", passwd);
216
217        if (debug)
218            System.setProperty("javax.net.debug", "all");
219
220        /*
221         * Start the tests.
222         */
223        new ServerTimeout();
224    }
225
226    Thread clientThread = null;
227    Thread serverThread = null;
228
229    /*
230     * Primary constructor, used to drive remainder of the test.
231     *
232     * Fork off the other side, then do your work.
233     */
234    ServerTimeout() throws Exception {
235        if (separateServerThread) {
236            startServer(true);
237            startClient(false);
238        } else {
239            startClient(true);
240            startServer(false);
241        }
242
243        /*
244         * Wait for other side to close down.
245         */
246        if (separateServerThread) {
247            serverThread.join();
248        } else {
249            clientThread.join();
250        }
251
252        /*
253         * When we get here, the test is pretty much over.
254         *
255         * If the main thread excepted, that propagates back
256         * immediately.  If the other thread threw an exception, we
257         * should report back.
258         */
259        if (serverException != null) {
260            System.out.print("Server Exception:");
261            throw serverException;
262        }
263        if (clientException != null) {
264            System.out.print("Client Exception:");
265            throw clientException;
266        }
267    }
268
269    void startServer(boolean newThread) throws Exception {
270        if (newThread) {
271            serverThread = new Thread() {
272                public void run() {
273                    try {
274                        doServerSide();
275                    } catch (Exception e) {
276                        /*
277                         * Our server thread just died.
278                         *
279                         * Release the client, if not active already...
280                         */
281                        System.err.println("Server died...");
282                        System.err.println(e);
283                        serverReady = true;
284                        serverException = e;
285                    }
286                }
287            };
288            serverThread.start();
289        } else {
290            doServerSide();
291        }
292    }
293
294    void startClient(boolean newThread) throws Exception {
295        if (newThread) {
296            clientThread = new Thread() {
297                public void run() {
298                    try {
299                        doClientSide();
300                    } catch (Exception e) {
301                        /*
302                         * Our client thread just died.
303                         */
304                        System.err.println("Client died...");
305                        clientException = e;
306                    }
307                }
308            };
309            clientThread.start();
310        } else {
311            doClientSide();
312        }
313    }
314}
315