1/*  Title:      Pure/General/ssh.scala
2    Author:     Makarius
3
4SSH client based on JSch (see also http://www.jcraft.com/jsch/examples).
5*/
6
7package isabelle
8
9
10import java.io.{InputStream, OutputStream, ByteArrayOutputStream}
11
12import scala.collection.mutable
13
14import com.jcraft.jsch.{JSch, Logger => JSch_Logger, Session => JSch_Session, SftpException,
15  OpenSSHConfig, UserInfo, Channel => JSch_Channel, ChannelExec, ChannelSftp, SftpATTRS}
16
17
18object SSH
19{
20  /* target machine: user@host syntax */
21
22  object Target
23  {
24    val User_Host = "^([^@]+)@(.+)$".r
25
26    def parse(s: String): (String, String) =
27      s match {
28        case User_Host(user, host) => (user, host)
29        case _ => ("", s)
30      }
31
32    def unapplySeq(s: String): Option[List[String]] =
33      parse(s) match {
34        case (_, "") => None
35        case (user, host) => Some(List(user, host))
36      }
37  }
38
39  val default_port = 22
40  def make_port(port: Int): Int = if (port > 0) port else default_port
41
42  def port_suffix(port: Int): String =
43    if (port == default_port) "" else ":" + port
44
45  def user_prefix(user: String): String =
46    proper_string(user) match {
47      case None => ""
48      case Some(name) => name + "@"
49    }
50
51  def connect_timeout(options: Options): Int =
52    options.seconds("ssh_connect_timeout").ms.toInt
53
54  def alive_interval(options: Options): Int =
55    options.seconds("ssh_alive_interval").ms.toInt
56
57  def alive_count_max(options: Options): Int =
58    options.int("ssh_alive_count_max")
59
60
61  /* init context */
62
63  def init_context(options: Options): Context =
64  {
65    val config_dir = Path.explode(options.string("ssh_config_dir"))
66    if (!config_dir.is_dir) error("Bad ssh config directory: " + config_dir)
67
68    val jsch = new JSch
69
70    val config_file = Path.explode(options.string("ssh_config_file"))
71    if (config_file.is_file)
72      jsch.setConfigRepository(OpenSSHConfig.parseFile(File.platform_path(config_file)))
73
74    val known_hosts = config_dir + Path.explode("known_hosts")
75    if (!known_hosts.is_file) known_hosts.file.createNewFile
76    jsch.setKnownHosts(File.platform_path(known_hosts))
77
78    val identity_files =
79      space_explode(':', options.string("ssh_identity_files")).map(Path.explode(_))
80    for (identity_file <- identity_files if identity_file.is_file)
81      jsch.addIdentity(File.platform_path(identity_file))
82
83    new Context(options, jsch)
84  }
85
86  def open_session(options: Options, host: String, user: String = "", port: Int = 0,
87      proxy_host: String = "", proxy_user: String = "", proxy_port: Int = 0,
88      permissive: Boolean = false): Session =
89    init_context(options).open_session(host = host, user = user, port = port,
90      proxy_host = proxy_host, proxy_user = proxy_user, proxy_port = proxy_port,
91      permissive = permissive)
92
93  class Context private[SSH](val options: Options, val jsch: JSch)
94  {
95    def update_options(new_options: Options): Context = new Context(new_options, jsch)
96
97    private def connect_session(host: String, user: String = "", port: Int = 0,
98      host_key_permissive: Boolean = false,
99      nominal_host: String = "",
100      nominal_user: String = "",
101      on_close: () => Unit = () => ()): Session =
102    {
103      val session = jsch.getSession(proper_string(user).orNull, host, make_port(port))
104
105      session.setUserInfo(No_User_Info)
106      session.setServerAliveInterval(alive_interval(options))
107      session.setServerAliveCountMax(alive_count_max(options))
108      session.setConfig("MaxAuthTries", "3")
109      if (host_key_permissive) session.setConfig("StrictHostKeyChecking", "no")
110      if (nominal_host != "") session.setHostKeyAlias(nominal_host)
111
112      if (options.bool("ssh_compression")) {
113        session.setConfig("compression.s2c", "zlib@openssh.com,zlib,none")
114        session.setConfig("compression.c2s", "zlib@openssh.com,zlib,none")
115        session.setConfig("compression_level", "9")
116      }
117      session.connect(connect_timeout(options))
118      new Session(options, session, on_close,
119        proper_string(nominal_host) getOrElse host,
120        proper_string(nominal_user) getOrElse user)
121    }
122
123    def open_session(host: String, user: String = "", port: Int = 0,
124      proxy_host: String = "", proxy_user: String = "", proxy_port: Int = 0,
125      permissive: Boolean = false): Session =
126    {
127      if (proxy_host == "") connect_session(host = host, user = user, port = port)
128      else {
129        val proxy = connect_session(host = proxy_host, port = proxy_port, user = proxy_user)
130
131        val fw =
132          try { proxy.port_forwarding(remote_host = host, remote_port = make_port(port)) }
133          catch { case exn: Throwable => proxy.close; throw exn }
134
135        try {
136          connect_session(host = fw.local_host, port = fw.local_port,
137            host_key_permissive = permissive,
138            nominal_host = host, nominal_user = user, user = user,
139            on_close = () => { fw.close; proxy.close })
140        }
141        catch { case exn: Throwable => fw.close; proxy.close; throw exn }
142      }
143    }
144  }
145
146
147  /* logging */
148
149  def logging(verbose: Boolean = true, debug: Boolean = false)
150  {
151    JSch.setLogger(if (verbose) new Logger(debug) else null)
152  }
153
154  private class Logger(debug: Boolean) extends JSch_Logger
155  {
156    def isEnabled(level: Int): Boolean = level != JSch_Logger.DEBUG || debug
157
158    def log(level: Int, msg: String)
159    {
160      level match {
161        case JSch_Logger.ERROR | JSch_Logger.FATAL => Output.error_message(msg)
162        case JSch_Logger.WARN => Output.warning(msg)
163        case _ => Output.writeln(msg)
164      }
165    }
166  }
167
168
169  /* user info */
170
171  object No_User_Info extends UserInfo
172  {
173    def getPassphrase: String = null
174    def getPassword: String = null
175    def promptPassword(msg: String): Boolean = false
176    def promptPassphrase(msg: String): Boolean = false
177    def promptYesNo(msg: String): Boolean = false
178    def showMessage(msg: String): Unit = Output.writeln(msg)
179  }
180
181
182  /* port forwarding */
183
184  object Port_Forwarding
185  {
186    def open(ssh: Session, ssh_close: Boolean,
187      local_host: String, local_port: Int, remote_host: String, remote_port: Int): Port_Forwarding =
188    {
189      val port = ssh.session.setPortForwardingL(local_host, local_port, remote_host, remote_port)
190      new Port_Forwarding(ssh, ssh_close, local_host, port, remote_host, remote_port)
191    }
192  }
193
194  class Port_Forwarding private[SSH](
195    ssh: SSH.Session,
196    ssh_close: Boolean,
197    val local_host: String,
198    val local_port: Int,
199    val remote_host: String,
200    val remote_port: Int) extends AutoCloseable
201  {
202    override def toString: String =
203      local_host + ":" + local_port + ":" + remote_host + ":" + remote_port
204
205    def close()
206    {
207      ssh.session.delPortForwardingL(local_host, local_port)
208      if (ssh_close) ssh.close()
209    }
210  }
211
212
213  /* Sftp channel */
214
215  type Attrs = SftpATTRS
216
217  sealed case class Dir_Entry(name: String, is_dir: Boolean)
218  {
219    def is_file: Boolean = !is_dir
220  }
221
222
223  /* exec channel */
224
225  private val exec_wait_delay = Time.seconds(0.3)
226
227  class Exec private[SSH](session: Session, channel: ChannelExec) extends AutoCloseable
228  {
229    override def toString: String = "exec " + session.toString
230
231    def close() { channel.disconnect }
232
233    val exit_status: Future[Int] =
234      Future.thread("ssh_wait") {
235        while (!channel.isClosed) Thread.sleep(exec_wait_delay.ms)
236        channel.getExitStatus
237      }
238
239    val stdin: OutputStream = channel.getOutputStream
240    val stdout: InputStream = channel.getInputStream
241    val stderr: InputStream = channel.getErrStream
242
243    // connect after preparing streams
244    channel.connect(connect_timeout(session.options))
245
246    def result(
247      progress_stdout: String => Unit = (_: String) => (),
248      progress_stderr: String => Unit = (_: String) => (),
249      strict: Boolean = true): Process_Result =
250    {
251      stdin.close
252
253      def read_lines(stream: InputStream, progress: String => Unit): List[String] =
254      {
255        val result = new mutable.ListBuffer[String]
256        val line_buffer = new ByteArrayOutputStream(100)
257        def line_flush()
258        {
259          val line = Library.trim_line(line_buffer.toString(UTF8.charset_name))
260          progress(line)
261          result += line
262          line_buffer.reset
263        }
264
265        var c = 0
266        var finished = false
267        while (!finished) {
268          while ({ c = stream.read; c != -1 && c != 10 }) line_buffer.write(c)
269          if (c == 10) line_flush()
270          else if (channel.isClosed) {
271            if (line_buffer.size > 0) line_flush()
272            finished = true
273          }
274          else Thread.sleep(exec_wait_delay.ms)
275        }
276
277        result.toList
278      }
279
280      val out_lines = Future.thread("ssh_stdout") { read_lines(stdout, progress_stdout) }
281      val err_lines = Future.thread("ssh_stderr") { read_lines(stderr, progress_stderr) }
282
283      def terminate()
284      {
285        close
286        out_lines.join
287        err_lines.join
288        exit_status.join
289      }
290
291      val rc =
292        try { exit_status.join }
293        catch { case Exn.Interrupt() => terminate(); Exn.Interrupt.return_code }
294
295      close
296      if (strict && rc == Exn.Interrupt.return_code) throw Exn.Interrupt()
297
298      Process_Result(rc, out_lines.join, err_lines.join)
299    }
300  }
301
302
303  /* session */
304
305  class Session private[SSH](
306    val options: Options,
307    val session: JSch_Session,
308    on_close: () => Unit,
309    val nominal_host: String,
310    val nominal_user: String) extends System with AutoCloseable
311  {
312    def update_options(new_options: Options): Session =
313      new Session(new_options, session, on_close, nominal_host, nominal_user)
314
315    def host: String = if (session.getHost == null) "" else session.getHost
316
317    override def hg_url: String =
318      "ssh://" + user_prefix(nominal_user) + nominal_host + "/"
319
320    override def prefix: String =
321      user_prefix(session.getUserName) + host + port_suffix(session.getPort) + ":"
322
323    override def toString: String =
324      user_prefix(session.getUserName) + host + port_suffix(session.getPort) +
325      (if (session.isConnected) "" else " (disconnected)")
326
327
328    /* port forwarding */
329
330    def port_forwarding(
331        remote_port: Int, remote_host: String = "localhost",
332        local_port: Int = 0, local_host: String = "localhost",
333        ssh_close: Boolean = false): Port_Forwarding =
334      Port_Forwarding.open(this, ssh_close, local_host, local_port, remote_host, remote_port)
335
336
337    /* sftp channel */
338
339    val sftp: ChannelSftp = session.openChannel("sftp").asInstanceOf[ChannelSftp]
340    sftp.connect(connect_timeout(options))
341
342    def close() { sftp.disconnect; session.disconnect; on_close() }
343
344    val settings: Map[String, String] =
345    {
346      val home = sftp.getHome
347      Map("HOME" -> home, "USER_HOME" -> home)
348    }
349    override def expand_path(path: Path): Path = path.expand_env(settings)
350    def remote_path(path: Path): String = expand_path(path).implode
351    override def bash_path(path: Path): String = Bash.string(remote_path(path))
352
353    def chmod(permissions: Int, path: Path): Unit = sftp.chmod(permissions, remote_path(path))
354    def mv(path1: Path, path2: Path): Unit = sftp.rename(remote_path(path1), remote_path(path2))
355    def rm(path: Path): Unit = sftp.rm(remote_path(path))
356    def mkdir(path: Path): Unit = sftp.mkdir(remote_path(path))
357    def rmdir(path: Path): Unit = sftp.rmdir(remote_path(path))
358
359    private def test_entry(path: Path, as_dir: Boolean): Boolean =
360      try {
361        val is_dir = sftp.stat(remote_path(path)).isDir
362        if (as_dir) is_dir else !is_dir
363      }
364      catch { case _: SftpException => false }
365
366    override def is_dir(path: Path): Boolean = test_entry(path, true)
367    override def is_file(path: Path): Boolean = test_entry(path, false)
368
369    def is_link(path: Path): Boolean =
370      try { sftp.lstat(remote_path(path)).isLink }
371      catch { case _: SftpException => false }
372
373    override def mkdirs(path: Path): Unit =
374      if (!is_dir(path)) {
375        execute(
376          "perl -e \"use File::Path make_path; make_path('" + remote_path(path) + "');\"")
377        if (!is_dir(path)) error("Failed to create directory: " + quote(remote_path(path)))
378      }
379
380    def read_dir(path: Path): List[Dir_Entry] =
381    {
382      if (!is_dir(path)) error("No such directory: " + path.toString)
383
384      val dir_name = remote_path(path)
385      val dir = sftp.ls(dir_name)
386      (for {
387        i <- (0 until dir.size).iterator
388        a = dir.get(i).asInstanceOf[AnyRef]
389        name = Untyped.get[String](a, "filename")
390        attrs = Untyped.get[Attrs](a, "attrs")
391        if name != "." && name != ".."
392      }
393      yield {
394        Dir_Entry(name,
395          if (attrs.isLink) {
396            try { sftp.stat(dir_name + "/" + name).isDir }
397            catch { case _: SftpException => false }
398          }
399          else attrs.isDir)
400      }).toList.sortBy(_.name)
401    }
402
403    def find_files(
404      start: Path,
405      pred: Path => Boolean = _ => true,
406      include_dirs: Boolean = false,
407      follow_links: Boolean = false): List[Path] =
408    {
409      val result = new mutable.ListBuffer[Path]
410      def check(path: Path) { if (pred(path)) result += path }
411
412      def find(dir: Path)
413      {
414        if (include_dirs) check(dir)
415        if (follow_links || !is_link(dir)) {
416          for (entry <- read_dir(dir)) {
417            val path = dir + Path.basic(entry.name)
418            if (entry.is_file) check(path) else find(path)
419          }
420        }
421      }
422      if (is_file(start)) check(start) else find(start)
423
424      result.toList
425    }
426
427    def open_input(path: Path): InputStream = sftp.get(remote_path(path))
428    def open_output(path: Path): OutputStream = sftp.put(remote_path(path))
429
430    def read_file(path: Path, local_path: Path): Unit =
431      sftp.get(remote_path(path), File.platform_path(local_path))
432    def read_bytes(path: Path): Bytes = using(open_input(path))(Bytes.read_stream(_))
433    def read(path: Path): String = using(open_input(path))(File.read_stream(_))
434
435    def write_file(path: Path, local_path: Path): Unit =
436      sftp.put(File.platform_path(local_path), remote_path(path))
437    def write_bytes(path: Path, bytes: Bytes): Unit =
438      using(open_output(path))(bytes.write_stream(_))
439    def write(path: Path, text: String): Unit =
440      using(open_output(path))(stream => Bytes(text).write_stream(stream))
441
442
443    /* exec channel */
444
445    def exec(command: String): Exec =
446    {
447      val channel = session.openChannel("exec").asInstanceOf[ChannelExec]
448      channel.setCommand("export USER_HOME=\"$HOME\"\n" + command)
449      new Exec(this, channel)
450    }
451
452    override def execute(command: String,
453        progress_stdout: String => Unit = (_: String) => (),
454        progress_stderr: String => Unit = (_: String) => (),
455        strict: Boolean = true): Process_Result =
456      exec(command).result(progress_stdout, progress_stderr, strict)
457
458
459    /* tmp dirs */
460
461    def rm_tree(dir: Path): Unit = rm_tree(remote_path(dir))
462
463    def rm_tree(remote_dir: String): Unit =
464      execute("rm -r -f " + Bash.string(remote_dir)).check
465
466    def tmp_dir(): String =
467      execute("mktemp -d -t tmp.XXXXXXXXXX").check.out
468
469    def with_tmp_dir[A](body: Path => A): A =
470    {
471      val remote_dir = tmp_dir()
472      try { body(Path.explode(remote_dir)) } finally { rm_tree(remote_dir) }
473    }
474  }
475
476
477
478  /* system operations */
479
480  trait System
481  {
482    def hg_url: String = ""
483    def prefix: String = ""
484
485    def expand_path(path: Path): Path = path.expand
486    def bash_path(path: Path): String = File.bash_path(path)
487    def is_dir(path: Path): Boolean = path.is_dir
488    def is_file(path: Path): Boolean = path.is_file
489    def mkdirs(path: Path): Unit = Isabelle_System.mkdirs(path)
490
491    def execute(command: String,
492        progress_stdout: String => Unit = (_: String) => (),
493        progress_stderr: String => Unit = (_: String) => (),
494        strict: Boolean = true): Process_Result =
495      Isabelle_System.bash(command, progress_stdout = progress_stdout,
496        progress_stderr = progress_stderr, strict = strict)
497  }
498
499  object Local extends System
500}
501