1(* ========================================================================== *)
2(* FILE          : tttSyntEval.sml                                            *)
3(* DESCRIPTION   : Synthesis of terms for conjecturing lemmas                 *)
4(* AUTHOR        : (c) Thibault Gauthier, University of Innsbruck             *)
5(* DATE          : 2018                                                       *)
6(* ========================================================================== *)
7
8structure tttSyntEval :> tttSyntEval =
9struct
10
11open HolKernel boolLib Abbrev tttTools tttSynt tttPredict
12  tttExec tttTermData
13
14val ERR = mk_HOL_ERR "tttSyntEval"
15
16type idict_t = (int * (term * (string, term) Redblackmap.dict))
17
18val provers_dir = HOLDIR ^ "/src/holyhammer/provers"
19
20(* --------------------------------------------------------------------------
21   Globals
22   -------------------------------------------------------------------------- *)
23
24val nb_premises = ref 128
25
26(* --------------------------------------------------------------------------
27   Features of term
28   -------------------------------------------------------------------------- *)
29
30val fea_cache = ref (dempty Term.compare)
31
32fun fea_of_term_cached tm =
33  dfind tm (!fea_cache) handle NotFound =>
34  let val fea = tttFeature.fea_of_goal ([],tm) in
35    fea_cache := dadd tm fea (!fea_cache);
36    fea
37  end
38
39fun assoc_fea tml = map (fn x => (x, fea_of_term_cached x)) tml
40
41(* --------------------------------------------------------------------------
42   Initialization
43   -------------------------------------------------------------------------- *)
44
45datatype role = Axiom | Theorem | Conjecture
46
47val id_compare = list_compare Int.compare
48
49fun init_dicts_thm (order_dict,term_dict,role_dict) nthy role (name,thm) =
50  let
51    val nthm = depnumber_of_thm thm
52    val thml = CONJUNCTS (SPEC_ALL thm)
53    val nthml = number_list 0 thml
54    fun insert_thm (nconj,x) =
55      let
56        val tm = (concl o GEN_ALL o DISCH_ALL) x
57        val id = [nthy,nthm,nconj]
58      in
59        if dmem tm (!term_dict) then () else
60        (
61        order_dict := dadd id tm (!order_dict);
62        term_dict  := dadd tm id (!term_dict);
63        role_dict  := dadd tm role (!role_dict)
64        )
65      end
66  in
67    app insert_thm nthml
68  end
69
70fun init_dicts_thy state (nthy,thy) =
71  let fun f role (name,thm) = init_dicts_thm state nthy role (name,thm) in
72    app (f Theorem) (DB.theorems thy);
73    app (f Axiom) (DB.axioms thy @ DB.definitions thy)
74  end
75
76fun init_dicts n =
77  let
78    val _ = clean_dir (!ttt_synt_dir)
79    val order_dict = ref (dempty id_compare)
80    val term_dict  = ref (dempty Term.compare)
81    val role_dict  = ref (dempty Term.compare)
82    val thyl0 = first_n n (sort_thyl (ancestry (current_theory ())))
83    val thyl1 = number_list 0 thyl0
84    val state = (order_dict,term_dict,role_dict)
85  in
86    app (init_dicts_thy state) thyl1;
87    (!order_dict, !term_dict, !role_dict)
88  end
89
90(* --------------------------------------------------------------------------
91   Inserting conjectures in dictionnaries
92   -------------------------------------------------------------------------- *)
93
94fun after_id_aux odict n id =
95  if dmem (id @ [n]) odict
96  then after_id_aux odict (n + 1) id
97  else id @ [n]
98
99fun after_id order_dict id = after_id_aux order_dict 0 id
100
101fun insert_cj (order_dict,term_dict,role_dict) (tm,lemmas) =
102  if dmem tm (!term_dict) then () else
103  let
104    val idl         = map (fn x => dfind x (!term_dict)) lemmas
105    val lastid      = last (dict_sort id_compare idl)
106    val id          = after_id (!order_dict) lastid
107  in
108    order_dict := dadd id tm (!order_dict);
109    term_dict  := dadd tm id (!term_dict);
110    role_dict  := dadd tm Conjecture (!role_dict)
111  end
112
113fun insert_cjl (odict,tdict,rdict) cjlp =
114  let
115    val _ = msg_synt cjlp "to be inserted in the dicts"
116    val state as (order_dict, term_dict, role_dict) = (ref odict, ref tdict, ref rdict)
117  in
118    app (insert_cj state) cjlp;
119    (!order_dict, !term_dict, !role_dict)
120  end
121
122(* --------------------------------------------------------------------------
123   Reading the result of launch_eprover_parallel
124   -------------------------------------------------------------------------- *)
125
126fun read_result pid_idict_list pid_result_list =
127  let
128    val pid_result_dict = dnew Int.compare pid_result_list
129    fun read_result_one (pid,(cj,idict)) = case dfind pid pid_result_dict of
130      SOME l => (pid, (cj, SOME (map (fn x => dfind x idict) l)))
131    | NONE   => (pid, (cj, NONE))
132  in
133    map read_result_one pid_idict_list
134  end
135
136fun roleterm_to_string rdict tm =
137  (
138  case dfind tm rdict of
139    Conjecture => "Conjecture"
140  | Theorem    => "Theorem"
141  | Axiom      => "Axiom"
142  )
143  ^ ": " ^ term_to_string tm
144
145fun write_result rdict file l =
146  let
147    val _ = log_synt ("writing result in " ^ file)
148    fun f (pid,(cj,ro)) =
149      "Pid: " ^ int_to_string pid ^ "\n" ^
150      "Target: " ^ term_to_string cj ^ "\n" ^
151      (
152      case ro of
153        SOME l =>
154        "Proof:\n  " ^ String.concatWith "\n  "
155         (map (roleterm_to_string rdict) l) ^ "\n"
156      | NONE => "Failure\n"
157      )
158  in
159    writel file (map f l)
160  end
161
162fun is_nontrivial x = case x of
163    SOME [] => false
164  | NONE    => false
165  | _        => true
166
167(* --------------------------------------------------------------------------
168   Proving conjectures
169   -------------------------------------------------------------------------- *)
170
171fun prove_predict (symweight,tmfea) cj =
172  (tmknn (!nb_premises) (symweight,tmfea) (fea_of_term_cached cj), cj)
173
174fun prove_write pdir transf exportf tml cjl =
175  let
176    val tmfea = time_synt "generating features" assoc_fea tml
177    val symweight = learn_tfidf tmfea
178    val pbl = time_synt "predict" (map (prove_predict (symweight,tmfea))) cjl
179    val _ = time_synt "translate" (map transf) tml
180  in
181    time_synt "export" (mapi (exportf pdir)) pbl
182  end
183
184fun prove_result rdict pl0 =
185  let
186    val pl1 = filter (isSome o snd o snd) pl0;
187    val _ = msg_synt pl1 "proven conjectures";
188    val pl2 = filter (is_nontrivial o snd o snd) pl0;
189    val _ = msg_synt pl2 "nontrivial conjectures";
190    val _ = write_result rdict (!ttt_synt_dir ^ "/prove_nontrivial") pl2
191  in
192    map (fn (a,b) => (a, valOf b)) (map snd pl2)
193  end
194
195fun prove_main rdict pdir ncores timelimit
196    transf exportf launchf tml cjl =
197  let
198    val _ = cleanDir_rec pdir
199    val _ = msg_synt tml "terms to select from"
200    val pid_idict_list = prove_write pdir transf exportf tml cjl
201    val pidl = map fst pid_idict_list
202    val _    = msg_synt pidl "proving tasks"
203    val pid_result_list = time_synt "launchf"
204      (launchf pdir ncores pidl) timelimit
205    val pl0 = read_result pid_idict_list pid_result_list
206  in
207    prove_result rdict pl0
208  end
209
210(* --------------------------------------------------------------------------
211   Evaluating conjectures
212   -------------------------------------------------------------------------- *)
213
214fun eval_predict (odict,tdict,rdict) tm =
215  let
216    val tml0 = dkeys tdict
217    val tmid = dfind tm tdict
218    fun is_older x = id_compare (dfind x tdict,tmid) = LESS
219    val tml1 = filter is_older tml0
220    val tmfea = assoc_fea tml1
221    val symweight = learn_tfidf tmfea
222    val predl = tmknn (!nb_premises) (symweight,tmfea) (fea_of_term_cached tm)
223    val cjl = filter (fn x => dfind x rdict = Conjecture) predl
224  in
225    SOME (predl,tm)
226  end
227
228fun eval_write pdir dicts transf exportf tml =
229  let
230    val pbl = time_synt "predict" (List.mapPartial (eval_predict dicts)) tml
231    val _ = time_synt "translate" (map transf) tml (* update the cache *)
232  in
233    time_synt "export" (mapi (exportf pdir)) pbl
234  end
235
236fun eval_result rdict el0 =
237  let
238    val el1 = filter (isSome o snd o snd) el0
239    val _ = msg_synt el1 "proven theorems"
240    val el2 = filter (is_nontrivial o snd o snd) el0
241    val _ = msg_synt el2 "nontrivial theorems"
242    val _   = write_result rdict (!ttt_synt_dir ^ "/" ^ "eval_nontrivial") el0
243    fun is_conjecture x = dfind x rdict = Conjecture
244    fun f (i,(a,b)) = (i,(a, filter is_conjecture (valOf b)))
245    val el3 = map f el2
246    val el4 = filter (not o null o snd o snd) el3
247    val _ = msg_synt el4 "theorems proven using at least one conjecture"
248  in
249    (map snd el4, map (fst o snd) el1)
250  end
251
252fun write_usefulcj el =
253  let
254    val ecjl0 = List.concat (map snd el)
255    val ecjl1 =
256      dict_sort compare_imax (dlist (count_dict (dempty Term.compare) ecjl0))
257    fun string_of_ecjl (tm,n) =
258      int_to_string n ^ "\n" ^ term_to_string tm ^ "\n"
259  in
260    writel (!ttt_synt_dir ^ "/useful_conjectures") (map string_of_ecjl ecjl1)
261  end
262
263fun eval_main pdir ncores timelimit (odict,tdict,rdict)
264    transf exportf launchf =
265  let
266    val _ = cleanDir_rec pdir
267    val tml0 = dkeys tdict
268    val tml1 = filter (fn x => dfind x rdict = Theorem) tml0
269    val _ = msg_synt tml1 "theorems to be proven"
270    val pid_idict_list =
271      eval_write pdir (odict,tdict,rdict) transf exportf tml1
272    val pidl = map fst pid_idict_list
273    val pid_result_list = time_synt "launchf"
274      (launchf pdir ncores pidl) timelimit
275    val el0 = read_result pid_idict_list pid_result_list
276    val (el1, proven) = eval_result rdict el0
277    val rate = approx 2 (int_div (length proven) (length tml1) * 100.0)
278    val _ = log_synt (Real.toString rate ^ " success rate")
279    val provendict = count_dict (dempty Term.compare) proven
280    val unproven = filter (fn x => not (dmem x provendict)) tml1
281  in
282    write_usefulcj el1; (el1, proven, unproven)
283  end
284
285end (* struct *)
286
287(*
288load "tttSyntEval"; load "holyHammer";
289open tttPredict tttTools tttSynt tttSyntEval holyHammer;
290
291(* Fixed parameters *)
292val provers_dir = HOLDIR ^ "/src/holyhammer/provers"
293val pdir_eval = provers_dir ^ "/parallel_eval";
294val pdir_prove = provers_dir ^ "/parallel_prove";
295val pdir_baseline = provers_dir ^ "/parallel_baseline";
296val exportf  = export_pb;
297val launchf  = eprover_parallel;
298val transf   = hhTranslate.cached_translate;
299
300(* Parameters *)
301val _ = show_types := true;
302  (* conjecturing *)
303val _ = conjecture_limit := 100000;
304val _ = patsub_flag := false;
305val _ = concept_flag := false;
306val _ = concept_threshold := 4;
307val ncj_max  = 1000;
308  (* proving *)
309val nthy_max = 1000;
310val _ = nb_premises := 128;
311val ncores    = 40;
312val timelimit = 5;
313  (* output *)
314val run_id = "baseline";
315val _ = ttt_synt_dir := tactictoe_dir ^ "/log_synt/" ^ run_id;
316val _ = mkDir_err (tactictoe_dir ^ "/log_synt");
317
318(* Initialization *)
319val dicts_org as (odict_org, tdict_org, rdict_org)  = init_dicts nthy_max;
320val tmlorg = dkeys tdict_org;
321
322(* Baseline *)
323val (eluseful, proven, unproven) =
324  eval_main pdir_eval ncores timelimit dicts_org transf exportf launchf;
325val _ = export_tml (!ttt_synt_dir ^ "/proven_thms") proven;
326val _ = export_tml (!ttt_synt_dir ^ "/unproven_thms") unproven;
327
328(* Generating conjectures *)
329val cjl0 = conjecture tmlorg;
330val cjl1 = first_n ncj_max cjl0;
331
332(* Proving conjectures *)
333val pl = prove_main rdict_org pdir_prove
334  ncores timelimit exportf launchf tmlorg cjl1;
335
336(* Updating the dictionnaries *)
337val dicts_new as (odict_new, tdict_new, rdict_new) = insert_cjl dicts_org pl;
338
339(* Evaluate conjectures *)
340val el = eval_main pdir_eval ncores timelimit dicts_new exportf launchf;
341
342*)
343
344