1(* ========================================================================= *)
2(* FILE          : tacticToe.sml                                             *)
3(* DESCRIPTION   : Automated theorem prover based on tactic selection        *)
4(* AUTHOR        : (c) Thibault Gauthier, University of Innsbruck            *)
5(* DATE          : 2017                                                      *)
6(* ========================================================================= *)
7
8structure tacticToe :> tacticToe =
9struct
10
11open HolKernel Abbrev boolLib aiLib
12  smlLexer smlParser smlExecute smlRedirect smlInfix
13  mlFeature mlThmData mlTacticData mlNearestNeighbor mlTreeNeuralNetwork
14  psMinimize tttSetup tttToken tttLearn tttSearch
15
16val ERR = mk_HOL_ERR "tacticToe"
17
18(* -------------------------------------------------------------------------
19   Global parameters
20   ------------------------------------------------------------------------- *)
21
22fun set_timeout r = (ttt_search_time := r)
23val metis_stac = "metisTools.METIS_TAC " ^ thmlarg_placeholder
24val prioritize_stacl = ref [metis_stac]
25
26(* -------------------------------------------------------------------------
27   Preparsing theorems and tactics
28   ------------------------------------------------------------------------- *)
29
30fun thml_of_thmidl thmidl = thml_of_sml (map dbfetch_of_thmid thmidl)
31
32fun preparse_thmidl thmidl = case thml_of_thmidl thmidl of
33    NONE =>
34    if is_singleton thmidl
35    then (print_endline ("Could not parse: " ^ singleton_of_list thmidl); [])
36    else
37      let val (l1,l2) = part_n (length thmidl div 2) thmidl in
38        (preparse_thmidl l1 @ preparse_thmidl l2)
39      end
40  | SOME rl => combine (thmidl,rl)
41
42fun preparse_stacl stacl = case pretacl_of_sml 1.0 stacl of
43    NONE =>
44    if is_singleton stacl
45    then (print_endline ("Could not parse: " ^ singleton_of_list stacl); [])
46    else
47      let val (l1,l2) = part_n (length stacl div 2) stacl in
48        (preparse_stacl l1 @ preparse_stacl l2)
49      end
50  | SOME rl => combine (stacl,rl)
51
52(* -------------------------------------------------------------------------
53   Preselection of theorems and tactics
54   ------------------------------------------------------------------------- *)
55
56fun select_thmfea (symweight,thmfea) gfea =
57  let
58    val l0 = thmknn_wdep (symweight,thmfea) (!ttt_presel_radius) gfea
59    val d = dset String.compare l0
60    val l1 = filter (fn (k,v) => dmem k d) thmfea
61  in
62    (symweight, l1)
63  end
64
65fun select_tacfea tacdata gfea =
66  let
67    val calld = #calld tacdata
68    val calls = dlist calld
69    val callfea = map_assoc (#fea o snd) calls
70    val symweight = learn_tfidf_symfreq_nofilter (dlength calld)
71      (#symfreq tacdata)
72    val sel1 = callknn (symweight,callfea) (!ttt_presel_radius) gfea
73    val sel2 = add_calldep calld (!ttt_presel_radius) sel1
74    val tacfea = map (fn (_,x) => (#stac x, #fea x)) sel2
75  in
76    (symweight,tacfea)
77  end
78
79(* -------------------------------------------------------------------------
80   Main function
81   ------------------------------------------------------------------------- *)
82
83fun build_searchobj (thmdata,tacdata) (vnno,pnno,anno) goal =
84  let
85    val _ = hidef QUse.use infix_file
86    (* preselection *)
87    val _ = print_endline "preselection"
88    val goalf = fea_of_goal true goal
89    val (thmsymweight,thmfea) = select_thmfea thmdata goalf
90    val (tacsymweight,tacfea) = select_tacfea tacdata goalf
91    (* parsing *)
92    val _ = print_endline "parsing"
93    val pstacl = preparse_stacl (!prioritize_stacl)
94    val thm_parse_dict = dnew String.compare
95      (preparse_thmidl (map fst thmfea))
96    val tac_parse_dict = dnew String.compare
97      (pstacl @ preparse_stacl (map fst tacfea))
98    fun parse_thmidl thmidl = map (fn x => dfind x thm_parse_dict) thmidl
99      handle NotFound =>
100        raise ERR "parse_thmidl" (String.concatWith " " thmidl)
101    fun parse_stac stac = dfind stac tac_parse_dict
102      handle NotFound => raise ERR "parse_stac" stac
103    val thmfea_filtered = filter (fn x => dmem (fst x) thm_parse_dict) thmfea
104    val tacfea_filtered = filter (fn x => dmem (fst x) tac_parse_dict) tacfea
105    val parsetoken =
106      {parse_stac = parse_stac,
107       parse_thmidl = parse_thmidl,
108       parse_sterm = fn x => [QUOTE x]}
109    (* predictors *)
110    val stacl_filtered = map fst pstacl @ map fst tacfea_filtered
111    val atyd = dnew String.compare (map_assoc extract_atyl stacl_filtered)
112    val thm_cache = ref (dempty goal_compare)
113    val tac_cache = ref (dempty goal_compare)
114    fun predthml g =
115      dfind g (!thm_cache) handle NotFound =>
116      let
117        val gfea = fea_of_goal true g
118        val thmidl = thmknn (thmsymweight,thmfea_filtered)
119          (!ttt_thmlarg_radius) gfea
120      in
121        thm_cache := dadd g thmidl (!thm_cache); thmidl
122      end
123    fun predarg stac aty g = case aty of
124        Athml =>
125        let val thml = predthml g in
126          if stac = metis_stac
127          then map Sthml (mk_batch_full (!ttt_metis_radius) thml)
128          else map Sthml (mk_batch_full 1 thml)
129        end
130      | Aterm => map Sterm (pred_svar 8 g)
131    fun predtac g =
132      dfind g (!tac_cache) handle NotFound =>
133      let
134        val gfea = fea_of_goal true g
135        val stacl1 = tacknn (tacsymweight,tacfea_filtered)
136          (!ttt_presel_radius) gfea
137        val stacl2 = mk_sameorder_set String.compare (map fst pstacl @ stacl1)
138        val stacl3 = map_assoc (fn x => dfind x atyd) stacl2
139      in
140        tac_cache := dadd g stacl3 (!tac_cache); stacl3
141      end
142  in
143    {predtac = predtac, predarg = predarg,
144     parsetoken = parsetoken,
145     vnno = vnno, pnno = pnno, anno = anno}
146  end
147
148fun main_tactictoe (thmdata,tacdata) nnol goal =
149  let val searchobj = build_searchobj (thmdata,tacdata) nnol goal in
150    print_endline "search"; search searchobj goal
151  end
152
153(* -------------------------------------------------------------------------
154   Return values
155   ------------------------------------------------------------------------- *)
156
157fun read_status status = case status of
158   ProofSaturated =>
159   (print_endline "saturated"; (NONE, FAIL_TAC "tactictoe: saturated"))
160 | ProofTimeout   =>
161   (print_endline "timeout"; (NONE, FAIL_TAC "tactictoe: timeout"))
162 | Proof s        =>
163   (print_endline ("  " ^ s);
164    (SOME s, hidef (tactic_of_sml (!ttt_search_time)) s))
165
166(* -------------------------------------------------------------------------
167   Interface
168   ------------------------------------------------------------------------- *)
169
170val ttt_tacdata_cache = ref (dempty (list_compare String.compare))
171fun clean_ttt_tacdata_cache () =
172  ttt_tacdata_cache := dempty (list_compare String.compare)
173
174fun has_boolty x = type_of x = bool
175fun has_boolty_goal goal = all has_boolty (snd goal :: fst goal)
176
177fun tactictoe_aux goal =
178  if not (has_boolty_goal goal)
179  then raise ERR "tactictoe" "type bool expected"
180  else
181  let
182    val cthyl = current_theory () :: ancestry (current_theory ())
183    val thmdata = hidef create_thmdata ()
184    val tacdata =
185      dfind cthyl (!ttt_tacdata_cache) handle NotFound =>
186      let val tacdata_aux = create_tacdata () in
187        ttt_tacdata_cache := dadd cthyl tacdata_aux (!ttt_tacdata_cache);
188        tacdata_aux
189      end
190    val (proofstatus,_) = hidef
191      (main_tactictoe (thmdata,tacdata) (NONE,NONE,NONE)) goal
192    val (staco,tac) = read_status proofstatus
193  in
194    tac
195  end
196
197fun ttt goal = (tactictoe_aux goal) goal
198
199fun tactictoe term =
200  let val goal = ([],term) in TAC_PROOF (goal, tactictoe_aux goal) end
201
202
203end (* struct *)
204