1(*  Title:      Pure/Concurrent/future.ML
2    Author:     Makarius
3
4Value-oriented parallel execution via futures and promises.
5*)
6
7signature FUTURE =
8sig
9  type task = Task_Queue.task
10  type group = Task_Queue.group
11  val new_group: group option -> group
12  val worker_task: unit -> task option
13  val worker_group: unit -> group option
14  val the_worker_group: unit -> group
15  val worker_subgroup: unit -> group
16  type 'a future
17  val task_of: 'a future -> task
18  val peek: 'a future -> 'a Exn.result option
19  val is_finished: 'a future -> bool
20  val interruptible_task: ('a -> 'b) -> 'a -> 'b
21  val cancel_group: group -> unit
22  val cancel: 'a future -> unit
23  type params = {name: string, group: group option, deps: task list, pri: int, interrupts: bool}
24  val default_params: params
25  val forks: params -> (unit -> 'a) list -> 'a future list
26  val fork: (unit -> 'a) -> 'a future
27  val join_results: 'a future list -> 'a Exn.result list
28  val join_result: 'a future -> 'a Exn.result
29  val joins: 'a future list -> 'a list
30  val join: 'a future -> 'a
31  val forked_results: {name: string, deps: Task_Queue.task list} ->
32    (unit -> 'a) list -> 'a Exn.result list
33  val task_context: string -> group -> ('a -> 'b) -> 'a -> 'b
34  val value_result: 'a Exn.result -> 'a future
35  val value: 'a -> 'a future
36  val cond_forks: params -> (unit -> 'a) list -> 'a future list
37  val map: ('a -> 'b) -> 'a future -> 'b future
38  val promise_name: string -> (unit -> unit) -> 'a future
39  val promise: (unit -> unit) -> 'a future
40  val fulfill_result: 'a future -> 'a Exn.result -> unit
41  val fulfill: 'a future -> 'a -> unit
42  val snapshot: group list -> task list
43  val shutdown: unit -> unit
44end;
45
46structure Future: FUTURE =
47struct
48
49open Portable
50
51(** future values **)
52
53type task = Task_Queue.task;
54type group = Task_Queue.group;
55val new_group = Task_Queue.new_group;
56
57
58(* identifiers *)
59
60local
61  val worker_task_var = Thread_Data.var () : task Thread_Data.var;
62in
63  fun worker_task () = Thread_Data.get worker_task_var;
64  fun setmp_worker_task task f x = Thread_Data.setmp worker_task_var (SOME task) f x;
65end;
66
67val worker_group = Option.map Task_Queue.group_of_task o worker_task;
68
69fun the_worker_group () =
70  (case worker_group () of
71    SOME group => group
72  | NONE => raise Fail "Missing worker thread context");
73
74fun worker_subgroup () = new_group (worker_group ());
75
76fun worker_joining e =
77  (case worker_task () of
78    NONE => e ()
79  | SOME task => Task_Queue.joining task e);
80
81fun worker_waiting deps e =
82  (case worker_task () of
83    NONE => e ()
84  | SOME task => Task_Queue.waiting task deps e);
85
86
87(* datatype future *)
88
89type 'a result = 'a Exn.result Single_Assignment.var;
90
91datatype 'a future =
92  Value of 'a Exn.result |
93  Future of
94   {promised: bool,
95    task: task,
96    result: 'a result};
97
98fun task_of (Value _) = Task_Queue.dummy_task
99  | task_of (Future {task, ...}) = task;
100
101fun peek (Value res) = SOME res
102  | peek (Future {result, ...}) = Single_Assignment.peek result;
103
104fun is_finished x = isSome (peek x);
105
106(** scheduling **)
107
108(* synchronization *)
109
110val scheduler_event = ConditionVar.conditionVar ();
111val work_available = ConditionVar.conditionVar ();
112val work_finished = ConditionVar.conditionVar ();
113
114local
115  val lock = Mutex.mutex ();
116in
117
118fun SYNCHRONIZED name = Multithreading.synchronized name lock;
119
120fun wait cond = (*requires SYNCHRONIZED*)
121  Multithreading.sync_wait NONE cond lock;
122
123fun wait_timeout timeout cond = (*requires SYNCHRONIZED*)
124  Multithreading.sync_wait (SOME (Time.now () + timeout)) cond lock;
125
126fun signal cond = (*requires SYNCHRONIZED*)
127  ConditionVar.signal cond;
128
129fun broadcast cond = (*requires SYNCHRONIZED*)
130  ConditionVar.broadcast cond;
131
132end;
133
134
135(* global state *)
136
137val queue = Unsynchronized.ref Task_Queue.empty;
138val next = Unsynchronized.ref 0;
139val scheduler = Unsynchronized.ref (NONE: Thread.thread option);
140val canceled = Unsynchronized.ref ([]: group list);
141val do_shutdown = Unsynchronized.ref false;
142val max_workers = Unsynchronized.ref 0;
143val max_active = Unsynchronized.ref 0;
144
145val status_ticks = Unsynchronized.ref 0;
146val last_round = Unsynchronized.ref Time.zeroTime;
147val next_round = Time.fromReal 0.05;
148
149datatype worker_state = Working | Waiting | Sleeping;
150val workers = Unsynchronized.ref ([]: (Thread.thread * worker_state Unsynchronized.ref) list);
151
152fun count_workers state = (*requires SYNCHRONIZED*)
153  foldl' (fn (_, state_ref) => fn i => if ! state_ref = state then i + 1 else i)
154         (! workers) 0
155
156
157
158(* cancellation primitives *)
159
160fun cancel_now group = (*requires SYNCHRONIZED*)
161  let
162    val running = Task_Queue.cancel (! queue) group;
163    val _ = running |> List.app (fn thread =>
164      if Standard_Thread.is_self thread then ()
165      else Standard_Thread.interrupt_unsynchronized thread);
166  in running end;
167
168fun cancel_all () = (*requires SYNCHRONIZED*)
169  let
170    val (groups, threads) = Task_Queue.cancel_all (! queue);
171    val _ = List.app Standard_Thread.interrupt_unsynchronized threads;
172  in groups end;
173
174fun cancel_later group = (*requires SYNCHRONIZED*)
175 (Unsynchronized.change canceled (op_insert (curry Task_Queue.eq_group) group);
176  broadcast scheduler_event);
177
178fun interruptible_task f x =
179  Thread_Attributes.with_attributes
180    (if isSome (worker_task ())
181     then Thread_Attributes.private_interrupts
182     else Thread_Attributes.public_interrupts)
183    (fn _ => f x)
184  before Thread_Attributes.expose_interrupt ();
185
186
187(* worker threads *)
188
189fun worker_exec (task, jobs) =
190  let
191    val group = Task_Queue.group_of_task task;
192    val valid = not (Task_Queue.is_canceled group);
193    val ok =
194      Task_Queue.running task (fn () =>
195        setmp_worker_task task (fn () =>
196          foldl' (fn job => fn ok => job valid andalso ok) jobs true) ());
197(*    val _ =
198      if ! Multithreading.trace >= 2 then
199        Output.try_protocol_message (Markup.task_statistics :: Task_Queue.task_statistics task) []
200      else ();
201*)
202    val _ = SYNCHRONIZED "finish" (fn () =>
203      let
204        val maximal = Unsynchronized.change_result queue (Task_Queue.finish task);
205        val test = Exn.capture Thread_Attributes.expose_interrupt ();
206        val _ =
207          if ok andalso not (Exn.is_interrupt_exn test) then ()
208          else if null (cancel_now group) then ()
209          else cancel_later group;
210        val _ = broadcast work_finished;
211        val _ = if maximal then () else signal work_available;
212      in () end);
213  in () end;
214
215fun worker_wait worker_state cond = (*requires SYNCHRONIZED*)
216  (case AList.lookup Thread.equal (! workers) (Thread.self ()) of
217    SOME state => Unsynchronized.setmp state worker_state wait cond
218  | NONE => wait cond);
219
220fun worker_next () = (*requires SYNCHRONIZED*)
221  if length (! workers) > ! max_workers then
222    (Unsynchronized.change workers (AList.delete Thread.equal (Thread.self ()));
223     signal work_available;
224     NONE)
225  else
226    let val urgent_only = count_workers Working > ! max_active in
227      (case Unsynchronized.change_result queue (Task_Queue.dequeue (Thread.self ()) urgent_only) of
228        NONE => (worker_wait Sleeping work_available; worker_next ())
229      | some => (signal work_available; some))
230    end;
231
232fun worker_loop name =
233  (case SYNCHRONIZED name (fn () => worker_next ()) of
234    NONE => ()
235  | SOME work => (worker_exec work; worker_loop name));
236
237val threads_stack_limit = 0.25 (* should be user-config option *)
238fun worker_start name = (*requires SYNCHRONIZED*)
239  let
240    val threads_stack_limit =
241      Real.floor (threads_stack_limit * 1024.0 * 1024.0 * 1024.0);
242    val stack_limit = if threads_stack_limit <= 0 then NONE else SOME threads_stack_limit;
243    val worker =
244      Standard_Thread.fork {name = "worker", stack_limit = stack_limit, interrupts = false}
245        (fn () => worker_loop name);
246  in Unsynchronized.change workers (cons (worker, Unsynchronized.ref Working)) end
247  handle Fail msg => Multithreading.tracing 0 (fn () => "SCHEDULER: " ^ msg);
248
249
250(* scheduler *)
251
252fun scheduler_next () = (*requires SYNCHRONIZED*)
253  let
254    val now = Time.now ();
255    val tick = ! last_round + next_round <= now;
256    val _ = if tick then last_round := now else ();
257
258
259    (* runtime status *)
260
261    val _ =
262      if tick then Unsynchronized.change status_ticks (fn i => i + 1) else ();
263    (*
264    val _ =
265      if tick andalso ! status_ticks mod (if ! Multithreading.trace >= 1 then 2 else 10) = 0
266      then report_status () else ();
267      *)
268
269    val _ =
270      if not tick orelse List.all (Thread.isActive o #1) (! workers) then ()
271      else
272        let
273          val (alive, dead) = List.partition (Thread.isActive o #1) (! workers);
274          val _ = workers := alive;
275        in
276          Multithreading.tracing 0 (fn () =>
277            "SCHEDULER: disposed " ^ Int.toString (length dead) ^
278            " dead worker threads")
279        end;
280
281
282    (* worker pool adjustments *)
283
284    val max_active0 = ! max_active;
285    val max_workers0 = ! max_workers;
286
287    val m =
288      if ! do_shutdown andalso Task_Queue.all_passive (! queue) then 0
289      else Multithreading.max_threads ();
290    val _ = max_active := m;
291    val _ = max_workers := 2 * m;
292
293    val missing = ! max_workers - length (! workers);
294    val _ =
295      if missing > 0 then
296        funpow missing (fn () =>
297          ignore (worker_start
298                    ("worker " ^ Int.toString (Unsynchronized.inc next)))) ()
299      else ();
300
301    val _ =
302      if ! max_active = max_active0 andalso ! max_workers = max_workers0 then ()
303      else signal work_available;
304
305
306    (* canceled groups *)
307
308    val _ =
309      if null (! canceled) then ()
310      else
311       (Multithreading.tracing 1 (fn () =>
312          Int.toString (length (! canceled)) ^ " canceled groups");
313        Unsynchronized.change canceled (filter_out (null o cancel_now));
314        signal work_available);
315
316
317    (* delay loop *)
318
319    val _ = Exn.release (wait_timeout next_round scheduler_event);
320
321
322    (* shutdown *)
323
324    val continue = not (! do_shutdown andalso null (! workers));
325    val _ = if continue then () else ((* report_status (); *)scheduler := NONE)
326
327    val _ = broadcast scheduler_event;
328  in continue end
329  handle exn =>
330    if Exn.is_interrupt exn then
331     (Multithreading.tracing 1 (fn () => "SCHEDULER: Interrupt");
332      List.app cancel_later (cancel_all ());
333      signal work_available; true)
334    else Exn.reraise exn;
335
336fun scheduler_loop () =
337 (while
338    Thread_Attributes.with_attributes
339      (Thread_Attributes.sync_interrupts Thread_Attributes.public_interrupts)
340      (fn _ => SYNCHRONIZED "scheduler" (fn () => scheduler_next ()))
341  do (); last_round := Time.zeroTime);
342
343fun scheduler_active () = (*requires SYNCHRONIZED*)
344  (case ! scheduler of NONE => false | SOME thread => Thread.isActive thread);
345
346fun scheduler_check () = (*requires SYNCHRONIZED*)
347 (do_shutdown := false;
348  if scheduler_active () then ()
349  else
350    scheduler :=
351      SOME (Standard_Thread.fork {name = "scheduler", stack_limit = NONE, interrupts = false}
352        scheduler_loop));
353
354
355
356(** futures **)
357
358(* cancel *)
359
360fun cancel_group_unsynchronized group = (*requires SYNCHRONIZED*)
361  let
362    val _ = if null (cancel_now group) then () else cancel_later group;
363    val _ = signal work_available;
364    val _ = scheduler_check ();
365  in () end;
366
367fun cancel_group group =
368  SYNCHRONIZED "cancel_group" (fn () => cancel_group_unsynchronized group);
369
370fun cancel x = cancel_group (Task_Queue.group_of_task (task_of x));
371
372
373(* results *)
374
375fun assign_result group result res =
376  let
377    val _ = Single_Assignment.assign result res
378      handle exn as Fail _ =>
379        (case Single_Assignment.peek result of
380          SOME (Exn.Exn e) => Exn.reraise (if Exn.is_interrupt e then e else exn)
381        | _ => Exn.reraise exn);
382    val ok =
383      (case valOf (Single_Assignment.peek result) of
384        Exn.Exn exn =>
385          (SYNCHRONIZED "cancel" (fn () => Task_Queue.cancel_group group exn); false)
386      | Exn.Res _ => true);
387  in ok end;
388
389
390(* future jobs *)
391
392fun future_job group atts (e: unit -> 'a) =
393  let
394    val result = Single_Assignment.var "future" : 'a result;
395    fun job ok =
396      let
397        val res =
398          if ok then
399            Exn.capture (fn () =>
400              Thread_Attributes.with_attributes atts (fn _ => e ())) ()
401          else Exn.interrupt_exn;
402      in
403        assign_result group result res
404      end;
405  in (result, job) end;
406
407
408(* fork *)
409
410type params = {name: string, group: group option, deps: task list, pri: int, interrupts: bool};
411val default_params: params = {name = "", group = NONE, deps = [], pri = 0, interrupts = true};
412
413fun forks ({name, group, deps, pri, interrupts}: params) es =
414  if null es then []
415  else
416    let
417      val grp =
418        (case group of
419          NONE => worker_subgroup ()
420        | SOME grp => grp);
421      fun enqueue e queue =
422        let
423          val atts =
424            if interrupts
425            then Thread_Attributes.private_interrupts
426            else Thread_Attributes.no_interrupts;
427          val (result, job) = future_job grp atts e;
428          val (task, queue') = Task_Queue.enqueue name grp deps pri job queue;
429          val future = Future {promised = false, task = task, result = result};
430        in (future, queue') end;
431    in
432      SYNCHRONIZED "enqueue" (fn () =>
433        let
434          val (queue', futures) =
435              foldl_map (fn (q,e) => swap $ enqueue e q) (! queue, es)
436          val _ = queue := queue';
437          val minimal = List.all (not o Task_Queue.known_task queue') deps;
438          val _ = if minimal then signal work_available else ();
439          val _ = scheduler_check ();
440        in futures end)
441    end;
442
443fun fork e =
444  (singleton o forks) {name = "fork", group = NONE, deps = [], pri = 0, interrupts = true} e;
445
446
447(* join *)
448
449fun get_result x =
450  (case peek x of
451    NONE => Exn.Exn (Fail "Unfinished future")
452  | SOME res =>
453      if Exn.is_interrupt_exn res then
454        (case Task_Queue.group_status (Task_Queue.group_of_task (task_of x)) of
455          [] => res
456        | exns => Exn.Exn (Par_Exn.make exns))
457      else res);
458
459local
460
461fun join_next atts deps = (*requires SYNCHRONIZED*)
462  if null deps then NONE
463  else
464    (case Unsynchronized.change_result queue (Task_Queue.dequeue_deps (Thread.self ()) deps) of
465      (NONE, []) => NONE
466    | (NONE, deps') =>
467       (worker_waiting deps' (fn () =>
468          Thread_Attributes.with_attributes atts (fn _ =>
469            Exn.release (worker_wait Waiting work_finished)));
470        join_next atts deps')
471    | (SOME work, deps') => SOME (work, deps'));
472
473fun join_loop atts deps =
474  (case SYNCHRONIZED "join" (fn () => join_next atts deps) of
475    NONE => ()
476  | SOME (work, deps') => (worker_joining (fn () => worker_exec work); join_loop atts deps'));
477
478in
479
480fun join_results xs =
481  let
482    val _ =
483      if List.all is_finished xs then ()
484      else if isSome (worker_task ()) then
485        Thread_Attributes.with_attributes Thread_Attributes.no_interrupts
486          (fn orig_atts => join_loop orig_atts (map task_of xs))
487      else
488        xs |> List.app
489          (fn Value _ => ()
490            | Future {result, ...} => ignore (Single_Assignment.await result));
491  in map get_result xs end;
492
493end;
494
495fun join_result x = singleton join_results x;
496fun joins xs = Par_Exn.release_all (join_results xs);
497fun join x = Exn.release (join_result x);
498
499
500(* forked results: nested parallel evaluation *)
501
502fun forked_results {name, deps} es =
503  Thread_Attributes.uninterruptible (fn restore_attributes => fn () =>
504    let
505      val (group, pri) =
506        (case worker_task () of
507          SOME task =>
508            (new_group (SOME (Task_Queue.group_of_task task)), Task_Queue.pri_of_task task)
509        | NONE => (new_group NONE, 0));
510      val futures =
511        forks {name = name, group = SOME group, deps = deps, pri = pri, interrupts = true} es;
512    in
513      restore_attributes join_results futures
514        handle exn => (if Exn.is_interrupt exn then cancel_group group else (); Exn.reraise exn)
515    end) ();
516
517
518(* task context for running thread *)
519
520fun task_context name group f x =
521  Thread_Attributes.with_attributes Thread_Attributes.no_interrupts (fn orig_atts =>
522    let
523      val (result, job) = future_job group orig_atts (fn () => f x);
524      val task =
525        SYNCHRONIZED "enroll" (fn () =>
526          Unsynchronized.change_result queue (Task_Queue.enroll (Thread.self ()) name group));
527      val _ = worker_exec (task, [job]);
528    in
529      (case Single_Assignment.peek result of
530        NONE => raise Fail "Missing task context result"
531      | SOME res => Exn.release res)
532    end);
533
534
535(* fast-path operations -- bypass task queue if possible *)
536
537fun value_result (res: 'a Exn.result) =
538  let
539    val task = Task_Queue.dummy_task
540    val group = Task_Queue.group_of_task task
541    val result = Single_Assignment.var "value" : 'a result
542    val _ = assign_result group result res
543  in Future {promised = false, task = task, result = result} end;
544
545fun value x = value_result (Exn.Res x);
546
547fun cond_forks args es =
548  if Multithreading.enabled () then forks args es
549  else map (fn e => value_result (Exn.interruptible_capture e ())) es;
550
551fun map_future f x =
552  if is_finished x then value_result (Exn.interruptible_capture (f o join) x)
553  else
554    let
555      val task = task_of x;
556      val group = Task_Queue.group_of_task task;
557      val (result, job) =
558        future_job group Thread_Attributes.private_interrupts (fn () => f (join x));
559
560      val extended = SYNCHRONIZED "extend" (fn () =>
561        (case Task_Queue.extend task job (! queue) of
562          SOME queue' => (queue := queue'; true)
563        | NONE => false));
564    in
565      if extended then Future {promised = false, task = task, result = result}
566      else
567        (singleton o cond_forks)
568          {name = "map_future", group = SOME group, deps = [task],
569            pri = Task_Queue.pri_of_task task, interrupts = true}
570          (fn () => f (join x))
571    end;
572
573
574(* promised futures -- fulfilled by external means *)
575
576fun promise_name name abort : 'a future =
577  let
578    val group = worker_subgroup ();
579    val result = Single_Assignment.var "promise" : 'a result;
580    fun assign () = assign_result group result Exn.interrupt_exn
581      handle Fail _ => true
582        | exn =>
583            if Exn.is_interrupt exn
584            then raise Fail "Concurrent attempt to fulfill promise"
585            else Exn.reraise exn;
586    fun job () =
587      Thread_Attributes.with_attributes Thread_Attributes.no_interrupts
588        (fn _ => Exn.release (Exn.capture assign () before abort ()));
589    val task = SYNCHRONIZED "enqueue_passive" (fn () =>
590      Unsynchronized.change_result queue (Task_Queue.enqueue_passive group name job));
591  in Future {promised = true, task = task, result = result} end;
592
593fun promise abort = promise_name "passive" abort;
594
595fun fulfill_result (Future {promised = true, task, result}) res =
596      let
597        val group = Task_Queue.group_of_task task;
598        fun job ok =
599          assign_result group result (if ok then res else Exn.interrupt_exn)
600        val _ =
601          Thread_Attributes.with_attributes Thread_Attributes.no_interrupts (fn _ =>
602            let
603              val passive_job =
604                SYNCHRONIZED "fulfill_result" (fn () =>
605                  Unsynchronized.change_result queue
606                    (Task_Queue.dequeue_passive (Thread.self ()) task));
607            in
608              (case passive_job of
609                SOME true => worker_exec (task, [job])
610              | SOME false => ()
611              | NONE => ignore (job (not (Task_Queue.is_canceled group))))
612            end);
613        val _ =
614          if isSome (Single_Assignment.peek result) then ()
615          else worker_waiting [task] (fn () => ignore (Single_Assignment.await result));
616      in () end
617  | fulfill_result _ _ = raise Fail "Not a promised future";
618
619fun fulfill x res = fulfill_result x (Exn.Res res);
620
621
622(* snapshot: current tasks of groups *)
623
624fun snapshot groups =
625  SYNCHRONIZED "snapshot" (fn () =>
626    Task_Queue.group_tasks (! queue) groups);
627
628
629(* shutdown *)
630
631fun shutdown () =
632  if isSome (worker_task ()) then
633    raise Fail "Cannot shutdown while running as worker thread"
634  else
635    SYNCHRONIZED "shutdown" (fn () =>
636      while scheduler_active () do
637       (do_shutdown := true;
638        Multithreading.tracing 1 (fn () => "SHUTDOWN: wait");
639        wait scheduler_event));
640
641
642(*final declarations of this structure!*)
643val map = map_future;
644
645end;
646
647(* type 'a future = 'a Future.future; *)
648