1(* ========================================================================== *) 2(* FILE : tttSynt.sml *) 3(* DESCRIPTION : Synthesis of terms for conjecturing lemmas *) 4(* AUTHOR : (c) Thibault Gauthier, University of Innsbruck *) 5(* DATE : 2018 *) 6(* ========================================================================== *) 7 8structure tttSynt :> tttSynt = 9struct 10 11open HolKernel boolLib Abbrev tttTools 12 13val ERR = mk_HOL_ERR "tttSynt" 14 15(* -------------------------------------------------------------------------- 16 Globals 17 -------------------------------------------------------------------------- *) 18 19val conjecture_limit = ref 100000 20val patsub_flag = ref false 21val type_errors = ref 0 22 23(* -------------------------------------------------------------------------- 24 Tools 25 -------------------------------------------------------------------------- *) 26 27fun my_gen_all tm = list_mk_forall (free_vars_lr tm, tm) 28 29fun my_gen_all_err tm = SOME (my_gen_all tm) 30 handle HOL_ERR _ => (incr type_errors; NONE) 31 32fun alpha_equal_or_error tm tm' = 33 Term.compare (my_gen_all tm, my_gen_all tm') = EQUAL 34 handle _ => true 35 36fun unvalid_change tm tm' = 37 alpha_equal_or_error tm tm' orelse 38 (type_of tm' <> bool handle HOL_ERR _ => true) 39 40(* -------------------------------------------------------------------------- 41 Debugging 42 -------------------------------------------------------------------------- *) 43 44val ttt_synt_dir = ref (tactictoe_dir ^ "/log_synt") 45 46fun log_synt_file file s = 47 append_endline (!ttt_synt_dir ^ "/" ^ file) s 48 49fun log_synt s = 50 (print_endline s; log_synt_file "log_main" s) 51 52fun msg_synt l s = 53 let val s' = int_to_string (length l) ^ " " ^ s in 54 log_synt s' 55 end 56 57fun msgd_synt d s = 58 let val s' = int_to_string (dlength d) ^ " " ^ s in 59 log_synt s' 60 end 61 62fun time_synt s f x = 63 let 64 val _ = log_synt s 65 val (r,t) = add_time f x 66 in 67 log_synt (s ^ ": " ^ Real.toString t); 68 r 69 end 70 71fun writel_synt s sl = writel (!ttt_synt_dir ^ "/" ^ s) sl 72 73(* -------------------------------------------------------------------------- 74 Statistics on conjecture generation. 75 -------------------------------------------------------------------------- *) 76 77fun string_of_tml tml = 78 (" " ^ String.concatWith "\n " (map term_to_string tml) ^ "\n") 79 80fun string_of_subst sub = 81 let fun f (a,b) = "(" ^ term_to_string a ^ ", " ^ term_to_string b ^ ")" in 82 "[" ^ String.concatWith ", " (map f sub) ^ "]" 83 end 84 85fun write_subdict subdict = 86 let 87 val _ = msgd_synt subdict "writing subdict" 88 val l = dlist subdict 89 fun f (sub, (cjl,score)) = 90 Real.toString score ^ " " ^ int_to_string (length cjl) ^ ": " ^ 91 string_of_subst sub 92 in 93 writel_synt "substitutions" (map f l) 94 end 95 96fun write_origdict origdict = 97 let 98 val _ = msgd_synt origdict "writing origdict" 99 val l = dlist origdict 100 fun g (sub,tm) = string_of_subst sub ^ ": " ^ term_to_string tm 101 fun f (cj,subtml) = String.concatWith "\n" 102 (["Conjecture:", term_to_string cj] @ map g subtml) 103 in 104 writel_synt "origdict" (map f l) 105 end 106 107(* -------------------------------------------------------------------------- 108 Stateful dictionnaries 109 -------------------------------------------------------------------------- *) 110 111type psubst = (int * int) list 112type tsubst = (term * term) list 113 114(* dictionnary *) 115val cdict_glob = ref (dempty Term.compare) 116val icdict_glob = ref (dempty Int.compare) 117val cdict_loc = ref (dempty Int.compare) 118val cjinfo_glob =ref (dempty Term.compare) 119 120 121fun fconst_glob c = 122 dfind c (!cdict_glob) handle NotFound => 123 let val cglob = dlength (!cdict_glob) in 124 cdict_glob := dadd c cglob (!cdict_glob); 125 icdict_glob := dadd cglob c (!icdict_glob); 126 cglob 127 end 128 129fun fconst_loc cglob = 130 dfind cglob (!cdict_loc) handle NotFound => 131 let val cloc = dlength (!cdict_loc) in 132 cdict_loc := dadd cglob cloc (!cdict_loc); 133 cloc 134 end 135 136fun fconst c = fconst_loc (fconst_glob c) 137 138fun init_synt () = 139 ( 140 cdict_glob := dempty Term.compare; 141 icdict_glob := dempty Int.compare; 142 cjinfo_glob := dempty Term.compare; 143 type_errors := 0 144 ) 145 146(* -------------------------------------------------------------------------- 147 Conceptualization 148 -------------------------------------------------------------------------- *) 149 150val concept_threshold = ref 4 151val concept_flag = ref false 152 153fun is_varconst x = is_var x orelse is_const x 154 155fun save_concept d tm = 156 if dmem tm (!d) then () else 157 let val v = mk_var ("C" ^ int_to_string (dlength (!d)), type_of tm) in 158 d := dadd tm v (!d) 159 end 160 161fun concept_selection tml = 162 let 163 fun f x = find_terms (not o is_varconst) x 164 val l0 = List.concat (map f tml) 165 val freq = count_dict (dempty Term.compare) l0 166 val l1 = dlist freq 167 fun above_threshold x = snd x >= !concept_threshold 168 val l2 = filter above_threshold l1 169 val l3 = dict_sort compare_imax l2 170 fun w (x,n) = int_to_string n ^ " :" ^ term_to_string x 171 val _ = writel_synt "concepts" (map w l3) 172 val _ = msg_synt l2 "selected concepts" 173 val d = ref (dempty Term.compare) 174 in 175 app (save_concept d) (map fst l2); 176 (!d) 177 end 178 179fun conceptualize_tm ceptdict tm = 180 let 181 fun is_cept x = dmem x ceptdict 182 val redexl0 = find_terms is_cept tm 183 fun cmp (tm1,tm2) = Int.compare (term_size tm2, term_size tm1) 184 val redexl1 = dict_sort cmp redexl0 185 fun f i tm = {redex = tm, residue = dfind tm ceptdict} 186 val sub = mapi f redexl1 187 val newtm = Term.subst sub tm 188 in 189 if term_eq newtm tm then [tm] else [tm,newtm] 190 end 191 192fun read_cept iceptdict c = 193 let val tm = dfind c (!icdict_glob) in 194 dfind tm iceptdict handle NotFound => tm 195 end 196 197fun read_subst iceptdict sub = 198 let fun f (a,b) = (read_cept iceptdict a, read_cept iceptdict b) in 199 map f sub 200 end 201 202(* -------------------------------------------------------------------------- 203 Patterns 204 -------------------------------------------------------------------------- *) 205 206datatype pattern = 207 Pconst of int 208 | Pcomb of pattern * pattern 209 | Plamb of pattern * pattern 210 211fun pattern_tm tm = 212 case dest_term tm of 213 VAR _ => Pconst (fconst tm) 214 | CONST _ => Pconst (fconst tm) 215 | COMB(Rator,Rand) => Pcomb (pattern_tm Rator, pattern_tm Rand) 216 | LAMB(Var,Bod) => Plamb (pattern_tm Var, pattern_tm Bod) 217 218fun patternify_one tm = 219 let 220 val _ = cdict_loc := dempty Int.compare 221 fun cmp (a,b) = Int.compare (snd a, snd b) 222 val p = pattern_tm tm 223 val l1 = dlist (!cdict_loc) 224 val l2 = dict_sort cmp l1 225 in 226 (p, map fst l2) 227 end 228 229fun pattern_compare (p1,p2) = case (p1,p2) of 230 (Pconst i1,Pconst i2) => Int.compare (i1,i2) 231 | (Pconst _,_) => LESS 232 | (_,Pconst _) => GREATER 233 | (Pcomb(a1,b1),Pcomb(a2,b2)) => 234 cpl_compare pattern_compare pattern_compare ((a1,b1),(a2,b2)) 235 | (Pcomb _,_) => LESS 236 | (_,Pcomb _) => GREATER 237 | (Plamb(a1,b1),Plamb(a2,b2)) => 238 cpl_compare pattern_compare pattern_compare ((a1,b1),(a2,b2)) 239 240fun string_of_pattern p = case p of 241 Pconst i => int_to_string i 242 | Pcomb (p1,p2) => 243 "(" ^ String.concatWith " " ("A" :: map string_of_pattern [p1,p2]) ^ ")" 244 | Plamb (p1,p2) => 245 "(" ^ String.concatWith " " ("L" :: map string_of_pattern [p1,p2]) ^ ")" 246 247fun write_patceptdict ntot patceptdict = 248 let 249 val _ = log_synt "writing patceptdict" 250 val l0 = dlist patceptdict 251 val l1 = filter (fn (a,b) => length b > 1) l0 252 val l2 = map (fn (a,b) => (a, length b)) l1 253 val r2 = int_div (sum_int (map snd l2)) ntot 254 val l3 = dict_sort compare_imax l2 255 fun w (p,n) = int_to_string n ^ ": " ^ string_of_pattern p 256 val _ = msg_synt l3 "patterns appearing at least twice" 257 in 258 writel_synt "patterns" (map w l3) 259 end 260 261fun write_ceptpatdict iceptdict ceptpatdict = 262 let 263 val _ = log_synt "writing ceptpatdict" 264 val l0 = dlist ceptpatdict 265 val l1 = filter (fn (a,b) => length b > 1) l0 266 val l2 = map (fn (a,b) => (a, length b)) l1 267 val l3 = dict_sort compare_imax l2 268 fun w (cl,n) = 269 int_to_string n ^ ": " ^ 270 String.concatWith "\n" 271 (map (term_to_string o read_cept iceptdict) cl) 272 val _ = msg_synt l3 "concept lists appearing at least twice" 273 in 274 writel_synt "concept_lists" (map w l3) 275 end 276 277fun patternify ntot ceptdict iceptdict tml = 278 let 279 val patceptdict = ref (dempty pattern_compare) 280 val ceptpatdict = ref (dempty (list_compare Int.compare)) 281 val thmpatdict = ref (dempty Term.compare) 282 val tml1 = mk_fast_set Term.compare tml 283 fun f tm = 284 let 285 val (p,cl) = patternify_one tm 286 val cll = dfind p (!patceptdict) handle NotFound => [] 287 val pl = dfind cl (!ceptpatdict) handle NotFound => [] 288 in 289 patceptdict := dadd p (cl :: cll) (!patceptdict); 290 ceptpatdict := dadd cl (p :: pl) (!ceptpatdict); 291 (p,cl) 292 end 293 fun g tm = 294 let 295 val variants = 296 if !concept_flag then conceptualize_tm ceptdict tm else [tm] 297 val patl = map f variants 298 in 299 thmpatdict := dadd tm patl (!thmpatdict) 300 end 301 val _ = app g tml1 302 val _ = msgd_synt (!patceptdict) "patterns" 303 val _ = msgd_synt (!ceptpatdict) "concept lists" 304 val _ = write_patceptdict ntot (!patceptdict) 305 val _ = write_ceptpatdict iceptdict (!ceptpatdict) 306 in 307 (!patceptdict, !ceptpatdict, !thmpatdict) 308 end 309 310fun term_of_pat idict (p,cl) = case p of 311 Pconst i => read_cept idict (List.nth (cl,i)) 312 | Pcomb (p1,p2) => 313 mk_comb (term_of_pat idict (p1,cl), term_of_pat idict (p2,cl)) 314 | Plamb (p1,p2) => 315 mk_abs (term_of_pat idict (p1,cl), term_of_pat idict (p2,cl)) 316 317(* -------------------------------------------------------------------------- 318 Concept substitutions. 319 -------------------------------------------------------------------------- *) 320 321fun compare_kimin (a,b) = Int.compare (fst a, fst b) 322 323fun norm_sub l = 324 let val l1 = filter (fn (x,y) => x <> y) l in 325 dict_sort compare_kimin l1 326 end 327 328fun pair_sub cll = 329 let 330 val cll' = mk_fast_set (list_compare Int.compare) cll 331 val cpl = cartesian_product cll' cll' 332 val cpl' = filter (fn (x,y) => x <> y) cpl 333 in 334 map combine cpl' 335 end 336 337fun create_sub iceptdict patceptdict = 338 let 339 fun f (p,cll) = pair_sub cll 340 val l1 = List.concat (map f (dlist patceptdict)) 341 val l2 = map norm_sub l1 342 val cmp = list_compare (cpl_compare Int.compare Int.compare) 343 val dfreq = count_dict (dempty cmp) l2 344 val _ = msgd_synt dfreq "concept substitutions" 345 val l3 = dict_sort compare_imax (dlist dfreq) 346 in 347 (map (read_subst iceptdict)) (map fst l3) 348 end 349 350fun unsafe_sub sub tm = 351 let val redreso = List.find (fn (red,res) => red = tm) sub in 352 if isSome redreso then snd (valOf (redreso)) else 353 ( 354 case dest_term tm of 355 VAR(Name,Ty) => tm 356 | CONST{Name,Thy,Ty} => tm 357 | COMB(Rator,Rand) => 358 mk_comb (unsafe_sub sub Rator, unsafe_sub sub Rand) 359 | LAMB(Var,Bod) => 360 mk_abs (unsafe_sub sub Var, unsafe_sub sub Bod) 361 ) 362 end 363 364fun apply_sub sub tm = 365 let val tm' = unsafe_sub sub tm in 366 if unvalid_change tm tm' then NONE else SOME (my_gen_all tm') 367 end 368 handle HOL_ERR _ => (incr type_errors; NONE) 369 370(* -------------------------------------------------------------------------- 371 Pattern substitutions 372 -------------------------------------------------------------------------- *) 373 374fun pair_patsub pl = 375 let 376 val l1 = mk_fast_set pattern_compare pl 377 val cpl = cartesian_product l1 l1 378 val cpl' = filter (fn (x,y) => x <> y) cpl 379 in 380 cpl' 381 end 382 383fun create_patsub ceptpatdict = 384 let 385 fun f (cl,pl) = pair_patsub pl 386 val cpl = List.concat (map f (dlist ceptpatdict)) 387 val cmp = cpl_compare pattern_compare pattern_compare 388 val dfreq = count_dict (dempty cmp) cpl 389 val _ = msgd_synt dfreq "pattern substitutions" 390 in 391 map fst (dict_sort compare_imax (dlist dfreq)) 392 end 393 394fun apply_patsub thmpatdict iceptdict (p1,p2) tm = 395 let 396 val patl = dfind tm thmpatdict 397 fun same_pat x (p,cl) = pattern_compare (p,x) = EQUAL 398 in 399 case List.find (same_pat p1) patl of 400 NONE => NONE 401 | SOME (p,cl) => 402 ( 403 let val tm' = term_of_pat iceptdict (p2,cl) in 404 if unvalid_change tm tm' then NONE else SOME (my_gen_all tm') 405 end 406 handle HOL_ERR _ => (incr type_errors; NONE) 407 ) 408 end 409 410(* -------------------------------------------------------------------------- 411 Conjecturing 412 -------------------------------------------------------------------------- *) 413 414 415fun update_genthmdict gencjdict genthmdict x = 416 if dmem x (!genthmdict) then () else 417 genthmdict := dadd x (dlength (!gencjdict), dlength (!genthmdict)) 418 (!genthmdict) 419 420fun update_gencjdict gencjdict x = 421 if dlength (!gencjdict) >= (!conjecture_limit) orelse dmem x (!gencjdict) 422 then () 423 else gencjdict := dadd x (dlength (!gencjdict)) (!gencjdict) 424 425fun update_gendict covdict genthmdict gencjdict x = 426 if dmem x covdict 427 then update_genthmdict gencjdict genthmdict x 428 else update_gencjdict gencjdict x 429 430fun conjecture_sub covdict tml subl = 431 let 432 val gencjdict = ref (dempty Term.compare) 433 val genthmdict = ref (dempty Term.compare) 434 val dsub = dnew Int.compare (number_list 0 subl) 435 val tmnl = map (fn x => (x,0)) tml 436 fun try_nsub n (tm,nsub) = 437 if not (dmem nsub dsub) orelse n <= 0 then (tm,nsub) else 438 ( 439 case apply_sub (dfind nsub dsub) tm of 440 NONE => try_nsub (n - 1) (tm, nsub + 1) 441 | SOME tm' => 442 ( 443 update_gendict covdict genthmdict gencjdict tm'; 444 (tm, nsub + 1) 445 ) 446 ) 447 val mem = ref (~1) 448 fun loop tmnl = 449 if dlength (!gencjdict) >= (!conjecture_limit) orelse 450 !mem >= dlength (!gencjdict) 451 then () else 452 let 453 val _ = mem := dlength (!gencjdict) 454 val _ = print_endline (int_to_string (!mem) ^ " conjectures") 455 val newtmnl = map (try_nsub 100) tmnl 456 in 457 loop newtmnl 458 end 459 in 460 loop tmnl; 461 (!gencjdict,!genthmdict) 462 end 463 464fun conjecture_patsub thmpatdict iceptdict covdict tml patsubl = 465 let 466 val gencjdict = ref (dempty Term.compare) 467 val genthmdict = ref (dempty Term.compare) 468 val dsub = dnew Int.compare (number_list 0 patsubl) 469 val tmnl = map (fn x => (x,0)) tml 470 fun try_nsub n (tm,nsub) = 471 if not (dmem nsub dsub) orelse n <= 0 then (tm,nsub) else 472 ( 473 case apply_patsub thmpatdict iceptdict (dfind nsub dsub) tm of 474 NONE => try_nsub (n - 1) (tm, nsub + 1) 475 | SOME tm' => 476 ( 477 update_gendict covdict genthmdict gencjdict tm'; 478 (tm, nsub + 1) 479 ) 480 ) 481 val mem = ref (~1) 482 fun loop tmnl = 483 if dlength (!gencjdict) >= (!conjecture_limit) orelse 484 !mem >= dlength (!gencjdict) 485 then () else 486 let 487 val _ = mem := dlength (!gencjdict) 488 val _ = print_endline (int_to_string (!mem) ^ " conjectures") 489 val newtmnl = map (try_nsub 100) tmnl 490 in 491 loop newtmnl 492 end 493 in 494 loop tmnl; 495 (!gencjdict,!genthmdict) 496 end 497 498fun gnuplotcmd filein fileout = 499 let 500 val plotcmd = "\"" ^ String.concatWith "; " [ 501 "set term postscript", 502 "set output " ^ "'" ^ fileout ^ "'", 503 "plot " ^ "'" ^ filein ^ "'"] 504 ^ "\"" 505 val cmd = "gnuplot -p -e " ^ plotcmd ^ " > " ^ fileout 506 in 507 cmd_in_dir tactictoe_dir cmd 508 end 509 510fun write_graph ntot genthmdict = 511 let 512 val _ = log_synt "writing graph" 513 val rcov = int_div (dlength genthmdict) ntot 514 val _ = log_synt (Real.toString rcov ^ " conjecture coverage") 515 val l0 = map snd (dlist genthmdict) 516 val d = ref (dempty Int.compare) 517 fun update_dict (a,b) = 518 let val oldb = dfind a (!d) handle NotFound => 0 in 519 if b > oldb then d := dadd a b (!d) else () 520 end 521 val l1 = (app update_dict l0; dlist (!d)) 522 fun w (a,b) = int_to_string a ^ " " ^ (Real.toString (int_div b ntot)) 523 val header = "# miss match" 524 val _ = writel_synt "coverage_data" (header :: map w l1) 525 val filein = (!ttt_synt_dir) ^ "/coverage_data" 526 val fileout = (!ttt_synt_dir) ^ "/coverage_graph.ps" 527 in 528 gnuplotcmd filein fileout 529 end 530 531fun conjecture tml = 532 let 533 val _ = init_synt () 534 val tml0 = mk_fast_set Term.compare tml 535 val tml1 = map (snd o strip_forall o rename_bvarl (fn _ => "")) tml0 536 val tml2 = mk_fast_set Term.compare tml1 537 val tml3 = map (fn x => (my_gen_all x, 0)) tml2 538 val _ = msg_synt tml3 "terms" 539 val covdict = dnew Term.compare tml3 540 val ntot = dlength covdict 541 val ceptdict = concept_selection tml2 542 val iceptdict = inv_dict Term.compare ceptdict 543 val (patceptdict, ceptpatdict, thmpatdict) = time_synt "patternify" 544 patternify ntot ceptdict iceptdict tml2 545 val _ = msgd_synt (!cdict_glob) "constants or variables" 546 547 (* conjecture generation from substitutions *) 548 val (gencjdict,genthmdict) = 549 if !patsub_flag 550 then 551 let val patsubl = create_patsub ceptpatdict in 552 time_synt "conjecture_patsub" 553 (conjecture_patsub thmpatdict iceptdict covdict tml2) patsubl 554 end 555 else 556 let val subl = create_sub iceptdict patceptdict in 557 time_synt "conjecture_sub" 558 (conjecture_sub covdict tml2) subl 559 end 560 val _ = write_graph ntot genthmdict 561 val _ = log_synt (int_to_string (!type_errors) ^ " type errors") 562 val _ = msgd_synt gencjdict "generated conjectures" 563 val igencjdict = inv_dict Int.compare gencjdict 564 in 565 map snd (dlist igencjdict) 566 end 567 568end (* struct *) 569