1/* 2 * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24import java.io.ByteArrayOutputStream; 25import java.io.FilterInputStream; 26import java.io.IOException; 27import java.io.InputStream; 28import java.io.OutputStream; 29import java.io.Serializable; 30import java.net.InetAddress; 31import java.net.ServerSocket; 32import java.net.Socket; 33import java.net.SocketAddress; 34import java.net.SocketException; 35import java.net.SocketOption; 36import java.nio.channels.ServerSocketChannel; 37import java.nio.channels.SocketChannel; 38import java.rmi.server.RMIClientSocketFactory; 39import java.rmi.server.RMIServerSocketFactory; 40import java.rmi.server.RMISocketFactory; 41import java.util.ArrayList; 42import java.util.Arrays; 43import java.util.List; 44import java.util.Objects; 45import java.util.Set; 46 47import org.testng.Assert; 48import org.testng.TestNG; 49import org.testng.annotations.Test; 50import org.testng.annotations.DataProvider; 51 52 53/** 54 * A RMISocketFactory utility factory to log RMI stream contents and to 55 * match and replace output stream contents to simulate failures. 56 */ 57public class TestSocketFactory extends RMISocketFactory 58 implements RMIClientSocketFactory, RMIServerSocketFactory, Serializable { 59 60 private static final long serialVersionUID = 1L; 61 62 private volatile transient byte[] matchBytes; 63 64 private volatile transient byte[] replaceBytes; 65 66 private transient final List<InterposeSocket> sockets = new ArrayList<>(); 67 68 private transient final List<InterposeServerSocket> serverSockets = new ArrayList<>(); 69 70 public static final boolean DEBUG = false; 71 72 /** 73 * Debugging output can be synchronized with logging of RMI actions. 74 * 75 * @param format a printf format 76 * @param args any args 77 */ 78 private static void DEBUG(String format, Object... args) { 79 if (DEBUG) { 80 System.err.printf(format, args); 81 } 82 } 83 84 /** 85 * Create a socket factory that creates InputStreams that log 86 * and OutputStreams that log . 87 */ 88 public TestSocketFactory() { 89 this.matchBytes = new byte[0]; 90 this.replaceBytes = this.matchBytes; 91 System.out.printf("Creating TestSocketFactory()%n"); 92 } 93 94 public void setMatchReplaceBytes(byte[] matchBytes, byte[] replaceBytes) { 95 this.matchBytes = Objects.requireNonNull(matchBytes, "matchBytes"); 96 this.replaceBytes = Objects.requireNonNull(replaceBytes, "replaceBytes"); 97 sockets.forEach( s -> s.setMatchReplaceBytes(matchBytes, replaceBytes)); 98 serverSockets.forEach( s -> s.setMatchReplaceBytes(matchBytes, replaceBytes)); 99 100 } 101 102 @Override 103 public Socket createSocket(String host, int port) throws IOException { 104 Socket socket = RMISocketFactory.getDefaultSocketFactory() 105 .createSocket(host, port); 106 InterposeSocket s = new InterposeSocket(socket, matchBytes, replaceBytes); 107 sockets.add(s); 108 return s; 109 } 110 111 /** 112 * Return the current list of sockets. 113 * @return Return a snapshot of the current list of sockets 114 */ 115 public List<InterposeSocket> getSockets() { 116 List<InterposeSocket> snap = new ArrayList<>(sockets); 117 return snap; 118 } 119 120 @Override 121 public ServerSocket createServerSocket(int port) throws IOException { 122 123 ServerSocket serverSocket = RMISocketFactory.getDefaultSocketFactory() 124 .createServerSocket(port); 125 InterposeServerSocket ss = new InterposeServerSocket(serverSocket, matchBytes, replaceBytes); 126 serverSockets.add(ss); 127 return ss; 128 } 129 130 /** 131 * Return the current list of server sockets. 132 * @return Return a snapshot of the current list of server sockets 133 */ 134 public List<InterposeServerSocket> getServerSockets() { 135 List<InterposeServerSocket> snap = new ArrayList<>(serverSockets); 136 return snap; 137 } 138 139 /** 140 * An InterposeSocket wraps a socket that produces InputStreams 141 * and OutputStreams that log the traffic. 142 * The OutputStreams it produces match an array of bytes and replace them. 143 * Useful for injecting protocol and content errors. 144 */ 145 public static class InterposeSocket extends Socket { 146 private final Socket socket; 147 private InputStream in; 148 private MatchReplaceOutputStream out; 149 private volatile byte[] matchBytes; 150 private volatile byte[] replaceBytes; 151 private final ByteArrayOutputStream inLogStream; 152 private final ByteArrayOutputStream outLogStream; 153 private final String name; 154 private static volatile int num = 0; // index for created InterposeSockets 155 156 public InterposeSocket(Socket socket, byte[] matchBytes, byte[] replaceBytes) { 157 this.socket = socket; 158 this.matchBytes = Objects.requireNonNull(matchBytes, "matchBytes"); 159 this.replaceBytes = Objects.requireNonNull(replaceBytes, "replaceBytes"); 160 this.inLogStream = new ByteArrayOutputStream(); 161 this.outLogStream = new ByteArrayOutputStream(); 162 this.name = "IS" + ++num + "::" 163 + Thread.currentThread().getName() + ": " 164 + socket.getLocalPort() + " < " + socket.getPort(); 165 } 166 167 public void setMatchReplaceBytes(byte[] matchBytes, byte[] replaceBytes) { 168 this.matchBytes = matchBytes; 169 this.replaceBytes = replaceBytes; 170 out.setMatchReplaceBytes(matchBytes, replaceBytes); 171 } 172 173 @Override 174 public void connect(SocketAddress endpoint) throws IOException { 175 socket.connect(endpoint); 176 } 177 178 @Override 179 public void connect(SocketAddress endpoint, int timeout) throws IOException { 180 socket.connect(endpoint, timeout); 181 } 182 183 @Override 184 public void bind(SocketAddress bindpoint) throws IOException { 185 socket.bind(bindpoint); 186 } 187 188 @Override 189 public InetAddress getInetAddress() { 190 return socket.getInetAddress(); 191 } 192 193 @Override 194 public InetAddress getLocalAddress() { 195 return socket.getLocalAddress(); 196 } 197 198 @Override 199 public int getPort() { 200 return socket.getPort(); 201 } 202 203 @Override 204 public int getLocalPort() { 205 return socket.getLocalPort(); 206 } 207 208 @Override 209 public SocketAddress getRemoteSocketAddress() { 210 return socket.getRemoteSocketAddress(); 211 } 212 213 @Override 214 public SocketAddress getLocalSocketAddress() { 215 return socket.getLocalSocketAddress(); 216 } 217 218 @Override 219 public SocketChannel getChannel() { 220 return socket.getChannel(); 221 } 222 223 @Override 224 public synchronized void close() throws IOException { 225 socket.close(); 226 } 227 228 @Override 229 public String toString() { 230 return "InterposeSocket " + name + ": " + socket.toString(); 231 } 232 233 @Override 234 public boolean isConnected() { 235 return socket.isConnected(); 236 } 237 238 @Override 239 public boolean isBound() { 240 return socket.isBound(); 241 } 242 243 @Override 244 public boolean isClosed() { 245 return socket.isClosed(); 246 } 247 248 @Override 249 public <T> Socket setOption(SocketOption<T> name, T value) throws IOException { 250 return socket.setOption(name, value); 251 } 252 253 @Override 254 public <T> T getOption(SocketOption<T> name) throws IOException { 255 return socket.getOption(name); 256 } 257 258 @Override 259 public Set<SocketOption<?>> supportedOptions() { 260 return socket.supportedOptions(); 261 } 262 263 @Override 264 public synchronized InputStream getInputStream() throws IOException { 265 if (in == null) { 266 in = socket.getInputStream(); 267 String name = Thread.currentThread().getName() + ": " 268 + socket.getLocalPort() + " < " + socket.getPort(); 269 in = new LoggingInputStream(in, name, inLogStream); 270 DEBUG("Created new InterposeInputStream: %s%n", name); 271 } 272 return in; 273 } 274 275 @Override 276 public synchronized OutputStream getOutputStream() throws IOException { 277 if (out == null) { 278 OutputStream o = socket.getOutputStream(); 279 String name = Thread.currentThread().getName() + ": " 280 + socket.getLocalPort() + " > " + socket.getPort(); 281 out = new MatchReplaceOutputStream(o, name, outLogStream, matchBytes, replaceBytes); 282 DEBUG("Created new MatchReplaceOutputStream: %s%n", name); 283 } 284 return out; 285 } 286 287 /** 288 * Return the bytes logged from the input stream. 289 * @return Return the bytes logged from the input stream. 290 */ 291 public byte[] getInLogBytes() { 292 return inLogStream.toByteArray(); 293 } 294 295 /** 296 * Return the bytes logged from the output stream. 297 * @return Return the bytes logged from the output stream. 298 */ 299 public byte[] getOutLogBytes() { 300 return outLogStream.toByteArray(); 301 } 302 303 } 304 305 /** 306 * InterposeServerSocket is a ServerSocket that wraps each Socket it accepts 307 * with an InterposeSocket so that its input and output streams can be monitored. 308 */ 309 public static class InterposeServerSocket extends ServerSocket { 310 private final ServerSocket socket; 311 private volatile byte[] matchBytes; 312 private volatile byte[] replaceBytes; 313 private final List<InterposeSocket> sockets = new ArrayList<>(); 314 315 public InterposeServerSocket(ServerSocket socket, byte[] matchBytes, byte[] replaceBytes) throws IOException { 316 this.socket = socket; 317 this.matchBytes = Objects.requireNonNull(matchBytes, "matchBytes"); 318 this.replaceBytes = Objects.requireNonNull(replaceBytes, "replaceBytes"); 319 } 320 321 public void setMatchReplaceBytes(byte[] matchBytes, byte[] replaceBytes) { 322 this.matchBytes = matchBytes; 323 this.replaceBytes = replaceBytes; 324 sockets.forEach(s -> s.setMatchReplaceBytes(matchBytes, replaceBytes)); 325 } 326 /** 327 * Return a snapshot of the current list of sockets created from this server socket. 328 * @return Return a snapshot of the current list of sockets 329 */ 330 public List<InterposeSocket> getSockets() { 331 List<InterposeSocket> snap = new ArrayList<>(sockets); 332 return snap; 333 } 334 335 @Override 336 public void bind(SocketAddress endpoint) throws IOException { 337 socket.bind(endpoint); 338 } 339 340 @Override 341 public void bind(SocketAddress endpoint, int backlog) throws IOException { 342 socket.bind(endpoint, backlog); 343 } 344 345 @Override 346 public InetAddress getInetAddress() { 347 return socket.getInetAddress(); 348 } 349 350 @Override 351 public int getLocalPort() { 352 return socket.getLocalPort(); 353 } 354 355 @Override 356 public SocketAddress getLocalSocketAddress() { 357 return socket.getLocalSocketAddress(); 358 } 359 360 @Override 361 public Socket accept() throws IOException { 362 Socket s = socket.accept(); 363 InterposeSocket socket = new InterposeSocket(s, matchBytes, replaceBytes); 364 sockets.add(socket); 365 return socket; 366 } 367 368 @Override 369 public void close() throws IOException { 370 socket.close(); 371 } 372 373 @Override 374 public ServerSocketChannel getChannel() { 375 return socket.getChannel(); 376 } 377 378 @Override 379 public boolean isClosed() { 380 return socket.isClosed(); 381 } 382 383 @Override 384 public String toString() { 385 return socket.toString(); 386 } 387 388 @Override 389 public <T> ServerSocket setOption(SocketOption<T> name, T value) throws IOException { 390 return socket.setOption(name, value); 391 } 392 393 @Override 394 public <T> T getOption(SocketOption<T> name) throws IOException { 395 return socket.getOption(name); 396 } 397 398 @Override 399 public Set<SocketOption<?>> supportedOptions() { 400 return socket.supportedOptions(); 401 } 402 403 @Override 404 public synchronized void setSoTimeout(int timeout) throws SocketException { 405 socket.setSoTimeout(timeout); 406 } 407 408 @Override 409 public synchronized int getSoTimeout() throws IOException { 410 return socket.getSoTimeout(); 411 } 412 } 413 414 /** 415 * LoggingInputStream is a stream and logs all bytes read to it. 416 * For identification it is given a name. 417 */ 418 public static class LoggingInputStream extends FilterInputStream { 419 private int bytesIn = 0; 420 private final String name; 421 private final OutputStream log; 422 423 public LoggingInputStream(InputStream in, String name, OutputStream log) { 424 super(in); 425 this.name = name; 426 this.log = log; 427 } 428 429 @Override 430 public int read() throws IOException { 431 int b = super.read(); 432 if (b >= 0) { 433 log.write(b); 434 bytesIn++; 435 } 436 return b; 437 } 438 439 @Override 440 public int read(byte[] b, int off, int len) throws IOException { 441 int bytes = super.read(b, off, len); 442 if (bytes > 0) { 443 log.write(b, off, bytes); 444 bytesIn += bytes; 445 } 446 return bytes; 447 } 448 449 @Override 450 public int read(byte[] b) throws IOException { 451 return read(b, 0, b.length); 452 } 453 454 @Override 455 public void close() throws IOException { 456 super.close(); 457 } 458 459 @Override 460 public String toString() { 461 return String.format("%s: In: (%d)", name, bytesIn); 462 } 463 } 464 465 /** 466 * An OutputStream that replaces one string of bytes with another. 467 * If any range matches, the match starts after the partial match. 468 */ 469 static class MatchReplaceOutputStream extends OutputStream { 470 private final OutputStream out; 471 private final String name; 472 private volatile byte[] matchBytes; 473 private volatile byte[] replaceBytes; 474 int matchIndex; 475 private int bytesOut = 0; 476 private final OutputStream log; 477 478 MatchReplaceOutputStream(OutputStream out, String name, OutputStream log, 479 byte[] matchBytes, byte[] replaceBytes) { 480 this.out = out; 481 this.name = name; 482 this.matchBytes = Objects.requireNonNull(matchBytes, "matchBytes"); 483 this.replaceBytes = Objects.requireNonNull(replaceBytes, "replaceBytes"); 484 matchIndex = 0; 485 this.log = log; 486 } 487 488 public void setMatchReplaceBytes(byte[] matchBytes, byte[] replaceBytes) { 489 this.matchBytes = matchBytes; 490 this.replaceBytes = replaceBytes; 491 matchIndex = 0; 492 } 493 494 495 public void write(int b) throws IOException { 496 b = b & 0xff; 497 if (matchBytes.length == 0) { 498 out.write(b); 499 log.write(b); 500 bytesOut++; 501 return; 502 } 503 if (b == (matchBytes[matchIndex] & 0xff)) { 504 if (++matchIndex >= matchBytes.length) { 505 matchIndex = 0; 506 DEBUG( "TestSocketFactory MatchReplace %s replaced %d bytes at offset: %d (x%04x)%n", 507 name, replaceBytes.length, bytesOut, bytesOut); 508 out.write(replaceBytes); 509 log.write(replaceBytes); 510 bytesOut += replaceBytes.length; 511 } 512 } else { 513 if (matchIndex > 0) { 514 // mismatch, write out any that matched already 515 if (matchIndex > 0) // Only non-trivial matches 516 DEBUG( "Partial match %s matched %d bytes at offset: %d (0x%04x), expected: x%02x, actual: x%02x%n", 517 name, matchIndex, bytesOut, bytesOut, matchBytes[matchIndex], b); 518 out.write(matchBytes, 0, matchIndex); 519 log.write(matchBytes, 0, matchIndex); 520 bytesOut += matchIndex; 521 matchIndex = 0; 522 } 523 if (b == (matchBytes[matchIndex] & 0xff)) { 524 matchIndex++; 525 } else { 526 out.write(b); 527 log.write(b); 528 bytesOut++; 529 } 530 } 531 } 532 533 @Override 534 public String toString() { 535 return String.format("%s: Out: (%d)", name, bytesOut); 536 } 537 } 538 539 private static byte[] orig = new byte[]{ 540 (byte) 0x80, 0x05, 541 0x73, 0x72, 0x00, 0x12, // TC_OBJECT, TC_CLASSDESC, length = 18 542 0x6A, 0x61, 0x76, 0x61, 0x2E, 0x72, 0x6D, 0x69, 0x2E, // "java.rmi." 543 0x64, 0x67, 0x63, 0x2E, 0x4C, 0x65, 0x61, 0x73, 0x65 // "dgc.Lease" 544 }; 545 private static byte[] repl = new byte[]{ 546 (byte) 0x80, 0x05, 547 0x73, 0x72, 0x00, 0x12, // TC_OBJECT, TC_CLASSDESC, length = 18 548 0x6A, 0x61, 0x76, 0x61, 0x2E, (byte) 'l', (byte) 'a', (byte) 'n', (byte) 'g', 549 0x2E, (byte) 'R', (byte) 'u', (byte) 'n', (byte) 'n', (byte) 'a', (byte) 'b', (byte) 'l', 550 (byte) 'e' 551 }; 552 553 @DataProvider(name = "MatchReplaceData") 554 static Object[][] matchReplaceData() { 555 byte[] empty = new byte[0]; 556 byte[] byte1 = new byte[]{1, 2, 3, 4, 5, 6}; 557 byte[] bytes2 = new byte[]{1, 2, 4, 3, 5, 6}; 558 byte[] bytes3 = new byte[]{6, 5, 4, 3, 2, 1}; 559 byte[] bytes4 = new byte[]{1, 2, 0x10, 0x20, 0x30, 0x40, 5, 6}; 560 byte[] bytes4a = new byte[]{1, 2, 0x10, 0x20, 0x30, 0x40, 5, 7}; // mostly matches bytes4 561 byte[] bytes5 = new byte[]{0x30, 0x40, 5, 6}; 562 byte[] bytes6 = new byte[]{1, 2, 0x10, 0x20, 0x30}; 563 564 return new Object[][]{ 565 {new byte[]{}, new byte[]{}, empty, empty}, 566 {new byte[]{}, new byte[]{}, byte1, byte1}, 567 {new byte[]{3, 4}, new byte[]{4, 3}, byte1, bytes2}, //swap bytes 568 {new byte[]{3, 4}, new byte[]{0x10, 0x20, 0x30, 0x40}, byte1, bytes4}, // insert 569 {new byte[]{1, 2, 0x10, 0x20}, new byte[]{}, bytes4, bytes5}, // delete head 570 {new byte[]{0x40, 5, 6}, new byte[]{}, bytes4, bytes6}, // delete tail 571 {new byte[]{0x40, 0x50}, new byte[]{0x60, 0x50}, bytes4, bytes4}, // partial match, replace nothing 572 {bytes4a, bytes3, bytes4, bytes4}, // long partial match, not replaced 573 {orig, repl, orig, repl}, 574 }; 575 } 576 577 @Test(enabled = true, dataProvider = "MatchReplaceData") 578 static void test3(byte[] match, byte[] replace, 579 byte[] input, byte[] expected) { 580 System.out.printf("match: %s, replace: %s%n", Arrays.toString(match), Arrays.toString(replace)); 581 try (ByteArrayOutputStream output = new ByteArrayOutputStream(); 582 ByteArrayOutputStream log = new ByteArrayOutputStream(); 583 OutputStream out = new MatchReplaceOutputStream(output, "test3", 584 log, match, replace)) { 585 out.write(input); 586 byte[] actual = output.toByteArray(); 587 long index = Arrays.mismatch(actual, expected); 588 589 if (index >= 0) { 590 System.out.printf("array mismatch, offset: %d%n", index); 591 System.out.printf("actual: %s%n", Arrays.toString(actual)); 592 System.out.printf("expected: %s%n", Arrays.toString(expected)); 593 } 594 Assert.assertEquals(actual, expected, "match/replace fail"); 595 } catch (IOException ioe) { 596 Assert.fail("unexpected exception", ioe); 597 } 598 } 599 600 601 602} 603