1/*
2 * Copyright (c) 2001, 2017, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24// SunJSSE does not support dynamic system properties, no way to re-use
25// system properties in samevm/agentvm mode.
26
27/*
28 * @test
29 * @bug   4366807
30 * @summary Need new APIs to get/set session timeout and session cache size.
31 * @run main/othervm SessionTimeOutTests
32 */
33
34import java.io.*;
35import java.net.InetSocketAddress;
36import java.net.SocketTimeoutException;
37
38import javax.net.ssl.*;
39import java.util.*;
40import java.security.*;
41import java.util.concurrent.CountDownLatch;
42import java.util.concurrent.TimeUnit;
43
44/**
45 * Session reuse time-out tests cover the cases below:
46 * 1. general test, i.e timeout is set to x and session invalidates when
47 * its lifetime exceeds x.
48 * 2. Effect of changing the timeout limit.
49 * The test suite does not cover the default timeout(24 hours) usage. This
50 * case has been tested independently.
51 *
52 * Invariant for passing this test is, at any given time,
53 * lifetime of a session < current_session_timeout, such that
54 * current_session_timeout > 0, for all sessions cached by the session
55 * context.
56 */
57
58public class SessionTimeOutTests {
59
60    /*
61     * =============================================================
62     * Set the various variables needed for the tests, then
63     * specify what tests to run on each side.
64     */
65
66    /*
67     * Should we run the client or server in a separate thread?
68     * Both sides can throw exceptions, but do you have a preference
69     * as to which side should be the main thread.
70     */
71    static boolean separateServerThread = true;
72
73    /*
74     * Where do we find the keystores?
75     */
76    static String pathToStores = "../etc";
77    static String keyStoreFile = "keystore";
78    static String trustStoreFile = "truststore";
79    static String passwd = "passphrase";
80
81    private static int PORTS = 3;
82
83    /*
84     * Is the server ready to serve?
85     */
86    private final CountDownLatch serverCondition = new CountDownLatch(PORTS);
87
88    /*
89     * Turn on SSL debugging?
90     */
91    static boolean debug = false;
92
93    /*
94     * If the client or server is doing some kind of object creation
95     * that the other side depends on, and that thread prematurely
96     * exits, you may experience a hang.  The test harness will
97     * terminate all hung threads after its timeout has expired,
98     * currently 3 minutes by default, but you might try to be
99     * smart about it....
100     */
101
102    /*
103     * Define the server side of the test.
104     */
105
106    /*
107     * A limit on the number of connections at any given time
108     */
109    static int MAX_ACTIVE_CONNECTIONS = 3;
110
111    /*
112     * Divide the max connections among the available server ports.
113     * The use of more than one server port ensures creation of more
114     * than one session.
115     */
116    private static final int serverConns = MAX_ACTIVE_CONNECTIONS / PORTS;
117    private static final int remainingConns = MAX_ACTIVE_CONNECTIONS % PORTS;
118
119    private static final int TIMEOUT = 30000; // in millisecond
120
121    void doServerSide(int slot, int serverConns) throws Exception {
122
123        SSLServerSocket sslServerSocket
124                = (SSLServerSocket) sslssf.createServerSocket(0);
125        sslServerSocket.setSoTimeout(TIMEOUT);
126        serverPorts[slot] = sslServerSocket.getLocalPort();
127
128        /*
129         * Signal Client, one server is ready for its connect.
130         */
131        serverCondition.countDown();
132
133        for (int nConnections = 0; nConnections < serverConns; nConnections++) {
134            SSLSocket sslSocket = null;
135            try {
136                sslSocket = (SSLSocket) sslServerSocket.accept();
137            }  catch (SocketTimeoutException ste) {
138                System.out.println(
139                        "No incoming client connection. Ignore in server side.");
140                continue;
141            }
142            InputStream sslIS = sslSocket.getInputStream();
143            OutputStream sslOS = sslSocket.getOutputStream();
144            sslIS.read();
145            sslSocket.getSession();
146            sslOS.write(85);
147            sslOS.flush();
148            sslSocket.close();
149        }
150    }
151
152    /*
153     * Define the client side of the test.
154     *
155     * If the server prematurely exits, serverReady will be set to zero
156     * to avoid infinite hangs.
157     */
158    void doClientSide() throws Exception {
159        /*
160         * Wait for server to get started.
161         */
162        if (!serverCondition.await(TIMEOUT, TimeUnit.MILLISECONDS)) {
163            System.out.println(
164                    "The server side is not ready yet. Ignore in client side.");
165            return;
166        }
167
168        SSLSocket sslSockets[] = new SSLSocket[MAX_ACTIVE_CONNECTIONS];
169        Vector<SSLSession> sessions = new Vector<>();
170        SSLSessionContext sessCtx = sslctx.getClientSessionContext();
171
172        sessCtx.setSessionTimeout(10); // in secs
173        int timeout = sessCtx.getSessionTimeout();
174        for (int nConnections = 0; nConnections < MAX_ACTIVE_CONNECTIONS;
175                nConnections++) {
176            // divide the connections among the available server ports
177            try {
178                SSLSocket sslSocket = (SSLSocket) sslsf.createSocket();
179                sslSocket.connect(new InetSocketAddress("localhost",
180                        serverPorts[nConnections % serverPorts.length]),
181                        TIMEOUT);
182                sslSockets[nConnections] = sslSocket;
183            } catch (IOException ioe) {
184                // The server side may be impacted by naughty test cases or
185                // third party routines, and cannot accept connections.
186                //
187                // Just ignore the test if the connection cannot be
188                // established.
189                System.out.println(
190                        "Cannot make a connection in time. Ignore in client side.");
191                continue;
192            }
193
194            InputStream sslIS = sslSockets[nConnections].getInputStream();
195            OutputStream sslOS = sslSockets[nConnections].getOutputStream();
196            sslOS.write(237);
197            sslOS.flush();
198            sslIS.read();
199            SSLSession sess = sslSockets[nConnections].getSession();
200            if (!sessions.contains(sess))
201                sessions.add(sess);
202        }
203        System.out.println();
204        System.out.println("Current timeout is set to: " + timeout);
205        System.out.println("Testing SSLSessionContext.getSession()......");
206        System.out.println("========================================"
207                                + "=======================");
208        System.out.println("Session                                 "
209                                + "Session-     Session");
210        System.out.println("                                        "
211                                + "lifetime     timedout?");
212        System.out.println("========================================"
213                                + "=======================");
214
215        for (int i = 0; i < sessions.size(); i++) {
216            SSLSession session = (SSLSession) sessions.elementAt(i);
217            long currentTime  = System.currentTimeMillis();
218            long creationTime = session.getCreationTime();
219            long lifetime = (currentTime - creationTime) / 1000;
220
221            System.out.print(session + "      " + lifetime + "            ");
222            if (sessCtx.getSession(session.getId()) == null) {
223                if (lifetime < timeout)
224                    // sessions can be garbage collected before the timeout
225                    // limit is reached
226                    System.out.println("Invalidated before timeout");
227                else
228                    System.out.println("YES");
229            } else {
230                System.out.println("NO");
231                if ((timeout != 0) && (lifetime > timeout)) {
232                    throw new Exception("Session timeout test failed for the"
233                        + " obove session");
234                }
235            }
236            // change the timeout
237            if (i == ((sessions.size()) / 2)) {
238                System.out.println();
239                sessCtx.setSessionTimeout(2); // in secs
240                timeout = sessCtx.getSessionTimeout();
241                System.out.println("timeout is changed to: " + timeout);
242                System.out.println();
243           }
244        }
245
246        // check the ids returned by the enumerator
247        Enumeration<byte[]> e = sessCtx.getIds();
248        System.out.println("----------------------------------------"
249                                + "-----------------------");
250        System.out.println("Testing SSLSessionContext.getId()......");
251        System.out.println();
252
253        SSLSession nextSess = null;
254        SSLSession sess;
255        for (int i = 0; i < sessions.size(); i++) {
256            sess = (SSLSession)sessions.elementAt(i);
257            String isTimedout = "YES";
258            long currentTime  = System.currentTimeMillis();
259            long creationTime  = sess.getCreationTime();
260            long lifetime = (currentTime - creationTime) / 1000;
261
262            if (nextSess != null) {
263                if (isEqualSessionId(nextSess.getId(), sess.getId())) {
264                    isTimedout = "NO";
265                    nextSess = null;
266                }
267            } else if (e.hasMoreElements()) {
268                nextSess = sessCtx.getSession((byte[]) e.nextElement());
269                if ((nextSess != null) && isEqualSessionId(nextSess.getId(),
270                                        sess.getId())) {
271                    nextSess = null;
272                    isTimedout = "NO";
273                }
274            }
275
276            /*
277             * A session not invalidated even after it's timeout?
278             */
279            if ((timeout != 0) && (lifetime > timeout) &&
280                        (isTimedout.equals("NO"))) {
281                throw new Exception("Session timeout test failed for session: "
282                                + sess + " lifetime: " + lifetime);
283            }
284            System.out.print(sess + "      " + lifetime);
285            if (((timeout == 0) || (lifetime < timeout)) &&
286                                  (isTimedout == "YES")) {
287                isTimedout = "Invalidated before timeout";
288            }
289
290            System.out.println("            " + isTimedout);
291        }
292        for (int i = 0; i < sslSockets.length; i++) {
293            sslSockets[i].close();
294        }
295        System.out.println("----------------------------------------"
296                                 + "-----------------------");
297        System.out.println("Session timeout test passed");
298    }
299
300    boolean isEqualSessionId(byte[] id1, byte[] id2) {
301        if (id1.length != id2.length)
302           return false;
303        else {
304            for (int i = 0; i < id1.length; i++) {
305                if (id1[i] != id2[i]) {
306                   return false;
307                }
308            }
309            return true;
310        }
311    }
312
313
314    /*
315     * =============================================================
316     * The remainder is just support stuff
317     */
318
319    int serverPorts[] = new int[PORTS];
320    static SSLServerSocketFactory sslssf;
321    static SSLSocketFactory sslsf;
322    static SSLContext sslctx;
323
324    volatile Exception serverException = null;
325    volatile Exception clientException = null;
326
327    public static void main(String[] args) throws Exception {
328        String keyFilename =
329            System.getProperty("test.src", "./") + "/" + pathToStores +
330                "/" + keyStoreFile;
331        String trustFilename =
332            System.getProperty("test.src", "./") + "/" + pathToStores +
333                "/" + trustStoreFile;
334
335        System.setProperty("javax.net.ssl.keyStore", keyFilename);
336        System.setProperty("javax.net.ssl.keyStorePassword", passwd);
337        System.setProperty("javax.net.ssl.trustStore", trustFilename);
338        System.setProperty("javax.net.ssl.trustStorePassword", passwd);
339
340        if (debug)
341            System.setProperty("javax.net.debug", "all");
342
343        sslctx = SSLContext.getInstance("TLS");
344        KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
345        KeyStore ks = KeyStore.getInstance("JKS");
346        ks.load(new FileInputStream(keyFilename), passwd.toCharArray());
347        kmf.init(ks, passwd.toCharArray());
348        sslctx.init(kmf.getKeyManagers(), null, null);
349        sslssf = (SSLServerSocketFactory) sslctx.getServerSocketFactory();
350        sslsf = (SSLSocketFactory) sslctx.getSocketFactory();
351
352        /*
353         * Start the tests.
354         */
355        new SessionTimeOutTests();
356    }
357
358    Thread clientThread = null;
359    Thread serverThread = null;
360
361    /*
362     * Primary constructor, used to drive remainder of the test.
363     *
364     * Fork off the other side, then do your work.
365     */
366    SessionTimeOutTests() throws Exception {
367
368        /*
369         * create the SSLServerSocket and SSLSocket factories
370         */
371
372        Exception startException = null;
373        try {
374            if (separateServerThread) {
375                for (int i = 0; i < serverPorts.length; i++) {
376                    // distribute remaining connections among the
377                    // available ports
378                    if (i < remainingConns)
379                        startServer(i, (serverConns + 1), true);
380                    else
381                        startServer(i, serverConns, true);
382                }
383                startClient(false);
384            } else {
385                startClient(true);
386                for (int i = 0; i < PORTS; i++) {
387                    if (i < remainingConns)
388                        startServer(i, (serverConns + 1), false);
389                    else
390                        startServer(i, serverConns, false);
391                }
392            }
393        } catch (Exception e) {
394            startException = e;
395        }
396
397        /*
398         * Wait for other side to close down.
399         */
400        if (separateServerThread) {
401            if (serverThread != null) {
402                serverThread.join();
403            }
404        } else {
405            if (clientThread != null) {
406                clientThread.join();
407            }
408        }
409
410        /*
411         * When we get here, the test is pretty much over.
412         * Which side threw the error?
413         */
414        Exception local;
415        Exception remote;
416
417        if (separateServerThread) {
418            remote = serverException;
419            local = clientException;
420        } else {
421            remote = clientException;
422            local = serverException;
423        }
424
425        Exception exception = null;
426
427        /*
428         * Check various exception conditions.
429         */
430        if ((local != null) && (remote != null)) {
431            // If both failed, return the curthread's exception.
432            local.initCause(remote);
433            exception = local;
434        } else if (local != null) {
435            exception = local;
436        } else if (remote != null) {
437            exception = remote;
438        } else if (startException != null) {
439            exception = startException;
440        }
441
442        /*
443         * If there was an exception *AND* a startException,
444         * output it.
445         */
446        if (exception != null) {
447            if (exception != startException && startException != null) {
448                exception.addSuppressed(startException);
449            }
450            throw exception;
451        }
452
453        // Fall-through: no exception to throw!
454    }
455
456    void startServer(final int slot, final int nConns, boolean newThread)
457            throws Exception {
458        if (newThread) {
459            serverThread = new Thread() {
460                public void run() {
461                    try {
462                        doServerSide(slot, nConns);
463                    } catch (Exception e) {
464                        /*
465                         * Our server thread just died.
466                         *
467                         * Release the client, if not active already...
468                         */
469                        System.err.println("Server died...");
470                        e.printStackTrace();
471                        serverException = e;
472                    }
473                }
474            };
475            serverThread.start();
476        } else {
477            try {
478                doServerSide(slot, nConns);
479            } catch (Exception e) {
480                serverException = e;
481            }
482        }
483    }
484
485    void startClient(boolean newThread)
486                 throws Exception {
487        if (newThread) {
488            clientThread = new Thread() {
489                public void run() {
490                    try {
491                        doClientSide();
492                    } catch (Exception e) {
493                        /*
494                         * Our client thread just died.
495                         */
496                        System.err.println("Client died...");
497                        clientException = e;
498                    }
499                }
500            };
501            clientThread.start();
502        } else {
503            try {
504                doClientSide();
505            } catch (Exception e) {
506                clientException = e;
507            }
508        }
509    }
510}
511