1(* =========================================================================  *)
2(* FILE          : tttPredictor.sml                                           *)
3(* DESCRIPTION   : Predictions of tactics, theorems, terms and lists of goals *)
4(* AUTHOR        : (c) Thibault Gauthier, University of Innsbruck             *)
5(* DATE          : 2018                                                       *)
6(* ========================================================================== *)
7
8structure tttPredict :> tttPredict =
9struct
10
11open HolKernel Abbrev tttTools tttSetup tttFeature tttExec
12
13val ERR = mk_HOL_ERR "tttPredict"
14
15(* --------------------------------------------------------------------------
16   TFIDF: weight of symbols (power of 6 comes from the distance)
17   -------------------------------------------------------------------------- *)
18
19fun weight_tfidf symsl =
20  let
21    val syms      = List.concat symsl
22    val dict      = count_dict (dempty Int.compare) syms
23    val n         = length symsl
24    fun f (fea,freq) =
25      Math.pow (Math.ln (Real.fromInt n) - Math.ln (Real.fromInt freq), 6.0)
26  in
27    Redblackmap.map f dict
28  end
29
30fun learn_tfidf feavl = weight_tfidf (map snd feavl)
31
32(* --------------------------------------------------------------------------
33   Distance
34   -------------------------------------------------------------------------- *)
35
36fun inter_dict dict l = filter (fn x => dmem x dict) l
37fun union_dict dict l = dkeys (daddl (map (fn x => (x,())) l) dict)
38
39val random_gen = Random.newgen ()
40
41fun knn_sim1 symweight dict_o fea_p =
42  if !ttt_randdist_flag
43    then Random.random random_gen
44  else
45    let
46      val fea_i   = inter_dict dict_o fea_p
47      fun wf n    = dfind_err "knn_sim1" n symweight
48      val weightl = map wf fea_i
49    in
50      sum_real weightl
51    end
52
53fun knn_sim2 symweight dict_o fea_p =
54  if !ttt_randdist_flag
55    then Random.random random_gen
56  else
57    let
58      val fea_i    = inter_dict dict_o fea_p
59      fun wf n     = dfind_err "knn_sim2" n symweight
60      val weightl  = map wf fea_i
61      val tot      = Real.fromInt (dlength dict_o + length fea_p)
62    in
63      sum_real weightl / Math.ln (Math.e + tot)
64    end
65
66fun knn_sim3 symweight dict_o fea_p =
67  if !ttt_randdist_flag
68    then Random.random random_gen
69  else
70    let
71      val feai     = inter_dict dict_o fea_p
72      val feau     = union_dict dict_o fea_p
73      fun wf n     = dfind n symweight handle _ => 0.0
74      val weightli = map wf feai
75      val weightlu = map wf feau
76    in
77      sum_real weightli / (sum_real weightlu + 1.0)
78    end
79
80(* --------------------------------------------------------------------------
81   Ordering prediction with duplicates
82   -------------------------------------------------------------------------- *)
83
84fun compare_score ((_,x),(_,y)) = Real.compare (y,x)
85
86fun pre_pred dist symweight (feal: ('a * int list) list) fea_o =
87  let
88    val dict_o = dnew Int.compare (map (fn x => (x,())) fea_o)
89    fun f (lbl,fea) = (lbl, dist symweight dict_o fea)
90    val l0 = map f feal
91    val l1 = dict_sort compare_score l0
92  in
93    l1
94  end
95
96fun pre_sim1 symweight feal fea_o = pre_pred knn_sim1 symweight feal fea_o
97fun pre_sim2 symweight feal fea_o = pre_pred knn_sim2 symweight feal fea_o
98fun pre_sim3 symweight feal fea_o = pre_pred knn_sim3 symweight feal fea_o
99
100(* --------------------------------------------------------------------------
101   Tactic predictions
102   -------------------------------------------------------------------------- *)
103
104(* used for preselection *)
105fun stacknn symweight n feal fea_o =
106  let
107    val l1 = map fst (pre_sim1 symweight feal fea_o)
108    fun coverage x = dfind x (!ttt_taccov) handle _ => 0
109    fun compare_coverage (lbl1,lbl2) =
110      Int.compare (coverage (#1 lbl2), coverage (#1 lbl1))
111    val l1' =
112      if !ttt_covdist_flag
113      then dict_sort compare_coverage l1
114      else l1
115    val l2 = mk_sameorder_set lbl_compare l1'
116  in
117    first_n n l2
118  end
119
120(* used during search *)
121fun stacknn_uniq symweight n feal fea_o =
122  let
123    val l = stacknn symweight n feal fea_o
124    fun f (lbl1,lbl2) = String.compare (#1 lbl1, #1 lbl2)
125  in
126    mk_sameorder_set f l
127  end
128
129(* --------------------------------------------------------------------------
130   Theorem predictions
131   -------------------------------------------------------------------------- *)
132
133fun exists_tid s =
134  let val (a,b) = split_string "Theory." s in
135    a = namespace_tag orelse
136    can (DB.fetch a) b
137  end
138
139fun thmknn (symweight,feav) n fea_o =
140  let
141    val l1 = map fst (pre_sim1 symweight feav fea_o)
142    val l2 = mk_sameorder_set String.compare l1
143  in
144    first_test_n exists_tid n l2
145  end
146
147val add_fea_cache = ref (dempty goal_compare)
148
149fun add_fea dict (name,thm) =
150  let val g = dest_thm thm in
151    if not (dmem g (!dict)) andalso uptodate_thm thm
152    then
153      let
154        val fea = dfind g (!add_fea_cache)
155          handle NotFound =>
156            let val fea' = fea_of_goal g in
157              add_fea_cache := dadd g fea' (!add_fea_cache);
158              fea'
159            end
160      in
161        dict := dadd g (name,fea) (!dict)
162      end
163    else ()
164  end
165
166fun insert_namespace thmdict =
167  let
168    val dict = ref thmdict
169    fun f (x,y) = (namespace_tag ^ "Theory." ^ x, y)
170    val l1 = namespace_thms ()
171    val l2 = map f l1
172  in
173    app (add_fea dict) l2;
174    (!dict)
175  end
176
177fun all_thmfeav () =
178  let
179    val newdict =
180      if !ttt_namespacethm_flag
181      then insert_namespace (!ttt_thmfea)
182      else (!ttt_thmfea)
183    val feav = map snd (dlist newdict)
184    fun f (g,(name,fea)) = (name,(g,fea))
185    val revdict = dnew String.compare (map f (dlist newdict))
186    val symweight = learn_tfidf feav
187  in
188    (symweight,feav,revdict)
189  end
190
191fun thmknn_std n goal =
192  let val (symweight,feav, _) = all_thmfeav () in
193    thmknn (symweight,feav) n (fea_of_goal goal)
194  end
195
196(* ----------------------------------------------------------------------
197   Adding theorem dependencies in the predictions
198   ---------------------------------------------------------------------- *)
199
200fun uptodate_tid s =
201  let val (a,b) = split_string "Theory." s in
202    a = namespace_tag orelse uptodate_thm (DB.fetch a b)
203  end
204
205(* Uptodate-ness is probably already verified elsewhere *)
206fun add_thmdep revdict n l0 =
207  let
208    fun f1 x = x :: deplPartial_of_sthm x
209    val l1 = mk_sameorder_set String.compare (List.concat (map f1 l0))
210    fun f2 x = exists_tid x andalso uptodate_tid x andalso dmem x revdict
211  in
212    first_test_n f2 n l1
213  end
214
215fun thmknn_wdep (symweight,feav,revdict) n gfea =
216  let val l0 = thmknn (symweight,feav) n gfea in
217    add_thmdep revdict n l0
218  end
219
220(* ----------------------------------------------------------------------
221   Adding stac descendants. Should be modified to work on labels instead.
222 ---------------------------------------------------------------------- *)
223
224(* includes itself *)
225fun desc_lbl_aux rlist rdict ddict (lbl as (stac,_,_,gl)) =
226  (
227  rlist := lbl :: (!rlist);
228  if dmem lbl rdict
229    then () (* debug ("Warning: descendant_of_feav: " ^ stac) *)
230    else
231      let
232        val new_rdict = dadd lbl () rdict
233        fun f g =
234          let val lbls = dfind g ddict handle _ => [] in
235            app (desc_lbl_aux rlist new_rdict ddict) lbls
236          end
237      in
238        app f gl
239      end
240  )
241
242fun desc_lbl ddict lbl =
243  let val rlist = ref [] in
244    desc_lbl_aux rlist (dempty lbl_compare) ddict lbl;
245    !rlist
246  end
247
248fun add_stacdesc ddict n l =
249   let
250     val l1 = List.concat (map (desc_lbl ddict) l)
251     val l2 = mk_sameorder_set lbl_compare l1
252   in
253     first_n n l2
254   end
255
256(* --------------------------------------------------------------------------
257   Term prediction.
258   Relies on mdict_glob to calculate symweight.
259   Predicts everything but the term itself.
260   -------------------------------------------------------------------------- *)
261
262fun termknn n ((asl,w):goal) term =
263  let
264    fun not_term tm = tm <> term
265    fun f x = (x, fea_of_goal ([],x))
266    val l0 = List.concat (map (rev o find_terms not_term) (w :: asl))
267    val l1 = mk_sameorder_set Term.compare l0
268    val thmfeav = map (snd o snd) (dlist (!ttt_thmfea))
269    val feal = map f l1
270    val fea_o = tttFeature.fea_of_goal ([],term)
271    val symweight = weight_tfidf (fea_o :: (map snd feal) @ thmfeav)
272    val pre_sim = case !ttt_termarg_pint of
273      1 => pre_sim1 | 2 => pre_sim2 | 3 => pre_sim3 | _ => pre_sim2
274    val l3 = pre_sim symweight feal fea_o
275  in
276    first_n n (map fst l3)
277  end
278
279(* --------------------------------------------------------------------------
280   Term prediction for conjecturing experiments.
281   Todo: add dependencies between deps.
282   -------------------------------------------------------------------------- *)
283
284fun tmknn n (symweight,tmfea) fea_o =
285  let val l = pre_sim1 symweight tmfea fea_o in
286    first_n n (map fst l)
287  end
288
289(* --------------------------------------------------------------------------
290   Goal list prediction.
291   -------------------------------------------------------------------------- *)
292
293fun mcknn symweight radius feal fea =
294  let
295    val pre_sim = case !ttt_mcev_pint of
296      1 => pre_sim1 | 2 => pre_sim2 | 3 => pre_sim3 | _ => pre_sim2
297    val bnl = map fst (first_n radius (pre_sim symweight feal fea))
298    fun ispos (b,n) = b
299    fun isneg (b,n) = not b
300    fun posf bnl = length (filter ispos bnl)
301    fun negf bnl = length (filter isneg bnl)
302    fun proba bnl =
303      let
304        val pos = Real.fromInt (posf bnl)
305        val neg = Real.fromInt (negf bnl)
306      in
307        pos / (neg + pos)
308      end
309  in
310    if null bnl then 0.0 else proba bnl
311  end
312
313end (* struct *)
314