1/*
2 * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24import java.io.ByteArrayOutputStream;
25import java.io.FilterInputStream;
26import java.io.IOException;
27import java.io.InputStream;
28import java.io.OutputStream;
29import java.io.Serializable;
30import java.net.InetAddress;
31import java.net.ServerSocket;
32import java.net.Socket;
33import java.net.SocketAddress;
34import java.net.SocketException;
35import java.net.SocketOption;
36import java.nio.channels.ServerSocketChannel;
37import java.nio.channels.SocketChannel;
38import java.rmi.server.RMIClientSocketFactory;
39import java.rmi.server.RMIServerSocketFactory;
40import java.rmi.server.RMISocketFactory;
41import java.util.ArrayList;
42import java.util.Arrays;
43import java.util.List;
44import java.util.Objects;
45import java.util.Set;
46
47import org.testng.Assert;
48import org.testng.TestNG;
49import org.testng.annotations.Test;
50import org.testng.annotations.DataProvider;
51
52
53/**
54 * A RMISocketFactory utility factory to log RMI stream contents and to
55 * match and replace output stream contents to simulate failures.
56 */
57public class TestSocketFactory extends RMISocketFactory
58        implements RMIClientSocketFactory, RMIServerSocketFactory, Serializable {
59
60    private static final long serialVersionUID = 1L;
61
62    private volatile transient byte[] matchBytes;
63
64    private volatile transient byte[] replaceBytes;
65
66    private transient final List<InterposeSocket> sockets = new ArrayList<>();
67
68    private transient final List<InterposeServerSocket> serverSockets = new ArrayList<>();
69
70    public static final boolean DEBUG = false;
71
72    /**
73     * Debugging output can be synchronized with logging of RMI actions.
74     *
75     * @param format a printf format
76     * @param args   any args
77     */
78    private static void DEBUG(String format, Object... args) {
79        if (DEBUG) {
80            System.err.printf(format, args);
81        }
82    }
83
84    /**
85     * Create a socket factory that creates InputStreams that log
86     * and OutputStreams that log .
87     */
88    public TestSocketFactory() {
89        this.matchBytes = new byte[0];
90        this.replaceBytes = this.matchBytes;
91        System.out.printf("Creating TestSocketFactory()%n");
92    }
93
94    public void setMatchReplaceBytes(byte[] matchBytes, byte[] replaceBytes) {
95        this.matchBytes = Objects.requireNonNull(matchBytes, "matchBytes");
96        this.replaceBytes = Objects.requireNonNull(replaceBytes, "replaceBytes");
97        sockets.forEach( s -> s.setMatchReplaceBytes(matchBytes, replaceBytes));
98        serverSockets.forEach( s -> s.setMatchReplaceBytes(matchBytes, replaceBytes));
99
100    }
101
102    @Override
103    public Socket createSocket(String host, int port) throws IOException {
104        Socket socket = RMISocketFactory.getDefaultSocketFactory()
105                .createSocket(host, port);
106        InterposeSocket s = new InterposeSocket(socket, matchBytes, replaceBytes);
107        sockets.add(s);
108        return s;
109    }
110
111    /**
112     * Return the current list of sockets.
113     * @return Return a snapshot of the current list of sockets
114     */
115    public List<InterposeSocket> getSockets() {
116        List<InterposeSocket> snap = new ArrayList<>(sockets);
117        return snap;
118    }
119
120    @Override
121    public ServerSocket createServerSocket(int port) throws IOException {
122
123        ServerSocket serverSocket = RMISocketFactory.getDefaultSocketFactory()
124                .createServerSocket(port);
125        InterposeServerSocket ss = new InterposeServerSocket(serverSocket, matchBytes, replaceBytes);
126        serverSockets.add(ss);
127        return ss;
128    }
129
130    /**
131     * Return the current list of server sockets.
132     * @return Return a snapshot of the current list of server sockets
133     */
134    public List<InterposeServerSocket> getServerSockets() {
135        List<InterposeServerSocket> snap = new ArrayList<>(serverSockets);
136        return snap;
137    }
138
139    /**
140     * An InterposeSocket wraps a socket that produces InputStreams
141     * and OutputStreams that log the traffic.
142     * The OutputStreams it produces match an array of bytes and replace them.
143     * Useful for injecting protocol and content errors.
144     */
145    public static class InterposeSocket extends Socket {
146        private final Socket socket;
147        private InputStream in;
148        private MatchReplaceOutputStream out;
149        private volatile byte[] matchBytes;
150        private volatile byte[] replaceBytes;
151        private final ByteArrayOutputStream inLogStream;
152        private final ByteArrayOutputStream outLogStream;
153        private final String name;
154        private static volatile int num = 0;    // index for created InterposeSockets
155
156        public InterposeSocket(Socket socket, byte[] matchBytes, byte[] replaceBytes) {
157            this.socket = socket;
158            this.matchBytes = Objects.requireNonNull(matchBytes, "matchBytes");
159            this.replaceBytes = Objects.requireNonNull(replaceBytes, "replaceBytes");
160            this.inLogStream = new ByteArrayOutputStream();
161            this.outLogStream = new ByteArrayOutputStream();
162            this.name = "IS" + ++num + "::"
163                    + Thread.currentThread().getName() + ": "
164                    + socket.getLocalPort() + " <  " + socket.getPort();
165        }
166
167        public void setMatchReplaceBytes(byte[] matchBytes, byte[] replaceBytes) {
168            this.matchBytes = matchBytes;
169            this.replaceBytes = replaceBytes;
170            out.setMatchReplaceBytes(matchBytes, replaceBytes);
171        }
172
173        @Override
174        public void connect(SocketAddress endpoint) throws IOException {
175            socket.connect(endpoint);
176        }
177
178        @Override
179        public void connect(SocketAddress endpoint, int timeout) throws IOException {
180            socket.connect(endpoint, timeout);
181        }
182
183        @Override
184        public void bind(SocketAddress bindpoint) throws IOException {
185            socket.bind(bindpoint);
186        }
187
188        @Override
189        public InetAddress getInetAddress() {
190            return socket.getInetAddress();
191        }
192
193        @Override
194        public InetAddress getLocalAddress() {
195            return socket.getLocalAddress();
196        }
197
198        @Override
199        public int getPort() {
200            return socket.getPort();
201        }
202
203        @Override
204        public int getLocalPort() {
205            return socket.getLocalPort();
206        }
207
208        @Override
209        public SocketAddress getRemoteSocketAddress() {
210            return socket.getRemoteSocketAddress();
211        }
212
213        @Override
214        public SocketAddress getLocalSocketAddress() {
215            return socket.getLocalSocketAddress();
216        }
217
218        @Override
219        public SocketChannel getChannel() {
220            return socket.getChannel();
221        }
222
223        @Override
224        public synchronized void close() throws IOException {
225            socket.close();
226        }
227
228        @Override
229        public String toString() {
230            return "InterposeSocket " + name + ": " + socket.toString();
231        }
232
233        @Override
234        public boolean isConnected() {
235            return socket.isConnected();
236        }
237
238        @Override
239        public boolean isBound() {
240            return socket.isBound();
241        }
242
243        @Override
244        public boolean isClosed() {
245            return socket.isClosed();
246        }
247
248        @Override
249        public <T> Socket setOption(SocketOption<T> name, T value) throws IOException {
250            return socket.setOption(name, value);
251        }
252
253        @Override
254        public <T> T getOption(SocketOption<T> name) throws IOException {
255            return socket.getOption(name);
256        }
257
258        @Override
259        public Set<SocketOption<?>> supportedOptions() {
260            return socket.supportedOptions();
261        }
262
263        @Override
264        public synchronized InputStream getInputStream() throws IOException {
265            if (in == null) {
266                in = socket.getInputStream();
267                String name = Thread.currentThread().getName() + ": "
268                        + socket.getLocalPort() + " <  " + socket.getPort();
269                in = new LoggingInputStream(in, name, inLogStream);
270                DEBUG("Created new InterposeInputStream: %s%n", name);
271            }
272            return in;
273        }
274
275        @Override
276        public synchronized OutputStream getOutputStream() throws IOException {
277            if (out == null) {
278                OutputStream o = socket.getOutputStream();
279                String name = Thread.currentThread().getName() + ": "
280                        + socket.getLocalPort() + "  > " + socket.getPort();
281                out = new MatchReplaceOutputStream(o, name, outLogStream, matchBytes, replaceBytes);
282                DEBUG("Created new MatchReplaceOutputStream: %s%n", name);
283            }
284            return out;
285        }
286
287        /**
288         * Return the bytes logged from the input stream.
289         * @return Return the bytes logged from the input stream.
290         */
291        public byte[] getInLogBytes() {
292            return inLogStream.toByteArray();
293        }
294
295        /**
296         * Return the bytes logged from the output stream.
297         * @return Return the bytes logged from the output stream.
298         */
299        public byte[] getOutLogBytes() {
300            return outLogStream.toByteArray();
301        }
302
303    }
304
305    /**
306     * InterposeServerSocket is a ServerSocket that wraps each Socket it accepts
307     * with an InterposeSocket so that its input and output streams can be monitored.
308     */
309    public static class InterposeServerSocket extends ServerSocket {
310        private final ServerSocket socket;
311        private volatile byte[] matchBytes;
312        private volatile byte[] replaceBytes;
313        private final List<InterposeSocket> sockets = new ArrayList<>();
314
315        public InterposeServerSocket(ServerSocket socket, byte[] matchBytes, byte[] replaceBytes) throws IOException {
316            this.socket = socket;
317            this.matchBytes = Objects.requireNonNull(matchBytes, "matchBytes");
318            this.replaceBytes = Objects.requireNonNull(replaceBytes, "replaceBytes");
319        }
320
321        public void setMatchReplaceBytes(byte[] matchBytes, byte[] replaceBytes) {
322            this.matchBytes = matchBytes;
323            this.replaceBytes = replaceBytes;
324            sockets.forEach(s -> s.setMatchReplaceBytes(matchBytes, replaceBytes));
325        }
326        /**
327         * Return a snapshot of the current list of sockets created from this server socket.
328         * @return Return a snapshot of the current list of sockets
329         */
330        public List<InterposeSocket> getSockets() {
331            List<InterposeSocket> snap = new ArrayList<>(sockets);
332            return snap;
333        }
334
335        @Override
336        public void bind(SocketAddress endpoint) throws IOException {
337            socket.bind(endpoint);
338        }
339
340        @Override
341        public void bind(SocketAddress endpoint, int backlog) throws IOException {
342            socket.bind(endpoint, backlog);
343        }
344
345        @Override
346        public InetAddress getInetAddress() {
347            return socket.getInetAddress();
348        }
349
350        @Override
351        public int getLocalPort() {
352            return socket.getLocalPort();
353        }
354
355        @Override
356        public SocketAddress getLocalSocketAddress() {
357            return socket.getLocalSocketAddress();
358        }
359
360        @Override
361        public Socket accept() throws IOException {
362            Socket s = socket.accept();
363            InterposeSocket socket = new InterposeSocket(s, matchBytes, replaceBytes);
364            sockets.add(socket);
365            return socket;
366        }
367
368        @Override
369        public void close() throws IOException {
370            socket.close();
371        }
372
373        @Override
374        public ServerSocketChannel getChannel() {
375            return socket.getChannel();
376        }
377
378        @Override
379        public boolean isClosed() {
380            return socket.isClosed();
381        }
382
383        @Override
384        public String toString() {
385            return socket.toString();
386        }
387
388        @Override
389        public <T> ServerSocket setOption(SocketOption<T> name, T value) throws IOException {
390            return socket.setOption(name, value);
391        }
392
393        @Override
394        public <T> T getOption(SocketOption<T> name) throws IOException {
395            return socket.getOption(name);
396        }
397
398        @Override
399        public Set<SocketOption<?>> supportedOptions() {
400            return socket.supportedOptions();
401        }
402
403        @Override
404        public synchronized void setSoTimeout(int timeout) throws SocketException {
405            socket.setSoTimeout(timeout);
406        }
407
408        @Override
409        public synchronized int getSoTimeout() throws IOException {
410            return socket.getSoTimeout();
411        }
412    }
413
414    /**
415     * LoggingInputStream is a stream and logs all bytes read to it.
416     * For identification it is given a name.
417     */
418    public static class LoggingInputStream extends FilterInputStream {
419        private int bytesIn = 0;
420        private final String name;
421        private final OutputStream log;
422
423        public LoggingInputStream(InputStream in, String name, OutputStream log) {
424            super(in);
425            this.name = name;
426            this.log = log;
427        }
428
429        @Override
430        public int read() throws IOException {
431            int b = super.read();
432            if (b >= 0) {
433                log.write(b);
434                bytesIn++;
435            }
436            return b;
437        }
438
439        @Override
440        public int read(byte[] b, int off, int len) throws IOException {
441            int bytes = super.read(b, off, len);
442            if (bytes > 0) {
443                log.write(b, off, bytes);
444                bytesIn += bytes;
445            }
446            return bytes;
447        }
448
449        @Override
450        public int read(byte[] b) throws IOException {
451            return read(b, 0, b.length);
452        }
453
454        @Override
455        public void close() throws IOException {
456            super.close();
457        }
458
459        @Override
460        public String toString() {
461            return String.format("%s: In: (%d)", name, bytesIn);
462        }
463    }
464
465    /**
466     * An OutputStream that replaces one string of bytes with another.
467     * If any range matches, the match starts after the partial match.
468     */
469    static class MatchReplaceOutputStream extends OutputStream {
470        private final OutputStream out;
471        private final String name;
472        private volatile byte[] matchBytes;
473        private volatile byte[] replaceBytes;
474        int matchIndex;
475        private int bytesOut = 0;
476        private final OutputStream log;
477
478        MatchReplaceOutputStream(OutputStream out, String name, OutputStream log,
479                                 byte[] matchBytes, byte[] replaceBytes) {
480            this.out = out;
481            this.name = name;
482            this.matchBytes = Objects.requireNonNull(matchBytes, "matchBytes");
483            this.replaceBytes = Objects.requireNonNull(replaceBytes, "replaceBytes");
484            matchIndex = 0;
485            this.log = log;
486        }
487
488        public void setMatchReplaceBytes(byte[] matchBytes, byte[] replaceBytes) {
489            this.matchBytes = matchBytes;
490            this.replaceBytes = replaceBytes;
491            matchIndex = 0;
492        }
493
494
495        public void write(int b) throws IOException {
496            b = b & 0xff;
497            if (matchBytes.length == 0) {
498                out.write(b);
499                log.write(b);
500                bytesOut++;
501                return;
502            }
503            if (b == (matchBytes[matchIndex] & 0xff)) {
504                if (++matchIndex >= matchBytes.length) {
505                    matchIndex = 0;
506                    DEBUG( "TestSocketFactory MatchReplace %s replaced %d bytes at offset: %d (x%04x)%n",
507                            name, replaceBytes.length, bytesOut, bytesOut);
508                    out.write(replaceBytes);
509                    log.write(replaceBytes);
510                    bytesOut += replaceBytes.length;
511                }
512            } else {
513                if (matchIndex > 0) {
514                    // mismatch, write out any that matched already
515                    if (matchIndex > 0) // Only non-trivial matches
516                        DEBUG( "Partial match %s matched %d bytes at offset: %d (0x%04x), expected: x%02x, actual: x%02x%n",
517                                name, matchIndex, bytesOut, bytesOut,  matchBytes[matchIndex], b);
518                    out.write(matchBytes, 0, matchIndex);
519                    log.write(matchBytes, 0, matchIndex);
520                    bytesOut += matchIndex;
521                    matchIndex = 0;
522                }
523                if (b == (matchBytes[matchIndex] & 0xff)) {
524                    matchIndex++;
525                } else {
526                    out.write(b);
527                    log.write(b);
528                    bytesOut++;
529                }
530            }
531        }
532
533        @Override
534        public String toString() {
535            return String.format("%s: Out: (%d)", name, bytesOut);
536        }
537    }
538
539    private static byte[] orig = new byte[]{
540            (byte) 0x80, 0x05,
541            0x73, 0x72, 0x00, 0x12, // TC_OBJECT, TC_CLASSDESC, length = 18
542            0x6A, 0x61, 0x76, 0x61, 0x2E, 0x72, 0x6D, 0x69, 0x2E, // "java.rmi."
543            0x64, 0x67, 0x63, 0x2E, 0x4C, 0x65, 0x61, 0x73, 0x65  // "dgc.Lease"
544    };
545    private static byte[] repl = new byte[]{
546            (byte) 0x80, 0x05,
547            0x73, 0x72, 0x00, 0x12, // TC_OBJECT, TC_CLASSDESC, length = 18
548            0x6A, 0x61, 0x76, 0x61, 0x2E, (byte) 'l', (byte) 'a', (byte) 'n', (byte) 'g',
549            0x2E, (byte) 'R', (byte) 'u', (byte) 'n', (byte) 'n', (byte) 'a', (byte) 'b', (byte) 'l',
550            (byte) 'e'
551    };
552
553    @DataProvider(name = "MatchReplaceData")
554    static Object[][] matchReplaceData() {
555        byte[] empty = new byte[0];
556        byte[] byte1 = new byte[]{1, 2, 3, 4, 5, 6};
557        byte[] bytes2 = new byte[]{1, 2, 4, 3, 5, 6};
558        byte[] bytes3 = new byte[]{6, 5, 4, 3, 2, 1};
559        byte[] bytes4 = new byte[]{1, 2, 0x10, 0x20, 0x30, 0x40, 5, 6};
560        byte[] bytes4a = new byte[]{1, 2, 0x10, 0x20, 0x30, 0x40, 5, 7};  // mostly matches bytes4
561        byte[] bytes5 = new byte[]{0x30, 0x40, 5, 6};
562        byte[] bytes6 = new byte[]{1, 2, 0x10, 0x20, 0x30};
563
564        return new Object[][]{
565                {new byte[]{}, new byte[]{}, empty, empty},
566                {new byte[]{}, new byte[]{}, byte1, byte1},
567                {new byte[]{3, 4}, new byte[]{4, 3}, byte1, bytes2}, //swap bytes
568                {new byte[]{3, 4}, new byte[]{0x10, 0x20, 0x30, 0x40}, byte1, bytes4}, // insert
569                {new byte[]{1, 2, 0x10, 0x20}, new byte[]{}, bytes4, bytes5}, // delete head
570                {new byte[]{0x40, 5, 6}, new byte[]{}, bytes4, bytes6},   // delete tail
571                {new byte[]{0x40, 0x50}, new byte[]{0x60, 0x50}, bytes4, bytes4}, // partial match, replace nothing
572                {bytes4a, bytes3, bytes4, bytes4}, // long partial match, not replaced
573                {orig, repl, orig, repl},
574        };
575    }
576
577    @Test(enabled = true, dataProvider = "MatchReplaceData")
578    static void test3(byte[] match, byte[] replace,
579                      byte[] input, byte[] expected) {
580        System.out.printf("match: %s, replace: %s%n", Arrays.toString(match), Arrays.toString(replace));
581        try (ByteArrayOutputStream output = new ByteArrayOutputStream();
582        ByteArrayOutputStream log = new ByteArrayOutputStream();
583             OutputStream out = new MatchReplaceOutputStream(output, "test3",
584                     log, match, replace)) {
585            out.write(input);
586            byte[] actual = output.toByteArray();
587            long index = Arrays.mismatch(actual, expected);
588
589            if (index >= 0) {
590                System.out.printf("array mismatch, offset: %d%n", index);
591                System.out.printf("actual: %s%n", Arrays.toString(actual));
592                System.out.printf("expected: %s%n", Arrays.toString(expected));
593            }
594            Assert.assertEquals(actual, expected, "match/replace fail");
595        } catch (IOException ioe) {
596            Assert.fail("unexpected exception", ioe);
597        }
598    }
599
600
601
602}
603