1/*
2 * Copyright (c) 2012, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24//
25// SunJSSE does not support dynamic system properties, no way to re-use
26// system properties in samevm/agentvm mode.
27//
28
29/**
30 * @test
31 * @bug 7068321
32 * @summary Support TLS Server Name Indication (SNI) Extension in JSSE Server
33 * @library ../templates
34 * @build SSLCapabilities SSLExplorer
35 * @run main/othervm SSLSocketExplorerWithSrvSNI
36 */
37
38import java.io.*;
39import java.nio.*;
40import java.nio.channels.*;
41import java.util.*;
42import java.net.*;
43import javax.net.ssl.*;
44
45public class SSLSocketExplorerWithSrvSNI {
46
47    /*
48     * =============================================================
49     * Set the various variables needed for the tests, then
50     * specify what tests to run on each side.
51     */
52
53    /*
54     * Should we run the client or server in a separate thread?
55     * Both sides can throw exceptions, but do you have a preference
56     * as to which side should be the main thread.
57     */
58    static boolean separateServerThread = true;
59
60    /*
61     * Where do we find the keystores?
62     */
63    static String pathToStores = "../etc";
64    static String keyStoreFile = "keystore";
65    static String trustStoreFile = "truststore";
66    static String passwd = "passphrase";
67
68    /*
69     * Is the server ready to serve?
70     */
71    volatile static boolean serverReady = false;
72
73    /*
74     * Turn on SSL debugging?
75     */
76    static boolean debug = false;
77
78    /*
79     * If the client or server is doing some kind of object creation
80     * that the other side depends on, and that thread prematurely
81     * exits, you may experience a hang.  The test harness will
82     * terminate all hung threads after its timeout has expired,
83     * currently 3 minutes by default, but you might try to be
84     * smart about it....
85     */
86
87    /*
88     * Define the server side of the test.
89     *
90     * If the server prematurely exits, serverReady will be set to true
91     * to avoid infinite hangs.
92     */
93    void doServerSide() throws Exception {
94
95        ServerSocket serverSocket = new ServerSocket(serverPort);
96
97        // Signal Client, we're ready for his connect.
98        serverPort = serverSocket.getLocalPort();
99        serverReady = true;
100
101        Socket socket = serverSocket.accept();
102        InputStream ins = socket.getInputStream();
103
104        byte[] buffer = new byte[0xFF];
105        int position = 0;
106        SSLCapabilities capabilities = null;
107
108        // Read the header of TLS record
109        while (position < SSLExplorer.RECORD_HEADER_SIZE) {
110            int count = SSLExplorer.RECORD_HEADER_SIZE - position;
111            int n = ins.read(buffer, position, count);
112            if (n < 0) {
113                throw new Exception("unexpected end of stream!");
114            }
115            position += n;
116        }
117
118        int recordLength = SSLExplorer.getRequiredSize(buffer, 0, position);
119        if (buffer.length < recordLength) {
120            buffer = Arrays.copyOf(buffer, recordLength);
121        }
122
123        while (position < recordLength) {
124            int count = recordLength - position;
125            int n = ins.read(buffer, position, count);
126            if (n < 0) {
127                throw new Exception("unexpected end of stream!");
128            }
129            position += n;
130        }
131
132        capabilities = SSLExplorer.explore(buffer, 0, recordLength);;
133        if (capabilities != null) {
134            System.out.println("Record version: " +
135                    capabilities.getRecordVersion());
136            System.out.println("Hello version: " +
137                    capabilities.getHelloVersion());
138        }
139
140        SSLSocketFactory sslsf =
141            (SSLSocketFactory) SSLSocketFactory.getDefault();
142        ByteArrayInputStream bais =
143            new ByteArrayInputStream(buffer, 0, position);
144        SSLSocket sslSocket = (SSLSocket)sslsf.createSocket(socket, bais, true);
145
146        SNIMatcher matcher = SNIHostName.createSNIMatcher(
147                                                serverAcceptableHostname);
148        Collection<SNIMatcher> matchers = new ArrayList<>(1);
149        matchers.add(matcher);
150        SSLParameters params = sslSocket.getSSLParameters();
151        params.setSNIMatchers(matchers);
152        sslSocket.setSSLParameters(params);
153
154        InputStream sslIS = sslSocket.getInputStream();
155        OutputStream sslOS = sslSocket.getOutputStream();
156
157        sslIS.read();
158        sslOS.write(85);
159        sslOS.flush();
160
161        ExtendedSSLSession session = (ExtendedSSLSession)sslSocket.getSession();
162        checkCapabilities(capabilities, session);
163
164        sslSocket.close();
165        serverSocket.close();
166    }
167
168
169    /*
170     * Define the client side of the test.
171     *
172     * If the server prematurely exits, serverReady will be set to true
173     * to avoid infinite hangs.
174     */
175    void doClientSide() throws Exception {
176
177        /*
178         * Wait for server to get started.
179         */
180        while (!serverReady) {
181            Thread.sleep(50);
182        }
183
184        SSLSocketFactory sslsf =
185            (SSLSocketFactory) SSLSocketFactory.getDefault();
186        SSLSocket sslSocket = (SSLSocket)
187            sslsf.createSocket("localhost", serverPort);
188
189        InputStream sslIS = sslSocket.getInputStream();
190        OutputStream sslOS = sslSocket.getOutputStream();
191
192        sslOS.write(280);
193        sslOS.flush();
194        sslIS.read();
195
196        ExtendedSSLSession session = (ExtendedSSLSession)sslSocket.getSession();
197        checkSNIInSession(session);
198
199        sslSocket.close();
200    }
201
202    private static String clientRequestedHostname = "www.example.com";
203    private static String serverAcceptableHostname =
204                                                "www\\.example\\.(com|org)";
205
206    void checkCapabilities(SSLCapabilities capabilities,
207            ExtendedSSLSession session) throws Exception {
208
209        List<SNIServerName> sessionSNI = session.getRequestedServerNames();
210        if (!sessionSNI.equals(capabilities.getServerNames())) {
211            for (SNIServerName sni : sessionSNI) {
212                System.out.println("SNI in session is " + sni);
213            }
214
215            List<SNIServerName> capaSNI = capabilities.getServerNames();
216            for (SNIServerName sni : capaSNI) {
217                System.out.println("SNI in session is " + sni);
218            }
219
220            throw new Exception(
221                    "server name indication does not match capabilities");
222        }
223
224        checkSNIInSession(session);
225    }
226
227    void checkSNIInSession(ExtendedSSLSession session) throws Exception {
228        List<SNIServerName> sessionSNI = session.getRequestedServerNames();
229        if (!sessionSNI.isEmpty()) {
230            throw new Exception(
231                    "should be empty request server name indication");
232        }
233    }
234
235    /*
236     * =============================================================
237     * The remainder is just support stuff
238     */
239
240    // use any free port by default
241    volatile int serverPort = 0;
242
243    volatile Exception serverException = null;
244    volatile Exception clientException = null;
245
246    public static void main(String[] args) throws Exception {
247        String keyFilename =
248            System.getProperty("test.src", ".") + "/" + pathToStores +
249                "/" + keyStoreFile;
250        String trustFilename =
251            System.getProperty("test.src", ".") + "/" + pathToStores +
252                "/" + trustStoreFile;
253
254        System.setProperty("javax.net.ssl.keyStore", keyFilename);
255        System.setProperty("javax.net.ssl.keyStorePassword", passwd);
256        System.setProperty("javax.net.ssl.trustStore", trustFilename);
257        System.setProperty("javax.net.ssl.trustStorePassword", passwd);
258
259        if (debug)
260            System.setProperty("javax.net.debug", "all");
261
262        /*
263         * Start the tests.
264         */
265        new SSLSocketExplorerWithSrvSNI();
266    }
267
268    Thread clientThread = null;
269    Thread serverThread = null;
270
271    /*
272     * Primary constructor, used to drive remainder of the test.
273     *
274     * Fork off the other side, then do your work.
275     */
276    SSLSocketExplorerWithSrvSNI() throws Exception {
277        try {
278            if (separateServerThread) {
279                startServer(true);
280                startClient(false);
281            } else {
282                startClient(true);
283                startServer(false);
284            }
285        } catch (Exception e) {
286            // swallow for now.  Show later
287        }
288
289        /*
290         * Wait for other side to close down.
291         */
292        if (separateServerThread) {
293            serverThread.join();
294        } else {
295            clientThread.join();
296        }
297
298        /*
299         * When we get here, the test is pretty much over.
300         * Which side threw the error?
301         */
302        Exception local;
303        Exception remote;
304        String whichRemote;
305
306        if (separateServerThread) {
307            remote = serverException;
308            local = clientException;
309            whichRemote = "server";
310        } else {
311            remote = clientException;
312            local = serverException;
313            whichRemote = "client";
314        }
315
316        /*
317         * If both failed, return the curthread's exception, but also
318         * print the remote side Exception
319         */
320        if ((local != null) && (remote != null)) {
321            System.out.println(whichRemote + " also threw:");
322            remote.printStackTrace();
323            System.out.println();
324            throw local;
325        }
326
327        if (remote != null) {
328            throw remote;
329        }
330
331        if (local != null) {
332            throw local;
333        }
334    }
335
336    void startServer(boolean newThread) throws Exception {
337        if (newThread) {
338            serverThread = new Thread() {
339                public void run() {
340                    try {
341                        doServerSide();
342                    } catch (Exception e) {
343                        /*
344                         * Our server thread just died.
345                         *
346                         * Release the client, if not active already...
347                         */
348                        System.err.println("Server died...");
349                        serverReady = true;
350                        serverException = e;
351                    }
352                }
353            };
354            serverThread.start();
355        } else {
356            try {
357                doServerSide();
358            } catch (Exception e) {
359                serverException = e;
360            } finally {
361                serverReady = true;
362            }
363        }
364    }
365
366    void startClient(boolean newThread) throws Exception {
367        if (newThread) {
368            clientThread = new Thread() {
369                public void run() {
370                    try {
371                        doClientSide();
372                    } catch (Exception e) {
373                        /*
374                         * Our client thread just died.
375                         */
376                        System.err.println("Client died...");
377                        clientException = e;
378                    }
379                }
380            };
381            clientThread.start();
382        } else {
383            try {
384                doClientSide();
385            } catch (Exception e) {
386                clientException = e;
387            }
388        }
389    }
390}
391