1/*
2 * Copyright (c) 2015, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24import java.io.Closeable;
25import java.io.IOException;
26import java.io.ObjectInputStream;
27import java.io.ObjectOutputStream;
28import java.io.Serializable;
29import java.net.ServerSocket;
30import java.net.Socket;
31import java.net.UnknownHostException;
32import java.util.ArrayList;
33import java.util.Arrays;
34import java.util.HashMap;
35import java.util.Map;
36import java.util.StringJoiner;
37import javax.security.auth.callback.Callback;
38import javax.security.auth.callback.CallbackHandler;
39import javax.security.auth.callback.NameCallback;
40import javax.security.auth.callback.PasswordCallback;
41import javax.security.auth.callback.UnsupportedCallbackException;
42import javax.security.sasl.AuthorizeCallback;
43import javax.security.sasl.RealmCallback;
44import javax.security.sasl.RealmChoiceCallback;
45import javax.security.sasl.Sasl;
46import javax.security.sasl.SaslClient;
47import javax.security.sasl.SaslException;
48import javax.security.sasl.SaslServer;
49
50/*
51 * @test
52 * @bug 8049814
53 * @summary JAVA SASL server and client tests with CRAM-MD5 and
54 *          DIGEST-MD5 mechanisms. The tests try different QOP values on
55 *          client and server side.
56 * @modules java.security.sasl/javax.security.sasl
57 */
58public class ClientServerTest {
59
60    private static final int DELAY = 100;
61    private static final String LOCALHOST = "localhost";
62    private static final String DIGEST_MD5 = "DIGEST-MD5";
63    private static final String CRAM_MD5 = "CRAM-MD5";
64    private static final String PROTOCOL = "saslservice";
65    private static final String USER_ID = "sasltester";
66    private static final String PASSWD = "password";
67    private static final String QOP_AUTH = "auth";
68    private static final String QOP_AUTH_CONF = "auth-conf";
69    private static final String QOP_AUTH_INT = "auth-int";
70    private static final String AUTHID_SASL_TESTER = "sasl_tester";
71    private static final ArrayList<String> SUPPORT_MECHS = new ArrayList<>();
72
73    static {
74        SUPPORT_MECHS.add(DIGEST_MD5);
75        SUPPORT_MECHS.add(CRAM_MD5);
76    }
77
78    public static void main(String[] args) throws Exception {
79        String[] allQops = { QOP_AUTH_CONF, QOP_AUTH_INT, QOP_AUTH };
80        String[] twoQops = { QOP_AUTH_INT, QOP_AUTH };
81        String[] authQop = { QOP_AUTH };
82        String[] authIntQop = { QOP_AUTH_INT };
83        String[] authConfQop = { QOP_AUTH_CONF };
84        String[] emptyQop = {};
85
86        boolean success = true;
87
88        success &= runTest("", CRAM_MD5, new String[] { QOP_AUTH },
89                new String[] { QOP_AUTH }, false);
90        success &= runTest("", DIGEST_MD5, new String[] { QOP_AUTH },
91                new String[] { QOP_AUTH }, false);
92        success &= runTest(AUTHID_SASL_TESTER, DIGEST_MD5,
93                new String[] { QOP_AUTH }, new String[] { QOP_AUTH }, false);
94        success &= runTest("", DIGEST_MD5, allQops, authQop, false);
95        success &= runTest("", DIGEST_MD5, allQops, authIntQop, false);
96        success &= runTest("", DIGEST_MD5, allQops, authConfQop, false);
97        success &= runTest("", DIGEST_MD5, twoQops, authQop, false);
98        success &= runTest("", DIGEST_MD5, twoQops, authIntQop, false);
99        success &= runTest("", DIGEST_MD5, twoQops, authConfQop, true);
100        success &= runTest("", DIGEST_MD5, authIntQop, authQop, true);
101        success &= runTest("", DIGEST_MD5, authConfQop, authQop, true);
102        success &= runTest("", DIGEST_MD5, authConfQop, emptyQop, true);
103        success &= runTest("", DIGEST_MD5, authIntQop, emptyQop, true);
104        success &= runTest("", DIGEST_MD5, authQop, emptyQop, true);
105
106        if (!success) {
107            throw new RuntimeException("At least one test case failed");
108        }
109
110        System.out.println("Test passed");
111    }
112
113    private static boolean runTest(String authId, String mech,
114            String[] clientQops, String[] serverQops, boolean expectException)
115            throws Exception {
116
117        System.out.println("AuthId:" + authId
118                + " mechanism:" + mech
119                + " clientQops: " + Arrays.toString(clientQops)
120                + " serverQops: " + Arrays.toString(serverQops)
121                + " expect exception:" + expectException);
122
123        try (Server server = Server.start(LOCALHOST, authId, serverQops)) {
124            new Client(LOCALHOST, server.getPort(), mech, authId, clientQops)
125                    .run();
126            if (expectException) {
127                System.out.println("Expected exception not thrown");
128                return false;
129            }
130        } catch (SaslException e) {
131            if (!expectException) {
132                System.out.println("Unexpected exception: " + e);
133                return false;
134            }
135            System.out.println("Expected exception: " + e);
136        }
137
138        return true;
139    }
140
141    static enum SaslStatus {
142        SUCCESS, FAILURE, CONTINUE
143    }
144
145    static class Message implements Serializable {
146
147        private final SaslStatus status;
148        private final byte[] data;
149
150        public Message(SaslStatus status, byte[] data) {
151            this.status = status;
152            this.data = data;
153        }
154
155        public SaslStatus getStatus() {
156            return status;
157        }
158
159        public byte[] getData() {
160            return data;
161        }
162    }
163
164    static class SaslPeer {
165
166        final String host;
167        final String mechanism;
168        final String qop;
169        final CallbackHandler callback;
170
171        SaslPeer(String host, String authId, String... qops) {
172            this(host, null, authId, qops);
173        }
174
175        SaslPeer(String host, String mechanism, String authId, String... qops) {
176            this.host = host;
177            this.mechanism = mechanism;
178
179            StringJoiner sj = new StringJoiner(",");
180            for (String q : qops) {
181                sj.add(q);
182            }
183            qop = sj.toString();
184
185            callback = new TestCallbackHandler(USER_ID, PASSWD, host, authId);
186        }
187
188        Message getMessage(Object ob) {
189            if (!(ob instanceof Message)) {
190                throw new RuntimeException("Expected an instance of Message");
191            }
192            return (Message) ob;
193        }
194    }
195
196    static class Server extends SaslPeer implements Runnable, Closeable {
197
198        private volatile boolean ready = false;
199        private volatile ServerSocket ssocket;
200
201        static Server start(String host, String authId, String[] serverQops)
202                throws UnknownHostException {
203            Server server = new Server(host, authId, serverQops);
204            Thread thread = new Thread(server);
205            thread.setDaemon(true);
206            thread.start();
207
208            while (!server.ready) {
209                try {
210                    Thread.sleep(DELAY);
211                } catch (InterruptedException e) {
212                    throw new RuntimeException(e);
213                }
214            }
215
216            return server;
217        }
218
219        Server(String host, String authId, String... qops) {
220            super(host, authId, qops);
221        }
222
223        int getPort() {
224            return ssocket.getLocalPort();
225        }
226
227        private void processConnection(SaslEndpoint endpoint)
228                throws SaslException, IOException, ClassNotFoundException {
229            System.out.println("process connection");
230            endpoint.send(SUPPORT_MECHS);
231            Object o = endpoint.receive();
232            if (!(o instanceof String)) {
233                throw new RuntimeException("Received unexpected object: " + o);
234            }
235            String mech = (String) o;
236            SaslServer saslServer = createSaslServer(mech);
237            Message msg = getMessage(endpoint.receive());
238            while (!saslServer.isComplete()) {
239                byte[] data = processData(msg.getData(), endpoint,
240                        saslServer);
241                if (saslServer.isComplete()) {
242                    System.out.println("server is complete");
243                    endpoint.send(new Message(SaslStatus.SUCCESS, data));
244                } else {
245                    System.out.println("server continues");
246                    endpoint.send(new Message(SaslStatus.CONTINUE, data));
247                    msg = getMessage(endpoint.receive());
248                }
249            }
250        }
251
252        private byte[] processData(byte[] data, SaslEndpoint endpoint,
253                SaslServer server) throws SaslException, IOException {
254            try {
255                return server.evaluateResponse(data);
256            } catch (SaslException e) {
257                endpoint.send(new Message(SaslStatus.FAILURE, null));
258                System.out.println("Error while processing data");
259                throw e;
260            }
261        }
262
263        private SaslServer createSaslServer(String mechanism)
264                throws SaslException {
265            Map<String, String> props = new HashMap<>();
266            props.put(Sasl.QOP, qop);
267            return Sasl.createSaslServer(mechanism, PROTOCOL, host, props,
268                    callback);
269        }
270
271        @Override
272        public void run() {
273            try (ServerSocket ss = new ServerSocket(0)) {
274                ssocket = ss;
275                System.out.println("server started on port " + getPort());
276                ready = true;
277                Socket socket = ss.accept();
278                try (SaslEndpoint endpoint = new SaslEndpoint(socket)) {
279                    System.out.println("server accepted connection");
280                    processConnection(endpoint);
281                }
282            } catch (Exception e) {
283                // ignore it for now, client will throw an exception
284            }
285        }
286
287        @Override
288        public void close() throws IOException {
289            if (!ssocket.isClosed()) {
290                ssocket.close();
291            }
292        }
293    }
294
295    static class Client extends SaslPeer {
296
297        private final int port;
298
299        Client(String host, int port, String mech, String authId,
300                String... qops) {
301            super(host, mech, authId, qops);
302            this.port = port;
303        }
304
305        public void run() throws Exception {
306            System.out.println("Host:" + host + " port: "
307                    + port);
308            try (SaslEndpoint endpoint = SaslEndpoint.create(host, port)) {
309                negotiateMechanism(endpoint);
310                SaslClient client = createSaslClient();
311                byte[] data = new byte[0];
312                if (client.hasInitialResponse()) {
313                    data = client.evaluateChallenge(data);
314                }
315                endpoint.send(new Message(SaslStatus.CONTINUE, data));
316                Message msg = getMessage(endpoint.receive());
317                while (!client.isComplete()
318                        && msg.getStatus() != SaslStatus.FAILURE) {
319                    switch (msg.getStatus()) {
320                        case CONTINUE:
321                            System.out.println("client continues");
322                            data = client.evaluateChallenge(msg.getData());
323                            endpoint.send(new Message(SaslStatus.CONTINUE,
324                                    data));
325                            msg = getMessage(endpoint.receive());
326                            break;
327                        case SUCCESS:
328                            System.out.println("client succeeded");
329                            data = client.evaluateChallenge(msg.getData());
330                            if (data != null) {
331                                throw new SaslException("data should be null");
332                            }
333                            break;
334                        default:
335                            throw new RuntimeException("Wrong status:"
336                                    + msg.getStatus());
337                    }
338                }
339
340                if (msg.getStatus() == SaslStatus.FAILURE) {
341                    throw new RuntimeException("Status is FAILURE");
342                }
343            }
344
345            System.out.println("Done");
346        }
347
348        private SaslClient createSaslClient() throws SaslException {
349            Map<String, String> props = new HashMap<>();
350            props.put(Sasl.QOP, qop);
351            return Sasl.createSaslClient(new String[] {mechanism}, USER_ID,
352                    PROTOCOL, host, props, callback);
353        }
354
355        private void negotiateMechanism(SaslEndpoint endpoint)
356                throws ClassNotFoundException, IOException {
357            Object o = endpoint.receive();
358            if (o instanceof ArrayList) {
359                ArrayList list = (ArrayList) o;
360                if (!list.contains(mechanism)) {
361                    throw new RuntimeException(
362                            "Server does not support specified mechanism:"
363                                    + mechanism);
364                }
365            } else {
366                throw new RuntimeException(
367                        "Expected an instance of ArrayList, but received " + o);
368            }
369
370            endpoint.send(mechanism);
371        }
372
373    }
374
375    static class SaslEndpoint implements AutoCloseable {
376
377        private final Socket socket;
378        private ObjectInputStream input;
379        private ObjectOutputStream output;
380
381        static SaslEndpoint create(String host, int port) throws IOException {
382            return new SaslEndpoint(new Socket(host, port));
383        }
384
385        SaslEndpoint(Socket socket) throws IOException {
386            this.socket = socket;
387        }
388
389        private ObjectInputStream getInput() throws IOException {
390            if (input == null && socket != null) {
391                input = new ObjectInputStream(socket.getInputStream());
392            }
393            return input;
394        }
395
396        private ObjectOutputStream getOutput() throws IOException {
397            if (output == null && socket != null) {
398                output = new ObjectOutputStream(socket.getOutputStream());
399            }
400            return output;
401        }
402
403        public Object receive() throws IOException, ClassNotFoundException {
404            return getInput().readObject();
405        }
406
407        public void send(Object obj) throws IOException {
408            getOutput().writeObject(obj);
409            getOutput().flush();
410        }
411
412        @Override
413        public void close() throws IOException {
414            if (socket != null && !socket.isClosed()) {
415                socket.close();
416            }
417        }
418
419    }
420
421    static class TestCallbackHandler implements CallbackHandler {
422
423        private final String userId;
424        private final char[] passwd;
425        private final String realm;
426        private String authId;
427
428        TestCallbackHandler(String userId, String passwd, String realm,
429                String authId) {
430            this.userId = userId;
431            this.passwd = passwd.toCharArray();
432            this.realm = realm;
433            this.authId = authId;
434        }
435
436        @Override
437        public void handle(Callback[] callbacks) throws IOException,
438                UnsupportedCallbackException {
439            for (Callback callback : callbacks) {
440                if (callback instanceof NameCallback) {
441                    System.out.println("NameCallback");
442                    ((NameCallback) callback).setName(userId);
443                } else if (callback instanceof PasswordCallback) {
444                    System.out.println("PasswordCallback");
445                    ((PasswordCallback) callback).setPassword(passwd);
446                } else if (callback instanceof RealmCallback) {
447                    System.out.println("RealmCallback");
448                    ((RealmCallback) callback).setText(realm);
449                } else if (callback instanceof RealmChoiceCallback) {
450                    System.out.println("RealmChoiceCallback");
451                    RealmChoiceCallback choice = (RealmChoiceCallback) callback;
452                    if (realm == null) {
453                        choice.setSelectedIndex(choice.getDefaultChoice());
454                    } else {
455                        String[] choices = choice.getChoices();
456                        for (int j = 0; j < choices.length; j++) {
457                            if (realm.equals(choices[j])) {
458                                choice.setSelectedIndex(j);
459                                break;
460                            }
461                        }
462                    }
463                } else if (callback instanceof AuthorizeCallback) {
464                    System.out.println("AuthorizeCallback");
465                    ((AuthorizeCallback) callback).setAuthorized(true);
466                    if (authId == null || authId.trim().length() == 0) {
467                        authId = userId;
468                    }
469                    ((AuthorizeCallback) callback).setAuthorizedID(authId);
470                } else {
471                    throw new UnsupportedCallbackException(callback);
472                }
473            }
474        }
475    }
476
477}
478