1/*
2 * Copyright (c) 2009, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24/* @test
25 * @bug 4927640
26 * @summary Tests the SCTP protocol implementation
27 * @author chegar
28 */
29
30import java.net.InetSocketAddress;
31import java.net.SocketAddress;
32import java.io.IOException;
33import java.util.concurrent.CountDownLatch;
34import java.nio.channels.AlreadyConnectedException;
35import java.nio.channels.AsynchronousCloseException;
36import java.nio.channels.NotYetBoundException;
37import java.nio.channels.ClosedByInterruptException;
38import java.nio.channels.ClosedChannelException;
39import com.sun.nio.sctp.SctpChannel;
40import com.sun.nio.sctp.SctpServerChannel;
41import static java.lang.System.out;
42import static java.lang.System.err;
43
44public class Accept {
45    static CountDownLatch acceptLatch = new CountDownLatch(1);
46    static CountDownLatch closeByIntLatch = new CountDownLatch(1);
47    static CountDownLatch asyncCloseLatch = new CountDownLatch(1);
48    AcceptServer server = null;
49
50    void test(String[] args) {
51        SocketAddress address = null;
52
53        if (!Util.isSCTPSupported()) {
54            out.println("SCTP protocol is not supported");
55            out.println("Test cannot be run");
56            return;
57        }
58
59        if (args.length == 2) {
60            /* requested to connecct to a specific address */
61            try {
62                int port = Integer.valueOf(args[1]);
63                address = new InetSocketAddress(args[0], port);
64            } catch (NumberFormatException nfe) {
65                err.println(nfe);
66            }
67        } else {
68            /* start server on local machine, default */
69            try {
70                server = new AcceptServer();
71                server.start();
72                address = server.address();
73                debug("Server started and listening on " + address);
74            } catch (IOException ioe) {
75                ioe.printStackTrace();
76                return;
77            }
78        }
79
80        doClient(address);
81    }
82
83    void doClient(SocketAddress peerAddress) {
84        SctpChannel channel = null;
85
86        try {
87            channel = SctpChannel.open(peerAddress, 0, 0);
88            acceptLatch.await();
89
90            /* for test 4 */
91            closeByIntLatch.await();
92            sleep(500);
93            server.thread().interrupt();
94
95            /* for test 5 */
96            asyncCloseLatch.await();
97            sleep(500);
98            server.channel().close();
99
100            /* wait for the server thread to finish */
101            join(server.thread(), 10000);
102        } catch (IOException ioe) {
103            unexpected(ioe);
104        } catch (InterruptedException ie) {
105            unexpected(ie);
106        } finally {
107            try { if (channel != null) channel.close(); }
108            catch (IOException e) { unexpected(e);}
109        }
110    }
111
112    class AcceptServer implements Runnable
113    {
114        final InetSocketAddress serverAddr;
115        private SctpServerChannel ssc;
116        private Thread serverThread;
117
118        public AcceptServer() throws IOException {
119            ssc = SctpServerChannel.open();
120
121            /* TEST 1: NotYetBoundException */
122            debug("TEST 1: NotYetBoundException");
123            try {
124                ssc.accept();
125                fail();
126            } catch (NotYetBoundException nybe) {
127                debug("  caught NotYetBoundException");
128                pass();
129            } catch (IOException ioe) {
130                unexpected(ioe);
131            }
132
133            ssc.bind(null);
134            java.util.Set<SocketAddress> addrs = ssc.getAllLocalAddresses();
135            if (addrs.isEmpty())
136                debug("addrs should not be empty");
137
138            serverAddr = (InetSocketAddress) addrs.iterator().next();
139
140            /* TEST 2: null if this channel is in non-blocking mode and no
141             *         association is available to be accepted  */
142            ssc.configureBlocking(false);
143            debug("TEST 2: non-blocking mode null");
144            try {
145                SctpChannel sc = ssc.accept();
146                check(sc == null, "non-blocking mode should return null");
147            } catch (IOException ioe) {
148                unexpected(ioe);
149            } finally {
150                ssc.configureBlocking(true);
151            }
152        }
153
154        void start() {
155            serverThread = new Thread(this, "AcceptServer-"  +
156                                              serverAddr.getPort());
157            serverThread.start();
158        }
159
160        InetSocketAddress address() {
161            return serverAddr;
162        }
163
164        SctpServerChannel channel() {
165            return ssc;
166        }
167
168        Thread thread() {
169            return serverThread;
170        }
171
172        @Override
173        public void run() {
174            SctpChannel sc = null;
175            try {
176                /* TEST 3: accepted channel */
177                debug("TEST 3: accepted channel");
178                sc = ssc.accept();
179
180                checkAcceptedChannel(sc);
181                acceptLatch.countDown();
182
183                /* TEST 4: ClosedByInterruptException */
184                debug("TEST 4: ClosedByInterruptException");
185                try {
186                    closeByIntLatch.countDown();
187                    ssc.accept();
188                    fail();
189                } catch (ClosedByInterruptException unused) {
190                    debug("  caught ClosedByInterruptException");
191                    pass();
192                }
193
194                /* TEST 5: AsynchronousCloseException */
195                debug("TEST 5: AsynchronousCloseException");
196                /* reset thread interrupt status */
197                Thread.currentThread().interrupted();
198
199                ssc = SctpServerChannel.open().bind(null);
200                try {
201                    asyncCloseLatch.countDown();
202                    ssc.accept();
203                    fail();
204                } catch (AsynchronousCloseException unused) {
205                    debug("  caught AsynchronousCloseException");
206                    pass();
207                }
208
209                /* TEST 6: ClosedChannelException */
210                debug("TEST 6: ClosedChannelException");
211                try {
212                    ssc.accept();
213                    fail();
214                } catch (ClosedChannelException unused) {
215                    debug("  caught ClosedChannelException");
216                    pass();
217                }
218                ssc = null;
219            } catch (IOException ioe) {
220                ioe.printStackTrace();
221            } finally {
222                try { if (ssc != null) ssc.close(); }
223                catch (IOException  ioe) { unexpected(ioe); }
224                try { if (sc != null) sc.close(); }
225                catch (IOException  ioe) { unexpected(ioe); }
226            }
227        }
228    }
229
230    void checkAcceptedChannel(SctpChannel sc) {
231        try {
232            debug("Checking accepted SctpChannel");
233            check(sc.association() != null,
234                  "accepted channel should have an association");
235            check(!(sc.getRemoteAddresses().isEmpty()),
236                  "accepted channel should be connected");
237            check(!(sc.isConnectionPending()),
238                  "accepted channel should not have a connection pending");
239            check(sc.isBlocking(),
240                  "accepted channel should be blocking");
241            try { sc.connect(new TestSocketAddress()); fail(); }
242            catch (AlreadyConnectedException unused) { pass(); }
243            try { sc.bind(new TestSocketAddress()); fail(); }
244            catch (AlreadyConnectedException unused) { pass(); }
245        } catch (IOException unused) { fail(); }
246    }
247
248    static class TestSocketAddress extends SocketAddress {}
249
250        //--------------------- Infrastructure ---------------------------
251    boolean debug = true;
252    volatile int passed = 0, failed = 0;
253    void pass() {passed++;}
254    void fail() {failed++; Thread.dumpStack();}
255    void fail(String msg) {err.println(msg); fail();}
256    void unexpected(Throwable t) {failed++; t.printStackTrace();}
257    void check(boolean cond) {if (cond) pass(); else fail();}
258    void check(boolean cond, String failMessage) {if (cond) pass(); else fail(failMessage);}
259    void debug(String message) {if(debug) { out.println(message); }  }
260    void sleep(long millis) { try { Thread.currentThread().sleep(millis); }
261                          catch(InterruptedException ie) { unexpected(ie); }}
262    void join(Thread thread, long millis) { try { thread.join(millis); }
263                          catch(InterruptedException ie) { unexpected(ie); }}
264    public static void main(String[] args) throws Throwable {
265        Class<?> k = new Object(){}.getClass().getEnclosingClass();
266        try {k.getMethod("instanceMain",String[].class)
267                .invoke( k.newInstance(), (Object) args);}
268        catch (Throwable e) {throw e.getCause();}}
269    public void instanceMain(String[] args) throws Throwable {
270        try {test(args);} catch (Throwable t) {unexpected(t);}
271        out.printf("%nPassed = %d, failed = %d%n%n", passed, failed);
272        if (failed > 0) throw new AssertionError("Some tests failed");}
273
274}
275