1(* ========================================================================== *)
2(* FILE          : tttSynt.sml                                                *)
3(* DESCRIPTION   : Synthesis of terms for conjecturing lemmas                 *)
4(* AUTHOR        : (c) Thibault Gauthier, University of Innsbruck             *)
5(* DATE          : 2018                                                       *)
6(* ========================================================================== *)
7
8structure tttSynt :> tttSynt =
9struct
10
11open HolKernel boolLib Abbrev tttTools
12
13val ERR = mk_HOL_ERR "tttSynt"
14
15(* --------------------------------------------------------------------------
16   Globals
17   -------------------------------------------------------------------------- *)
18
19val conjecture_limit = ref 100000
20val patsub_flag = ref false
21val type_errors = ref 0
22
23(* --------------------------------------------------------------------------
24   Tools
25   -------------------------------------------------------------------------- *)
26
27fun my_gen_all tm = list_mk_forall (free_vars_lr tm, tm)
28
29fun my_gen_all_err tm = SOME (my_gen_all tm)
30  handle HOL_ERR _ => (incr type_errors; NONE)
31
32fun alpha_equal_or_error tm tm' =
33  Term.compare (my_gen_all tm, my_gen_all tm') = EQUAL
34  handle _ => true
35
36fun unvalid_change tm tm' =
37  alpha_equal_or_error tm tm' orelse
38  (type_of tm' <> bool handle HOL_ERR _ => true)
39
40(* --------------------------------------------------------------------------
41   Debugging
42   -------------------------------------------------------------------------- *)
43
44val ttt_synt_dir = ref (tactictoe_dir ^ "/log_synt")
45
46fun log_synt_file file s =
47  append_endline (!ttt_synt_dir ^ "/" ^ file) s
48
49fun log_synt s =
50  (print_endline s; log_synt_file "log_main" s)
51
52fun msg_synt l s =
53  let val s' = int_to_string (length l) ^ " " ^ s in
54    log_synt s'
55  end
56
57fun msgd_synt d s =
58  let val s' = int_to_string (dlength d) ^ " " ^ s in
59    log_synt s'
60  end
61
62fun time_synt s f x =
63  let
64    val _ = log_synt s
65    val (r,t) = add_time f x
66  in
67    log_synt (s ^ ": " ^ Real.toString t);
68    r
69  end
70
71fun writel_synt s sl = writel (!ttt_synt_dir ^ "/" ^ s) sl
72
73(* --------------------------------------------------------------------------
74   Statistics on conjecture generation.
75   -------------------------------------------------------------------------- *)
76
77fun string_of_tml tml =
78  ("  " ^ String.concatWith "\n  " (map term_to_string tml) ^ "\n")
79
80fun string_of_subst sub =
81  let fun f (a,b) = "(" ^ term_to_string a ^ ", " ^ term_to_string b ^ ")" in
82    "[" ^ String.concatWith ", " (map f sub) ^ "]"
83  end
84
85fun write_subdict subdict =
86  let
87    val _ = msgd_synt subdict "writing subdict"
88    val l = dlist subdict
89    fun f (sub, (cjl,score)) =
90      Real.toString score ^ " " ^ int_to_string (length cjl) ^ ": " ^
91      string_of_subst sub
92  in
93    writel_synt "substitutions" (map f l)
94  end
95
96fun write_origdict origdict =
97  let
98    val _ = msgd_synt origdict "writing origdict"
99    val l = dlist origdict
100    fun g (sub,tm) = string_of_subst sub ^ ": " ^ term_to_string tm
101    fun f (cj,subtml) = String.concatWith "\n"
102      (["Conjecture:", term_to_string cj] @ map g subtml)
103  in
104    writel_synt "origdict" (map f l)
105  end
106
107(* --------------------------------------------------------------------------
108   Stateful dictionnaries
109   -------------------------------------------------------------------------- *)
110
111type psubst = (int * int) list
112type tsubst = (term * term) list
113
114(* dictionnary *)
115val cdict_glob = ref (dempty Term.compare)
116val icdict_glob = ref (dempty Int.compare)
117val cdict_loc = ref (dempty Int.compare)
118val cjinfo_glob =ref (dempty Term.compare)
119
120
121fun fconst_glob c =
122  dfind c (!cdict_glob) handle NotFound =>
123  let val cglob = dlength (!cdict_glob) in
124    cdict_glob := dadd c cglob (!cdict_glob);
125    icdict_glob := dadd cglob c (!icdict_glob);
126    cglob
127  end
128
129fun fconst_loc cglob =
130  dfind cglob (!cdict_loc) handle NotFound =>
131  let val cloc = dlength (!cdict_loc) in
132    cdict_loc := dadd cglob cloc (!cdict_loc);
133    cloc
134  end
135
136fun fconst c = fconst_loc (fconst_glob c)
137
138fun init_synt () =
139  (
140  cdict_glob := dempty Term.compare;
141  icdict_glob := dempty Int.compare;
142  cjinfo_glob := dempty Term.compare;
143  type_errors := 0
144  )
145
146(* --------------------------------------------------------------------------
147   Conceptualization
148   -------------------------------------------------------------------------- *)
149
150val concept_threshold = ref 4
151val concept_flag = ref false
152
153fun is_varconst x = is_var x orelse is_const x
154
155fun save_concept d tm =
156  if dmem tm (!d) then () else
157    let val v = mk_var ("C" ^ int_to_string (dlength (!d)), type_of tm) in
158      d := dadd tm v (!d)
159    end
160
161fun concept_selection tml =
162  let
163    fun f x = find_terms (not o is_varconst) x
164    val l0 = List.concat (map f tml)
165    val freq = count_dict (dempty Term.compare) l0
166    val l1 = dlist freq
167    fun above_threshold x = snd x >= !concept_threshold
168    val l2 = filter above_threshold l1
169    val l3 = dict_sort compare_imax l2
170    fun w (x,n) = int_to_string n ^ " :" ^ term_to_string x
171    val _  = writel_synt "concepts" (map w l3)
172    val _  = msg_synt l2 "selected concepts"
173    val d = ref (dempty Term.compare)
174  in
175    app (save_concept d) (map fst l2);
176    (!d)
177  end
178
179fun conceptualize_tm ceptdict tm =
180  let
181    fun is_cept x = dmem x ceptdict
182    val redexl0 = find_terms is_cept tm
183    fun cmp (tm1,tm2) = Int.compare (term_size tm2, term_size tm1)
184    val redexl1 = dict_sort cmp redexl0
185    fun f i tm = {redex = tm, residue = dfind tm ceptdict}
186    val sub = mapi f redexl1
187    val newtm = Term.subst sub tm
188  in
189    if term_eq newtm tm then [tm] else [tm,newtm]
190  end
191
192fun read_cept iceptdict c =
193  let val tm = dfind c (!icdict_glob) in
194    dfind tm iceptdict handle NotFound => tm
195  end
196
197fun read_subst iceptdict sub =
198  let fun f (a,b) = (read_cept iceptdict a, read_cept iceptdict b) in
199    map f sub
200  end
201
202(* --------------------------------------------------------------------------
203   Patterns
204   -------------------------------------------------------------------------- *)
205
206datatype pattern =
207    Pconst of int
208  | Pcomb  of pattern * pattern
209  | Plamb  of pattern * pattern
210
211fun pattern_tm tm =
212  case dest_term tm of
213    VAR _   => Pconst (fconst tm)
214  | CONST _ => Pconst (fconst tm)
215  | COMB(Rator,Rand) => Pcomb (pattern_tm Rator, pattern_tm Rand)
216  | LAMB(Var,Bod)    => Plamb (pattern_tm Var, pattern_tm Bod)
217
218fun patternify_one tm =
219  let
220    val _ = cdict_loc := dempty Int.compare
221    fun cmp (a,b) = Int.compare (snd a, snd b)
222    val p = pattern_tm tm
223    val l1 = dlist (!cdict_loc)
224    val l2 = dict_sort cmp l1
225  in
226    (p, map fst l2)
227  end
228
229fun pattern_compare (p1,p2) = case (p1,p2) of
230    (Pconst i1,Pconst i2) => Int.compare (i1,i2)
231  | (Pconst _,_) => LESS
232  | (_,Pconst _) => GREATER
233  | (Pcomb(a1,b1),Pcomb(a2,b2)) =>
234    cpl_compare pattern_compare pattern_compare ((a1,b1),(a2,b2))
235  | (Pcomb _,_) => LESS
236  | (_,Pcomb _) => GREATER
237  | (Plamb(a1,b1),Plamb(a2,b2)) =>
238    cpl_compare pattern_compare pattern_compare ((a1,b1),(a2,b2))
239
240fun string_of_pattern p = case p of
241    Pconst i => int_to_string i
242  | Pcomb (p1,p2) =>
243    "(" ^ String.concatWith " " ("A" :: map string_of_pattern [p1,p2]) ^ ")"
244  | Plamb (p1,p2) =>
245    "(" ^ String.concatWith " " ("L" :: map string_of_pattern [p1,p2]) ^ ")"
246
247fun write_patceptdict ntot patceptdict =
248  let
249    val _ = log_synt "writing patceptdict"
250    val l0 = dlist patceptdict
251    val l1 = filter (fn (a,b) => length b > 1) l0
252    val l2 = map (fn (a,b) => (a, length b)) l1
253    val r2 = int_div (sum_int (map snd l2)) ntot
254    val l3 = dict_sort compare_imax l2
255    fun w (p,n) = int_to_string n ^ ": " ^ string_of_pattern p
256    val _ = msg_synt l3 "patterns appearing at least twice"
257  in
258    writel_synt "patterns" (map w l3)
259  end
260
261fun write_ceptpatdict iceptdict ceptpatdict =
262  let
263    val _  = log_synt "writing ceptpatdict"
264    val l0 = dlist ceptpatdict
265    val l1 = filter (fn (a,b) => length b > 1) l0
266    val l2 = map (fn (a,b) => (a, length b)) l1
267    val l3 = dict_sort compare_imax l2
268    fun w (cl,n) =
269      int_to_string n ^ ": " ^
270      String.concatWith "\n"
271        (map (term_to_string o read_cept iceptdict) cl)
272    val _ = msg_synt l3 "concept lists appearing at least twice"
273  in
274    writel_synt "concept_lists" (map w l3)
275  end
276
277fun patternify ntot ceptdict iceptdict tml =
278  let
279    val patceptdict = ref (dempty pattern_compare)
280    val ceptpatdict = ref (dempty (list_compare Int.compare))
281    val thmpatdict = ref (dempty Term.compare)
282    val tml1 = mk_fast_set Term.compare tml
283    fun f tm =
284      let
285        val (p,cl) = patternify_one tm
286        val cll = dfind p (!patceptdict) handle NotFound => []
287        val pl  = dfind cl (!ceptpatdict) handle NotFound => []
288      in
289        patceptdict := dadd p (cl :: cll) (!patceptdict);
290        ceptpatdict := dadd cl (p :: pl) (!ceptpatdict);
291        (p,cl)
292      end
293    fun g tm =
294      let
295        val variants =
296          if !concept_flag then conceptualize_tm ceptdict tm else [tm]
297        val patl = map f variants
298      in
299        thmpatdict := dadd tm patl (!thmpatdict)
300      end
301    val _ = app g tml1
302    val _ = msgd_synt (!patceptdict) "patterns"
303    val _ = msgd_synt (!ceptpatdict) "concept lists"
304    val _ = write_patceptdict ntot (!patceptdict)
305    val _ = write_ceptpatdict iceptdict (!ceptpatdict)
306  in
307    (!patceptdict, !ceptpatdict, !thmpatdict)
308  end
309
310fun term_of_pat idict (p,cl) = case p of
311    Pconst i => read_cept idict (List.nth (cl,i))
312  | Pcomb (p1,p2) =>
313    mk_comb (term_of_pat idict (p1,cl), term_of_pat idict (p2,cl))
314  | Plamb (p1,p2) =>
315    mk_abs (term_of_pat idict (p1,cl), term_of_pat idict (p2,cl))
316
317(* --------------------------------------------------------------------------
318   Concept substitutions.
319   -------------------------------------------------------------------------- *)
320
321fun compare_kimin (a,b) = Int.compare (fst a, fst b)
322
323fun norm_sub l =
324  let val l1 = filter (fn (x,y) => x <> y) l in
325    dict_sort compare_kimin l1
326  end
327
328fun pair_sub cll =
329  let
330    val cll' = mk_fast_set (list_compare Int.compare) cll
331    val cpl  = cartesian_product cll' cll'
332    val cpl' = filter (fn (x,y) => x <> y) cpl
333  in
334    map combine cpl'
335  end
336
337fun create_sub iceptdict patceptdict =
338  let
339    fun f (p,cll) = pair_sub cll
340    val l1  = List.concat (map f (dlist patceptdict))
341    val l2  = map norm_sub l1
342    val cmp = list_compare (cpl_compare Int.compare Int.compare)
343    val dfreq = count_dict (dempty cmp) l2
344    val _   = msgd_synt dfreq "concept substitutions"
345    val l3  = dict_sort compare_imax (dlist dfreq)
346  in
347    (map (read_subst iceptdict)) (map fst l3)
348  end
349
350fun unsafe_sub sub tm =
351  let val redreso = List.find (fn (red,res) => red = tm) sub in
352    if isSome redreso then snd (valOf (redreso)) else
353      (
354      case dest_term tm of
355        VAR(Name,Ty)       => tm
356      | CONST{Name,Thy,Ty} => tm
357      | COMB(Rator,Rand)   =>
358        mk_comb (unsafe_sub sub Rator, unsafe_sub sub Rand)
359      | LAMB(Var,Bod)      =>
360        mk_abs (unsafe_sub sub Var, unsafe_sub sub Bod)
361      )
362  end
363
364fun apply_sub sub tm =
365  let val tm' = unsafe_sub sub tm in
366    if unvalid_change tm tm' then NONE else SOME (my_gen_all tm')
367  end
368  handle HOL_ERR _ => (incr type_errors; NONE)
369
370(* --------------------------------------------------------------------------
371   Pattern substitutions
372   -------------------------------------------------------------------------- *)
373
374fun pair_patsub pl =
375  let
376    val l1 = mk_fast_set pattern_compare pl
377    val cpl = cartesian_product l1 l1
378    val cpl' = filter (fn (x,y) => x <> y) cpl
379  in
380    cpl'
381  end
382
383fun create_patsub ceptpatdict =
384  let
385    fun f (cl,pl) = pair_patsub pl
386    val cpl       = List.concat (map f (dlist ceptpatdict))
387    val cmp       = cpl_compare pattern_compare pattern_compare
388    val dfreq     = count_dict (dempty cmp) cpl
389    val _         = msgd_synt dfreq "pattern substitutions"
390  in
391    map fst (dict_sort compare_imax (dlist dfreq))
392  end
393
394fun apply_patsub thmpatdict iceptdict (p1,p2) tm =
395  let
396    val patl = dfind tm thmpatdict
397    fun same_pat x (p,cl) = pattern_compare (p,x) = EQUAL
398  in
399    case List.find (same_pat p1) patl of
400      NONE => NONE
401    | SOME (p,cl) =>
402      (
403      let val tm' = term_of_pat iceptdict (p2,cl) in
404        if unvalid_change tm tm' then NONE else SOME (my_gen_all tm')
405      end
406      handle HOL_ERR _ => (incr type_errors; NONE)
407      )
408  end
409
410(* --------------------------------------------------------------------------
411   Conjecturing
412   -------------------------------------------------------------------------- *)
413
414
415fun update_genthmdict gencjdict genthmdict x =
416  if dmem x (!genthmdict) then () else
417  genthmdict := dadd x (dlength (!gencjdict), dlength (!genthmdict))
418  (!genthmdict)
419
420fun update_gencjdict gencjdict x =
421  if dlength (!gencjdict) >= (!conjecture_limit) orelse dmem x (!gencjdict)
422  then ()
423  else gencjdict := dadd x (dlength (!gencjdict)) (!gencjdict)
424
425fun update_gendict covdict genthmdict gencjdict x =
426  if dmem x covdict
427  then update_genthmdict gencjdict genthmdict x
428  else update_gencjdict gencjdict x
429
430fun conjecture_sub covdict tml subl =
431  let
432    val gencjdict = ref (dempty Term.compare)
433    val genthmdict = ref (dempty Term.compare)
434    val dsub = dnew Int.compare (number_list 0 subl)
435    val tmnl = map (fn x => (x,0)) tml
436    fun try_nsub n (tm,nsub) =
437      if not (dmem nsub dsub) orelse n <= 0 then (tm,nsub) else
438      (
439      case apply_sub (dfind nsub dsub) tm of
440        NONE => try_nsub (n - 1) (tm, nsub + 1)
441      | SOME tm' =>
442        (
443        update_gendict covdict genthmdict gencjdict tm';
444        (tm, nsub + 1)
445        )
446      )
447    val mem = ref (~1)
448    fun loop tmnl =
449       if dlength (!gencjdict) >= (!conjecture_limit) orelse
450          !mem >= dlength (!gencjdict)
451       then () else
452         let
453           val _ = mem := dlength (!gencjdict)
454           val _ = print_endline (int_to_string (!mem) ^ " conjectures")
455           val newtmnl = map (try_nsub 100) tmnl
456         in
457           loop newtmnl
458         end
459  in
460    loop tmnl;
461    (!gencjdict,!genthmdict)
462  end
463
464fun conjecture_patsub thmpatdict iceptdict covdict tml patsubl =
465  let
466    val gencjdict = ref (dempty Term.compare)
467    val genthmdict = ref (dempty Term.compare)
468    val dsub = dnew Int.compare (number_list 0 patsubl)
469    val tmnl = map (fn x => (x,0)) tml
470    fun try_nsub n (tm,nsub) =
471      if not (dmem nsub dsub) orelse n <= 0 then (tm,nsub) else
472      (
473      case apply_patsub thmpatdict iceptdict (dfind nsub dsub) tm of
474        NONE => try_nsub (n - 1) (tm, nsub + 1)
475      | SOME tm' =>
476        (
477        update_gendict covdict genthmdict gencjdict tm';
478        (tm, nsub + 1)
479        )
480      )
481    val mem = ref (~1)
482    fun loop tmnl =
483       if dlength (!gencjdict) >= (!conjecture_limit) orelse
484          !mem >= dlength (!gencjdict)
485       then () else
486         let
487           val _ = mem := dlength (!gencjdict)
488           val _ = print_endline (int_to_string (!mem) ^ " conjectures")
489           val newtmnl = map (try_nsub 100) tmnl
490         in
491           loop newtmnl
492         end
493  in
494    loop tmnl;
495    (!gencjdict,!genthmdict)
496  end
497
498fun gnuplotcmd filein fileout =
499  let
500    val plotcmd = "\"" ^ String.concatWith "; " [
501      "set term postscript",
502      "set output " ^ "'" ^ fileout ^ "'",
503      "plot " ^ "'" ^ filein ^ "'"]
504      ^ "\""
505    val cmd = "gnuplot -p -e " ^ plotcmd ^ " > " ^ fileout
506  in
507    cmd_in_dir tactictoe_dir cmd
508  end
509
510fun write_graph ntot genthmdict =
511  let
512    val _    = log_synt "writing graph"
513    val rcov = int_div (dlength genthmdict) ntot
514    val _    = log_synt (Real.toString rcov ^ " conjecture coverage")
515    val l0 = map snd (dlist genthmdict)
516    val d = ref (dempty Int.compare)
517    fun update_dict (a,b) =
518      let val oldb = dfind a (!d) handle NotFound => 0 in
519        if b > oldb then d := dadd a b (!d) else ()
520      end
521    val l1 = (app update_dict l0; dlist (!d))
522    fun w (a,b) = int_to_string a ^ " " ^ (Real.toString (int_div b ntot))
523    val header  = "# miss match"
524    val _       = writel_synt "coverage_data" (header :: map w l1)
525    val filein  = (!ttt_synt_dir) ^ "/coverage_data"
526    val fileout = (!ttt_synt_dir) ^ "/coverage_graph.ps"
527  in
528    gnuplotcmd filein fileout
529  end
530
531fun conjecture tml =
532  let
533    val _     = init_synt ()
534    val tml0 = mk_fast_set Term.compare tml
535    val tml1 = map (snd o strip_forall o rename_bvarl (fn _ => "")) tml0
536    val tml2 = mk_fast_set Term.compare tml1
537    val tml3 = map (fn x => (my_gen_all x, 0)) tml2
538    val _    = msg_synt tml3 "terms"
539    val covdict = dnew Term.compare tml3
540    val ntot = dlength covdict
541    val ceptdict = concept_selection tml2
542    val iceptdict = inv_dict Term.compare ceptdict
543    val (patceptdict, ceptpatdict, thmpatdict) = time_synt "patternify"
544       patternify ntot ceptdict iceptdict tml2
545    val _ = msgd_synt (!cdict_glob) "constants or variables"
546
547    (* conjecture generation from substitutions *)
548    val (gencjdict,genthmdict) =
549      if !patsub_flag
550      then
551        let val patsubl = create_patsub ceptpatdict in
552          time_synt "conjecture_patsub"
553          (conjecture_patsub thmpatdict iceptdict covdict tml2) patsubl
554        end
555      else
556        let val subl = create_sub iceptdict patceptdict in
557          time_synt "conjecture_sub"
558          (conjecture_sub covdict tml2) subl
559        end
560    val _ = write_graph ntot genthmdict
561    val _ = log_synt (int_to_string (!type_errors) ^ " type errors")
562    val _     = msgd_synt gencjdict "generated conjectures"
563    val igencjdict = inv_dict Int.compare gencjdict
564  in
565    map snd (dlist igencjdict)
566  end
567
568end (* struct *)
569