1(* ========================================================================== *) 2(* FILE : tttSearch.sml *) 3(* DESCRIPTION : Search algorithm for TacticToe. *) 4(* AUTHOR : (c) Thibault Gauthier, University of Innsbruck *) 5(* DATE : 2017 *) 6(* ========================================================================== *) 7 8structure tttSearch :> tttSearch = 9struct 10 11open HolKernel boolLib Abbrev tttTools tttTimeout tttFeature tttPredict 12tttExec tttLexer tttMinimize tttThmData tttTacticData tttLearn tttSetup 13 14val ERR = mk_HOL_ERR "tttSearch" 15val last_stac = ref "" 16fun debug_err s = (debug ("Error: " ^ s); raise ERR "standard" "error") 17 18(* -------------------------------------------------------------------------- 19 Exceptions 20 -------------------------------------------------------------------------- *) 21 22exception SearchTimeOut 23exception NoNextTac 24 25(* -------------------------------------------------------------------------- 26 Asynchronous calls to provers 27 -------------------------------------------------------------------------- *) 28 29(* Result *) 30datatype async_result_t = 31 HSuccess of (string * goal) | 32 HFailure | 33 HRunning of Thread.thread | 34 HVoid 35 36(* 100000 is the maximum number of nodes *) 37val hammer_ref = ref 0 38val async_result = Array.array (100000, HVoid) 39val install_async = ref (dempty Int.compare) 40val running_async = ref (dempty Int.compare) 41 42(* Start and end of search *) 43fun init_async () = 44 if !ttt_eprover_flag 45 then 46 ( 47 hammer_ref := 0; 48 install_async := dempty Int.compare; 49 running_async := dempty Int.compare; 50 Array.modify (fn _ => HVoid) async_result 51 ) 52 else () 53 54fun terminate_thread pid thread = 55 while (Thread.isActive thread) 56 do (debug_search ("terminate thread " ^ int_to_string pid); 57 Thread.interrupt thread) 58 59fun terminate_async_pid pid = 60 if !ttt_eprover_flag then 61 ( 62 Array.update (async_result,pid,HVoid); 63 install_async := drem pid (!install_async); 64 running_async := drem pid (!running_async); 65 if dmem pid (!running_async) 66 then terminate_thread pid (dfind pid (!running_async)) 67 else () 68 ) 69 else () 70 71fun terminate_async () = 72 if !ttt_eprover_flag 73 then app terminate_async_pid (dkeys (!running_async)) 74 else () 75 76fun queue_async pid g = 77 if !ttt_eprover_flag then 78 ( 79 terminate_async_pid pid; 80 debug_search ("install thread " ^ int_to_string pid); 81 install_async := dadd pid g (!install_async) 82 ) 83 else () 84 85 86(* ------------------------------------------------------------------------- 87 Tell if a node is active or not 88 -------------------------------------------------------------------------- *) 89 90val notactivedict = ref (dempty Int.compare) 91fun is_notactive x = dmem x (!notactivedict) 92fun is_active x = not (is_notactive x) 93 94fun deactivate x = 95 ( 96 debug_search ("deactivate " ^ int_to_string x); 97 terminate_async_pid x; 98 notactivedict := dadd x () (!notactivedict) 99 ) 100 101(* ------------------------------------------------------------------------- 102 Search references 103 -------------------------------------------------------------------------- *) 104 105val glob_timer = ref NONE 106val proofdict = ref (dempty Int.compare) 107 108(* global values to prevent many arguments in functions *) 109val thmpredictor_glob = ref (fn _ => (fn _ => [])) 110val tacpredictor_glob = ref (fn _ => []) 111val glpredictor_glob = ref (fn _ => 0.0) 112val hammer_glob = ref (fn _ => (fn _ => NONE)) 113 114(* -------------------------------------------------------------------------- 115 Caching tactic applications on goals 116 -------------------------------------------------------------------------- *) 117 118val stacgoal_cache = ref (dempty (cpl_compare String.compare goal_compare)) 119 120(* -------------------------------------------------------------------------- 121 Statistics 122 -------------------------------------------------------------------------- *) 123 124val stac_counter = ref 0 125 126fun string_of_pred pred = 127 "[" ^ String.concatWith "," pred ^ "]" 128 129val tactime = ref 0.0 130val thmtime = ref 0.0 131val gltime = ref 0.0 132 133val tactimer = total_time tactime 134val thmtimer = total_time thmtime 135val gltimer = total_time gltime 136 137val inst_time = ref 0.0 138val terminst_time = ref 0.0 139val infstep_time = ref 0.0 140val node_create_time = ref 0.0 141val node_find_time = ref 0.0 142 143val inst_timer = total_time inst_time 144val infstep_timer = total_time infstep_time 145fun node_create_timer f x = total_time node_create_time f x 146val node_find_timer = total_time node_find_time 147 148val tot_time = ref 0.0 149fun total_timer f x = total_time tot_time f x 150 151fun reset_timers () = 152 ( 153 tactime := 0.0; 154 thmtime := 0.0; 155 gltime := 0.0; 156 inst_time := 0.0; 157 infstep_time := 0.0; 158 node_create_time := 0.0; 159 node_find_time := 0.0; 160 tot_time := 0.0 161 ) 162 163(* -------------------------------------------------------------------------- 164 Special tactics 165 -------------------------------------------------------------------------- *) 166 167val metis_spec = "tactictoe_metis" 168val eprover_spec = "tactictoe_eprover" 169 170fun add_eprover pred = 171 if !ttt_eprover_flag then eprover_spec :: pred else pred 172 173fun add_metis pred = 174 if !ttt_metis_flag then metis_spec :: pred else pred 175 176(* -------------------------------------------------------------------------- 177 MCTS: Priors 178 -------------------------------------------------------------------------- *) 179 180fun array_to_list a = 181 let fun f (a,l) = a :: l in rev (Array.foldl f [] a) end 182 183fun init_eval pripol pid = 184 let 185 val _ = debug_search "mcts evaluation" 186 val prec = dfind pid (!proofdict) 187 val {visit,pending,goalarr,prioreval,cureval,priorpolicy,...} = prec 188 val eval = 189 if !ttt_mcevnone_flag then 0.0 190 else if !ttt_mcevtriv_flag then 1.0 191 else (!glpredictor_glob) (array_to_list (#goalarr prec)) 192 in 193 priorpolicy := pripol; 194 visit := 1.0; 195 prioreval := eval; 196 cureval := [eval] 197 end 198 199(* -------------------------------------------------------------------------- 200 MCTS: Backpropagation 201 -------------------------------------------------------------------------- *) 202 203fun backup_loop beval eval cid = 204 let 205 val crec = dfind cid (!proofdict) 206 val {parid,visit,cureval,...} = crec 207 in 208 if beval 209 then cureval := eval :: !cureval 210 else () 211 ; 212 visit := !visit + 1.0; 213 if parid = NONE then () else backup_loop beval eval (valOf parid) 214 end 215 216fun backup cid = 217 let 218 val _ = debug_search "mcts backpropagation" 219 val crec = dfind cid (!proofdict) 220 val {parid,prioreval,...} = crec 221 in 222 if parid = NONE 223 then () 224 else backup_loop true (!prioreval) (valOf parid) 225 end 226 227fun backup_fail cid = 228 let 229 val _ = debug_search "backup fail" 230 val crec = dfind cid (!proofdict) 231 val {parid,...} = crec 232 in 233 if parid = NONE 234 then () 235 else backup_loop (!ttt_mcevfail_flag) 0.0 (valOf parid) 236 end 237 238fun backup_success cid = 239 let 240 val _ = debug_search "backup success" 241 val crec = dfind cid (!proofdict) 242 val {parid,...} = crec 243 in 244 if parid = NONE 245 then () 246 else backup_loop true 1.0 (valOf parid) 247 end 248 249(* -------------------------------------------------------------------------- 250 Node creation and deletion 251 -------------------------------------------------------------------------- *) 252 253val max_depth_mem = ref 0 254val pid_counter = ref 0 255 256fun next_pid () = 257 let 258 val r = !pid_counter 259 val _ = pid_counter := !pid_counter + 1 260 in 261 r 262 end 263 264fun root_create goal pred = 265 let 266 fun init_empty _ = ref [] 267 val selfid = next_pid () 268 val selfrec = 269 { 270 selfid = selfid, 271 parid = NONE, 272 parstac = NONE, 273 pargn = NONE, 274 parg = NONE, 275 goalarr = Array.fromList [goal], 276 predarr = Array.fromList [pred], 277 depth = 0, 278 (* *) 279 pending = ref [0], 280 children = ref [], 281 (* proof saved for reconstruction + children *) 282 proofl = ref [], 283 childrena = Array.fromList (map init_empty [goal]), 284 (* preventing loop and parallel steps *) 285 pardict = dempty goal_compare, 286 trydict = ref (dempty (list_compare goal_compare)), 287 (* monte carlo *) 288 priorpolicy = ref 0.0, 289 visit = ref 0.0, 290 prioreval = ref 0.0, 291 cureval = ref [] 292 } 293 in 294 debug_search "Root"; 295 debug_search (" goal: " ^ 296 String.concatWith "," (map string_of_goal [goal])); 297 debug_search (" pred: \n " ^ 298 String.concatWith ",\n " (map (string_of_pred o (first_n 2)) [pred])); 299 proofdict := dadd selfid selfrec (!proofdict); 300 init_eval 0.0 selfid 301 end 302 303fun root_create_wrap g = 304 root_create g ((add_eprover o add_metis o !tacpredictor_glob) g) 305 306fun node_create pripol tactime parid parstac pargn parg goallist 307 predlist pending pardict = 308 let 309 val selfid = next_pid () 310 fun init_empty _ = ref [] 311 val selfrec = 312 { 313 selfid = selfid, 314 parid = SOME parid, 315 parstac = SOME parstac, 316 pargn = SOME pargn, 317 parg = SOME parg, 318 goalarr = Array.fromList goallist, 319 predarr = Array.fromList predlist, 320 depth = #depth (dfind parid (!proofdict)) + 1, 321 (* goal considered *) 322 pending = ref pending, 323 children = ref [], 324 (* proof saved for reconstruction + children *) 325 proofl = ref [], 326 childrena = Array.fromList (map init_empty goallist), 327 (* preventing loop and parallel steps *) 328 pardict = pardict, 329 trydict = ref (dempty (list_compare goal_compare)), 330 (* monte carlo: dummy values changed by init_eval *) 331 priorpolicy = ref 0.0, 332 visit = ref 0.0, 333 prioreval = ref 0.0, 334 cureval = ref [] 335 } 336 val cdepth = #depth selfrec 337 in 338 if cdepth > !max_depth_mem then max_depth_mem := cdepth else (); 339 debug_search 340 ("Node " ^ int_to_string selfid ^ " " ^ int_to_string parid ^ " " ^ 341 Real.toString (! (#priorpolicy selfrec))); 342 debug_search 343 (" goals: " ^ String.concatWith "," (map string_of_goal goallist)); 344 debug_search (" predictions: " ^ 345 String.concatWith ",\n " (map (string_of_pred o (first_n 2)) predlist)); 346 proofdict := dadd selfid selfrec (!proofdict); 347 init_eval pripol selfid; 348 selfid 349 end 350 351fun node_delete pid = 352 (debug_search ("node_delete " ^ int_to_string pid); deactivate pid) 353 354(* -------------------------------------------------------------------------- 355 Change the name of the tactic that has been applied 356 -------------------------------------------------------------------------- *) 357 358fun update_curstac newstac pid = 359 let 360 val prec = dfind pid (!proofdict) 361 val gn = hd (!(#pending prec)) 362 val pred = Array.sub (#predarr prec, gn) 363 val newpred = newstac :: tl pred 364 in 365 Array.update (#predarr prec, gn, newpred) 366 end 367 handle _ => debug_err ("update_curstac :" ^ newstac) 368 369(* -------------------------------------------------------------------------- 370 Trying multiple terms. 371 -------------------------------------------------------------------------- *) 372 373fun try_nqtm pid n (stac,tac) (otm,qtac) g = 374 let 375 val glo = SOME (fst (tac g)) handle _ => NONE 376 fun locprod x = case x of SOME gl => not (mem g gl) | NONE => false 377 in 378 if locprod glo then glo else 379 let fun loop qtac tml = 380 case tml of [] => NONE | tm :: m => 381 let val glo' = SOME (fst (qtac [ANTIQUOTE tm] g)) handle _ => NONE in 382 if locprod glo' 383 then 384 let val newstac = inst_timer (inst_termarg stac) tm in 385 update_curstac newstac pid; glo' 386 end 387 else loop qtac m 388 end 389 in 390 loop qtac (termknn n g otm) 391 end 392 end 393 394(* -------------------------------------------------------------------------- 395 Transfomring code into a tactic. Doing necessary predictions. 396 -------------------------------------------------------------------------- *) 397 398val thml_dict = ref (dempty (cpl_compare goal_compare Int.compare)) 399val inst_dict = ref (dempty (cpl_compare String.compare goal_compare)) 400val tac_dict = ref (dempty String.compare) 401 402fun pred_sthml thmpredictor thml_dict n g = 403 dfind (g,n) (!thml_dict) handle NotFound => 404 let val sl = thmpredictor n g in 405 thml_dict := dadd (g,n) sl (!thml_dict); 406 sl 407 end 408 409fun stac_to_tac thmpred (tac_dict,inst_dict,thml_dict) stac g = 410 ( 411 if !ttt_thmlarg_flag andalso is_absarg_stac stac then 412 ( 413 dfind (stac,g) (!inst_dict) handle NotFound => 414 let 415 val _ = debug_search ("instantiating: " ^ stac) 416 val sl = 417 if !ttt_thmlarg_flag 418 then pred_sthml thmpred thml_dict (!ttt_thmlarg_radius) g 419 else [] 420 val thmls = String.concatWith " , " (map dbfetch_of_string sl) 421 val newstac = inst_stac thmls g stac 422 val newtac = tactic_of_sml newstac 423 handle _ => 424 (debug ("Warning: stac_to_tac: " ^ newstac); raise ERR "stac_to_tac" "") 425 in 426 inst_dict := dadd (stac,g) (newstac,newtac, !ttt_tactic_time) (!inst_dict); 427 debug_search ("to: " ^ newstac); 428 (newstac, newtac, !ttt_tactic_time) 429 end 430 ) 431 else if stac = metis_spec then 432 ( 433 dfind (stac,g) (!inst_dict) handle NotFound => 434 let 435 val sl = pred_sthml thmpred thml_dict (!ttt_metis_radius) g 436 val newstac = mk_metis_call sl 437 val newtac = tactic_of_sml newstac 438 in 439 inst_dict := dadd (stac,g) (newstac,newtac,!ttt_metis_time) (!inst_dict); 440 debug_search ("to: " ^ newstac); 441 (newstac,newtac,!ttt_metis_time) 442 end 443 ) 444 else 445 let fun find_stac stac = 446 dfind stac (!tac_dict) handle NotFound => 447 let val tac = tactic_of_sml stac in 448 tac_dict := dadd stac tac (!tac_dict); 449 tac 450 end 451 in 452 (stac, find_stac stac, !ttt_tactic_time) 453 end 454 ) 455 handle _ => 456 (debug ("Warning: stac_to_tac: " ^ stac); 457 ("Tactical.NO_TAC", NO_TAC, !ttt_tactic_time)) 458 459(* -------------------------------------------------------------------------- 460 Application of a tactic. 461 -------------------------------------------------------------------------- *) 462 463fun glob_productive pardict trydict g glo = 464 case glo of 465 NONE => NONE 466 | SOME gl => 467 ( 468 if mem g gl orelse exists (fn x => dmem x pardict) gl orelse dmem gl trydict 469 then NONE 470 else SOME gl 471 ) 472 473fun apply_stac pid pardict trydict stac g = 474 let 475 val _ = last_stac := stac 476 val _ = stac_counter := !stac_counter + 1 477 (* instantiation of theorems and reading *) 478 val (newstac,newtac,tim) = 479 stac_to_tac (!thmpredictor_glob) (tac_dict,inst_dict,thml_dict) stac g 480 val _ = update_curstac newstac pid 481 (* execution *) 482 val glo = dfind (newstac,g) (!stacgoal_cache) handle NotFound => 483 let val cpo = if !ttt_termarg_flag then abs_termarg newstac else NONE in 484 case cpo of 485 NONE => app_tac tim newtac g 486 | SOME (otm,qtac) => 487 (* instantiations of terms *) 488 let 489 val etac = 490 try_nqtm pid (!ttt_termarg_radius) (newstac,newtac) (otm,qtac) 491 val glo = app_qtac tim etac g 492 in 493 glo 494 end 495 end 496 (* updating *) 497 498 (* testing for loops *) 499 val newglo = glob_productive pardict trydict g glo 500 in 501 stacgoal_cache := dadd (newstac,g) glo (!stacgoal_cache); 502 newglo 503 end 504 505fun apply_next_stac pid = 506 let 507 val _ = debug_search "apply_next_stac" 508 val prec = dfind pid (!proofdict) 509 val gn = hd (! (#pending prec)) 510 handle _ => debug_err "apply_next_stac: empty pending" 511 val g = Array.sub (#goalarr prec, gn) 512 val pred = Array.sub (#predarr prec, gn) 513 val trydict = !(#trydict prec) 514 val pardict = (#pardict prec) 515 val stac = hd pred 516 handle _ => debug_err "apply_next_stac: empty pred" 517 in 518 if stac = eprover_spec 519 then (queue_async pid g; NONE) 520 else infstep_timer (apply_stac pid pardict trydict stac) g 521 end 522 523(* ---------------------------------------------------------------------- 524 Searching for a node (goal list) to explore. 525 ---------------------------------------------------------------------- *) 526 527fun has_empty_pred pid = 528 let 529 val prec = dfind pid (!proofdict) 530 val gn = hd (!(#pending prec)) 531 val pred = Array.sub (#predarr prec, gn) 532 handle _ => debug_err ("find_next_tac: " ^ int_to_string pid) 533 in 534 if null pred then (deactivate pid; true) else false 535 end 536 537fun mc_node_find pid = 538 if Timer.checkRealTimer (valOf (!glob_timer)) > (!ttt_search_time) 539 then (debug "Warning: mc_node_find: loop"; raise SearchTimeOut) 540 else 541 let 542 val prec = dfind pid (!proofdict) 543 val {children,visit,...} = prec 544 val pvisit = !(#visit prec) 545 val pdenom = Math.sqrt pvisit 546 (* try new tactic on the node itself *) 547 val n = length (!children) 548 val self_pripol = 549 Math.pow (1.0 - !ttt_mcpol_coeff, Real.fromInt n) * !ttt_mcpol_coeff 550 val self_curpol = 1.0 / pdenom 551 val self_selsc = (pid, (!ttt_mcev_coeff) * (self_pripol / self_curpol)) 552 (* or explore deeper existing paritial proofs *) 553 fun f cid = 554 let 555 val crec = dfind cid (!proofdict) 556 val pripol = !(#priorpolicy crec) 557 val meaneval = average_real (!(#cureval crec)) 558 val visit = !(#visit crec) 559 val curpol = (visit + 1.0) / pdenom 560 in 561 (cid, meaneval + (!ttt_mcev_coeff) * (pripol / curpol)) 562 end 563 (* sort and select node with best selection score *) 564 val l0 = self_selsc :: List.map f (!children) 565 val l1 = dict_sort compare_rmax l0 566 val (selid,_) = hd l1 567 in 568 if pid = selid then (pid,self_pripol) else mc_node_find selid 569 end 570 571fun try_mc_find () = 572 if Timer.checkRealTimer (valOf (!glob_timer)) > (!ttt_search_time) 573 then (debug "Warning: try_mc_find"; raise SearchTimeOut) 574 else 575 let 576 val _ = debug_search "mc_node_find" 577 val (pid,pripol) = mc_node_find 0 578 in 579 if is_notactive pid 580 then (backup_fail pid; try_mc_find ()) 581 else (debug_search ("Find " ^ int_to_string pid); (pid,pripol)) 582 end 583 584(* --------------------------------------------------------------------------- 585 Closing proofs 586 -------------------------------------------------------------------------- *) 587 588fun children_of pid = 589 let val prec = dfind pid (!proofdict) in !(#children prec) end 590 591fun descendant_of pid = 592 let val cidl = children_of pid in 593 cidl @ List.concat (map descendant_of cidl) 594 end 595 596fun close_descendant pid = app node_delete (descendant_of pid) 597 598exception ProofFound 599 600fun close_proof cid pid = 601 let 602 val crec = dfind cid (!proofdict) 603 val prec = dfind pid (!proofdict) 604 val {pargn = gn, parstac = stac,...} = crec 605 val {proofl,pending,parid,children,visit,trydict,priorpolicy,...} = prec 606 in 607 (* checking some assertions *) 608 if !pending <> [] then () else debug_err "close_proof: pending"; 609 if valOf gn = hd (!pending) then () else debug_err "close_proof"; 610 (* remember which child gave the proof of which goal *) 611 proofl := (valOf gn, valOf stac, cid) :: !proofl; 612 (* close all current children *) 613 close_descendant pid; 614 (* switching to next pending goal, erasing previous statistics *) 615 children := []; 616 trydict := dempty (list_compare goal_compare); 617 pending := tl (!pending); 618 (* optional reinitialization of the evaluation function *) 619 if !ttt_mcevinit_flag then init_eval (!priorpolicy) pid else (); 620 (* check if the goal was solved and recursively close *) 621 if null (!pending) 622 then 623 if parid = NONE (* special case when it's root *) 624 then (debug "proof found"; node_delete pid; raise ProofFound) 625 else close_proof pid (valOf parid) 626 else () 627 end 628 629(* -------------------------------------------------------------------------- 630 Creating new nodes 631 -------------------------------------------------------------------------- *) 632 633fun node_create_gl pripol tactime gl pid = 634 let 635 val prec = dfind pid (!proofdict) 636 val gn = hd (! (#pending prec)) 637 val goal = Array.sub (#goalarr prec, gn) 638 val prev_predl = Array.sub (#predarr prec, gn) 639 val stac = hd prev_predl 640 val parchildren = #children prec 641 val parchildrensave = Array.sub (#childrena prec,gn) 642 val depth = #depth prec + 1 643 val predlist = map (add_eprover o add_metis o !tacpredictor_glob) gl 644 val pending = rev (map fst (number_list 0 predlist)) 645 (* Updating list of parents *) 646 val new_pardict = dadd goal () (#pardict prec) 647 (* New node *) 648 val selfid = 649 node_create pripol 650 tactime pid stac gn goal gl predlist pending new_pardict 651 in 652 parchildren := selfid :: (!parchildren); 653 parchildrensave := selfid :: (!parchildrensave); 654 selfid 655 end 656 657(* fake a node when a proof is found but no search is performed on this node *) 658fun node_create_empty staco tactime pid = 659 let 660 val prec = dfind pid (!proofdict) 661 val gn = hd (! (#pending prec)) 662 val goal = Array.sub (#goalarr prec, gn) 663 val pred = Array.sub (#predarr prec, gn) 664 val stac = 665 case staco of 666 NONE => hd pred 667 | SOME s => s 668 val parchildren = #children prec 669 val parchildrensave = Array.sub (#childrena prec,gn) 670 val selfid = node_create 0.0 tactime pid stac gn goal [] [] [] 671 (dempty goal_compare) 672 in 673 parchildren := selfid :: (!parchildren); 674 parchildrensave := selfid :: (!parchildrensave); 675 selfid 676 end 677 678(* pid should be active and the goal should match *) 679fun close_proof_wrap staco tactime pid = 680 let val cid = node_create_timer (node_create_empty staco tactime) pid in 681 backup cid; 682 close_proof cid pid 683 end 684 685 686(* -------------------------------------------------------------------------- 687 Handling asynchronously calls 688 -------------------------------------------------------------------------- *) 689 690fun current_goal pid = 691 let 692 val prec = dfind pid (!proofdict) 693 val gn = hd (!(#pending prec)) 694 in 695 Array.sub (#goalarr prec, gn) 696 end 697 698(* Opening a thread *) 699fun hammer_call pid g = 700 ( 701 case !hammer_glob (!hammer_ref) g of 702 NONE => Array.update (async_result,pid,HFailure) 703 | SOME stac => Array.update (async_result,pid,HSuccess (stac,g)) 704 ) 705 handle _ => Array.update (async_result,pid,HFailure) 706(* add a debug message here *) 707 708fun fork_hammer () = 709 if null (dkeys (!install_async)) then () else 710 let 711 val pid = hd (dkeys (!install_async)) 712 val _ = install_async := drem pid (!install_async) 713 val _ = incr hammer_ref 714 val _ = debug_search ("new thread " ^ int_to_string pid) 715 val file = ttt_code_dir ^ "/hammer" ^ int_to_string (!hammer_ref) 716 val thread = 717 Thread.fork (fn () => hammer_call pid (current_goal pid), []) 718 in 719 running_async := dadd pid thread (!running_async); 720 Array.update (async_result,pid,HRunning thread) 721 end 722 723fun open_async () = 724 if dlength (!running_async) < !ttt_eprover_async 725 then 726 let 727 val n = dlength (!running_async) 728 val m = length (filter (Thread.isActive o snd) 729 (dlist (!running_async))) 730 in 731 debug_search (int_to_string n ^ " running thread"); 732 debug_search (int_to_string m ^ " active thread"); 733 fork_hammer () 734 end 735 else () 736 737(* Closing all successfull threads in increasing order of pid *) 738 739fun close_async () = 740 let 741 val pidl = dkeys (!running_async) 742 fun f pid = case Array.sub (async_result,pid) of 743 HSuccess(stac,g) => 744 ( 745 debug_search ("success thread " ^ int_to_string pid); 746 running_async := drem pid (!running_async); 747 Array.update (async_result,pid,HVoid); 748 if is_active pid andalso current_goal pid = g 749 then close_proof_wrap (SOME stac) 0.0 pid 750 else () 751 ) 752 | HFailure => 753 ( 754 debug_search ("failure thread " ^ int_to_string pid); 755 Array.update (async_result,pid,HVoid); 756 running_async := drem pid (!running_async) 757 ) 758 | _ => () 759 in 760 app f pidl 761 end 762 763(* --------------------------------------------------------------------------- 764 Search function. Modifies the proof state. 765 -------------------------------------------------------------------------- *) 766 767fun init_search thmpred tacpred glpred hammer g = 768 ( 769 (* async *) 770 init_async (); 771 (* global time-out *) 772 glob_timer := SOME (Timer.startRealTimer ()); 773 (* caching *) 774 stacgoal_cache := dempty (cpl_compare String.compare goal_compare); 775 thml_dict := dempty (cpl_compare goal_compare Int.compare); 776 inst_dict := dempty (cpl_compare String.compare goal_compare); 777 tac_dict := dempty String.compare; 778 (* proof states *) 779 pid_counter := 0; 780 notactivedict := dempty Int.compare; 781 proofdict := dempty Int.compare; 782 (* easier access to values *) 783 tacpredictor_glob := tactimer tacpred; 784 thmpredictor_glob := thmtimer thmpred; 785 glpredictor_glob := gltimer glpred; 786 hammer_glob := hammer; 787 (* statistics *) 788 reset_timers (); 789 stac_counter := 0; 790 max_depth_mem := 0 791 ) 792 793fun get_next_pred pid = 794 let 795 val _ = debug_search "get_next_pred" 796 val prec = dfind pid (!proofdict) 797 in 798 if null (!(#pending prec)) then () else 799 let 800 val gn = hd (!(#pending prec)) 801 val pred = Array.sub (#predarr prec, gn) 802 in 803 if null pred orelse null (tl pred) 804 then deactivate pid 805 else Array.update (#predarr prec, gn, tl pred) 806 end 807 end 808 809fun node_find () = 810 let 811 val _ = debug_search "node_find" 812 val l0 = filter (fn x => is_active (fst x)) (dlist (!proofdict)) 813 (* also deactivate node with empty predictions *) 814 val l1 = filter (fn x => not (has_empty_pred (fst x))) l0 815 val _ = if !ttt_eprover_flag then (close_async (); open_async ()) else () 816 val l2 = if !ttt_eprover_flag 817 then filter (fn x => is_active (fst x)) l1 818 else l1 819 val _ = if null l2 then (debug_search "nonexttac"; raise NoNextTac) else () 820 in 821 try_mc_find () 822 end 823 824 825fun search_step () = 826 let 827 val (pid,pripol) = node_find_timer node_find () 828 val prec = dfind pid (!proofdict) 829 val trydict = #trydict prec 830 val (glo,tactime) = add_time apply_next_stac pid 831 fun f0 () = (backup_fail pid; get_next_pred pid) 832 fun f1 gl = 833 if gl = [] 834 then 835 (backup_success pid; 836 close_proof_wrap NONE tactime pid) 837 else 838 ( 839 trydict := dadd gl () (!trydict); 840 let val cid = 841 node_create_timer (node_create_gl pripol tactime gl) pid 842 in 843 backup cid; get_next_pred pid 844 end 845 ) 846 in 847 case glo of 848 NONE => f0 () 849 | SOME gl => f1 gl 850 end 851 852datatype proof_status_t = 853 ProofError | ProofSaturated | ProofTimeOut | Proof of string 854 855fun search_loop () = 856 ( 857 if Timer.checkRealTimer (valOf (!glob_timer)) > (!ttt_search_time) 858 then ProofTimeOut 859 else (search_step (); debug_search "search step"; search_loop ()) 860 ) 861 handle NoNextTac => (debug "proof: saturated"; ProofSaturated) 862 | SearchTimeOut => (debug "proof: timeout"; ProofTimeOut) 863 | ProofFound => (debug "proof: found"; Proof "") 864 | e => raise e 865 866fun proofl_of pid = 867 let 868 val prec = dfind pid (!proofdict) handle _ => debug_err "proofl_of" 869 fun compare_gn ((gn1,_,_),(gn2,_,_)) = Int.compare (gn1,gn2) 870 val proofl = !(#proofl prec) 871 val new_proofl = dict_sort compare_gn proofl 872 fun f (gn,stac,cid) = 873 let 874 val g = Array.sub (#goalarr prec, gn) 875 val contl = proofl_of cid 876 val tac = Tactic (stac,g) 877 in 878 if null contl then tac 879 else if List.length contl = 1 then Then (tac, hd contl) 880 else Thenl (tac, contl) 881 end 882 in 883 map f new_proofl 884 end 885 886fun end_search () = 887 ( 888 debug_proof ("Statistics"); 889 debug_proof (" infstep : " ^ int_to_string (!stac_counter)); 890 debug_proof (" nodes : " ^ int_to_string (!pid_counter)); 891 debug_proof (" maxdepth: " ^ int_to_string (!max_depth_mem)); 892 debug_proof ("Time: " ^ Real.toString (!tot_time)); 893 debug_proof (" inferstep: " ^ Real.toString (!infstep_time)); 894 debug_proof (" node_find: " ^ Real.toString (!node_find_time)); 895 debug_proof (" node_crea: " ^ Real.toString (!node_create_time)); 896 debug_proof (" thminst : " ^ Real.toString (!inst_time)); 897 debug_proof (" tacpred : " ^ Real.toString (!tactime)); 898 debug_proof (" thmpred : " ^ Real.toString (!thmtime)); 899 debug_proof (" glpred : " ^ Real.toString (!gltime)); 900 proofdict := dempty Int.compare; 901 tac_dict := dempty String.compare; 902 inst_dict := dempty (cpl_compare String.compare goal_compare); 903 stacgoal_cache := dempty (cpl_compare String.compare goal_compare) 904 ) 905 906(* --------------------------------------------------------------------------- 907 Self learning 908 -------------------------------------------------------------------------- *) 909 910fun selflearn_aux proof = case proof of 911 Tactic (stac,g) => 912 ( 913 let 914 val ((gl,_),t) = add_time (tactic_of_sml stac) g 915 val lbl = (stac,t,g,gl) 916 in 917 update_tacdata lbl 918 end 919 handle _ => debug_search ("Error: selflearn: " ^ stac) 920 ) 921 | Then (p1,p2) => (selflearn_aux p1; selflearn_aux p2) 922 | Thenl (p,pl) => (selflearn_aux p; app selflearn_aux pl) 923 924fun string_of_proof proof = case proof of 925 Tactic (stac,g) => stac 926 | Then (p1,p2) => string_of_proof p1 ^ " THEN " ^ string_of_proof p2 927 | Thenl (p,pl) => string_of_proof p ^ " THENL " ^ 928 String.concatWith " " (map string_of_proof pl) 929 930fun selflearn proof = 931 if !ttt_selflearn_flag 932 then debug_t "selflearn" selflearn_aux proof 933 else () 934 935(* --------------------------------------------------------------------------- 936 Main 937 -------------------------------------------------------------------------- *) 938 939fun search thmpred tacpred glpred hammer goal = 940 ( 941 init_search thmpred tacpred glpred hammer goal; 942 total_timer (node_create_timer root_create_wrap) goal; 943 let 944 val r = total_timer search_loop () 945 val _ = debug_search "End search loop" 946 val _ = terminate_async () 947 val _ = debug_search "After termination" 948 val proof_status = case r of 949 Proof _ => 950 let 951 val proofl = proofl_of 0 handle _ => debug_err "SNH0" 952 val proof0 = hd proofl handle Empty => debug_err "SNH1" 953 val _ = selflearn proof0 954 val proof1 = debug_t "minimize" minimize_proof proof0 955 val sproof = debug_t "reconstruct" reconstruct goal proof1 956 in 957 Proof sproof 958 end 959 | _ => r 960 in 961 end_search (); 962 proof_status 963 end 964 ) 965 966end (* struct *) 967