1/* Title: Pure/General/ssh.scala 2 Author: Makarius 3 4SSH client based on JSch (see also http://www.jcraft.com/jsch/examples). 5*/ 6 7package isabelle 8 9 10import java.io.{InputStream, OutputStream, ByteArrayOutputStream} 11 12import scala.collection.mutable 13 14import com.jcraft.jsch.{JSch, Logger => JSch_Logger, Session => JSch_Session, SftpException, 15 OpenSSHConfig, UserInfo, Channel => JSch_Channel, ChannelExec, ChannelSftp, SftpATTRS} 16 17 18object SSH 19{ 20 /* target machine: user@host syntax */ 21 22 object Target 23 { 24 val User_Host = "^([^@]+)@(.+)$".r 25 26 def parse(s: String): (String, String) = 27 s match { 28 case User_Host(user, host) => (user, host) 29 case _ => ("", s) 30 } 31 32 def unapplySeq(s: String): Option[List[String]] = 33 parse(s) match { 34 case (_, "") => None 35 case (user, host) => Some(List(user, host)) 36 } 37 } 38 39 val default_port = 22 40 def make_port(port: Int): Int = if (port > 0) port else default_port 41 42 def port_suffix(port: Int): String = 43 if (port == default_port) "" else ":" + port 44 45 def user_prefix(user: String): String = 46 proper_string(user) match { 47 case None => "" 48 case Some(name) => name + "@" 49 } 50 51 def connect_timeout(options: Options): Int = 52 options.seconds("ssh_connect_timeout").ms.toInt 53 54 def alive_interval(options: Options): Int = 55 options.seconds("ssh_alive_interval").ms.toInt 56 57 def alive_count_max(options: Options): Int = 58 options.int("ssh_alive_count_max") 59 60 61 /* init context */ 62 63 def init_context(options: Options): Context = 64 { 65 val config_dir = Path.explode(options.string("ssh_config_dir")) 66 if (!config_dir.is_dir) error("Bad ssh config directory: " + config_dir) 67 68 val jsch = new JSch 69 70 val config_file = Path.explode(options.string("ssh_config_file")) 71 if (config_file.is_file) 72 jsch.setConfigRepository(OpenSSHConfig.parseFile(File.platform_path(config_file))) 73 74 val known_hosts = config_dir + Path.explode("known_hosts") 75 if (!known_hosts.is_file) known_hosts.file.createNewFile 76 jsch.setKnownHosts(File.platform_path(known_hosts)) 77 78 val identity_files = 79 space_explode(':', options.string("ssh_identity_files")).map(Path.explode(_)) 80 for (identity_file <- identity_files if identity_file.is_file) 81 jsch.addIdentity(File.platform_path(identity_file)) 82 83 new Context(options, jsch) 84 } 85 86 def open_session(options: Options, host: String, user: String = "", port: Int = 0, 87 proxy_host: String = "", proxy_user: String = "", proxy_port: Int = 0, 88 permissive: Boolean = false): Session = 89 init_context(options).open_session(host = host, user = user, port = port, 90 proxy_host = proxy_host, proxy_user = proxy_user, proxy_port = proxy_port, 91 permissive = permissive) 92 93 class Context private[SSH](val options: Options, val jsch: JSch) 94 { 95 def update_options(new_options: Options): Context = new Context(new_options, jsch) 96 97 private def connect_session(host: String, user: String = "", port: Int = 0, 98 host_key_permissive: Boolean = false, 99 nominal_host: String = "", 100 nominal_user: String = "", 101 on_close: () => Unit = () => ()): Session = 102 { 103 val session = jsch.getSession(proper_string(user).orNull, host, make_port(port)) 104 105 session.setUserInfo(No_User_Info) 106 session.setServerAliveInterval(alive_interval(options)) 107 session.setServerAliveCountMax(alive_count_max(options)) 108 session.setConfig("MaxAuthTries", "3") 109 if (host_key_permissive) session.setConfig("StrictHostKeyChecking", "no") 110 if (nominal_host != "") session.setHostKeyAlias(nominal_host) 111 112 if (options.bool("ssh_compression")) { 113 session.setConfig("compression.s2c", "zlib@openssh.com,zlib,none") 114 session.setConfig("compression.c2s", "zlib@openssh.com,zlib,none") 115 session.setConfig("compression_level", "9") 116 } 117 session.connect(connect_timeout(options)) 118 new Session(options, session, on_close, 119 proper_string(nominal_host) getOrElse host, 120 proper_string(nominal_user) getOrElse user) 121 } 122 123 def open_session(host: String, user: String = "", port: Int = 0, 124 proxy_host: String = "", proxy_user: String = "", proxy_port: Int = 0, 125 permissive: Boolean = false): Session = 126 { 127 if (proxy_host == "") connect_session(host = host, user = user, port = port) 128 else { 129 val proxy = connect_session(host = proxy_host, port = proxy_port, user = proxy_user) 130 131 val fw = 132 try { proxy.port_forwarding(remote_host = host, remote_port = make_port(port)) } 133 catch { case exn: Throwable => proxy.close; throw exn } 134 135 try { 136 connect_session(host = fw.local_host, port = fw.local_port, 137 host_key_permissive = permissive, 138 nominal_host = host, nominal_user = user, user = user, 139 on_close = () => { fw.close; proxy.close }) 140 } 141 catch { case exn: Throwable => fw.close; proxy.close; throw exn } 142 } 143 } 144 } 145 146 147 /* logging */ 148 149 def logging(verbose: Boolean = true, debug: Boolean = false) 150 { 151 JSch.setLogger(if (verbose) new Logger(debug) else null) 152 } 153 154 private class Logger(debug: Boolean) extends JSch_Logger 155 { 156 def isEnabled(level: Int): Boolean = level != JSch_Logger.DEBUG || debug 157 158 def log(level: Int, msg: String) 159 { 160 level match { 161 case JSch_Logger.ERROR | JSch_Logger.FATAL => Output.error_message(msg) 162 case JSch_Logger.WARN => Output.warning(msg) 163 case _ => Output.writeln(msg) 164 } 165 } 166 } 167 168 169 /* user info */ 170 171 object No_User_Info extends UserInfo 172 { 173 def getPassphrase: String = null 174 def getPassword: String = null 175 def promptPassword(msg: String): Boolean = false 176 def promptPassphrase(msg: String): Boolean = false 177 def promptYesNo(msg: String): Boolean = false 178 def showMessage(msg: String): Unit = Output.writeln(msg) 179 } 180 181 182 /* port forwarding */ 183 184 object Port_Forwarding 185 { 186 def open(ssh: Session, ssh_close: Boolean, 187 local_host: String, local_port: Int, remote_host: String, remote_port: Int): Port_Forwarding = 188 { 189 val port = ssh.session.setPortForwardingL(local_host, local_port, remote_host, remote_port) 190 new Port_Forwarding(ssh, ssh_close, local_host, port, remote_host, remote_port) 191 } 192 } 193 194 class Port_Forwarding private[SSH]( 195 ssh: SSH.Session, 196 ssh_close: Boolean, 197 val local_host: String, 198 val local_port: Int, 199 val remote_host: String, 200 val remote_port: Int) extends AutoCloseable 201 { 202 override def toString: String = 203 local_host + ":" + local_port + ":" + remote_host + ":" + remote_port 204 205 def close() 206 { 207 ssh.session.delPortForwardingL(local_host, local_port) 208 if (ssh_close) ssh.close() 209 } 210 } 211 212 213 /* Sftp channel */ 214 215 type Attrs = SftpATTRS 216 217 sealed case class Dir_Entry(name: String, is_dir: Boolean) 218 { 219 def is_file: Boolean = !is_dir 220 } 221 222 223 /* exec channel */ 224 225 private val exec_wait_delay = Time.seconds(0.3) 226 227 class Exec private[SSH](session: Session, channel: ChannelExec) extends AutoCloseable 228 { 229 override def toString: String = "exec " + session.toString 230 231 def close() { channel.disconnect } 232 233 val exit_status: Future[Int] = 234 Future.thread("ssh_wait") { 235 while (!channel.isClosed) Thread.sleep(exec_wait_delay.ms) 236 channel.getExitStatus 237 } 238 239 val stdin: OutputStream = channel.getOutputStream 240 val stdout: InputStream = channel.getInputStream 241 val stderr: InputStream = channel.getErrStream 242 243 // connect after preparing streams 244 channel.connect(connect_timeout(session.options)) 245 246 def result( 247 progress_stdout: String => Unit = (_: String) => (), 248 progress_stderr: String => Unit = (_: String) => (), 249 strict: Boolean = true): Process_Result = 250 { 251 stdin.close 252 253 def read_lines(stream: InputStream, progress: String => Unit): List[String] = 254 { 255 val result = new mutable.ListBuffer[String] 256 val line_buffer = new ByteArrayOutputStream(100) 257 def line_flush() 258 { 259 val line = Library.trim_line(line_buffer.toString(UTF8.charset_name)) 260 progress(line) 261 result += line 262 line_buffer.reset 263 } 264 265 var c = 0 266 var finished = false 267 while (!finished) { 268 while ({ c = stream.read; c != -1 && c != 10 }) line_buffer.write(c) 269 if (c == 10) line_flush() 270 else if (channel.isClosed) { 271 if (line_buffer.size > 0) line_flush() 272 finished = true 273 } 274 else Thread.sleep(exec_wait_delay.ms) 275 } 276 277 result.toList 278 } 279 280 val out_lines = Future.thread("ssh_stdout") { read_lines(stdout, progress_stdout) } 281 val err_lines = Future.thread("ssh_stderr") { read_lines(stderr, progress_stderr) } 282 283 def terminate() 284 { 285 close 286 out_lines.join 287 err_lines.join 288 exit_status.join 289 } 290 291 val rc = 292 try { exit_status.join } 293 catch { case Exn.Interrupt() => terminate(); Exn.Interrupt.return_code } 294 295 close 296 if (strict && rc == Exn.Interrupt.return_code) throw Exn.Interrupt() 297 298 Process_Result(rc, out_lines.join, err_lines.join) 299 } 300 } 301 302 303 /* session */ 304 305 class Session private[SSH]( 306 val options: Options, 307 val session: JSch_Session, 308 on_close: () => Unit, 309 val nominal_host: String, 310 val nominal_user: String) extends System with AutoCloseable 311 { 312 def update_options(new_options: Options): Session = 313 new Session(new_options, session, on_close, nominal_host, nominal_user) 314 315 def host: String = if (session.getHost == null) "" else session.getHost 316 317 override def hg_url: String = 318 "ssh://" + user_prefix(nominal_user) + nominal_host + "/" 319 320 override def prefix: String = 321 user_prefix(session.getUserName) + host + port_suffix(session.getPort) + ":" 322 323 override def toString: String = 324 user_prefix(session.getUserName) + host + port_suffix(session.getPort) + 325 (if (session.isConnected) "" else " (disconnected)") 326 327 328 /* port forwarding */ 329 330 def port_forwarding( 331 remote_port: Int, remote_host: String = "localhost", 332 local_port: Int = 0, local_host: String = "localhost", 333 ssh_close: Boolean = false): Port_Forwarding = 334 Port_Forwarding.open(this, ssh_close, local_host, local_port, remote_host, remote_port) 335 336 337 /* sftp channel */ 338 339 val sftp: ChannelSftp = session.openChannel("sftp").asInstanceOf[ChannelSftp] 340 sftp.connect(connect_timeout(options)) 341 342 def close() { sftp.disconnect; session.disconnect; on_close() } 343 344 val settings: Map[String, String] = 345 { 346 val home = sftp.getHome 347 Map("HOME" -> home, "USER_HOME" -> home) 348 } 349 override def expand_path(path: Path): Path = path.expand_env(settings) 350 def remote_path(path: Path): String = expand_path(path).implode 351 override def bash_path(path: Path): String = Bash.string(remote_path(path)) 352 353 def chmod(permissions: Int, path: Path): Unit = sftp.chmod(permissions, remote_path(path)) 354 def mv(path1: Path, path2: Path): Unit = sftp.rename(remote_path(path1), remote_path(path2)) 355 def rm(path: Path): Unit = sftp.rm(remote_path(path)) 356 def mkdir(path: Path): Unit = sftp.mkdir(remote_path(path)) 357 def rmdir(path: Path): Unit = sftp.rmdir(remote_path(path)) 358 359 private def test_entry(path: Path, as_dir: Boolean): Boolean = 360 try { 361 val is_dir = sftp.stat(remote_path(path)).isDir 362 if (as_dir) is_dir else !is_dir 363 } 364 catch { case _: SftpException => false } 365 366 override def is_dir(path: Path): Boolean = test_entry(path, true) 367 override def is_file(path: Path): Boolean = test_entry(path, false) 368 369 def is_link(path: Path): Boolean = 370 try { sftp.lstat(remote_path(path)).isLink } 371 catch { case _: SftpException => false } 372 373 override def mkdirs(path: Path): Unit = 374 if (!is_dir(path)) { 375 execute( 376 "perl -e \"use File::Path make_path; make_path('" + remote_path(path) + "');\"") 377 if (!is_dir(path)) error("Failed to create directory: " + quote(remote_path(path))) 378 } 379 380 def read_dir(path: Path): List[Dir_Entry] = 381 { 382 if (!is_dir(path)) error("No such directory: " + path.toString) 383 384 val dir_name = remote_path(path) 385 val dir = sftp.ls(dir_name) 386 (for { 387 i <- (0 until dir.size).iterator 388 a = dir.get(i).asInstanceOf[AnyRef] 389 name = Untyped.get[String](a, "filename") 390 attrs = Untyped.get[Attrs](a, "attrs") 391 if name != "." && name != ".." 392 } 393 yield { 394 Dir_Entry(name, 395 if (attrs.isLink) { 396 try { sftp.stat(dir_name + "/" + name).isDir } 397 catch { case _: SftpException => false } 398 } 399 else attrs.isDir) 400 }).toList.sortBy(_.name) 401 } 402 403 def find_files( 404 start: Path, 405 pred: Path => Boolean = _ => true, 406 include_dirs: Boolean = false, 407 follow_links: Boolean = false): List[Path] = 408 { 409 val result = new mutable.ListBuffer[Path] 410 def check(path: Path) { if (pred(path)) result += path } 411 412 def find(dir: Path) 413 { 414 if (include_dirs) check(dir) 415 if (follow_links || !is_link(dir)) { 416 for (entry <- read_dir(dir)) { 417 val path = dir + Path.basic(entry.name) 418 if (entry.is_file) check(path) else find(path) 419 } 420 } 421 } 422 if (is_file(start)) check(start) else find(start) 423 424 result.toList 425 } 426 427 def open_input(path: Path): InputStream = sftp.get(remote_path(path)) 428 def open_output(path: Path): OutputStream = sftp.put(remote_path(path)) 429 430 def read_file(path: Path, local_path: Path): Unit = 431 sftp.get(remote_path(path), File.platform_path(local_path)) 432 def read_bytes(path: Path): Bytes = using(open_input(path))(Bytes.read_stream(_)) 433 def read(path: Path): String = using(open_input(path))(File.read_stream(_)) 434 435 def write_file(path: Path, local_path: Path): Unit = 436 sftp.put(File.platform_path(local_path), remote_path(path)) 437 def write_bytes(path: Path, bytes: Bytes): Unit = 438 using(open_output(path))(bytes.write_stream(_)) 439 def write(path: Path, text: String): Unit = 440 using(open_output(path))(stream => Bytes(text).write_stream(stream)) 441 442 443 /* exec channel */ 444 445 def exec(command: String): Exec = 446 { 447 val channel = session.openChannel("exec").asInstanceOf[ChannelExec] 448 channel.setCommand("export USER_HOME=\"$HOME\"\n" + command) 449 new Exec(this, channel) 450 } 451 452 override def execute(command: String, 453 progress_stdout: String => Unit = (_: String) => (), 454 progress_stderr: String => Unit = (_: String) => (), 455 strict: Boolean = true): Process_Result = 456 exec(command).result(progress_stdout, progress_stderr, strict) 457 458 459 /* tmp dirs */ 460 461 def rm_tree(dir: Path): Unit = rm_tree(remote_path(dir)) 462 463 def rm_tree(remote_dir: String): Unit = 464 execute("rm -r -f " + Bash.string(remote_dir)).check 465 466 def tmp_dir(): String = 467 execute("mktemp -d -t tmp.XXXXXXXXXX").check.out 468 469 def with_tmp_dir[A](body: Path => A): A = 470 { 471 val remote_dir = tmp_dir() 472 try { body(Path.explode(remote_dir)) } finally { rm_tree(remote_dir) } 473 } 474 } 475 476 477 478 /* system operations */ 479 480 trait System 481 { 482 def hg_url: String = "" 483 def prefix: String = "" 484 485 def expand_path(path: Path): Path = path.expand 486 def bash_path(path: Path): String = File.bash_path(path) 487 def is_dir(path: Path): Boolean = path.is_dir 488 def is_file(path: Path): Boolean = path.is_file 489 def mkdirs(path: Path): Unit = Isabelle_System.mkdirs(path) 490 491 def execute(command: String, 492 progress_stdout: String => Unit = (_: String) => (), 493 progress_stderr: String => Unit = (_: String) => (), 494 strict: Boolean = true): Process_Result = 495 Isabelle_System.bash(command, progress_stdout = progress_stdout, 496 progress_stderr = progress_stderr, strict = strict) 497 } 498 499 object Local extends System 500} 501