1/*
2 * Copyright (c) 2005, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24/*
25 *
26 */
27
28import java.net.*;
29import java.io.*;
30import java.nio.*;
31import java.nio.channels.*;
32import sun.net.www.MessageHeader;
33import java.util.*;
34
35public class TunnelProxy {
36
37    ServerSocketChannel schan;
38    int threads;
39    int cperthread;
40    Server[] servers;
41
42    /**
43     * Create a <code>TunnelProxy<code> instance with the specified callback object
44     * for handling requests. One thread is created to handle requests,
45     * and up to ten TCP connections will be handled simultaneously.
46     * @param cb the callback object which is invoked to handle each
47     *  incoming request
48     */
49
50    public TunnelProxy () throws IOException {
51        this (1, 10, 0);
52    }
53
54    /**
55     * Create a <code>TunnelProxy<code> instance with the specified number of
56     * threads and maximum number of connections per thread. This functions
57     * the same as the 4 arg constructor, where the port argument is set to zero.
58     * @param cb the callback object which is invoked to handle each
59     *     incoming request
60     * @param threads the number of threads to create to handle requests
61     *     in parallel
62     * @param cperthread the number of simultaneous TCP connections to
63     *     handle per thread
64     */
65
66    public TunnelProxy (int threads, int cperthread)
67        throws IOException {
68        this (threads, cperthread, 0);
69    }
70
71    /**
72     * Create a <code>TunnelProxy<code> instance with the specified number
73     * of threads and maximum number of connections per thread and running on
74     * the specified port. The specified number of threads are created to
75     * handle incoming requests, and each thread is allowed
76     * to handle a number of simultaneous TCP connections.
77     * @param cb the callback object which is invoked to handle
78     *  each incoming request
79     * @param threads the number of threads to create to handle
80     *  requests in parallel
81     * @param cperthread the number of simultaneous TCP connections
82     *  to handle per thread
83     * @param port the port number to bind the server to. <code>Zero</code>
84     *  means choose any free port.
85     */
86
87    public TunnelProxy (int threads, int cperthread, int port)
88        throws IOException {
89        schan = ServerSocketChannel.open ();
90        InetSocketAddress addr = new InetSocketAddress (port);
91        schan.socket().bind (addr);
92        this.threads = threads;
93        this.cperthread = cperthread;
94        servers = new Server [threads];
95        for (int i=0; i<threads; i++) {
96            servers[i] = new Server (schan, cperthread);
97            servers[i].start();
98        }
99    }
100
101    /** Tell all threads in the server to exit within 5 seconds.
102     *  This is an abortive termination. Just prior to the thread exiting
103     *  all channels in that thread waiting to be closed are forceably closed.
104     */
105
106    public void terminate () {
107        for (int i=0; i<threads; i++) {
108            servers[i].terminate ();
109        }
110    }
111
112    /**
113     * return the local port number to which the server is bound.
114     * @return the local port number
115     */
116
117    public int getLocalPort () {
118        return schan.socket().getLocalPort ();
119    }
120
121    static class Server extends Thread {
122
123        ServerSocketChannel schan;
124        Selector selector;
125        SelectionKey listenerKey;
126        SelectionKey key; /* the current key being processed */
127        ByteBuffer consumeBuffer;
128        int maxconn;
129        int nconn;
130        ClosedChannelList clist;
131        boolean shutdown;
132        Pipeline pipe1 = null;
133        Pipeline pipe2 = null;
134
135        Server (ServerSocketChannel schan, int maxconn) {
136            this.schan = schan;
137            this.maxconn = maxconn;
138            nconn = 0;
139            consumeBuffer = ByteBuffer.allocate (512);
140            clist = new ClosedChannelList ();
141            try {
142                selector = Selector.open ();
143                schan.configureBlocking (false);
144                listenerKey = schan.register (selector, SelectionKey.OP_ACCEPT);
145            } catch (IOException e) {
146                System.err.println ("Server could not start: " + e);
147            }
148        }
149
150        /* Stop the thread as soon as possible */
151        public synchronized void terminate () {
152            shutdown = true;
153            if (pipe1 != null) pipe1.terminate();
154            if (pipe2 != null) pipe2.terminate();
155        }
156
157        public void run ()  {
158            try {
159                while (true) {
160                    selector.select (1000);
161                    Set selected = selector.selectedKeys();
162                    Iterator iter = selected.iterator();
163                    while (iter.hasNext()) {
164                        key = (SelectionKey)iter.next();
165                        if (key.equals (listenerKey)) {
166                            SocketChannel sock = schan.accept ();
167                            if (sock == null) {
168                                /* false notification */
169                                iter.remove();
170                                continue;
171                            }
172                            sock.configureBlocking (false);
173                            sock.register (selector, SelectionKey.OP_READ);
174                            nconn ++;
175                            if (nconn == maxconn) {
176                                /* deregister */
177                                listenerKey.cancel ();
178                                listenerKey = null;
179                            }
180                        } else {
181                            if (key.isReadable()) {
182                                boolean closed;
183                                SocketChannel chan = (SocketChannel) key.channel();
184                                if (key.attachment() != null) {
185                                    closed = consume (chan);
186                                } else {
187                                    closed = read (chan, key);
188                                }
189                                if (closed) {
190                                    chan.close ();
191                                    key.cancel ();
192                                    if (nconn == maxconn) {
193                                        listenerKey = schan.register (selector, SelectionKey.OP_ACCEPT);
194                                    }
195                                    nconn --;
196                                }
197                            }
198                        }
199                        iter.remove();
200                    }
201                    clist.check();
202                    if (shutdown) {
203                        clist.terminate ();
204                        return;
205                    }
206                }
207            } catch (IOException e) {
208                System.out.println ("Server exception: " + e);
209                // TODO finish
210            }
211        }
212
213        /* read all the data off the channel without looking at it
214             * return true if connection closed
215             */
216        boolean consume (SocketChannel chan) {
217            try {
218                consumeBuffer.clear ();
219                int c = chan.read (consumeBuffer);
220                if (c == -1)
221                    return true;
222            } catch (IOException e) {
223                return true;
224            }
225            return false;
226        }
227
228        /* return true if the connection is closed, false otherwise */
229
230        private boolean read (SocketChannel chan, SelectionKey key) {
231            HttpTransaction msg;
232            boolean res;
233            try {
234                InputStream is = new BufferedInputStream (new NioInputStream (chan));
235                String requestline = readLine (is);
236                MessageHeader mhead = new MessageHeader (is);
237                String[] req = requestline.split (" ");
238                if (req.length < 2) {
239                    /* invalid request line */
240                    return false;
241                }
242                String cmd = req[0];
243                URI uri = null;
244                if (!("CONNECT".equalsIgnoreCase(cmd))) {
245                    // we expect CONNECT command
246                    return false;
247                }
248                try {
249                    uri = new URI("http://" + req[1]);
250                } catch (URISyntaxException e) {
251                    System.err.println ("Invalid URI: " + e);
252                    res = true;
253                }
254
255                // CONNECT ack
256                OutputStream os = new BufferedOutputStream(new NioOutputStream(chan));
257                byte[] ack = "HTTP/1.1 200 Connection established\r\n\r\n".getBytes();
258                os.write(ack, 0, ack.length);
259                os.flush();
260
261                // tunnel anything else
262                tunnel(is, os, uri);
263
264                res = false;
265            } catch (IOException e) {
266                res = true;
267            }
268            return res;
269        }
270
271        private void tunnel(InputStream fromClient, OutputStream toClient, URI serverURI) throws IOException {
272            Socket sockToServer = new Socket(serverURI.getHost(), serverURI.getPort());
273            OutputStream toServer = sockToServer.getOutputStream();
274            InputStream fromServer = sockToServer.getInputStream();
275
276            pipe1 = new Pipeline(fromClient, toServer);
277            pipe2 = new Pipeline(fromServer, toClient);
278            // start pump
279            pipe1.start();
280            pipe2.start();
281            // wait them to end
282            try {
283                pipe1.join();
284            } catch (InterruptedException e) {
285                // No-op
286            } finally {
287                sockToServer.close();
288            }
289        }
290
291        private String readLine (InputStream is) throws IOException {
292            boolean done=false, readCR=false;
293            byte[] b = new byte [512];
294            int c, l = 0;
295
296            while (!done) {
297                c = is.read ();
298                if (c == '\n' && readCR) {
299                    done = true;
300                } else {
301                    if (c == '\r' && !readCR) {
302                        readCR = true;
303                    } else {
304                        b[l++] = (byte)c;
305                    }
306                }
307            }
308            return new String (b);
309        }
310
311        /** close the channel associated with the current key by:
312         * 1. shutdownOutput (send a FIN)
313         * 2. mark the key so that incoming data is to be consumed and discarded
314         * 3. After a period, close the socket
315         */
316
317        synchronized void orderlyCloseChannel (SelectionKey key) throws IOException {
318            SocketChannel ch = (SocketChannel)key.channel ();
319            ch.socket().shutdownOutput();
320            key.attach (this);
321            clist.add (key);
322        }
323
324        synchronized void abortiveCloseChannel (SelectionKey key) throws IOException {
325            SocketChannel ch = (SocketChannel)key.channel ();
326            Socket s = ch.socket ();
327            s.setSoLinger (true, 0);
328            ch.close();
329        }
330    }
331
332
333    /**
334     * Implements blocking reading semantics on top of a non-blocking channel
335     */
336
337    static class NioInputStream extends InputStream {
338        SocketChannel channel;
339        Selector selector;
340        ByteBuffer chanbuf;
341        SelectionKey key;
342        int available;
343        byte[] one;
344        boolean closed;
345        ByteBuffer markBuf; /* reads may be satisifed from this buffer */
346        boolean marked;
347        boolean reset;
348        int readlimit;
349
350        public NioInputStream (SocketChannel chan) throws IOException {
351            this.channel = chan;
352            selector = Selector.open();
353            chanbuf = ByteBuffer.allocate (1024);
354            key = chan.register (selector, SelectionKey.OP_READ);
355            available = 0;
356            one = new byte[1];
357            closed = marked = reset = false;
358        }
359
360        public synchronized int read (byte[] b) throws IOException {
361            return read (b, 0, b.length);
362        }
363
364        public synchronized int read () throws IOException {
365            return read (one, 0, 1);
366        }
367
368        public synchronized int read (byte[] b, int off, int srclen) throws IOException {
369
370            int canreturn, willreturn;
371
372            if (closed)
373                return -1;
374
375            if (reset) { /* satisfy from markBuf */
376                canreturn = markBuf.remaining ();
377                willreturn = canreturn>srclen ? srclen : canreturn;
378                markBuf.get(b, off, willreturn);
379                if (canreturn == willreturn) {
380                    reset = false;
381                }
382            } else { /* satisfy from channel */
383                canreturn = available();
384                if (canreturn == 0) {
385                    block ();
386                    canreturn = available();
387                }
388                willreturn = canreturn>srclen ? srclen : canreturn;
389                chanbuf.get(b, off, willreturn);
390                available -= willreturn;
391
392                if (marked) { /* copy into markBuf */
393                    try {
394                        markBuf.put (b, off, willreturn);
395                    } catch (BufferOverflowException e) {
396                        marked = false;
397                    }
398                }
399            }
400            return willreturn;
401        }
402
403        public synchronized int available () throws IOException {
404            if (closed)
405                throw new IOException ("Stream is closed");
406
407            if (reset)
408                return markBuf.remaining();
409
410            if (available > 0)
411                return available;
412
413            chanbuf.clear ();
414            available = channel.read (chanbuf);
415            if (available > 0)
416                chanbuf.flip();
417            else if (available == -1)
418                throw new IOException ("Stream is closed");
419            return available;
420        }
421
422        /**
423         * block() only called when available==0 and buf is empty
424         */
425        private synchronized void block () throws IOException {
426            //assert available == 0;
427            int n = selector.select ();
428            //assert n == 1;
429            selector.selectedKeys().clear();
430            available ();
431        }
432
433        public void close () throws IOException {
434            if (closed)
435                return;
436            channel.close ();
437            closed = true;
438        }
439
440        public synchronized void mark (int readlimit) {
441            if (closed)
442                return;
443            this.readlimit = readlimit;
444            markBuf = ByteBuffer.allocate (readlimit);
445            marked = true;
446            reset = false;
447        }
448
449        public synchronized void reset () throws IOException {
450            if (closed )
451                return;
452            if (!marked)
453                throw new IOException ("Stream not marked");
454            marked = false;
455            reset = true;
456            markBuf.flip ();
457        }
458    }
459
460    static class NioOutputStream extends OutputStream {
461        SocketChannel channel;
462        ByteBuffer buf;
463        SelectionKey key;
464        Selector selector;
465        boolean closed;
466        byte[] one;
467
468        public NioOutputStream (SocketChannel channel) throws IOException {
469            this.channel = channel;
470            selector = Selector.open ();
471            key = channel.register (selector, SelectionKey.OP_WRITE);
472            closed = false;
473            one = new byte [1];
474        }
475
476        public synchronized void write (int b) throws IOException {
477            one[0] = (byte)b;
478            write (one, 0, 1);
479        }
480
481        public synchronized void write (byte[] b) throws IOException {
482            write (b, 0, b.length);
483        }
484
485        public synchronized void write (byte[] b, int off, int len) throws IOException {
486            if (closed)
487                throw new IOException ("stream is closed");
488
489            buf = ByteBuffer.allocate (len);
490            buf.put (b, off, len);
491            buf.flip ();
492            int n;
493            while ((n = channel.write (buf)) < len) {
494                len -= n;
495                if (len == 0)
496                    return;
497                selector.select ();
498                selector.selectedKeys().clear ();
499            }
500        }
501
502        public void close () throws IOException {
503            if (closed)
504                return;
505            channel.close ();
506            closed = true;
507        }
508    }
509
510    /*
511     * Pipeline object :-
512     * 1) Will pump every byte from its input stream to output stream
513     * 2) Is an 'active object'
514     */
515    static class Pipeline implements Runnable {
516        InputStream in;
517        OutputStream out;
518        Thread t;
519
520        public Pipeline(InputStream is, OutputStream os) {
521            in = is;
522            out = os;
523        }
524
525        public void start() {
526            t = new Thread(this);
527            t.start();
528        }
529
530        public void join() throws InterruptedException {
531            t.join();
532        }
533
534        public void terminate() {
535            t.interrupt();
536        }
537
538        public void run() {
539            byte[] buffer = new byte[10000];
540            try {
541                while (!Thread.interrupted()) {
542                    int len;
543                    while ((len = in.read(buffer)) != -1) {
544                        out.write(buffer, 0, len);
545                        out.flush();
546                    }
547                }
548            } catch(IOException e) {
549                // No-op
550            } finally {
551            }
552        }
553    }
554
555    /**
556     * Utilities for synchronization. A condition is
557     * identified by a string name, and is initialized
558     * upon first use (ie. setCondition() or waitForCondition()). Threads
559     * are blocked until some thread calls (or has called) setCondition() for the same
560     * condition.
561     * <P>
562     * A rendezvous built on a condition is also provided for synchronizing
563     * N threads.
564     */
565
566    private static HashMap conditions = new HashMap();
567
568    /*
569     * Modifiable boolean object
570     */
571    private static class BValue {
572        boolean v;
573    }
574
575    /*
576     * Modifiable int object
577     */
578    private static class IValue {
579        int v;
580        IValue (int i) {
581            v =i;
582        }
583    }
584
585
586    private static BValue getCond (String condition) {
587        synchronized (conditions) {
588            BValue cond = (BValue) conditions.get (condition);
589            if (cond == null) {
590                cond = new BValue();
591                conditions.put (condition, cond);
592            }
593            return cond;
594        }
595    }
596
597    /**
598     * Set the condition to true. Any threads that are currently blocked
599     * waiting on the condition, will be unblocked and allowed to continue.
600     * Threads that subsequently call waitForCondition() will not block.
601     * If the named condition did not exist prior to the call, then it is created
602     * first.
603     */
604
605    public static void setCondition (String condition) {
606        BValue cond = getCond (condition);
607        synchronized (cond) {
608            if (cond.v) {
609                return;
610            }
611            cond.v = true;
612            cond.notifyAll();
613        }
614    }
615
616    /**
617     * If the named condition does not exist, then it is created and initialized
618     * to false. If the condition exists or has just been created and its value
619     * is false, then the thread blocks until another thread sets the condition.
620     * If the condition exists and is already set to true, then this call returns
621     * immediately without blocking.
622     */
623
624    public static void waitForCondition (String condition) {
625        BValue cond = getCond (condition);
626        synchronized (cond) {
627            if (!cond.v) {
628                try {
629                    cond.wait();
630                } catch (InterruptedException e) {}
631            }
632        }
633    }
634
635    /* conditions must be locked when accessing this */
636    static HashMap rv = new HashMap();
637
638    /**
639     * Force N threads to rendezvous (ie. wait for each other) before proceeding.
640     * The first thread(s) to call are blocked until the last
641     * thread makes the call. Then all threads continue.
642     * <p>
643     * All threads that call with the same condition name, must use the same value
644     * for N (or the results may be not be as expected).
645     * <P>
646     * Obviously, if fewer than N threads make the rendezvous then the result
647     * will be a hang.
648     */
649
650    public static void rendezvous (String condition, int N) {
651        BValue cond;
652        IValue iv;
653        String name = "RV_"+condition;
654
655        /* get the condition */
656
657        synchronized (conditions) {
658            cond = (BValue)conditions.get (name);
659            if (cond == null) {
660                /* we are first caller */
661                if (N < 2) {
662                    throw new RuntimeException ("rendezvous must be called with N >= 2");
663                }
664                cond = new BValue ();
665                conditions.put (name, cond);
666                iv = new IValue (N-1);
667                rv.put (name, iv);
668            } else {
669                /* already initialised, just decrement the counter */
670                iv = (IValue) rv.get (name);
671                iv.v --;
672            }
673        }
674
675        if (iv.v > 0) {
676            waitForCondition (name);
677        } else {
678            setCondition (name);
679            synchronized (conditions) {
680                clearCondition (name);
681                rv.remove (name);
682            }
683        }
684    }
685
686    /**
687     * If the named condition exists and is set then remove it, so it can
688     * be re-initialized and used again. If the condition does not exist, or
689     * exists but is not set, then the call returns without doing anything.
690     * Note, some higher level synchronization
691     * may be needed between clear and the other operations.
692     */
693
694    public static void clearCondition(String condition) {
695        BValue cond;
696        synchronized (conditions) {
697            cond = (BValue) conditions.get (condition);
698            if (cond == null) {
699                return;
700            }
701            synchronized (cond) {
702                if (cond.v) {
703                    conditions.remove (condition);
704                }
705            }
706        }
707    }
708}
709