1/*  Title:      Pure/General/ssh.scala
2    Author:     Makarius
3
4SSH client based on JSch (see also http://www.jcraft.com/jsch/examples).
5*/
6
7package isabelle
8
9
10import java.io.{InputStream, OutputStream, ByteArrayOutputStream}
11
12import scala.collection.{mutable, JavaConversions}
13
14import com.jcraft.jsch.{JSch, Logger => JSch_Logger, Session => JSch_Session, SftpException,
15  OpenSSHConfig, UserInfo, Channel => JSch_Channel, ChannelExec, ChannelSftp, SftpATTRS}
16
17
18object SSH
19{
20  /* target machine: user@host syntax */
21
22  object Target
23  {
24    val Pattern = "^([^@]+)@(.+)$".r
25
26    def parse(s: String): (String, String) =
27      s match {
28        case Pattern(user, host) => (user, host)
29        case _ => ("", s)
30      }
31
32    def unapplySeq(s: String): Option[List[String]] =
33      parse(s) match {
34        case (_, "") => None
35        case (user, host) => Some(List(user, host))
36      }
37  }
38
39  val default_port = 22
40  def make_port(port: Int): Int = if (port > 0) port else default_port
41
42  def connect_timeout(options: Options): Int =
43    options.seconds("ssh_connect_timeout").ms.toInt
44
45  def alive_interval(options: Options): Int =
46    options.seconds("ssh_alive_interval").ms.toInt
47
48  def alive_count_max(options: Options): Int =
49    options.int("ssh_alive_count_max")
50
51
52  /* init context */
53
54  def init_context(options: Options): Context =
55  {
56    val config_dir = Path.explode(options.string("ssh_config_dir"))
57    if (!config_dir.is_dir) error("Bad ssh config directory: " + config_dir)
58
59    val jsch = new JSch
60
61    val config_file = Path.explode(options.string("ssh_config_file"))
62    if (config_file.is_file)
63      jsch.setConfigRepository(OpenSSHConfig.parseFile(File.platform_path(config_file)))
64
65    val known_hosts = config_dir + Path.explode("known_hosts")
66    if (!known_hosts.is_file) known_hosts.file.createNewFile
67    jsch.setKnownHosts(File.platform_path(known_hosts))
68
69    val identity_files =
70      space_explode(':', options.string("ssh_identity_files")).map(Path.explode(_))
71    for (identity_file <- identity_files if identity_file.is_file)
72      jsch.addIdentity(File.platform_path(identity_file))
73
74    new Context(options, jsch)
75  }
76
77  def open_session(options: Options, host: String, user: String = "", port: Int = 0,
78      proxy_host: String = "", proxy_user: String = "", proxy_port: Int = 0,
79      permissive: Boolean = false): Session =
80    init_context(options).open_session(host = host, user = user, port = port,
81      proxy_host = proxy_host, proxy_user = proxy_user, proxy_port = proxy_port,
82      permissive = permissive)
83
84  class Context private[SSH](val options: Options, val jsch: JSch)
85  {
86    def update_options(new_options: Options): Context = new Context(new_options, jsch)
87
88    def connect_session(host: String, user: String = "", port: Int = 0,
89      host_key_permissive: Boolean = false,
90      host_key_alias: String = "",
91      on_close: () => Unit = () => ()): Session =
92    {
93      val session = jsch.getSession(proper_string(user) getOrElse null, host, make_port(port))
94
95      session.setUserInfo(No_User_Info)
96      session.setServerAliveInterval(alive_interval(options))
97      session.setServerAliveCountMax(alive_count_max(options))
98      session.setConfig("MaxAuthTries", "3")
99      if (host_key_permissive) session.setConfig("StrictHostKeyChecking", "no")
100      if (host_key_alias != "") session.setHostKeyAlias(host_key_alias)
101
102      if (options.bool("ssh_compression")) {
103        session.setConfig("compression.s2c", "zlib@openssh.com,zlib,none")
104        session.setConfig("compression.c2s", "zlib@openssh.com,zlib,none")
105        session.setConfig("compression_level", "9")
106      }
107      session.connect(connect_timeout(options))
108      new Session(options, session, on_close)
109    }
110
111    def open_session(host: String, user: String = "", port: Int = 0,
112      proxy_host: String = "", proxy_user: String = "", proxy_port: Int = 0,
113      permissive: Boolean = false): Session =
114    {
115      if (proxy_host == "") connect_session(host = host, user = user, port = port)
116      else {
117        val proxy = connect_session(host = proxy_host, port = proxy_port, user = proxy_user)
118
119        val fw =
120          try { proxy.port_forwarding(remote_host = host, remote_port = make_port(port)) }
121          catch { case exn: Throwable => proxy.close; throw exn }
122
123        try {
124          connect_session(host = fw.local_host, port = fw.local_port,
125            host_key_permissive = permissive, host_key_alias = host,
126            user = user, on_close = () => { fw.close; proxy.close })
127        }
128        catch { case exn: Throwable => fw.close; proxy.close; throw exn }
129      }
130    }
131  }
132
133
134  /* logging */
135
136  def logging(verbose: Boolean = true, debug: Boolean = false)
137  {
138    JSch.setLogger(if (verbose) new Logger(debug) else null)
139  }
140
141  private class Logger(debug: Boolean) extends JSch_Logger
142  {
143    def isEnabled(level: Int): Boolean = level != JSch_Logger.DEBUG || debug
144
145    def log(level: Int, msg: String)
146    {
147      level match {
148        case JSch_Logger.ERROR | JSch_Logger.FATAL => Output.error_message(msg)
149        case JSch_Logger.WARN => Output.warning(msg)
150        case _ => Output.writeln(msg)
151      }
152    }
153  }
154
155
156  /* user info */
157
158  object No_User_Info extends UserInfo
159  {
160    def getPassphrase: String = null
161    def getPassword: String = null
162    def promptPassword(msg: String): Boolean = false
163    def promptPassphrase(msg: String): Boolean = false
164    def promptYesNo(msg: String): Boolean = false
165    def showMessage(msg: String): Unit = Output.writeln(msg)
166  }
167
168
169  /* port forwarding */
170
171  object Port_Forwarding
172  {
173    def open(ssh: Session, ssh_close: Boolean,
174      local_host: String, local_port: Int, remote_host: String, remote_port: Int): Port_Forwarding =
175    {
176      val port = ssh.session.setPortForwardingL(local_host, local_port, remote_host, remote_port)
177      new Port_Forwarding(ssh, ssh_close, local_host, port, remote_host, remote_port)
178    }
179  }
180
181  class Port_Forwarding private[SSH](
182    ssh: SSH.Session,
183    ssh_close: Boolean,
184    val local_host: String,
185    val local_port: Int,
186    val remote_host: String,
187    val remote_port: Int)
188  {
189    override def toString: String =
190      local_host + ":" + local_port + ":" + remote_host + ":" + remote_port
191
192    def close()
193    {
194      ssh.session.delPortForwardingL(local_host, local_port)
195      if (ssh_close) ssh.close()
196    }
197  }
198
199
200  /* Sftp channel */
201
202  type Attrs = SftpATTRS
203
204  sealed case class Dir_Entry(name: Path, attrs: Attrs)
205  {
206    def is_file: Boolean = attrs.isReg
207    def is_dir: Boolean = attrs.isDir
208  }
209
210
211  /* exec channel */
212
213  private val exec_wait_delay = Time.seconds(0.3)
214
215  class Exec private[SSH](session: Session, channel: ChannelExec)
216  {
217    override def toString: String = "exec " + session.toString
218
219    def close() { channel.disconnect }
220
221    val exit_status: Future[Int] =
222      Future.thread("ssh_wait") {
223        while (!channel.isClosed) Thread.sleep(exec_wait_delay.ms)
224        channel.getExitStatus
225      }
226
227    val stdin: OutputStream = channel.getOutputStream
228    val stdout: InputStream = channel.getInputStream
229    val stderr: InputStream = channel.getErrStream
230
231    // connect after preparing streams
232    channel.connect(connect_timeout(session.options))
233
234    def result(
235      progress_stdout: String => Unit = (_: String) => (),
236      progress_stderr: String => Unit = (_: String) => (),
237      strict: Boolean = true): Process_Result =
238    {
239      stdin.close
240
241      def read_lines(stream: InputStream, progress: String => Unit): List[String] =
242      {
243        val result = new mutable.ListBuffer[String]
244        val line_buffer = new ByteArrayOutputStream(100)
245        def line_flush()
246        {
247          val line = Library.trim_line(line_buffer.toString(UTF8.charset_name))
248          progress(line)
249          result += line
250          line_buffer.reset
251        }
252
253        var c = 0
254        var finished = false
255        while (!finished) {
256          while ({ c = stream.read; c != -1 && c != 10 }) line_buffer.write(c)
257          if (c == 10) line_flush()
258          else if (channel.isClosed) {
259            if (line_buffer.size > 0) line_flush()
260            finished = true
261          }
262          else Thread.sleep(exec_wait_delay.ms)
263        }
264
265        result.toList
266      }
267
268      val out_lines = Future.thread("ssh_stdout") { read_lines(stdout, progress_stdout) }
269      val err_lines = Future.thread("ssh_stderr") { read_lines(stderr, progress_stderr) }
270
271      def terminate()
272      {
273        close
274        out_lines.join
275        err_lines.join
276        exit_status.join
277      }
278
279      val rc =
280        try { exit_status.join }
281        catch { case Exn.Interrupt() => terminate(); Exn.Interrupt.return_code }
282
283      close
284      if (strict && rc == Exn.Interrupt.return_code) throw Exn.Interrupt()
285
286      Process_Result(rc, out_lines.join, err_lines.join)
287    }
288  }
289
290
291  /* session */
292
293  class Session private[SSH](
294    val options: Options,
295    val session: JSch_Session,
296    on_close: () => Unit) extends System
297  {
298    def update_options(new_options: Options): Session =
299      new Session(new_options, session, on_close)
300
301    def user_prefix: String = if (session.getUserName == null) "" else session.getUserName + "@"
302    def host: String = if (session.getHost == null) "" else session.getHost
303    def port_suffix: String = if (session.getPort == default_port) "" else ":" + session.getPort
304    override def hg_url: String = "ssh://" + user_prefix + host + port_suffix + "/"
305
306    override def prefix: String =
307      user_prefix + host + port_suffix + ":"
308
309    override def toString: String =
310      user_prefix + host + port_suffix + (if (session.isConnected) "" else " (disconnected)")
311
312
313    /* port forwarding */
314
315    def port_forwarding(
316        remote_port: Int, remote_host: String = "localhost",
317        local_port: Int = 0, local_host: String = "localhost",
318        ssh_close: Boolean = false): Port_Forwarding =
319      Port_Forwarding.open(this, ssh_close, local_host, local_port, remote_host, remote_port)
320
321
322    /* sftp channel */
323
324    val sftp: ChannelSftp = session.openChannel("sftp").asInstanceOf[ChannelSftp]
325    sftp.connect(connect_timeout(options))
326
327    def close() { sftp.disconnect; session.disconnect; on_close() }
328
329    val settings: Map[String, String] =
330    {
331      val home = sftp.getHome
332      Map("HOME" -> home, "USER_HOME" -> home)
333    }
334    override def expand_path(path: Path): Path = path.expand_env(settings)
335    def remote_path(path: Path): String = expand_path(path).implode
336    override def bash_path(path: Path): String = Bash.string(remote_path(path))
337
338    def chmod(permissions: Int, path: Path): Unit = sftp.chmod(permissions, remote_path(path))
339    def mv(path1: Path, path2: Path): Unit = sftp.rename(remote_path(path1), remote_path(path2))
340    def rm(path: Path): Unit = sftp.rm(remote_path(path))
341    def mkdir(path: Path): Unit = sftp.mkdir(remote_path(path))
342    def rmdir(path: Path): Unit = sftp.rmdir(remote_path(path))
343
344    def stat(path: Path): Option[Dir_Entry] =
345      try { Some(Dir_Entry(expand_path(path), sftp.stat(remote_path(path)))) }
346      catch { case _: SftpException => None }
347
348    override def is_file(path: Path): Boolean = stat(path).map(_.is_file) getOrElse false
349    override def is_dir(path: Path): Boolean = stat(path).map(_.is_dir) getOrElse false
350
351    override def mkdirs(path: Path): Unit =
352      if (!is_dir(path)) {
353        execute(
354          "perl -e \"use File::Path make_path; make_path('" + remote_path(path) + "');\"")
355        if (!is_dir(path)) error("Failed to create directory: " + quote(remote_path(path)))
356      }
357
358    def read_dir(path: Path): List[Dir_Entry] =
359    {
360      val dir = sftp.ls(remote_path(path))
361      (for {
362        i <- (0 until dir.size).iterator
363        a = dir.get(i).asInstanceOf[AnyRef]
364        name = Untyped.get[String](a, "filename")
365        attrs = Untyped.get[Attrs](a, "attrs")
366        if name != "." && name != ".."
367      } yield Dir_Entry(Path.basic(name), attrs)).toList
368    }
369
370    def find_files(root: Path, pred: Dir_Entry => Boolean = _ => true): List[Dir_Entry] =
371    {
372      def find(dir: Path): List[Dir_Entry] =
373        read_dir(dir).flatMap(entry =>
374          {
375            val file = dir + entry.name
376            if (entry.is_dir) find(file)
377            else if (pred(entry)) List(entry.copy(name = file))
378            else Nil
379          })
380      find(root)
381    }
382
383    def open_input(path: Path): InputStream = sftp.get(remote_path(path))
384    def open_output(path: Path): OutputStream = sftp.put(remote_path(path))
385
386    def read_file(path: Path, local_path: Path): Unit =
387      sftp.get(remote_path(path), File.platform_path(local_path))
388    def read_bytes(path: Path): Bytes = using(open_input(path))(Bytes.read_stream(_))
389    def read(path: Path): String = using(open_input(path))(File.read_stream(_))
390
391    def write_file(path: Path, local_path: Path): Unit =
392      sftp.put(File.platform_path(local_path), remote_path(path))
393    def write_bytes(path: Path, bytes: Bytes): Unit =
394      using(open_output(path))(bytes.write_stream(_))
395    def write(path: Path, text: String): Unit =
396      using(open_output(path))(stream => Bytes(text).write_stream(stream))
397
398
399    /* exec channel */
400
401    def exec(command: String): Exec =
402    {
403      val channel = session.openChannel("exec").asInstanceOf[ChannelExec]
404      channel.setCommand("export USER_HOME=\"$HOME\"\n" + command)
405      new Exec(this, channel)
406    }
407
408    override def execute(command: String,
409        progress_stdout: String => Unit = (_: String) => (),
410        progress_stderr: String => Unit = (_: String) => (),
411        strict: Boolean = true): Process_Result =
412      exec(command).result(progress_stdout, progress_stderr, strict)
413
414
415    /* tmp dirs */
416
417    def rm_tree(dir: Path): Unit = rm_tree(remote_path(dir))
418
419    def rm_tree(remote_dir: String): Unit =
420      execute("rm -r -f " + Bash.string(remote_dir)).check
421
422    def tmp_dir(): String =
423      execute("mktemp -d -t tmp.XXXXXXXXXX").check.out
424
425    def with_tmp_dir[A](body: Path => A): A =
426    {
427      val remote_dir = tmp_dir()
428      try { body(Path.explode(remote_dir)) } finally { rm_tree(remote_dir) }
429    }
430  }
431
432
433
434  /* system operations */
435
436  trait System
437  {
438    def hg_url: String = ""
439    def prefix: String = ""
440
441    def expand_path(path: Path): Path = path.expand
442    def bash_path(path: Path): String = File.bash_path(path)
443    def is_file(path: Path): Boolean = path.is_file
444    def is_dir(path: Path): Boolean = path.is_dir
445    def mkdirs(path: Path): Unit = Isabelle_System.mkdirs(path)
446
447    def execute(command: String,
448        progress_stdout: String => Unit = (_: String) => (),
449        progress_stderr: String => Unit = (_: String) => (),
450        strict: Boolean = true): Process_Result =
451      Isabelle_System.bash(command, progress_stdout = progress_stdout,
452        progress_stderr = progress_stderr, strict = strict)
453  }
454
455  object Local extends System
456}
457