1(* ========================================================================== *)
2(* FILE          : tttSearch.sml                                              *)
3(* DESCRIPTION   : Search algorithm for TacticToe.                            *)
4(* AUTHOR        : (c) Thibault Gauthier, University of Innsbruck             *)
5(* DATE          : 2017                                                       *)
6(* ========================================================================== *)
7
8structure tttSearch :> tttSearch =
9struct
10
11open HolKernel boolLib Abbrev tttTools tttTimeout tttFeature tttPredict
12tttExec tttLexer tttMinimize tttThmData tttTacticData tttLearn tttSetup
13
14val ERR = mk_HOL_ERR "tttSearch"
15val last_stac = ref ""
16fun debug_err s = (debug ("Error: " ^ s); raise ERR "standard" "error")
17
18(* --------------------------------------------------------------------------
19   Exceptions
20   -------------------------------------------------------------------------- *)
21
22exception SearchTimeOut
23exception NoNextTac
24
25(* --------------------------------------------------------------------------
26   Asynchronous calls to provers
27   -------------------------------------------------------------------------- *)
28
29(* Result *)
30datatype async_result_t =
31  HSuccess of (string * goal) |
32  HFailure |
33  HRunning of Thread.thread |
34  HVoid
35
36(* 100000 is the maximum number of nodes *)
37val hammer_ref = ref 0
38val async_result = Array.array (100000, HVoid)
39val install_async = ref (dempty Int.compare)
40val running_async = ref (dempty Int.compare)
41
42(* Start and end of search *)
43fun init_async () =
44  if !ttt_eprover_flag
45  then
46    (
47    hammer_ref := 0;
48    install_async := dempty Int.compare;
49    running_async := dempty Int.compare;
50    Array.modify (fn _ => HVoid) async_result
51    )
52  else ()
53
54fun terminate_thread pid thread =
55  while (Thread.isActive thread)
56  do (debug_search ("terminate thread " ^ int_to_string pid);
57      Thread.interrupt thread)
58
59fun terminate_async_pid pid =
60  if !ttt_eprover_flag then
61    (
62    Array.update (async_result,pid,HVoid);
63    install_async := drem pid (!install_async);
64    running_async := drem pid (!running_async);
65    if dmem pid (!running_async)
66    then terminate_thread pid (dfind pid (!running_async))
67    else ()
68    )
69  else ()
70
71fun terminate_async () =
72  if !ttt_eprover_flag
73  then app terminate_async_pid (dkeys (!running_async))
74  else ()
75
76fun queue_async pid g =
77  if !ttt_eprover_flag then
78    (
79    terminate_async_pid pid;
80    debug_search ("install thread " ^ int_to_string pid);
81    install_async := dadd pid g (!install_async)
82    )
83  else ()
84
85
86(* -------------------------------------------------------------------------
87   Tell if a node is active or not
88   -------------------------------------------------------------------------- *)
89
90val notactivedict = ref (dempty Int.compare)
91fun is_notactive x = dmem x (!notactivedict)
92fun is_active x = not (is_notactive x)
93
94fun deactivate x =
95  (
96  debug_search ("deactivate " ^ int_to_string x);
97  terminate_async_pid x;
98  notactivedict := dadd x () (!notactivedict)
99  )
100
101(* -------------------------------------------------------------------------
102   Search references
103   -------------------------------------------------------------------------- *)
104
105val glob_timer = ref NONE
106val proofdict = ref (dempty Int.compare)
107
108(* global values to prevent many arguments in functions *)
109val thmpredictor_glob = ref (fn _ => (fn _ => []))
110val tacpredictor_glob = ref (fn _ => [])
111val glpredictor_glob = ref (fn _ => 0.0)
112val hammer_glob = ref (fn _ => (fn _ => NONE))
113
114(* --------------------------------------------------------------------------
115   Caching tactic applications on goals
116   -------------------------------------------------------------------------- *)
117
118val stacgoal_cache = ref (dempty (cpl_compare String.compare goal_compare))
119
120(* --------------------------------------------------------------------------
121   Statistics
122   -------------------------------------------------------------------------- *)
123
124val stac_counter = ref 0
125
126fun string_of_pred pred =
127  "[" ^ String.concatWith "," pred ^ "]"
128
129val tactime = ref 0.0
130val thmtime = ref 0.0
131val gltime = ref 0.0
132
133val tactimer = total_time tactime
134val thmtimer = total_time thmtime
135val gltimer = total_time gltime
136
137val inst_time = ref 0.0
138val terminst_time = ref 0.0
139val infstep_time = ref 0.0
140val node_create_time = ref 0.0
141val node_find_time = ref 0.0
142
143val inst_timer = total_time inst_time
144val infstep_timer = total_time infstep_time
145fun node_create_timer f x = total_time node_create_time f x
146val node_find_timer = total_time node_find_time
147
148val tot_time = ref 0.0
149fun total_timer f x = total_time tot_time f x
150
151fun reset_timers () =
152  (
153  tactime := 0.0;
154  thmtime := 0.0;
155  gltime := 0.0;
156  inst_time := 0.0;
157  infstep_time := 0.0;
158  node_create_time := 0.0;
159  node_find_time := 0.0;
160  tot_time := 0.0
161  )
162
163(* --------------------------------------------------------------------------
164   Special tactics
165   -------------------------------------------------------------------------- *)
166
167val metis_spec = "tactictoe_metis"
168val eprover_spec = "tactictoe_eprover"
169
170fun add_eprover pred =
171  if !ttt_eprover_flag then eprover_spec :: pred else pred
172
173fun add_metis pred =
174  if !ttt_metis_flag then metis_spec :: pred else pred
175
176(* --------------------------------------------------------------------------
177   MCTS: Priors
178   -------------------------------------------------------------------------- *)
179
180fun array_to_list a =
181  let fun f (a,l) = a :: l in rev (Array.foldl f [] a) end
182
183fun init_eval pripol pid =
184  let
185    val _ = debug_search "mcts evaluation"
186    val prec = dfind pid (!proofdict)
187    val {visit,pending,goalarr,prioreval,cureval,priorpolicy,...} = prec
188    val eval =
189      if !ttt_mcevnone_flag then 0.0
190      else if !ttt_mcevtriv_flag then 1.0
191      else (!glpredictor_glob) (array_to_list (#goalarr prec))
192  in
193    priorpolicy := pripol;
194    visit := 1.0;
195    prioreval := eval;
196    cureval := [eval]
197  end
198
199(* --------------------------------------------------------------------------
200   MCTS: Backpropagation
201   -------------------------------------------------------------------------- *)
202
203fun backup_loop beval eval cid =
204  let
205    val crec = dfind cid (!proofdict)
206    val {parid,visit,cureval,...} = crec
207  in
208    if beval
209    then cureval := eval :: !cureval
210    else ()
211    ;
212    visit := !visit + 1.0;
213    if parid = NONE then () else backup_loop beval eval (valOf parid)
214  end
215
216fun backup cid =
217  let
218    val _ = debug_search "mcts backpropagation"
219    val crec = dfind cid (!proofdict)
220    val {parid,prioreval,...} = crec
221  in
222    if parid = NONE
223    then ()
224    else backup_loop true (!prioreval) (valOf parid)
225  end
226
227fun backup_fail cid =
228  let
229    val _ = debug_search "backup fail"
230    val crec = dfind cid (!proofdict)
231    val {parid,...} = crec
232  in
233    if parid = NONE
234    then ()
235    else backup_loop (!ttt_mcevfail_flag) 0.0 (valOf parid)
236  end
237
238fun backup_success cid =
239  let
240    val _ = debug_search "backup success"
241    val crec = dfind cid (!proofdict)
242    val {parid,...} = crec
243  in
244    if parid = NONE
245    then ()
246    else backup_loop true 1.0 (valOf parid)
247  end
248
249(* --------------------------------------------------------------------------
250   Node creation and deletion
251   -------------------------------------------------------------------------- *)
252
253val max_depth_mem = ref 0
254val pid_counter = ref 0
255
256fun next_pid () =
257  let
258    val r = !pid_counter
259    val _ = pid_counter := !pid_counter + 1
260  in
261    r
262  end
263
264fun root_create goal pred =
265  let
266    fun init_empty _ = ref []
267    val selfid = next_pid ()
268    val selfrec =
269      {
270      selfid   = selfid,
271      parid    = NONE,
272      parstac  = NONE,
273      pargn    = NONE,
274      parg     = NONE,
275      goalarr  = Array.fromList [goal],
276      predarr  = Array.fromList [pred],
277      depth = 0,
278      (* *)
279      pending  = ref [0],
280      children = ref [],
281      (* proof saved for reconstruction + children *)
282      proofl   = ref [],
283      childrena = Array.fromList (map init_empty [goal]),
284      (* preventing loop and parallel steps *)
285      pardict  = dempty goal_compare,
286      trydict  = ref (dempty (list_compare goal_compare)),
287      (* monte carlo *)
288      priorpolicy = ref 0.0,
289      visit = ref 0.0,
290      prioreval = ref 0.0,
291      cureval = ref []
292      }
293  in
294    debug_search "Root";
295    debug_search ("  goal: " ^
296          String.concatWith "," (map string_of_goal [goal]));
297    debug_search ("  pred: \n  " ^
298       String.concatWith ",\n  " (map (string_of_pred o (first_n 2)) [pred]));
299    proofdict := dadd selfid selfrec (!proofdict);
300    init_eval 0.0 selfid
301  end
302
303fun root_create_wrap g =
304  root_create g ((add_eprover o add_metis o !tacpredictor_glob) g)
305
306fun node_create pripol tactime parid parstac pargn parg goallist
307    predlist pending pardict =
308  let
309    val selfid = next_pid ()
310    fun init_empty _ = ref []
311    val selfrec =
312    {
313      selfid   = selfid,
314      parid    = SOME parid,
315      parstac  = SOME parstac,
316      pargn    = SOME pargn,
317      parg     = SOME parg,
318      goalarr  = Array.fromList goallist,
319      predarr  = Array.fromList predlist,
320      depth    = #depth (dfind parid (!proofdict)) + 1,
321      (* goal considered *)
322      pending  = ref pending,
323      children = ref [],
324      (* proof saved for reconstruction + children *)
325      proofl = ref [],
326      childrena = Array.fromList (map init_empty goallist),
327      (* preventing loop and parallel steps *)
328      pardict  = pardict,
329      trydict  = ref (dempty (list_compare goal_compare)),
330      (* monte carlo: dummy values changed by init_eval *)
331      priorpolicy = ref 0.0,
332      visit = ref 0.0,
333      prioreval = ref 0.0,
334      cureval = ref []
335    }
336    val cdepth = #depth selfrec
337  in
338    if cdepth > !max_depth_mem then max_depth_mem := cdepth else ();
339    debug_search
340       ("Node " ^ int_to_string selfid ^ " " ^ int_to_string parid ^ " " ^
341        Real.toString (! (#priorpolicy selfrec)));
342    debug_search
343       ("  goals: " ^ String.concatWith "," (map string_of_goal goallist));
344    debug_search ("  predictions: " ^
345       String.concatWith ",\n  " (map (string_of_pred o (first_n 2)) predlist));
346    proofdict := dadd selfid selfrec (!proofdict);
347    init_eval pripol selfid;
348    selfid
349  end
350
351fun node_delete pid =
352  (debug_search ("node_delete " ^ int_to_string pid); deactivate pid)
353
354(* --------------------------------------------------------------------------
355   Change the name of the tactic that has been applied
356   -------------------------------------------------------------------------- *)
357
358fun update_curstac newstac pid =
359  let
360    val prec = dfind pid (!proofdict)
361    val gn = hd (!(#pending prec))
362    val pred = Array.sub (#predarr prec, gn)
363    val newpred = newstac :: tl pred
364  in
365    Array.update (#predarr prec, gn, newpred)
366  end
367  handle _ => debug_err ("update_curstac :" ^ newstac)
368
369(* --------------------------------------------------------------------------
370   Trying multiple terms.
371   -------------------------------------------------------------------------- *)
372
373fun try_nqtm pid n (stac,tac) (otm,qtac) g =
374  let
375    val glo = SOME (fst (tac g)) handle _ => NONE
376    fun locprod x = case x of SOME gl => not (mem g gl) | NONE => false
377  in
378    if locprod glo then glo else
379      let fun loop qtac tml =
380        case tml of [] => NONE | tm :: m =>
381        let val glo' = SOME (fst (qtac [ANTIQUOTE tm] g)) handle _ => NONE in
382          if locprod glo'
383          then
384            let val newstac = inst_timer (inst_termarg stac) tm in
385              update_curstac newstac pid; glo'
386            end
387          else loop qtac m
388        end
389      in
390        loop qtac (termknn n g otm)
391      end
392  end
393
394(* --------------------------------------------------------------------------
395   Transfomring code into a tactic. Doing necessary predictions.
396   -------------------------------------------------------------------------- *)
397
398val thml_dict = ref (dempty (cpl_compare goal_compare Int.compare))
399val inst_dict = ref (dempty (cpl_compare String.compare goal_compare))
400val tac_dict = ref (dempty String.compare)
401
402fun pred_sthml thmpredictor thml_dict n g =
403  dfind (g,n) (!thml_dict) handle NotFound =>
404    let val sl = thmpredictor n g in
405      thml_dict := dadd (g,n) sl (!thml_dict);
406      sl
407    end
408
409fun stac_to_tac thmpred (tac_dict,inst_dict,thml_dict) stac g =
410  (
411  if !ttt_thmlarg_flag andalso is_absarg_stac stac then
412    (
413    dfind (stac,g) (!inst_dict) handle NotFound =>
414    let
415      val _ = debug_search ("instantiating: " ^ stac)
416      val sl =
417        if !ttt_thmlarg_flag
418        then pred_sthml thmpred thml_dict (!ttt_thmlarg_radius) g
419        else []
420      val thmls = String.concatWith " , " (map dbfetch_of_string sl)
421      val newstac = inst_stac thmls g stac
422      val newtac = tactic_of_sml newstac
423        handle _ =>
424        (debug ("Warning: stac_to_tac: " ^ newstac); raise ERR "stac_to_tac" "")
425    in
426      inst_dict := dadd (stac,g) (newstac,newtac, !ttt_tactic_time) (!inst_dict);
427      debug_search ("to: " ^ newstac);
428      (newstac, newtac, !ttt_tactic_time)
429    end
430    )
431  else if stac = metis_spec then
432    (
433    dfind (stac,g) (!inst_dict) handle NotFound =>
434    let
435      val sl = pred_sthml thmpred thml_dict (!ttt_metis_radius) g
436      val newstac = mk_metis_call sl
437      val newtac = tactic_of_sml newstac
438    in
439      inst_dict := dadd (stac,g) (newstac,newtac,!ttt_metis_time) (!inst_dict);
440      debug_search ("to: " ^ newstac);
441      (newstac,newtac,!ttt_metis_time)
442    end
443    )
444  else
445    let fun find_stac stac =
446      dfind stac (!tac_dict) handle NotFound =>
447        let val tac = tactic_of_sml stac in
448          tac_dict := dadd stac tac (!tac_dict);
449          tac
450        end
451    in
452      (stac, find_stac stac, !ttt_tactic_time)
453    end
454  )
455  handle _ =>
456    (debug ("Warning: stac_to_tac: " ^ stac);
457     ("Tactical.NO_TAC", NO_TAC, !ttt_tactic_time))
458
459(* --------------------------------------------------------------------------
460   Application of a tactic.
461   -------------------------------------------------------------------------- *)
462
463fun glob_productive pardict trydict g glo =
464  case glo of
465    NONE => NONE
466  | SOME gl =>
467    (
468    if mem g gl orelse exists (fn x => dmem x pardict) gl orelse dmem gl trydict
469    then NONE
470    else SOME gl
471    )
472
473fun apply_stac pid pardict trydict stac g =
474  let
475    val _ = last_stac := stac
476    val _ = stac_counter := !stac_counter + 1
477    (* instantiation of theorems and reading *)
478    val (newstac,newtac,tim) =
479      stac_to_tac (!thmpredictor_glob) (tac_dict,inst_dict,thml_dict) stac g
480    val _ = update_curstac newstac pid
481    (* execution *)
482    val glo = dfind (newstac,g) (!stacgoal_cache) handle NotFound =>
483      let val cpo = if !ttt_termarg_flag then abs_termarg newstac else NONE in
484        case cpo of
485          NONE => app_tac tim newtac g
486        | SOME (otm,qtac) =>
487        (* instantiations of terms *)
488          let
489            val etac =
490              try_nqtm pid (!ttt_termarg_radius) (newstac,newtac) (otm,qtac)
491            val glo =  app_qtac tim etac g
492          in
493            glo
494          end
495      end
496    (* updating  *)
497
498    (* testing for loops *)
499    val newglo = glob_productive pardict trydict g glo
500  in
501    stacgoal_cache := dadd (newstac,g) glo (!stacgoal_cache);
502    newglo
503  end
504
505fun apply_next_stac pid =
506  let
507    val _ = debug_search "apply_next_stac"
508    val prec = dfind pid (!proofdict)
509    val gn = hd (! (#pending prec))
510      handle _ => debug_err "apply_next_stac: empty pending"
511    val g = Array.sub (#goalarr prec, gn)
512    val pred = Array.sub (#predarr prec, gn)
513    val trydict = !(#trydict prec)
514    val pardict = (#pardict prec)
515    val stac = hd pred
516      handle _ => debug_err "apply_next_stac: empty pred"
517  in
518    if stac = eprover_spec
519      then (queue_async pid g; NONE)
520      else infstep_timer (apply_stac pid pardict trydict stac) g
521  end
522
523(* ----------------------------------------------------------------------
524   Searching for a node (goal list) to explore.
525   ---------------------------------------------------------------------- *)
526
527fun has_empty_pred pid =
528  let
529    val prec = dfind pid (!proofdict)
530    val gn = hd (!(#pending prec))
531    val pred = Array.sub (#predarr prec, gn)
532      handle _ => debug_err ("find_next_tac: " ^ int_to_string pid)
533  in
534    if null pred then (deactivate pid; true) else false
535  end
536
537fun mc_node_find pid =
538  if Timer.checkRealTimer (valOf (!glob_timer)) > (!ttt_search_time)
539  then (debug "Warning: mc_node_find: loop"; raise SearchTimeOut)
540  else
541    let
542      val prec = dfind pid (!proofdict)
543      val {children,visit,...} = prec
544      val pvisit = !(#visit prec)
545      val pdenom = Math.sqrt pvisit
546      (* try new tactic on the node itself *)
547      val n = length (!children)
548      val self_pripol =
549        Math.pow (1.0 - !ttt_mcpol_coeff, Real.fromInt n) * !ttt_mcpol_coeff
550      val self_curpol = 1.0 / pdenom
551      val self_selsc = (pid, (!ttt_mcev_coeff) * (self_pripol / self_curpol))
552      (* or explore deeper existing paritial proofs *)
553      fun f cid =
554        let
555          val crec = dfind cid (!proofdict)
556          val pripol = !(#priorpolicy crec)
557          val meaneval = average_real (!(#cureval crec))
558          val visit = !(#visit crec)
559          val curpol = (visit + 1.0) / pdenom
560        in
561          (cid, meaneval + (!ttt_mcev_coeff) * (pripol / curpol))
562        end
563      (* sort and select node with best selection score *)
564      val l0 = self_selsc :: List.map f (!children)
565      val l1 = dict_sort compare_rmax l0
566      val (selid,_) = hd l1
567    in
568      if pid = selid then (pid,self_pripol) else mc_node_find selid
569    end
570
571fun try_mc_find () =
572  if Timer.checkRealTimer (valOf (!glob_timer)) > (!ttt_search_time)
573  then (debug "Warning: try_mc_find"; raise SearchTimeOut)
574  else
575    let
576      val _ = debug_search "mc_node_find"
577      val (pid,pripol) = mc_node_find 0
578    in
579      if is_notactive pid
580      then (backup_fail pid; try_mc_find ())
581      else (debug_search ("Find " ^ int_to_string pid); (pid,pripol))
582    end
583
584(* ---------------------------------------------------------------------------
585   Closing proofs
586   -------------------------------------------------------------------------- *)
587
588fun children_of pid =
589  let val prec = dfind pid (!proofdict) in !(#children prec) end
590
591fun descendant_of pid =
592  let val cidl = children_of pid in
593    cidl @ List.concat (map descendant_of cidl)
594  end
595
596fun close_descendant pid = app node_delete (descendant_of pid)
597
598exception ProofFound
599
600fun close_proof cid pid =
601  let
602    val crec = dfind cid (!proofdict)
603    val prec = dfind pid (!proofdict)
604    val {pargn = gn, parstac = stac,...} = crec
605    val {proofl,pending,parid,children,visit,trydict,priorpolicy,...} = prec
606  in
607    (* checking some assertions *)
608    if !pending <> [] then () else debug_err "close_proof: pending";
609    if valOf gn = hd (!pending) then () else debug_err "close_proof";
610    (* remember which child gave the proof of which goal *)
611    proofl := (valOf gn, valOf stac, cid) :: !proofl;
612    (* close all current  children *)
613    close_descendant pid;
614    (* switching to next pending goal, erasing previous statistics *)
615    children := [];
616    trydict := dempty (list_compare goal_compare);
617    pending := tl (!pending);
618    (* optional reinitialization of the evaluation function *)
619    if !ttt_mcevinit_flag then init_eval (!priorpolicy) pid else ();
620    (* check if the goal was solved and recursively close *)
621    if null (!pending)
622    then
623      if parid = NONE (* special case when it's root *)
624      then (debug "proof found"; node_delete pid; raise ProofFound)
625      else close_proof pid (valOf parid)
626    else ()
627  end
628
629(* --------------------------------------------------------------------------
630   Creating new nodes
631   -------------------------------------------------------------------------- *)
632
633fun node_create_gl pripol tactime gl pid =
634  let
635    val prec = dfind pid (!proofdict)
636    val gn = hd (! (#pending prec))
637    val goal = Array.sub (#goalarr prec, gn)
638    val prev_predl = Array.sub (#predarr prec, gn)
639    val stac = hd prev_predl
640    val parchildren = #children prec
641    val parchildrensave = Array.sub (#childrena prec,gn)
642    val depth = #depth prec + 1
643    val predlist = map (add_eprover o add_metis o !tacpredictor_glob) gl
644    val pending = rev (map fst (number_list 0 predlist))
645    (* Updating list of parents *)
646    val new_pardict = dadd goal () (#pardict prec)
647    (* New node *)
648    val selfid =
649      node_create pripol
650        tactime pid stac gn goal gl predlist pending new_pardict
651  in
652    parchildren := selfid :: (!parchildren);
653    parchildrensave := selfid :: (!parchildrensave);
654    selfid
655  end
656
657(* fake a node when a proof is found but no search is performed on this node *)
658fun node_create_empty staco tactime pid =
659  let
660    val prec = dfind pid (!proofdict)
661    val gn   = hd (! (#pending prec))
662    val goal = Array.sub (#goalarr prec, gn)
663    val pred = Array.sub (#predarr prec, gn)
664    val stac =
665      case staco of
666        NONE => hd pred
667      | SOME s => s
668    val parchildren = #children prec
669    val parchildrensave = Array.sub (#childrena prec,gn)
670    val selfid = node_create 0.0 tactime pid stac gn goal [] [] []
671                   (dempty goal_compare)
672  in
673    parchildren := selfid :: (!parchildren);
674    parchildrensave := selfid :: (!parchildrensave);
675    selfid
676  end
677
678(* pid should be active and the goal should match *)
679fun close_proof_wrap staco tactime pid =
680  let val cid = node_create_timer (node_create_empty staco tactime) pid in
681    backup cid;
682    close_proof cid pid
683  end
684
685
686(* --------------------------------------------------------------------------
687   Handling asynchronously calls
688   -------------------------------------------------------------------------- *)
689
690fun current_goal pid =
691  let
692    val prec = dfind pid (!proofdict)
693    val gn   = hd (!(#pending prec))
694  in
695    Array.sub (#goalarr prec, gn)
696  end
697
698(* Opening a thread *)
699fun hammer_call pid g =
700  (
701  case !hammer_glob (!hammer_ref) g of
702    NONE      => Array.update (async_result,pid,HFailure)
703  | SOME stac => Array.update (async_result,pid,HSuccess (stac,g))
704  )
705  handle _ => Array.update (async_result,pid,HFailure)
706(* add a debug message here *)
707
708fun fork_hammer () =
709  if null (dkeys (!install_async)) then () else
710  let
711    val pid = hd (dkeys (!install_async))
712    val _ = install_async := drem pid (!install_async)
713    val _ = incr hammer_ref
714    val _ = debug_search ("new thread " ^ int_to_string pid)
715    val file = ttt_code_dir ^ "/hammer" ^ int_to_string (!hammer_ref)
716    val thread =
717      Thread.fork (fn () => hammer_call pid (current_goal pid), [])
718  in
719    running_async := dadd pid thread (!running_async);
720    Array.update (async_result,pid,HRunning thread)
721  end
722
723fun open_async () =
724  if dlength (!running_async) < !ttt_eprover_async
725  then
726    let
727      val n = dlength (!running_async)
728      val m = length (filter (Thread.isActive o snd)
729        (dlist (!running_async)))
730    in
731      debug_search (int_to_string n ^ " running thread");
732      debug_search (int_to_string m ^ " active thread");
733      fork_hammer ()
734    end
735  else ()
736
737(* Closing all successfull threads in increasing order of pid *)
738
739fun close_async () =
740  let
741    val pidl = dkeys (!running_async)
742    fun f pid = case Array.sub (async_result,pid) of
743      HSuccess(stac,g) =>
744      (
745      debug_search ("success thread " ^ int_to_string pid);
746      running_async := drem pid (!running_async);
747      Array.update (async_result,pid,HVoid);
748      if is_active pid andalso current_goal pid = g
749        then close_proof_wrap (SOME stac) 0.0 pid
750        else ()
751      )
752    | HFailure =>
753      (
754      debug_search ("failure thread " ^ int_to_string pid);
755      Array.update (async_result,pid,HVoid);
756      running_async := drem pid (!running_async)
757      )
758    | _ => ()
759  in
760    app f pidl
761  end
762
763(* ---------------------------------------------------------------------------
764   Search function. Modifies the proof state.
765   -------------------------------------------------------------------------- *)
766
767fun init_search thmpred tacpred glpred hammer g =
768  (
769  (* async *)
770  init_async ();
771  (* global time-out *)
772  glob_timer := SOME (Timer.startRealTimer ());
773  (* caching *)
774  stacgoal_cache := dempty (cpl_compare String.compare goal_compare);
775  thml_dict := dempty (cpl_compare goal_compare Int.compare);
776  inst_dict := dempty (cpl_compare String.compare goal_compare);
777  tac_dict := dempty String.compare;
778  (* proof states *)
779  pid_counter := 0;
780  notactivedict := dempty Int.compare;
781  proofdict := dempty Int.compare;
782  (* easier access to values *)
783  tacpredictor_glob := tactimer tacpred;
784  thmpredictor_glob := thmtimer thmpred;
785  glpredictor_glob  := gltimer glpred;
786  hammer_glob := hammer;
787  (* statistics *)
788  reset_timers ();
789  stac_counter := 0;
790  max_depth_mem := 0
791  )
792
793fun get_next_pred pid =
794  let
795    val _ = debug_search "get_next_pred"
796    val prec = dfind pid (!proofdict)
797  in
798    if null (!(#pending prec)) then () else
799      let
800        val gn   = hd (!(#pending prec))
801        val pred = Array.sub (#predarr prec, gn)
802      in
803        if null pred orelse null (tl pred)
804          then deactivate pid
805          else Array.update (#predarr prec, gn, tl pred)
806      end
807  end
808
809fun node_find () =
810  let
811    val _ = debug_search "node_find"
812    val l0 = filter (fn x => is_active (fst x)) (dlist (!proofdict))
813    (* also deactivate node with empty predictions *)
814    val l1 = filter (fn x => not (has_empty_pred (fst x))) l0
815    val _ = if !ttt_eprover_flag then (close_async (); open_async ()) else ()
816    val l2 = if !ttt_eprover_flag
817             then filter (fn x => is_active (fst x)) l1
818             else l1
819    val _ = if null l2 then (debug_search "nonexttac"; raise NoNextTac) else ()
820  in
821    try_mc_find ()
822  end
823
824
825fun search_step () =
826  let
827    val (pid,pripol) = node_find_timer node_find ()
828    val prec = dfind pid (!proofdict)
829    val trydict = #trydict prec
830    val (glo,tactime) = add_time apply_next_stac pid
831    fun f0 () = (backup_fail pid; get_next_pred pid)
832    fun f1 gl =
833      if gl = []
834      then
835        (backup_success pid;
836         close_proof_wrap NONE tactime pid)
837      else
838        (
839        trydict := dadd gl () (!trydict);
840        let val cid =
841          node_create_timer (node_create_gl pripol tactime gl) pid
842        in
843          backup cid; get_next_pred pid
844        end
845        )
846  in
847    case glo of
848      NONE    => f0 ()
849    | SOME gl => f1 gl
850  end
851
852datatype proof_status_t =
853  ProofError | ProofSaturated | ProofTimeOut | Proof of string
854
855fun search_loop () =
856  (
857  if Timer.checkRealTimer (valOf (!glob_timer)) > (!ttt_search_time)
858    then ProofTimeOut
859    else (search_step (); debug_search "search step"; search_loop ())
860  )
861  handle NoNextTac => (debug "proof: saturated"; ProofSaturated)
862       | SearchTimeOut => (debug "proof: timeout"; ProofTimeOut)
863       | ProofFound => (debug "proof: found"; Proof "")
864       | e => raise e
865
866fun proofl_of pid =
867  let
868    val prec = dfind pid (!proofdict) handle _ => debug_err "proofl_of"
869    fun compare_gn ((gn1,_,_),(gn2,_,_)) = Int.compare (gn1,gn2)
870    val proofl = !(#proofl prec)
871    val new_proofl = dict_sort compare_gn proofl
872    fun f (gn,stac,cid) =
873      let
874        val g = Array.sub (#goalarr prec, gn)
875        val contl = proofl_of cid
876        val tac = Tactic (stac,g)
877      in
878        if null contl then tac
879        else if List.length contl = 1 then Then (tac, hd contl)
880        else Thenl (tac, contl)
881      end
882  in
883    map f new_proofl
884  end
885
886fun end_search () =
887  (
888  debug_proof ("Statistics");
889  debug_proof ("  infstep : " ^ int_to_string (!stac_counter));
890  debug_proof ("  nodes   : " ^ int_to_string (!pid_counter));
891  debug_proof ("  maxdepth: " ^ int_to_string (!max_depth_mem));
892  debug_proof ("Time: " ^ Real.toString (!tot_time));
893  debug_proof ("  inferstep: " ^ Real.toString (!infstep_time));
894  debug_proof ("  node_find: " ^ Real.toString (!node_find_time));
895  debug_proof ("  node_crea: " ^ Real.toString (!node_create_time));
896  debug_proof ("  thminst  : " ^ Real.toString (!inst_time));
897  debug_proof ("  tacpred  : " ^ Real.toString (!tactime));
898  debug_proof ("  thmpred  : " ^ Real.toString (!thmtime));
899  debug_proof ("  glpred   : " ^ Real.toString (!gltime));
900  proofdict      := dempty Int.compare;
901  tac_dict       := dempty String.compare;
902  inst_dict      := dempty (cpl_compare String.compare goal_compare);
903  stacgoal_cache := dempty (cpl_compare String.compare goal_compare)
904  )
905
906(* ---------------------------------------------------------------------------
907   Self learning
908   -------------------------------------------------------------------------- *)
909
910fun selflearn_aux proof = case proof of
911    Tactic (stac,g) =>
912      (
913      let
914        val ((gl,_),t) = add_time (tactic_of_sml stac) g
915        val lbl = (stac,t,g,gl)
916      in
917        update_tacdata lbl
918      end
919      handle _ => debug_search ("Error: selflearn: " ^ stac)
920      )
921  | Then (p1,p2) => (selflearn_aux p1; selflearn_aux p2)
922  | Thenl (p,pl) => (selflearn_aux p; app selflearn_aux pl)
923
924fun string_of_proof proof = case proof of
925    Tactic (stac,g) => stac
926  | Then (p1,p2) => string_of_proof p1 ^ " THEN " ^ string_of_proof p2
927  | Thenl (p,pl) => string_of_proof p ^ " THENL " ^
928     String.concatWith " " (map string_of_proof pl)
929
930fun selflearn proof =
931  if !ttt_selflearn_flag
932  then debug_t "selflearn" selflearn_aux proof
933  else ()
934
935(* ---------------------------------------------------------------------------
936   Main
937   -------------------------------------------------------------------------- *)
938
939fun search thmpred tacpred glpred hammer goal =
940  (
941  init_search thmpred tacpred glpred hammer goal;
942  total_timer (node_create_timer root_create_wrap) goal;
943  let
944    val r = total_timer search_loop ()
945    val _ = debug_search "End search loop"
946    val _ = terminate_async ()
947    val _ = debug_search "After termination"
948    val proof_status = case r of
949      Proof _  =>
950      let
951        val proofl = proofl_of 0 handle _ => debug_err "SNH0"
952        val proof0 = hd proofl handle Empty => debug_err "SNH1"
953        val _ = selflearn proof0
954        val proof1 = debug_t "minimize" minimize_proof proof0
955        val sproof = debug_t "reconstruct" reconstruct goal proof1
956      in
957        Proof sproof
958      end
959    | _ => r
960  in
961    end_search ();
962    proof_status
963  end
964  )
965
966end (* struct *)
967