1/*  Title:      Pure/Tools/server.scala
2    Author:     Makarius
3
4Resident Isabelle servers.
5
6Message formats:
7  - short message (single line):
8      NAME ARGUMENT
9  - long message (multiple lines):
10      BYTE_LENGTH
11      NAME ARGUMENT
12
13Argument formats:
14  - Unit as empty string
15  - XML.Elem in YXML notation
16  - JSON.T in standard notation
17*/
18
19package isabelle
20
21
22import java.io.{BufferedInputStream, BufferedOutputStream, InputStreamReader, OutputStreamWriter,
23  IOException}
24import java.net.{Socket, SocketException, SocketTimeoutException, ServerSocket, InetAddress}
25
26
27object Server
28{
29  /* message argument */
30
31  object Argument
32  {
33    def is_name_char(c: Char): Boolean =
34      Symbol.is_ascii_letter(c) || Symbol.is_ascii_digit(c) || c == '_' || c == '.'
35
36    def split(msg: String): (String, String) =
37    {
38      val name = msg.takeWhile(is_name_char(_))
39      val argument = msg.substring(name.length).dropWhile(Symbol.is_ascii_blank(_))
40      (name, argument)
41    }
42
43    def print(arg: Any): String =
44      arg match {
45        case () => ""
46        case t: XML.Elem => YXML.string_of_tree(t)
47        case t: JSON.T => JSON.Format(t)
48      }
49
50    def parse(argument: String): Any =
51      if (argument == "") ()
52      else if (YXML.detect_elem(argument)) YXML.parse_elem(argument)
53      else JSON.parse(argument, strict = false)
54
55    def unapply(argument: String): Option[Any] =
56      try { Some(parse(argument)) }
57      catch { case ERROR(_) => None }
58  }
59
60
61  /* input command */
62
63  object Command
64  {
65    type T = PartialFunction[(Context, Any), Any]
66
67    private val table: Map[String, T] =
68      Map(
69        "help" -> { case (_, ()) => table.keySet.toList.sorted },
70        "echo" -> { case (_, t) => t },
71        "shutdown" -> { case (context, ()) => context.server.shutdown() },
72        "cancel" ->
73          { case (context, Server_Commands.Cancel(args)) => context.cancel_task(args.task) },
74        "session_build" ->
75          { case (context, Server_Commands.Session_Build(args)) =>
76              context.make_task(task =>
77                Server_Commands.Session_Build.command(args, progress = task.progress)._1)
78          },
79        "session_start" ->
80          { case (context, Server_Commands.Session_Start(args)) =>
81              context.make_task(task =>
82                {
83                  val (res, entry) =
84                    Server_Commands.Session_Start.command(
85                      args, progress = task.progress, log = context.server.log)
86                  context.server.add_session(entry)
87                  res
88                })
89          },
90        "session_stop" ->
91          { case (context, Server_Commands.Session_Stop(id)) =>
92              context.make_task(_ =>
93                {
94                  val session = context.server.remove_session(id)
95                  Server_Commands.Session_Stop.command(session)._1
96                })
97          },
98        "use_theories" ->
99          { case (context, Server_Commands.Use_Theories(args)) =>
100              context.make_task(task =>
101                {
102                  val session = context.server.the_session(args.session_id)
103                  Server_Commands.Use_Theories.command(
104                    args, session, id = task.id, progress = task.progress)._1
105                })
106          },
107        "purge_theories" ->
108          { case (context, Server_Commands.Purge_Theories(args)) =>
109              val session = context.server.the_session(args.session_id)
110              Server_Commands.Purge_Theories.command(args, session)._1
111          })
112
113    def unapply(name: String): Option[T] = table.get(name)
114  }
115
116
117  /* output reply */
118
119  class Error(val message: String, val json: JSON.Object.T = JSON.Object.empty)
120    extends RuntimeException(message)
121
122  def json_error(exn: Throwable): JSON.Object.T =
123    exn match {
124      case e: Error => Reply.error_message(e.message) ++ e.json
125      case ERROR(msg) => Reply.error_message(msg)
126      case _ if Exn.is_interrupt(exn) => Reply.error_message(Exn.message(exn))
127      case _ => JSON.Object.empty
128    }
129
130  object Reply extends Enumeration
131  {
132    val OK, ERROR, FINISHED, FAILED, NOTE = Value
133
134    def message(msg: String, kind: String = ""): JSON.Object.T =
135      JSON.Object(Markup.KIND -> proper_string(kind).getOrElse(Markup.WRITELN), "message" -> msg)
136
137    def error_message(msg: String): JSON.Object.T =
138      message(msg, kind = Markup.ERROR)
139
140    def unapply(msg: String): Option[(Reply.Value, Any)] =
141    {
142      if (msg == "") None
143      else {
144        val (name, argument) = Argument.split(msg)
145        for {
146          reply <-
147            try { Some(withName(name)) }
148            catch { case _: NoSuchElementException => None }
149          arg <- Argument.unapply(argument)
150        } yield (reply, arg)
151      }
152    }
153  }
154
155
156  /* socket connection */
157
158  object Connection
159  {
160    def apply(socket: Socket): Connection =
161      new Connection(socket)
162  }
163
164  class Connection private(socket: Socket) extends AutoCloseable
165  {
166    override def toString: String = socket.toString
167
168    def close() { socket.close }
169
170    def set_timeout(t: Time) { socket.setSoTimeout(t.ms.toInt) }
171
172    private val in = new BufferedInputStream(socket.getInputStream)
173    private val out = new BufferedOutputStream(socket.getOutputStream)
174    private val out_lock: AnyRef = new Object
175
176    def tty_loop(interrupt: Option[() => Unit] = None): TTY_Loop =
177      new TTY_Loop(
178        new OutputStreamWriter(out),
179        new InputStreamReader(in),
180        writer_lock = out_lock,
181        interrupt = interrupt)
182
183    def read_password(password: String): Boolean =
184      try { Byte_Message.read_line(in).map(_.text) == Some(password) }
185      catch { case _: IOException => false }
186
187    def read_message(): Option[String] =
188      try { Byte_Message.read_line_message(in).map(_.text) }
189      catch { case _: IOException => None }
190
191    def write_message(msg: String): Unit =
192      out_lock.synchronized { Byte_Message.write_line_message(out, Bytes(UTF8.bytes(msg))) }
193
194    def reply(r: Reply.Value, arg: Any)
195    {
196      val argument = Argument.print(arg)
197      write_message(if (argument == "") r.toString else r.toString + " " + argument)
198    }
199
200    def reply_ok(arg: Any) { reply(Reply.OK, arg) }
201    def reply_error(arg: Any) { reply(Reply.ERROR, arg) }
202    def reply_error_message(message: String, more: JSON.Object.Entry*): Unit =
203      reply_error(Reply.error_message(message) ++ more)
204
205    def notify(arg: Any) { reply(Reply.NOTE, arg) }
206  }
207
208
209  /* context with output channels */
210
211  class Context private[Server](val server: Server, connection: Connection)
212    extends AutoCloseable
213  {
214    context =>
215
216    def reply(r: Reply.Value, arg: Any) { connection.reply(r, arg) }
217    def notify(arg: Any) { connection.notify(arg) }
218    def message(kind: String, msg: String, more: JSON.Object.Entry*): Unit =
219      notify(Reply.message(msg, kind = kind) ++ more)
220    def writeln(msg: String, more: JSON.Object.Entry*): Unit = message(Markup.WRITELN, msg, more:_*)
221    def warning(msg: String, more: JSON.Object.Entry*): Unit = message(Markup.WARNING, msg, more:_*)
222    def error_message(msg: String, more: JSON.Object.Entry*): Unit =
223      message(Markup.ERROR, msg, more:_*)
224
225    def progress(more: JSON.Object.Entry*): Connection_Progress =
226      new Connection_Progress(context, more:_*)
227
228    override def toString: String = connection.toString
229
230
231    /* asynchronous tasks */
232
233    private val _tasks = Synchronized(Set.empty[Task])
234
235    def make_task(body: Task => JSON.Object.T): Task =
236    {
237      val task = new Task(context, body)
238      _tasks.change(_ + task)
239      task
240    }
241
242    def remove_task(task: Task): Unit =
243      _tasks.change(_ - task)
244
245    def cancel_task(id: UUID.T): Unit =
246      _tasks.change(tasks => { tasks.find(task => task.id == id).foreach(_.cancel); tasks })
247
248    def close()
249    {
250      while(_tasks.change_result(tasks => { tasks.foreach(_.cancel); (tasks.nonEmpty, tasks) }))
251      { _tasks.value.foreach(_.join) }
252    }
253  }
254
255  class Connection_Progress private[Server](context: Context, more: JSON.Object.Entry*)
256    extends Progress
257  {
258    override def echo(msg: String): Unit = context.writeln(msg, more:_*)
259    override def echo_warning(msg: String): Unit = context.warning(msg, more:_*)
260    override def echo_error_message(msg: String): Unit = context.error_message(msg, more:_*)
261
262    override def theory(theory: Progress.Theory)
263    {
264      val entries: List[JSON.Object.Entry] =
265        List("theory" -> theory.theory, "session" -> theory.session) :::
266          (theory.percentage match { case None => Nil case Some(p) => List("percentage" -> p) })
267      context.writeln(theory.message, entries ::: more.toList:_*)
268    }
269
270    override def nodes_status(nodes_status: Document_Status.Nodes_Status)
271    {
272      val json =
273        for ((name, node_status) <- nodes_status.present)
274          yield name.json + ("status" -> nodes_status(name).json)
275      context.notify(JSON.Object(Markup.KIND -> Markup.NODES_STATUS, Markup.NODES_STATUS -> json))
276    }
277
278    @volatile private var is_stopped = false
279    override def stopped: Boolean = is_stopped
280    def stop { is_stopped = true }
281
282    override def toString: String = context.toString
283  }
284
285  class Task private[Server](val context: Context, body: Task => JSON.Object.T)
286  {
287    task =>
288
289    val id: UUID.T = UUID.random()
290    val ident: JSON.Object.Entry = ("task" -> id.toString)
291
292    val progress: Connection_Progress = context.progress(ident)
293    def cancel { progress.stop }
294
295    private lazy val thread = Standard_Thread.fork("server_task")
296    {
297      Exn.capture { body(task) } match {
298        case Exn.Res(res) =>
299          context.reply(Reply.FINISHED, res + ident)
300        case Exn.Exn(exn) =>
301          val err = json_error(exn)
302          if (err.isEmpty) throw exn else context.reply(Reply.FAILED, err + ident)
303      }
304      progress.stop
305      context.remove_task(task)
306    }
307    def start { thread }
308    def join { thread.join }
309  }
310
311
312  /* server info */
313
314  val localhost_name: String = "127.0.0.1"
315  def localhost: InetAddress = InetAddress.getByName(localhost_name)
316
317  def print_address(port: Int): String = localhost_name + ":" + port
318
319  def print(port: Int, password: String): String =
320    print_address(port) + " (password " + quote(password) + ")"
321
322  object Info
323  {
324    private val Pattern =
325      ("""server "([^"]*)" = \Q""" + localhost_name + """\E:(\d+) \(password "([^"]*)"\)""").r
326
327    def parse(s: String): Option[Info] =
328      s match {
329        case Pattern(name, Value.Int(port), password) => Some(Info(name, port, password))
330        case _ => None
331      }
332
333    def apply(name: String, port: Int, password: String): Info =
334      new Info(name, port, password)
335  }
336
337  class Info private(val name: String, val port: Int, val password: String)
338  {
339    def address: String = print_address(port)
340
341    override def toString: String =
342      "server " + quote(name) + " = " + print(port, password)
343
344    def connection(): Connection =
345    {
346      val connection = Connection(new Socket(localhost, port))
347      connection.write_message(password)
348      connection
349    }
350
351    def active(): Boolean =
352      try {
353        using(connection())(connection =>
354          {
355            connection.set_timeout(Time.seconds(2.0))
356            connection.read_message() match {
357              case Some(Reply(Reply.OK, _)) => true
358              case _ => false
359            }
360          })
361      }
362      catch {
363        case _: IOException => false
364        case _: SocketException => false
365        case _: SocketTimeoutException => false
366      }
367  }
368
369
370  /* per-user servers */
371
372  val default_name = "isabelle"
373
374  object Data
375  {
376    val database = Path.explode("$ISABELLE_HOME_USER/servers.db")
377
378    val name = SQL.Column.string("name").make_primary_key
379    val port = SQL.Column.int("port")
380    val password = SQL.Column.string("password")
381    val table = SQL.Table("isabelle_servers", List(name, port, password))
382  }
383
384  def list(db: SQLite.Database): List[Info] =
385    if (db.tables.contains(Data.table.name)) {
386      db.using_statement(Data.table.select())(stmt =>
387        stmt.execute_query().iterator(res =>
388          Info(
389            res.string(Data.name),
390            res.int(Data.port),
391            res.string(Data.password))).toList.sortBy(_.name))
392    }
393    else Nil
394
395  def find(db: SQLite.Database, name: String): Option[Info] =
396    list(db).find(server_info => server_info.name == name && server_info.active)
397
398  def init(
399    name: String = default_name,
400    port: Int = 0,
401    existing_server: Boolean = false,
402    log: Logger = No_Logger): (Info, Option[Server]) =
403  {
404    using(SQLite.open_database(Data.database))(db =>
405      {
406        db.transaction {
407          Isabelle_System.chmod("600", Data.database)
408          db.create_table(Data.table)
409          list(db).filterNot(_.active).foreach(server_info =>
410            db.using_statement(Data.table.delete(Data.name.where_equal(server_info.name)))(
411              _.execute))
412        }
413        db.transaction {
414          find(db, name) match {
415            case Some(server_info) => (server_info, None)
416            case None =>
417              if (existing_server) error("Isabelle server " + quote(name) + " not running")
418
419              val server = new Server(port, log)
420              val server_info = Info(name, server.port, server.password)
421
422              db.using_statement(Data.table.delete(Data.name.where_equal(name)))(_.execute)
423              db.using_statement(Data.table.insert())(stmt =>
424              {
425                stmt.string(1) = server_info.name
426                stmt.int(2) = server_info.port
427                stmt.string(3) = server_info.password
428                stmt.execute()
429              })
430
431              server.start
432              (server_info, Some(server))
433          }
434        }
435      })
436  }
437
438  def exit(name: String = default_name): Boolean =
439  {
440    using(SQLite.open_database(Data.database))(db =>
441      db.transaction {
442        find(db, name) match {
443          case Some(server_info) =>
444            using(server_info.connection())(_.write_message("shutdown"))
445            while(server_info.active) { Thread.sleep(50) }
446            true
447          case None => false
448        }
449      })
450  }
451
452
453  /* Isabelle tool wrapper */
454
455  val isabelle_tool =
456    Isabelle_Tool("server", "manage resident Isabelle servers", args =>
457    {
458      var console = false
459      var log_file: Option[Path] = None
460      var operation_list = false
461      var operation_exit = false
462      var name = default_name
463      var port = 0
464      var existing_server = false
465
466      val getopts =
467        Getopts("""
468Usage: isabelle server [OPTIONS]
469
470  Options are:
471    -L FILE      logging on FILE
472    -c           console interaction with specified server
473    -l           list servers (alternative operation)
474    -n NAME      explicit server name (default: """ + default_name + """)
475    -p PORT      explicit server port
476    -s           assume existing server, no implicit startup
477    -x           exit specified server (alternative operation)
478
479  Manage resident Isabelle servers.
480""",
481          "L:" -> (arg => log_file = Some(Path.explode(File.standard_path(arg)))),
482          "c" -> (_ => console = true),
483          "l" -> (_ => operation_list = true),
484          "n:" -> (arg => name = arg),
485          "p:" -> (arg => port = Value.Int.parse(arg)),
486          "s" -> (_ => existing_server = true),
487          "x" -> (_ => operation_exit = true))
488
489      val more_args = getopts(args)
490      if (more_args.nonEmpty) getopts.usage()
491
492      if (operation_list) {
493        for {
494          server_info <- using(SQLite.open_database(Data.database))(list(_))
495          if server_info.active
496        } Output.writeln(server_info.toString, stdout = true)
497      }
498      else if (operation_exit) {
499        val ok = Server.exit(name)
500        sys.exit(if (ok) 0 else 2)
501      }
502      else {
503        val log = Logger.make(log_file)
504        val (server_info, server) =
505          init(name, port = port, existing_server = existing_server, log = log)
506        Output.writeln(server_info.toString, stdout = true)
507        if (console) {
508          using(server_info.connection())(connection => connection.tty_loop().join)
509        }
510        server.foreach(_.join)
511      }
512    })
513}
514
515class Server private(_port: Int, val log: Logger)
516{
517  server =>
518
519  private val server_socket = new ServerSocket(_port, 50, Server.localhost)
520
521  private val _sessions = Synchronized(Map.empty[UUID.T, Headless.Session])
522  def err_session(id: UUID.T): Nothing = error("No session " + Library.single_quote(id.toString))
523  def the_session(id: UUID.T): Headless.Session = _sessions.value.getOrElse(id, err_session(id))
524  def add_session(entry: (UUID.T, Headless.Session)) { _sessions.change(_ + entry) }
525  def remove_session(id: UUID.T): Headless.Session =
526  {
527    _sessions.change_result(sessions =>
528      sessions.get(id) match {
529        case Some(session) => (session, sessions - id)
530        case None => err_session(id)
531      })
532  }
533
534  def shutdown()
535  {
536    server_socket.close
537
538    val sessions = _sessions.change_result(sessions => (sessions, Map.empty))
539    for ((_, session) <- sessions) {
540      try {
541        val result = session.stop()
542        if (!result.ok) log("Session shutdown failed: return code " + result.rc)
543      }
544      catch { case ERROR(msg) => log("Session shutdown failed: " + msg) }
545    }
546  }
547
548  def port: Int = server_socket.getLocalPort
549  val password: String = UUID.random_string()
550
551  override def toString: String = Server.print(port, password)
552
553  private def handle(connection: Server.Connection)
554  {
555    using(new Server.Context(server, connection))(context =>
556    {
557      if (connection.read_password(password)) {
558        connection.reply_ok(
559          JSON.Object(
560            "isabelle_id" -> Isabelle_System.isabelle_id(),
561            "isabelle_version" -> Distribution.version))
562
563        var finished = false
564        while (!finished) {
565          connection.read_message() match {
566            case None => finished = true
567            case Some("") => context.notify("Command 'help' provides list of commands")
568            case Some(msg) =>
569              val (name, argument) = Server.Argument.split(msg)
570              name match {
571                case Server.Command(cmd) =>
572                  argument match {
573                    case Server.Argument(arg) =>
574                      if (cmd.isDefinedAt((context, arg))) {
575                        Exn.capture { cmd((context, arg)) } match {
576                          case Exn.Res(task: Server.Task) =>
577                            connection.reply_ok(JSON.Object(task.ident))
578                            task.start
579                          case Exn.Res(res) => connection.reply_ok(res)
580                          case Exn.Exn(exn) =>
581                            val err = Server.json_error(exn)
582                            if (err.isEmpty) throw exn else connection.reply_error(err)
583                        }
584                      }
585                      else {
586                        connection.reply_error_message(
587                          "Bad argument for command " + Library.single_quote(name),
588                          "argument" -> argument)
589                      }
590                    case _ =>
591                      connection.reply_error_message(
592                        "Malformed argument for command " + Library.single_quote(name),
593                        "argument" -> argument)
594                  }
595                case _ => connection.reply_error("Bad command " + Library.single_quote(name))
596              }
597          }
598        }
599      }
600    })
601  }
602
603  private lazy val server_thread: Thread =
604    Standard_Thread.fork("server") {
605      var finished = false
606      while (!finished) {
607        Exn.capture(server_socket.accept) match {
608          case Exn.Res(socket) =>
609            Standard_Thread.fork("server_connection")
610              { using(Server.Connection(socket))(handle(_)) }
611          case Exn.Exn(_) => finished = true
612        }
613      }
614    }
615
616  def start { server_thread }
617
618  def join { server_thread.join; shutdown() }
619}
620