1/*
2 * Copyright (c) 2012, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24//
25// SunJSSE does not support dynamic system properties, no way to re-use
26// system properties in samevm/agentvm mode.
27//
28
29/**
30 * @test
31 * @bug 7068321
32 * @summary Support TLS Server Name Indication (SNI) Extension in JSSE Server
33 * @library ../templates
34 * @build SSLCapabilities SSLExplorer
35 * @run main/othervm SSLSocketExplorerWithCliSNI
36 */
37
38import java.io.*;
39import java.nio.*;
40import java.nio.channels.*;
41import java.util.*;
42import java.net.*;
43import javax.net.ssl.*;
44
45public class SSLSocketExplorerWithCliSNI {
46
47    /*
48     * =============================================================
49     * Set the various variables needed for the tests, then
50     * specify what tests to run on each side.
51     */
52
53    /*
54     * Should we run the client or server in a separate thread?
55     * Both sides can throw exceptions, but do you have a preference
56     * as to which side should be the main thread.
57     */
58    static boolean separateServerThread = true;
59
60    /*
61     * Where do we find the keystores?
62     */
63    static String pathToStores = "../etc";
64    static String keyStoreFile = "keystore";
65    static String trustStoreFile = "truststore";
66    static String passwd = "passphrase";
67
68    /*
69     * Is the server ready to serve?
70     */
71    volatile static boolean serverReady = false;
72
73    /*
74     * Turn on SSL debugging?
75     */
76    static boolean debug = false;
77
78    /*
79     * If the client or server is doing some kind of object creation
80     * that the other side depends on, and that thread prematurely
81     * exits, you may experience a hang.  The test harness will
82     * terminate all hung threads after its timeout has expired,
83     * currently 3 minutes by default, but you might try to be
84     * smart about it....
85     */
86
87    /*
88     * Define the server side of the test.
89     *
90     * If the server prematurely exits, serverReady will be set to true
91     * to avoid infinite hangs.
92     */
93    void doServerSide() throws Exception {
94
95        ServerSocket serverSocket = new ServerSocket(serverPort);
96
97        // Signal Client, we're ready for his connect.
98        serverPort = serverSocket.getLocalPort();
99        serverReady = true;
100
101        Socket socket = serverSocket.accept();
102        InputStream ins = socket.getInputStream();
103
104        byte[] buffer = new byte[0xFF];
105        int position = 0;
106        SSLCapabilities capabilities = null;
107
108        // Read the header of TLS record
109        while (position < SSLExplorer.RECORD_HEADER_SIZE) {
110            int count = SSLExplorer.RECORD_HEADER_SIZE - position;
111            int n = ins.read(buffer, position, count);
112            if (n < 0) {
113                throw new Exception("unexpected end of stream!");
114            }
115            position += n;
116        }
117
118        int recordLength = SSLExplorer.getRequiredSize(buffer, 0, position);
119        if (buffer.length < recordLength) {
120            buffer = Arrays.copyOf(buffer, recordLength);
121        }
122
123        while (position < recordLength) {
124            int count = recordLength - position;
125            int n = ins.read(buffer, position, count);
126            if (n < 0) {
127                throw new Exception("unexpected end of stream!");
128            }
129            position += n;
130        }
131
132        capabilities = SSLExplorer.explore(buffer, 0, recordLength);;
133        if (capabilities != null) {
134            System.out.println("Record version: " +
135                    capabilities.getRecordVersion());
136            System.out.println("Hello version: " +
137                    capabilities.getHelloVersion());
138        }
139
140        SSLSocketFactory sslsf =
141            (SSLSocketFactory) SSLSocketFactory.getDefault();
142        ByteArrayInputStream bais =
143            new ByteArrayInputStream(buffer, 0, position);
144        SSLSocket sslSocket = (SSLSocket)sslsf.createSocket(socket, bais, true);
145
146        InputStream sslIS = sslSocket.getInputStream();
147        OutputStream sslOS = sslSocket.getOutputStream();
148
149        sslIS.read();
150        sslOS.write(85);
151        sslOS.flush();
152
153        ExtendedSSLSession session = (ExtendedSSLSession)sslSocket.getSession();
154        checkCapabilities(capabilities, session);
155
156        sslSocket.close();
157        serverSocket.close();
158    }
159
160
161    /*
162     * Define the client side of the test.
163     *
164     * If the server prematurely exits, serverReady will be set to true
165     * to avoid infinite hangs.
166     */
167    void doClientSide() throws Exception {
168
169        /*
170         * Wait for server to get started.
171         */
172        while (!serverReady) {
173            Thread.sleep(50);
174        }
175
176        SSLSocketFactory sslsf =
177            (SSLSocketFactory) SSLSocketFactory.getDefault();
178        SSLSocket sslSocket = (SSLSocket)
179            sslsf.createSocket("localhost", serverPort);
180
181        SNIHostName serverName = new SNIHostName(clientRequestedHostname);
182        List<SNIServerName> serverNames = new ArrayList<>(1);
183        serverNames.add(serverName);
184        SSLParameters params = sslSocket.getSSLParameters();
185        params.setServerNames(serverNames);
186        sslSocket.setSSLParameters(params);
187
188        InputStream sslIS = sslSocket.getInputStream();
189        OutputStream sslOS = sslSocket.getOutputStream();
190
191        sslOS.write(280);
192        sslOS.flush();
193        sslIS.read();
194
195        ExtendedSSLSession session = (ExtendedSSLSession)sslSocket.getSession();
196        checkSNIInSession(session);
197
198        sslSocket.close();
199    }
200
201    private static String clientRequestedHostname = "www.example.com";
202    private static String serverAcceptableHostname =
203                                                "www\\.example\\.(com|org)";
204
205    void checkCapabilities(SSLCapabilities capabilities,
206            ExtendedSSLSession session) throws Exception {
207
208        List<SNIServerName> sessionSNI = session.getRequestedServerNames();
209        if (!sessionSNI.equals(capabilities.getServerNames())) {
210            for (SNIServerName sni : sessionSNI) {
211                System.out.println("SNI in session is " + sni);
212            }
213
214            List<SNIServerName> capaSNI = capabilities.getServerNames();
215            for (SNIServerName sni : capaSNI) {
216                System.out.println("SNI in session is " + sni);
217            }
218
219            throw new Exception(
220                    "server name indication does not match capabilities");
221        }
222
223        checkSNIInSession(session);
224    }
225
226    void checkSNIInSession(ExtendedSSLSession session) throws Exception {
227        List<SNIServerName> sessionSNI = session.getRequestedServerNames();
228        if (sessionSNI.isEmpty()) {
229            throw new Exception(
230                    "unexpected empty request server name indication");
231        }
232
233        if (sessionSNI.size() != 1) {
234            throw new Exception(
235                    "unexpected request server name indication");
236        }
237
238        SNIServerName serverName = sessionSNI.get(0);
239        if (!(serverName instanceof SNIHostName)) {
240            throw new Exception(
241                    "unexpected instance of request server name indication");
242        }
243
244        String hostname = ((SNIHostName)serverName).getAsciiName();
245        if (!clientRequestedHostname.equalsIgnoreCase(hostname)) {
246            throw new Exception(
247                    "unexpected request server name indication value");
248        }
249    }
250
251
252    /*
253     * =============================================================
254     * The remainder is just support stuff
255     */
256
257    // use any free port by default
258    volatile int serverPort = 0;
259
260    volatile Exception serverException = null;
261    volatile Exception clientException = null;
262
263    public static void main(String[] args) throws Exception {
264        String keyFilename =
265            System.getProperty("test.src", ".") + "/" + pathToStores +
266                "/" + keyStoreFile;
267        String trustFilename =
268            System.getProperty("test.src", ".") + "/" + pathToStores +
269                "/" + trustStoreFile;
270
271        System.setProperty("javax.net.ssl.keyStore", keyFilename);
272        System.setProperty("javax.net.ssl.keyStorePassword", passwd);
273        System.setProperty("javax.net.ssl.trustStore", trustFilename);
274        System.setProperty("javax.net.ssl.trustStorePassword", passwd);
275
276        if (debug)
277            System.setProperty("javax.net.debug", "all");
278
279        /*
280         * Start the tests.
281         */
282        new SSLSocketExplorerWithCliSNI();
283    }
284
285    Thread clientThread = null;
286    Thread serverThread = null;
287
288    /*
289     * Primary constructor, used to drive remainder of the test.
290     *
291     * Fork off the other side, then do your work.
292     */
293    SSLSocketExplorerWithCliSNI() throws Exception {
294        try {
295            if (separateServerThread) {
296                startServer(true);
297                startClient(false);
298            } else {
299                startClient(true);
300                startServer(false);
301            }
302        } catch (Exception e) {
303            // swallow for now.  Show later
304        }
305
306        /*
307         * Wait for other side to close down.
308         */
309        if (separateServerThread) {
310            serverThread.join();
311        } else {
312            clientThread.join();
313        }
314
315        /*
316         * When we get here, the test is pretty much over.
317         * Which side threw the error?
318         */
319        Exception local;
320        Exception remote;
321        String whichRemote;
322
323        if (separateServerThread) {
324            remote = serverException;
325            local = clientException;
326            whichRemote = "server";
327        } else {
328            remote = clientException;
329            local = serverException;
330            whichRemote = "client";
331        }
332
333        /*
334         * If both failed, return the curthread's exception, but also
335         * print the remote side Exception
336         */
337        if ((local != null) && (remote != null)) {
338            System.out.println(whichRemote + " also threw:");
339            remote.printStackTrace();
340            System.out.println();
341            throw local;
342        }
343
344        if (remote != null) {
345            throw remote;
346        }
347
348        if (local != null) {
349            throw local;
350        }
351    }
352
353    void startServer(boolean newThread) throws Exception {
354        if (newThread) {
355            serverThread = new Thread() {
356                public void run() {
357                    try {
358                        doServerSide();
359                    } catch (Exception e) {
360                        /*
361                         * Our server thread just died.
362                         *
363                         * Release the client, if not active already...
364                         */
365                        System.err.println("Server died...");
366                        serverReady = true;
367                        serverException = e;
368                    }
369                }
370            };
371            serverThread.start();
372        } else {
373            try {
374                doServerSide();
375            } catch (Exception e) {
376                serverException = e;
377            } finally {
378                serverReady = true;
379            }
380        }
381    }
382
383    void startClient(boolean newThread) throws Exception {
384        if (newThread) {
385            clientThread = new Thread() {
386                public void run() {
387                    try {
388                        doClientSide();
389                    } catch (Exception e) {
390                        /*
391                         * Our client thread just died.
392                         */
393                        System.err.println("Client died...");
394                        clientException = e;
395                    }
396                }
397            };
398            clientThread.start();
399        } else {
400            try {
401                doClientSide();
402            } catch (Exception e) {
403                clientException = e;
404            }
405        }
406    }
407}
408