ShortRSAKeyWithinTLS.java revision 6718:e7cce63bf293
1/*
2 * Copyright (c) 2012, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24import java.io.*;
25import java.net.*;
26import java.util.*;
27import java.security.*;
28import javax.net.*;
29import javax.net.ssl.*;
30import java.lang.reflect.*;
31
32import sun.security.util.KeyUtil;
33
34public class ShortRSAKeyWithinTLS {
35
36    /*
37     * =============================================================
38     * Set the various variables needed for the tests, then
39     * specify what tests to run on each side.
40     */
41
42    /*
43     * Should we run the client or server in a separate thread?
44     * Both sides can throw exceptions, but do you have a preference
45     * as to which side should be the main thread.
46     */
47    static boolean separateServerThread = false;
48
49    /*
50     * Is the server ready to serve?
51     */
52    volatile static boolean serverReady = false;
53
54    /*
55     * Turn on SSL debugging?
56     */
57    static boolean debug = false;
58
59    /*
60     * If the client or server is doing some kind of object creation
61     * that the other side depends on, and that thread prematurely
62     * exits, you may experience a hang.  The test harness will
63     * terminate all hung threads after its timeout has expired,
64     * currently 3 minutes by default, but you might try to be
65     * smart about it....
66     */
67
68    /*
69     * Define the server side of the test.
70     *
71     * If the server prematurely exits, serverReady will be set to true
72     * to avoid infinite hangs.
73     */
74    void doServerSide() throws Exception {
75
76        // load the key store
77        KeyStore ks = KeyStore.getInstance("Windows-MY", "SunMSCAPI");
78        ks.load(null, null);
79        System.out.println("Loaded keystore: Windows-MY");
80
81        // check key size
82        checkKeySize(ks);
83
84        // initialize the SSLContext
85        KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
86        kmf.init(ks, null);
87
88        TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
89        tmf.init(ks);
90
91        SSLContext ctx = SSLContext.getInstance("TLS");
92        ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
93
94        ServerSocketFactory ssf = ctx.getServerSocketFactory();
95        SSLServerSocket sslServerSocket = (SSLServerSocket)
96                                ssf.createServerSocket(serverPort);
97        sslServerSocket.setNeedClientAuth(true);
98        serverPort = sslServerSocket.getLocalPort();
99        System.out.println("serverPort = " + serverPort);
100
101        /*
102         * Signal Client, we're ready for his connect.
103         */
104        serverReady = true;
105
106        SSLSocket sslSocket = (SSLSocket) sslServerSocket.accept();
107        InputStream sslIS = sslSocket.getInputStream();
108        OutputStream sslOS = sslSocket.getOutputStream();
109
110        sslIS.read();
111        sslOS.write(85);
112        sslOS.flush();
113
114        sslSocket.close();
115    }
116
117    /*
118     * Define the client side of the test.
119     *
120     * If the server prematurely exits, serverReady will be set to true
121     * to avoid infinite hangs.
122     */
123    void doClientSide() throws Exception {
124
125        /*
126         * Wait for server to get started.
127         */
128        while (!serverReady) {
129            Thread.sleep(50);
130        }
131
132        // load the key store
133        KeyStore ks = KeyStore.getInstance("Windows-MY", "SunMSCAPI");
134        ks.load(null, null);
135        System.out.println("Loaded keystore: Windows-MY");
136
137        // initialize the SSLContext
138        KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
139        kmf.init(ks, null);
140
141        TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
142        tmf.init(ks);
143
144        SSLContext ctx = SSLContext.getInstance("TLS");
145        ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
146
147        SSLSocketFactory sslsf = ctx.getSocketFactory();
148        SSLSocket sslSocket = (SSLSocket)
149            sslsf.createSocket("localhost", serverPort);
150
151        if (clientProtocol != null) {
152            sslSocket.setEnabledProtocols(new String[] {clientProtocol});
153        }
154
155        if (clientCiperSuite != null) {
156            sslSocket.setEnabledCipherSuites(new String[] {clientCiperSuite});
157        }
158
159        InputStream sslIS = sslSocket.getInputStream();
160        OutputStream sslOS = sslSocket.getOutputStream();
161
162        sslOS.write(280);
163        sslOS.flush();
164        sslIS.read();
165
166        sslSocket.close();
167    }
168
169    private void checkKeySize(KeyStore ks) throws Exception {
170        PrivateKey privateKey = null;
171        PublicKey publicKey = null;
172
173        if (ks.containsAlias(keyAlias)) {
174            System.out.println("Loaded entry: " + keyAlias);
175            privateKey = (PrivateKey)ks.getKey(keyAlias, null);
176            publicKey = (PublicKey)ks.getCertificate(keyAlias).getPublicKey();
177
178            int privateKeySize = KeyUtil.getKeySize(privateKey);
179            if (privateKeySize != keySize) {
180                throw new Exception("Expected key size is " + keySize +
181                        ", but the private key size is " + privateKeySize);
182            }
183
184            int publicKeySize = KeyUtil.getKeySize(publicKey);
185            if (publicKeySize != keySize) {
186                throw new Exception("Expected key size is " + keySize +
187                        ", but the public key size is " + publicKeySize);
188            }
189        }
190    }
191
192    /*
193     * =============================================================
194     * The remainder is just support stuff
195     */
196
197    // use any free port by default
198    volatile int serverPort = 0;
199
200    volatile Exception serverException = null;
201    volatile Exception clientException = null;
202
203    private static String keyAlias;
204    private static int keySize;
205    private static String clientProtocol = null;
206    private static String clientCiperSuite = null;
207
208    private static void parseArguments(String[] args) {
209        keyAlias = args[0];
210        keySize = Integer.parseInt(args[1]);
211
212        if (args.length > 2) {
213            clientProtocol = args[2];
214        }
215
216        if (args.length > 3) {
217            clientCiperSuite = args[3];
218        }
219    }
220
221    public static void main(String[] args) throws Exception {
222        if (debug) {
223            System.setProperty("javax.net.debug", "all");
224        }
225
226        // Get the customized arguments.
227        parseArguments(args);
228
229        new ShortRSAKeyWithinTLS();
230    }
231
232    Thread clientThread = null;
233    Thread serverThread = null;
234
235    /*
236     * Primary constructor, used to drive remainder of the test.
237     *
238     * Fork off the other side, then do your work.
239     */
240    ShortRSAKeyWithinTLS() throws Exception {
241        try {
242            if (separateServerThread) {
243                startServer(true);
244                startClient(false);
245            } else {
246                startClient(true);
247                startServer(false);
248            }
249        } catch (Exception e) {
250            // swallow for now.  Show later
251        }
252
253        /*
254         * Wait for other side to close down.
255         */
256        if (separateServerThread) {
257            serverThread.join();
258        } else {
259            clientThread.join();
260        }
261
262        /*
263         * When we get here, the test is pretty much over.
264         * Which side threw the error?
265         */
266        Exception local;
267        Exception remote;
268        String whichRemote;
269
270        if (separateServerThread) {
271            remote = serverException;
272            local = clientException;
273            whichRemote = "server";
274        } else {
275            remote = clientException;
276            local = serverException;
277            whichRemote = "client";
278        }
279
280        /*
281         * If both failed, return the curthread's exception, but also
282         * print the remote side Exception
283         */
284        if ((local != null) && (remote != null)) {
285            System.out.println(whichRemote + " also threw:");
286            remote.printStackTrace();
287            System.out.println();
288            throw local;
289        }
290
291        if (remote != null) {
292            throw remote;
293        }
294
295        if (local != null) {
296            throw local;
297        }
298    }
299
300    void startServer(boolean newThread) throws Exception {
301        if (newThread) {
302            serverThread = new Thread() {
303                public void run() {
304                    try {
305                        doServerSide();
306                    } catch (Exception e) {
307                        /*
308                         * Our server thread just died.
309                         *
310                         * Release the client, if not active already...
311                         */
312                        System.err.println("Server died...");
313                        serverReady = true;
314                        serverException = e;
315                    }
316                }
317            };
318            serverThread.start();
319        } else {
320            try {
321                doServerSide();
322            } catch (Exception e) {
323                serverException = e;
324            } finally {
325                serverReady = true;
326            }
327        }
328    }
329
330    void startClient(boolean newThread) throws Exception {
331        if (newThread) {
332            clientThread = new Thread() {
333                public void run() {
334                    try {
335                        doClientSide();
336                    } catch (Exception e) {
337                        /*
338                         * Our client thread just died.
339                         */
340                        System.err.println("Client died...");
341                        clientException = e;
342                    }
343                }
344            };
345            clientThread.start();
346        } else {
347            try {
348                doClientSide();
349            } catch (Exception e) {
350                clientException = e;
351            }
352        }
353    }
354}
355
356