1/* Title: Pure/General/ssh.scala 2 Author: Makarius 3 4SSH client based on JSch (see also http://www.jcraft.com/jsch/examples). 5*/ 6 7package isabelle 8 9 10import java.io.{InputStream, OutputStream, ByteArrayOutputStream} 11 12import scala.collection.{mutable, JavaConversions} 13 14import com.jcraft.jsch.{JSch, Logger => JSch_Logger, Session => JSch_Session, SftpException, 15 OpenSSHConfig, UserInfo, Channel => JSch_Channel, ChannelExec, ChannelSftp, SftpATTRS} 16 17 18object SSH 19{ 20 /* target machine: user@host syntax */ 21 22 object Target 23 { 24 val Pattern = "^([^@]+)@(.+)$".r 25 26 def parse(s: String): (String, String) = 27 s match { 28 case Pattern(user, host) => (user, host) 29 case _ => ("", s) 30 } 31 32 def unapplySeq(s: String): Option[List[String]] = 33 parse(s) match { 34 case (_, "") => None 35 case (user, host) => Some(List(user, host)) 36 } 37 } 38 39 val default_port = 22 40 def make_port(port: Int): Int = if (port > 0) port else default_port 41 42 def connect_timeout(options: Options): Int = 43 options.seconds("ssh_connect_timeout").ms.toInt 44 45 def alive_interval(options: Options): Int = 46 options.seconds("ssh_alive_interval").ms.toInt 47 48 def alive_count_max(options: Options): Int = 49 options.int("ssh_alive_count_max") 50 51 52 /* init context */ 53 54 def init_context(options: Options): Context = 55 { 56 val config_dir = Path.explode(options.string("ssh_config_dir")) 57 if (!config_dir.is_dir) error("Bad ssh config directory: " + config_dir) 58 59 val jsch = new JSch 60 61 val config_file = Path.explode(options.string("ssh_config_file")) 62 if (config_file.is_file) 63 jsch.setConfigRepository(OpenSSHConfig.parseFile(File.platform_path(config_file))) 64 65 val known_hosts = config_dir + Path.explode("known_hosts") 66 if (!known_hosts.is_file) known_hosts.file.createNewFile 67 jsch.setKnownHosts(File.platform_path(known_hosts)) 68 69 val identity_files = 70 space_explode(':', options.string("ssh_identity_files")).map(Path.explode(_)) 71 for (identity_file <- identity_files if identity_file.is_file) 72 jsch.addIdentity(File.platform_path(identity_file)) 73 74 new Context(options, jsch) 75 } 76 77 def open_session(options: Options, host: String, user: String = "", port: Int = 0, 78 proxy_host: String = "", proxy_user: String = "", proxy_port: Int = 0, 79 permissive: Boolean = false): Session = 80 init_context(options).open_session(host = host, user = user, port = port, 81 proxy_host = proxy_host, proxy_user = proxy_user, proxy_port = proxy_port, 82 permissive = permissive) 83 84 class Context private[SSH](val options: Options, val jsch: JSch) 85 { 86 def update_options(new_options: Options): Context = new Context(new_options, jsch) 87 88 def connect_session(host: String, user: String = "", port: Int = 0, 89 host_key_permissive: Boolean = false, 90 host_key_alias: String = "", 91 on_close: () => Unit = () => ()): Session = 92 { 93 val session = jsch.getSession(proper_string(user) getOrElse null, host, make_port(port)) 94 95 session.setUserInfo(No_User_Info) 96 session.setServerAliveInterval(alive_interval(options)) 97 session.setServerAliveCountMax(alive_count_max(options)) 98 session.setConfig("MaxAuthTries", "3") 99 if (host_key_permissive) session.setConfig("StrictHostKeyChecking", "no") 100 if (host_key_alias != "") session.setHostKeyAlias(host_key_alias) 101 102 if (options.bool("ssh_compression")) { 103 session.setConfig("compression.s2c", "zlib@openssh.com,zlib,none") 104 session.setConfig("compression.c2s", "zlib@openssh.com,zlib,none") 105 session.setConfig("compression_level", "9") 106 } 107 session.connect(connect_timeout(options)) 108 new Session(options, session, on_close) 109 } 110 111 def open_session(host: String, user: String = "", port: Int = 0, 112 proxy_host: String = "", proxy_user: String = "", proxy_port: Int = 0, 113 permissive: Boolean = false): Session = 114 { 115 if (proxy_host == "") connect_session(host = host, user = user, port = port) 116 else { 117 val proxy = connect_session(host = proxy_host, port = proxy_port, user = proxy_user) 118 119 val fw = 120 try { proxy.port_forwarding(remote_host = host, remote_port = make_port(port)) } 121 catch { case exn: Throwable => proxy.close; throw exn } 122 123 try { 124 connect_session(host = fw.local_host, port = fw.local_port, 125 host_key_permissive = permissive, host_key_alias = host, 126 user = user, on_close = () => { fw.close; proxy.close }) 127 } 128 catch { case exn: Throwable => fw.close; proxy.close; throw exn } 129 } 130 } 131 } 132 133 134 /* logging */ 135 136 def logging(verbose: Boolean = true, debug: Boolean = false) 137 { 138 JSch.setLogger(if (verbose) new Logger(debug) else null) 139 } 140 141 private class Logger(debug: Boolean) extends JSch_Logger 142 { 143 def isEnabled(level: Int): Boolean = level != JSch_Logger.DEBUG || debug 144 145 def log(level: Int, msg: String) 146 { 147 level match { 148 case JSch_Logger.ERROR | JSch_Logger.FATAL => Output.error_message(msg) 149 case JSch_Logger.WARN => Output.warning(msg) 150 case _ => Output.writeln(msg) 151 } 152 } 153 } 154 155 156 /* user info */ 157 158 object No_User_Info extends UserInfo 159 { 160 def getPassphrase: String = null 161 def getPassword: String = null 162 def promptPassword(msg: String): Boolean = false 163 def promptPassphrase(msg: String): Boolean = false 164 def promptYesNo(msg: String): Boolean = false 165 def showMessage(msg: String): Unit = Output.writeln(msg) 166 } 167 168 169 /* port forwarding */ 170 171 object Port_Forwarding 172 { 173 def open(ssh: Session, ssh_close: Boolean, 174 local_host: String, local_port: Int, remote_host: String, remote_port: Int): Port_Forwarding = 175 { 176 val port = ssh.session.setPortForwardingL(local_host, local_port, remote_host, remote_port) 177 new Port_Forwarding(ssh, ssh_close, local_host, port, remote_host, remote_port) 178 } 179 } 180 181 class Port_Forwarding private[SSH]( 182 ssh: SSH.Session, 183 ssh_close: Boolean, 184 val local_host: String, 185 val local_port: Int, 186 val remote_host: String, 187 val remote_port: Int) 188 { 189 override def toString: String = 190 local_host + ":" + local_port + ":" + remote_host + ":" + remote_port 191 192 def close() 193 { 194 ssh.session.delPortForwardingL(local_host, local_port) 195 if (ssh_close) ssh.close() 196 } 197 } 198 199 200 /* Sftp channel */ 201 202 type Attrs = SftpATTRS 203 204 sealed case class Dir_Entry(name: Path, attrs: Attrs) 205 { 206 def is_file: Boolean = attrs.isReg 207 def is_dir: Boolean = attrs.isDir 208 } 209 210 211 /* exec channel */ 212 213 private val exec_wait_delay = Time.seconds(0.3) 214 215 class Exec private[SSH](session: Session, channel: ChannelExec) 216 { 217 override def toString: String = "exec " + session.toString 218 219 def close() { channel.disconnect } 220 221 val exit_status: Future[Int] = 222 Future.thread("ssh_wait") { 223 while (!channel.isClosed) Thread.sleep(exec_wait_delay.ms) 224 channel.getExitStatus 225 } 226 227 val stdin: OutputStream = channel.getOutputStream 228 val stdout: InputStream = channel.getInputStream 229 val stderr: InputStream = channel.getErrStream 230 231 // connect after preparing streams 232 channel.connect(connect_timeout(session.options)) 233 234 def result( 235 progress_stdout: String => Unit = (_: String) => (), 236 progress_stderr: String => Unit = (_: String) => (), 237 strict: Boolean = true): Process_Result = 238 { 239 stdin.close 240 241 def read_lines(stream: InputStream, progress: String => Unit): List[String] = 242 { 243 val result = new mutable.ListBuffer[String] 244 val line_buffer = new ByteArrayOutputStream(100) 245 def line_flush() 246 { 247 val line = Library.trim_line(line_buffer.toString(UTF8.charset_name)) 248 progress(line) 249 result += line 250 line_buffer.reset 251 } 252 253 var c = 0 254 var finished = false 255 while (!finished) { 256 while ({ c = stream.read; c != -1 && c != 10 }) line_buffer.write(c) 257 if (c == 10) line_flush() 258 else if (channel.isClosed) { 259 if (line_buffer.size > 0) line_flush() 260 finished = true 261 } 262 else Thread.sleep(exec_wait_delay.ms) 263 } 264 265 result.toList 266 } 267 268 val out_lines = Future.thread("ssh_stdout") { read_lines(stdout, progress_stdout) } 269 val err_lines = Future.thread("ssh_stderr") { read_lines(stderr, progress_stderr) } 270 271 def terminate() 272 { 273 close 274 out_lines.join 275 err_lines.join 276 exit_status.join 277 } 278 279 val rc = 280 try { exit_status.join } 281 catch { case Exn.Interrupt() => terminate(); Exn.Interrupt.return_code } 282 283 close 284 if (strict && rc == Exn.Interrupt.return_code) throw Exn.Interrupt() 285 286 Process_Result(rc, out_lines.join, err_lines.join) 287 } 288 } 289 290 291 /* session */ 292 293 class Session private[SSH]( 294 val options: Options, 295 val session: JSch_Session, 296 on_close: () => Unit) extends System 297 { 298 def update_options(new_options: Options): Session = 299 new Session(new_options, session, on_close) 300 301 def user_prefix: String = if (session.getUserName == null) "" else session.getUserName + "@" 302 def host: String = if (session.getHost == null) "" else session.getHost 303 def port_suffix: String = if (session.getPort == default_port) "" else ":" + session.getPort 304 override def hg_url: String = "ssh://" + user_prefix + host + port_suffix + "/" 305 306 override def prefix: String = 307 user_prefix + host + port_suffix + ":" 308 309 override def toString: String = 310 user_prefix + host + port_suffix + (if (session.isConnected) "" else " (disconnected)") 311 312 313 /* port forwarding */ 314 315 def port_forwarding( 316 remote_port: Int, remote_host: String = "localhost", 317 local_port: Int = 0, local_host: String = "localhost", 318 ssh_close: Boolean = false): Port_Forwarding = 319 Port_Forwarding.open(this, ssh_close, local_host, local_port, remote_host, remote_port) 320 321 322 /* sftp channel */ 323 324 val sftp: ChannelSftp = session.openChannel("sftp").asInstanceOf[ChannelSftp] 325 sftp.connect(connect_timeout(options)) 326 327 def close() { sftp.disconnect; session.disconnect; on_close() } 328 329 val settings: Map[String, String] = 330 { 331 val home = sftp.getHome 332 Map("HOME" -> home, "USER_HOME" -> home) 333 } 334 override def expand_path(path: Path): Path = path.expand_env(settings) 335 def remote_path(path: Path): String = expand_path(path).implode 336 override def bash_path(path: Path): String = Bash.string(remote_path(path)) 337 338 def chmod(permissions: Int, path: Path): Unit = sftp.chmod(permissions, remote_path(path)) 339 def mv(path1: Path, path2: Path): Unit = sftp.rename(remote_path(path1), remote_path(path2)) 340 def rm(path: Path): Unit = sftp.rm(remote_path(path)) 341 def mkdir(path: Path): Unit = sftp.mkdir(remote_path(path)) 342 def rmdir(path: Path): Unit = sftp.rmdir(remote_path(path)) 343 344 def stat(path: Path): Option[Dir_Entry] = 345 try { Some(Dir_Entry(expand_path(path), sftp.stat(remote_path(path)))) } 346 catch { case _: SftpException => None } 347 348 override def is_file(path: Path): Boolean = stat(path).map(_.is_file) getOrElse false 349 override def is_dir(path: Path): Boolean = stat(path).map(_.is_dir) getOrElse false 350 351 override def mkdirs(path: Path): Unit = 352 if (!is_dir(path)) { 353 execute( 354 "perl -e \"use File::Path make_path; make_path('" + remote_path(path) + "');\"") 355 if (!is_dir(path)) error("Failed to create directory: " + quote(remote_path(path))) 356 } 357 358 def read_dir(path: Path): List[Dir_Entry] = 359 { 360 val dir = sftp.ls(remote_path(path)) 361 (for { 362 i <- (0 until dir.size).iterator 363 a = dir.get(i).asInstanceOf[AnyRef] 364 name = Untyped.get[String](a, "filename") 365 attrs = Untyped.get[Attrs](a, "attrs") 366 if name != "." && name != ".." 367 } yield Dir_Entry(Path.basic(name), attrs)).toList 368 } 369 370 def find_files(root: Path, pred: Dir_Entry => Boolean = _ => true): List[Dir_Entry] = 371 { 372 def find(dir: Path): List[Dir_Entry] = 373 read_dir(dir).flatMap(entry => 374 { 375 val file = dir + entry.name 376 if (entry.is_dir) find(file) 377 else if (pred(entry)) List(entry.copy(name = file)) 378 else Nil 379 }) 380 find(root) 381 } 382 383 def open_input(path: Path): InputStream = sftp.get(remote_path(path)) 384 def open_output(path: Path): OutputStream = sftp.put(remote_path(path)) 385 386 def read_file(path: Path, local_path: Path): Unit = 387 sftp.get(remote_path(path), File.platform_path(local_path)) 388 def read_bytes(path: Path): Bytes = using(open_input(path))(Bytes.read_stream(_)) 389 def read(path: Path): String = using(open_input(path))(File.read_stream(_)) 390 391 def write_file(path: Path, local_path: Path): Unit = 392 sftp.put(File.platform_path(local_path), remote_path(path)) 393 def write_bytes(path: Path, bytes: Bytes): Unit = 394 using(open_output(path))(bytes.write_stream(_)) 395 def write(path: Path, text: String): Unit = 396 using(open_output(path))(stream => Bytes(text).write_stream(stream)) 397 398 399 /* exec channel */ 400 401 def exec(command: String): Exec = 402 { 403 val channel = session.openChannel("exec").asInstanceOf[ChannelExec] 404 channel.setCommand("export USER_HOME=\"$HOME\"\n" + command) 405 new Exec(this, channel) 406 } 407 408 override def execute(command: String, 409 progress_stdout: String => Unit = (_: String) => (), 410 progress_stderr: String => Unit = (_: String) => (), 411 strict: Boolean = true): Process_Result = 412 exec(command).result(progress_stdout, progress_stderr, strict) 413 414 415 /* tmp dirs */ 416 417 def rm_tree(dir: Path): Unit = rm_tree(remote_path(dir)) 418 419 def rm_tree(remote_dir: String): Unit = 420 execute("rm -r -f " + Bash.string(remote_dir)).check 421 422 def tmp_dir(): String = 423 execute("mktemp -d -t tmp.XXXXXXXXXX").check.out 424 425 def with_tmp_dir[A](body: Path => A): A = 426 { 427 val remote_dir = tmp_dir() 428 try { body(Path.explode(remote_dir)) } finally { rm_tree(remote_dir) } 429 } 430 } 431 432 433 434 /* system operations */ 435 436 trait System 437 { 438 def hg_url: String = "" 439 def prefix: String = "" 440 441 def expand_path(path: Path): Path = path.expand 442 def bash_path(path: Path): String = File.bash_path(path) 443 def is_file(path: Path): Boolean = path.is_file 444 def is_dir(path: Path): Boolean = path.is_dir 445 def mkdirs(path: Path): Unit = Isabelle_System.mkdirs(path) 446 447 def execute(command: String, 448 progress_stdout: String => Unit = (_: String) => (), 449 progress_stderr: String => Unit = (_: String) => (), 450 strict: Boolean = true): Process_Result = 451 Isabelle_System.bash(command, progress_stdout = progress_stdout, 452 progress_stderr = progress_stderr, strict = strict) 453 } 454 455 object Local extends System 456} 457