ShortRSAKeyWithinTLS.java revision 4922:11e52d5ba64e
1238106Sdes/*
2238106Sdes * Copyright (c) 2012, Oracle and/or its affiliates. All rights reserved.
3238106Sdes * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4238106Sdes *
5238106Sdes * This code is free software; you can redistribute it and/or modify it
6238106Sdes * under the terms of the GNU General Public License version 2 only, as
7238106Sdes * published by the Free Software Foundation.
8238106Sdes *
9238106Sdes * This code is distributed in the hope that it will be useful, but WITHOUT
10238106Sdes * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11238106Sdes * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12238106Sdes * version 2 for more details (a copy is included in the LICENSE file that
13238106Sdes * accompanied this code).
14238106Sdes *
15238106Sdes * You should have received a copy of the GNU General Public License version
16238106Sdes * 2 along with this work; if not, write to the Free Software Foundation,
17238106Sdes * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18238106Sdes *
19238106Sdes * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20238106Sdes * or visit www.oracle.com if you need additional information or have any
21238106Sdes * questions.
22238106Sdes */
23238106Sdes
24238106Sdesimport java.io.*;
25238106Sdesimport java.net.*;
26238106Sdesimport java.util.*;
27238106Sdesimport java.security.*;
28238106Sdesimport javax.net.*;
29238106Sdesimport javax.net.ssl.*;
30238106Sdesimport java.lang.reflect.*;
31238106Sdes
32238106Sdesimport sun.security.util.KeyLength;
33238106Sdes
34238106Sdespublic class ShortRSAKeyWithinTLS {
35238106Sdes
36238106Sdes    /*
37238106Sdes     * =============================================================
38238106Sdes     * Set the various variables needed for the tests, then
39238106Sdes     * specify what tests to run on each side.
40238106Sdes     */
41238106Sdes
42238106Sdes    /*
43238106Sdes     * Should we run the client or server in a separate thread?
44238106Sdes     * Both sides can throw exceptions, but do you have a preference
45238106Sdes     * as to which side should be the main thread.
46238106Sdes     */
47238106Sdes    static boolean separateServerThread = false;
48238106Sdes
49238106Sdes    /*
50238106Sdes     * Is the server ready to serve?
51238106Sdes     */
52238106Sdes    volatile static boolean serverReady = false;
53238106Sdes
54238106Sdes    /*
55238106Sdes     * Turn on SSL debugging?
56238106Sdes     */
57238106Sdes    static boolean debug = false;
58238106Sdes
59238106Sdes    /*
60238106Sdes     * If the client or server is doing some kind of object creation
61238106Sdes     * that the other side depends on, and that thread prematurely
62238106Sdes     * exits, you may experience a hang.  The test harness will
63238106Sdes     * terminate all hung threads after its timeout has expired,
64238106Sdes     * currently 3 minutes by default, but you might try to be
65238106Sdes     * smart about it....
66238106Sdes     */
67238106Sdes
68238106Sdes    /*
69238106Sdes     * Define the server side of the test.
70238106Sdes     *
71238106Sdes     * If the server prematurely exits, serverReady will be set to true
72238106Sdes     * to avoid infinite hangs.
73238106Sdes     */
74238106Sdes    void doServerSide() throws Exception {
75238106Sdes
76238106Sdes        // load the key store
77238106Sdes        KeyStore ks = KeyStore.getInstance("Windows-MY", "SunMSCAPI");
78238106Sdes        ks.load(null, null);
79238106Sdes        System.out.println("Loaded keystore: Windows-MY");
80238106Sdes
81238106Sdes        // check key size
82238106Sdes        checkKeySize(ks);
83238106Sdes
84238106Sdes        // initialize the SSLContext
85238106Sdes        KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
86238106Sdes        kmf.init(ks, null);
87238106Sdes
88238106Sdes        TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
89238106Sdes        tmf.init(ks);
90238106Sdes
91238106Sdes        SSLContext ctx = SSLContext.getInstance("TLS");
92238106Sdes        ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
93238106Sdes
94238106Sdes        ServerSocketFactory ssf = ctx.getServerSocketFactory();
95238106Sdes        SSLServerSocket sslServerSocket = (SSLServerSocket)
96238106Sdes                                ssf.createServerSocket(serverPort);
97238106Sdes        sslServerSocket.setNeedClientAuth(true);
98238106Sdes        serverPort = sslServerSocket.getLocalPort();
99238106Sdes        System.out.println("serverPort = " + serverPort);
100238106Sdes
101238106Sdes        /*
102238106Sdes         * Signal Client, we're ready for his connect.
103238106Sdes         */
104238106Sdes        serverReady = true;
105238106Sdes
106238106Sdes        SSLSocket sslSocket = (SSLSocket) sslServerSocket.accept();
107238106Sdes        InputStream sslIS = sslSocket.getInputStream();
108238106Sdes        OutputStream sslOS = sslSocket.getOutputStream();
109238106Sdes
110238106Sdes        sslIS.read();
111238106Sdes        sslOS.write(85);
112238106Sdes        sslOS.flush();
113238106Sdes
114238106Sdes        sslSocket.close();
115238106Sdes    }
116238106Sdes
117238106Sdes    /*
118238106Sdes     * Define the client side of the test.
119238106Sdes     *
120238106Sdes     * If the server prematurely exits, serverReady will be set to true
121238106Sdes     * to avoid infinite hangs.
122238106Sdes     */
123238106Sdes    void doClientSide() throws Exception {
124238106Sdes
125238106Sdes        /*
126238106Sdes         * Wait for server to get started.
127238106Sdes         */
128238106Sdes        while (!serverReady) {
129238106Sdes            Thread.sleep(50);
130238106Sdes        }
131238106Sdes
132238106Sdes        // load the key store
133238106Sdes        KeyStore ks = KeyStore.getInstance("Windows-MY", "SunMSCAPI");
134238106Sdes        ks.load(null, null);
135238106Sdes        System.out.println("Loaded keystore: Windows-MY");
136238106Sdes
137238106Sdes        // initialize the SSLContext
138238106Sdes        KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
139238106Sdes        kmf.init(ks, null);
140238106Sdes
141238106Sdes        TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
142238106Sdes        tmf.init(ks);
143238106Sdes
144238106Sdes        SSLContext ctx = SSLContext.getInstance("TLS");
145238106Sdes        ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
146238106Sdes
147238106Sdes        SSLSocketFactory sslsf = ctx.getSocketFactory();
148238106Sdes        SSLSocket sslSocket = (SSLSocket)
149238106Sdes            sslsf.createSocket("localhost", serverPort);
150238106Sdes
151238106Sdes        if (clientProtocol != null) {
152238106Sdes            sslSocket.setEnabledProtocols(new String[] {clientProtocol});
153238106Sdes        }
154238106Sdes
155238106Sdes        if (clientCiperSuite != null) {
156238106Sdes            sslSocket.setEnabledCipherSuites(new String[] {clientCiperSuite});
157238106Sdes        }
158238106Sdes
159238106Sdes        InputStream sslIS = sslSocket.getInputStream();
160238106Sdes        OutputStream sslOS = sslSocket.getOutputStream();
161238106Sdes
162238106Sdes        sslOS.write(280);
163238106Sdes        sslOS.flush();
164238106Sdes        sslIS.read();
165238106Sdes
166238106Sdes        sslSocket.close();
167238106Sdes    }
168238106Sdes
169238106Sdes    private void checkKeySize(KeyStore ks) throws Exception {
170238106Sdes        PrivateKey privateKey = null;
171238106Sdes        PublicKey publicKey = null;
172238106Sdes
173238106Sdes        if (ks.containsAlias(keyAlias)) {
174238106Sdes            System.out.println("Loaded entry: " + keyAlias);
175238106Sdes            privateKey = (PrivateKey)ks.getKey(keyAlias, null);
176238106Sdes            publicKey = (PublicKey)ks.getCertificate(keyAlias).getPublicKey();
177238106Sdes
178238106Sdes            int privateKeySize = KeyLength.getKeySize(privateKey);
179238106Sdes            if (privateKeySize != keySize) {
180238106Sdes                throw new Exception("Expected key size is " + keySize +
181238106Sdes                        ", but the private key size is " + privateKeySize);
182238106Sdes            }
183238106Sdes
184238106Sdes            int publicKeySize = KeyLength.getKeySize(publicKey);
185238106Sdes            if (publicKeySize != keySize) {
186238106Sdes                throw new Exception("Expected key size is " + keySize +
187238106Sdes                        ", but the public key size is " + publicKeySize);
188238106Sdes            }
189238106Sdes        }
190238106Sdes    }
191238106Sdes
192238106Sdes    /*
193238106Sdes     * =============================================================
194238106Sdes     * The remainder is just support stuff
195238106Sdes     */
196238106Sdes
197238106Sdes    // use any free port by default
198238106Sdes    volatile int serverPort = 0;
199238106Sdes
200238106Sdes    volatile Exception serverException = null;
201238106Sdes    volatile Exception clientException = null;
202238106Sdes
203238106Sdes    private static String keyAlias;
204238106Sdes    private static int keySize;
205238106Sdes    private static String clientProtocol = null;
206238106Sdes    private static String clientCiperSuite = null;
207238106Sdes
208238106Sdes    private static void parseArguments(String[] args) {
209238106Sdes        keyAlias = args[0];
210238106Sdes        keySize = Integer.parseInt(args[1]);
211238106Sdes
212238106Sdes        if (args.length > 2) {
213            clientProtocol = args[2];
214        }
215
216        if (args.length > 3) {
217            clientCiperSuite = args[3];
218        }
219    }
220
221    public static void main(String[] args) throws Exception {
222        if (debug) {
223            System.setProperty("javax.net.debug", "all");
224        }
225
226        // Get the customized arguments.
227        parseArguments(args);
228
229        new ShortRSAKeyWithinTLS();
230    }
231
232    Thread clientThread = null;
233    Thread serverThread = null;
234
235    /*
236     * Primary constructor, used to drive remainder of the test.
237     *
238     * Fork off the other side, then do your work.
239     */
240    ShortRSAKeyWithinTLS() throws Exception {
241        try {
242            if (separateServerThread) {
243                startServer(true);
244                startClient(false);
245            } else {
246                startClient(true);
247                startServer(false);
248            }
249        } catch (Exception e) {
250            // swallow for now.  Show later
251        }
252
253        /*
254         * Wait for other side to close down.
255         */
256        if (separateServerThread) {
257            serverThread.join();
258        } else {
259            clientThread.join();
260        }
261
262        /*
263         * When we get here, the test is pretty much over.
264         * Which side threw the error?
265         */
266        Exception local;
267        Exception remote;
268        String whichRemote;
269
270        if (separateServerThread) {
271            remote = serverException;
272            local = clientException;
273            whichRemote = "server";
274        } else {
275            remote = clientException;
276            local = serverException;
277            whichRemote = "client";
278        }
279
280        /*
281         * If both failed, return the curthread's exception, but also
282         * print the remote side Exception
283         */
284        if ((local != null) && (remote != null)) {
285            System.out.println(whichRemote + " also threw:");
286            remote.printStackTrace();
287            System.out.println();
288            throw local;
289        }
290
291        if (remote != null) {
292            throw remote;
293        }
294
295        if (local != null) {
296            throw local;
297        }
298    }
299
300    void startServer(boolean newThread) throws Exception {
301        if (newThread) {
302            serverThread = new Thread() {
303                public void run() {
304                    try {
305                        doServerSide();
306                    } catch (Exception e) {
307                        /*
308                         * Our server thread just died.
309                         *
310                         * Release the client, if not active already...
311                         */
312                        System.err.println("Server died...");
313                        serverReady = true;
314                        serverException = e;
315                    }
316                }
317            };
318            serverThread.start();
319        } else {
320            try {
321                doServerSide();
322            } catch (Exception e) {
323                serverException = e;
324            } finally {
325                serverReady = true;
326            }
327        }
328    }
329
330    void startClient(boolean newThread) throws Exception {
331        if (newThread) {
332            clientThread = new Thread() {
333                public void run() {
334                    try {
335                        doClientSide();
336                    } catch (Exception e) {
337                        /*
338                         * Our client thread just died.
339                         */
340                        System.err.println("Client died...");
341                        clientException = e;
342                    }
343                }
344            };
345            clientThread.start();
346        } else {
347            try {
348                doClientSide();
349            } catch (Exception e) {
350                clientException = e;
351            }
352        }
353    }
354}
355
356