1/*
2 * Copyright (c) 2002, 2010, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24/*
25 * @test
26 * @bug 4636628
27 * @summary HttpURLConnection duplicates HTTP GET requests when used with multiple threads
28*/
29
30/*
31 * This tests keep-alive behavior using chunkedinputstreams
32 * It checks that keep-alive connections are used and also
33 * that requests are not being repeated (due to errors)
34 *
35 * It also checks that the keepalive connections are closed eventually
36 * because the test will not terminate if the connections
37 * are not closed by the keep-alive timer.
38 */
39
40import java.net.*;
41import java.io.*;
42
43public class MultiThreadTest extends Thread {
44
45    /*
46     * Is debugging enabled - start with -d to enable.
47     */
48    static boolean debug = false;
49
50    static Object threadlock = new Object ();
51    static int threadCounter = 0;
52
53    static Object getLock() { return threadlock; }
54
55    static void debug(String msg) {
56        if (debug)
57            System.out.println(msg);
58    }
59
60    static int reqnum = 0;
61
62    void doRequest(String uri) throws Exception {
63        URL url = new URL(uri + "?foo="+reqnum);
64        reqnum ++;
65        HttpURLConnection http = (HttpURLConnection)url.openConnection();
66
67        InputStream in = http.getInputStream();
68        byte b[] = new byte[100];
69        int total = 0;
70        int n;
71        do {
72            n = in.read(b);
73            if (n > 0) total += n;
74        } while (n > 0);
75        debug ("client: read " + total + " bytes");
76        in.close();
77        http.disconnect();
78    }
79
80    String uri;
81    byte[] b;
82    int requests;
83
84    MultiThreadTest(int port, int requests) throws Exception {
85        uri = "http://localhost:" +
86                     port + "/foo.html";
87
88        b = new byte [256];
89        this.requests = requests;
90
91        synchronized (threadlock) {
92            threadCounter ++;
93        }
94    }
95
96    public void run () {
97        try {
98            for (int i=0; i<requests; i++) {
99                doRequest (uri);
100            }
101        } catch (Exception e) {
102            throw new RuntimeException (e.getMessage());
103        } finally {
104            synchronized (threadlock) {
105                threadCounter --;
106                if (threadCounter == 0) {
107                    threadlock.notifyAll();
108                }
109            }
110        }
111    }
112
113    static int threads=5;
114
115    public static void main(String args[]) throws Exception {
116
117        int x = 0, arg_len = args.length;
118        int requests = 20;
119
120        if (arg_len > 0 && args[0].equals("-d")) {
121            debug = true;
122            x = 1;
123            arg_len --;
124        }
125        if (arg_len > 0) {
126            threads = Integer.parseInt (args[x]);
127            requests = Integer.parseInt (args[x+1]);
128        }
129
130        /* start the server */
131        ServerSocket ss = new ServerSocket(0);
132        Server svr = new Server(ss);
133        svr.start();
134
135        Object lock = MultiThreadTest.getLock();
136        synchronized (lock) {
137            for (int i=0; i<threads; i++) {
138                MultiThreadTest t = new MultiThreadTest(ss.getLocalPort(), requests);
139                t.start ();
140            }
141            try {
142                lock.wait();
143            } catch (InterruptedException e) {}
144        }
145
146        // shutdown server - we're done.
147        svr.shutdown();
148
149        int cnt = svr.connectionCount();
150        MultiThreadTest.debug("Connections = " + cnt);
151        int reqs = Worker.getRequests ();
152        MultiThreadTest.debug("Requests = " + reqs);
153        System.out.println ("Connection count = " + cnt + " Request count = " + reqs);
154        if (cnt > threads) { // could be less
155            throw new RuntimeException ("Expected "+threads + " connections: used " +cnt);
156        }
157        if  (reqs != threads*requests) {
158            throw new RuntimeException ("Expected "+ threads*requests+ " requests: got " +reqs);
159        }
160    }
161}
162
163    /*
164     * Server thread to accept connection and create worker threads
165     * to service each connection.
166     */
167    class Server extends Thread {
168        ServerSocket ss;
169        int connectionCount;
170        boolean shutdown = false;
171
172        Server(ServerSocket ss) {
173            this.ss = ss;
174        }
175
176        public synchronized int connectionCount() {
177            return connectionCount;
178        }
179
180        public synchronized void shutdown() {
181            shutdown = true;
182        }
183
184        public void run() {
185            try {
186                ss.setSoTimeout(2000);
187
188                for (;;) {
189                    Socket s;
190                    try {
191                        MultiThreadTest.debug("server: calling accept.");
192                        s = ss.accept();
193                        MultiThreadTest.debug("server: return accept.");
194                    } catch (SocketTimeoutException te) {
195                        MultiThreadTest.debug("server: STE");
196                        synchronized (this) {
197                            if (shutdown) {
198                                MultiThreadTest.debug("server: Shuting down.");
199                                return;
200                            }
201                        }
202                        continue;
203                    }
204
205                    int id;
206                    synchronized (this) {
207                        id = connectionCount++;
208                    }
209
210                    Worker w = new Worker(s, id);
211                    w.start();
212                    MultiThreadTest.debug("server: Started worker " + id);
213                }
214
215            } catch (Exception e) {
216                e.printStackTrace();
217            } finally {
218                try {
219                    ss.close();
220                } catch (Exception e) { }
221            }
222        }
223    }
224
225    /*
226     * Worker thread to service single connection - can service
227     * multiple http requests on same connection.
228     */
229    class Worker extends Thread {
230        Socket s;
231        int id;
232
233        Worker(Socket s, int id) {
234            this.s = s;
235            this.id = id;
236        }
237
238        static int requests = 0;
239        static Object rlock = new Object();
240
241        public static int getRequests () {
242            synchronized (rlock) {
243                return requests;
244            }
245        }
246        public static void incRequests () {
247            synchronized (rlock) {
248                requests++;
249            }
250        }
251
252        int readUntil (InputStream in, char[] seq) throws IOException {
253            int i=0, count=0;
254            while (true) {
255                int c = in.read();
256                if (c == -1)
257                    return -1;
258                count++;
259                if (c == seq[i]) {
260                    i++;
261                    if (i == seq.length)
262                        return count;
263                    continue;
264                } else {
265                    i = 0;
266                }
267            }
268        }
269
270        public void run() {
271            try {
272                int max = 400;
273                byte b[] = new byte[1000];
274                InputStream in = new BufferedInputStream (s.getInputStream());
275                // response to client
276                PrintStream out = new PrintStream(
277                                    new BufferedOutputStream(
278                                                s.getOutputStream() ));
279
280                for (;;) {
281
282                    // read entire request from client
283                    int n=0;
284
285                    n = readUntil (in, new char[] {'\r','\n', '\r','\n'});
286
287                    if (n <= 0) {
288                        MultiThreadTest.debug("worker: " + id + ": Shutdown");
289                        s.close();
290                        return;
291                    }
292
293                    MultiThreadTest.debug("worker " + id +
294                        ": Read request from client " +
295                        "(" + n + " bytes).");
296
297                    incRequests();
298                    out.print("HTTP/1.1 200 OK\r\n");
299                    out.print("Transfer-Encoding: chunked\r\n");
300                    out.print("Content-Type: text/html\r\n");
301                    out.print("Connection: Keep-Alive\r\n");
302                    out.print ("Keep-Alive: timeout=15, max="+max+"\r\n");
303                    out.print("\r\n");
304                    out.print ("6\r\nHello \r\n");
305                    out.print ("5\r\nWorld\r\n");
306                    out.print ("0\r\n\r\n");
307                    out.flush();
308
309                    if (--max == 0) {
310                        s.close();
311                        return;
312                    }
313                }
314
315            } catch (Exception e) {
316                e.printStackTrace();
317            } finally {
318                try {
319                    s.close();
320                } catch (Exception e) { }
321            }
322        }
323    }
324