1(* ========================================================================= *) 2(* FILE : tttPredictor.sml *) 3(* DESCRIPTION : Predictions of tactics, theorems, terms and lists of goals *) 4(* AUTHOR : (c) Thibault Gauthier, University of Innsbruck *) 5(* DATE : 2018 *) 6(* ========================================================================== *) 7 8structure tttPredict :> tttPredict = 9struct 10 11open HolKernel Abbrev tttTools tttSetup tttFeature tttExec 12 13val ERR = mk_HOL_ERR "tttPredict" 14 15(* -------------------------------------------------------------------------- 16 TFIDF: weight of symbols (power of 6 comes from the distance) 17 -------------------------------------------------------------------------- *) 18 19fun weight_tfidf symsl = 20 let 21 val syms = List.concat symsl 22 val dict = count_dict (dempty Int.compare) syms 23 val n = length symsl 24 fun f (fea,freq) = 25 Math.pow (Math.ln (Real.fromInt n) - Math.ln (Real.fromInt freq), 6.0) 26 in 27 Redblackmap.map f dict 28 end 29 30fun learn_tfidf feavl = weight_tfidf (map snd feavl) 31 32(* -------------------------------------------------------------------------- 33 Distance 34 -------------------------------------------------------------------------- *) 35 36fun inter_dict dict l = filter (fn x => dmem x dict) l 37fun union_dict dict l = dkeys (daddl (map (fn x => (x,())) l) dict) 38 39val random_gen = Random.newgen () 40 41fun knn_sim1 symweight dict_o fea_p = 42 if !ttt_randdist_flag 43 then Random.random random_gen 44 else 45 let 46 val fea_i = inter_dict dict_o fea_p 47 fun wf n = dfind_err "knn_sim1" n symweight 48 val weightl = map wf fea_i 49 in 50 sum_real weightl 51 end 52 53fun knn_sim2 symweight dict_o fea_p = 54 if !ttt_randdist_flag 55 then Random.random random_gen 56 else 57 let 58 val fea_i = inter_dict dict_o fea_p 59 fun wf n = dfind_err "knn_sim2" n symweight 60 val weightl = map wf fea_i 61 val tot = Real.fromInt (dlength dict_o + length fea_p) 62 in 63 sum_real weightl / Math.ln (Math.e + tot) 64 end 65 66fun knn_sim3 symweight dict_o fea_p = 67 if !ttt_randdist_flag 68 then Random.random random_gen 69 else 70 let 71 val feai = inter_dict dict_o fea_p 72 val feau = union_dict dict_o fea_p 73 fun wf n = dfind n symweight handle _ => 0.0 74 val weightli = map wf feai 75 val weightlu = map wf feau 76 in 77 sum_real weightli / (sum_real weightlu + 1.0) 78 end 79 80(* -------------------------------------------------------------------------- 81 Ordering prediction with duplicates 82 -------------------------------------------------------------------------- *) 83 84fun compare_score ((_,x),(_,y)) = Real.compare (y,x) 85 86fun pre_pred dist symweight (feal: ('a * int list) list) fea_o = 87 let 88 val dict_o = dnew Int.compare (map (fn x => (x,())) fea_o) 89 fun f (lbl,fea) = (lbl, dist symweight dict_o fea) 90 val l0 = map f feal 91 val l1 = dict_sort compare_score l0 92 in 93 l1 94 end 95 96fun pre_sim1 symweight feal fea_o = pre_pred knn_sim1 symweight feal fea_o 97fun pre_sim2 symweight feal fea_o = pre_pred knn_sim2 symweight feal fea_o 98fun pre_sim3 symweight feal fea_o = pre_pred knn_sim3 symweight feal fea_o 99 100(* -------------------------------------------------------------------------- 101 Tactic predictions 102 -------------------------------------------------------------------------- *) 103 104(* used for preselection *) 105fun stacknn symweight n feal fea_o = 106 let 107 val l1 = map fst (pre_sim1 symweight feal fea_o) 108 fun coverage x = dfind x (!ttt_taccov) handle _ => 0 109 fun compare_coverage (lbl1,lbl2) = 110 Int.compare (coverage (#1 lbl2), coverage (#1 lbl1)) 111 val l1' = 112 if !ttt_covdist_flag 113 then dict_sort compare_coverage l1 114 else l1 115 val l2 = mk_sameorder_set lbl_compare l1' 116 in 117 first_n n l2 118 end 119 120(* used during search *) 121fun stacknn_uniq symweight n feal fea_o = 122 let 123 val l = stacknn symweight n feal fea_o 124 fun f (lbl1,lbl2) = String.compare (#1 lbl1, #1 lbl2) 125 in 126 mk_sameorder_set f l 127 end 128 129(* -------------------------------------------------------------------------- 130 Theorem predictions 131 -------------------------------------------------------------------------- *) 132 133fun exists_tid s = 134 let val (a,b) = split_string "Theory." s in 135 a = namespace_tag orelse 136 can (DB.fetch a) b 137 end 138 139fun thmknn (symweight,feav) n fea_o = 140 let 141 val l1 = map fst (pre_sim1 symweight feav fea_o) 142 val l2 = mk_sameorder_set String.compare l1 143 in 144 first_test_n exists_tid n l2 145 end 146 147val add_fea_cache = ref (dempty goal_compare) 148 149fun add_fea dict (name,thm) = 150 let val g = dest_thm thm in 151 if not (dmem g (!dict)) andalso uptodate_thm thm 152 then 153 let 154 val fea = dfind g (!add_fea_cache) 155 handle NotFound => 156 let val fea' = fea_of_goal g in 157 add_fea_cache := dadd g fea' (!add_fea_cache); 158 fea' 159 end 160 in 161 dict := dadd g (name,fea) (!dict) 162 end 163 else () 164 end 165 166fun insert_namespace thmdict = 167 let 168 val dict = ref thmdict 169 fun f (x,y) = (namespace_tag ^ "Theory." ^ x, y) 170 val l1 = namespace_thms () 171 val l2 = map f l1 172 in 173 app (add_fea dict) l2; 174 (!dict) 175 end 176 177fun all_thmfeav () = 178 let 179 val newdict = 180 if !ttt_namespacethm_flag 181 then insert_namespace (!ttt_thmfea) 182 else (!ttt_thmfea) 183 val feav = map snd (dlist newdict) 184 fun f (g,(name,fea)) = (name,(g,fea)) 185 val revdict = dnew String.compare (map f (dlist newdict)) 186 val symweight = learn_tfidf feav 187 in 188 (symweight,feav,revdict) 189 end 190 191fun thmknn_std n goal = 192 let val (symweight,feav, _) = all_thmfeav () in 193 thmknn (symweight,feav) n (fea_of_goal goal) 194 end 195 196(* ---------------------------------------------------------------------- 197 Adding theorem dependencies in the predictions 198 ---------------------------------------------------------------------- *) 199 200fun uptodate_tid s = 201 let val (a,b) = split_string "Theory." s in 202 a = namespace_tag orelse uptodate_thm (DB.fetch a b) 203 end 204 205(* Uptodate-ness is probably already verified elsewhere *) 206fun add_thmdep revdict n l0 = 207 let 208 fun f1 x = x :: deplPartial_of_sthm x 209 val l1 = mk_sameorder_set String.compare (List.concat (map f1 l0)) 210 fun f2 x = exists_tid x andalso uptodate_tid x andalso dmem x revdict 211 in 212 first_test_n f2 n l1 213 end 214 215fun thmknn_wdep (symweight,feav,revdict) n gfea = 216 let val l0 = thmknn (symweight,feav) n gfea in 217 add_thmdep revdict n l0 218 end 219 220(* ---------------------------------------------------------------------- 221 Adding stac descendants. Should be modified to work on labels instead. 222 ---------------------------------------------------------------------- *) 223 224(* includes itself *) 225fun desc_lbl_aux rlist rdict ddict (lbl as (stac,_,_,gl)) = 226 ( 227 rlist := lbl :: (!rlist); 228 if dmem lbl rdict 229 then () (* debug ("Warning: descendant_of_feav: " ^ stac) *) 230 else 231 let 232 val new_rdict = dadd lbl () rdict 233 fun f g = 234 let val lbls = dfind g ddict handle _ => [] in 235 app (desc_lbl_aux rlist new_rdict ddict) lbls 236 end 237 in 238 app f gl 239 end 240 ) 241 242fun desc_lbl ddict lbl = 243 let val rlist = ref [] in 244 desc_lbl_aux rlist (dempty lbl_compare) ddict lbl; 245 !rlist 246 end 247 248fun add_stacdesc ddict n l = 249 let 250 val l1 = List.concat (map (desc_lbl ddict) l) 251 val l2 = mk_sameorder_set lbl_compare l1 252 in 253 first_n n l2 254 end 255 256(* -------------------------------------------------------------------------- 257 Term prediction. 258 Relies on mdict_glob to calculate symweight. 259 Predicts everything but the term itself. 260 -------------------------------------------------------------------------- *) 261 262fun termknn n ((asl,w):goal) term = 263 let 264 fun not_term tm = tm <> term 265 fun f x = (x, fea_of_goal ([],x)) 266 val l0 = List.concat (map (rev o find_terms not_term) (w :: asl)) 267 val l1 = mk_sameorder_set Term.compare l0 268 val thmfeav = map (snd o snd) (dlist (!ttt_thmfea)) 269 val feal = map f l1 270 val fea_o = tttFeature.fea_of_goal ([],term) 271 val symweight = weight_tfidf (fea_o :: (map snd feal) @ thmfeav) 272 val pre_sim = case !ttt_termarg_pint of 273 1 => pre_sim1 | 2 => pre_sim2 | 3 => pre_sim3 | _ => pre_sim2 274 val l3 = pre_sim symweight feal fea_o 275 in 276 first_n n (map fst l3) 277 end 278 279(* -------------------------------------------------------------------------- 280 Term prediction for conjecturing experiments. 281 Todo: add dependencies between deps. 282 -------------------------------------------------------------------------- *) 283 284fun tmknn n (symweight,tmfea) fea_o = 285 let val l = pre_sim1 symweight tmfea fea_o in 286 first_n n (map fst l) 287 end 288 289(* -------------------------------------------------------------------------- 290 Goal list prediction. 291 -------------------------------------------------------------------------- *) 292 293fun mcknn symweight radius feal fea = 294 let 295 val pre_sim = case !ttt_mcev_pint of 296 1 => pre_sim1 | 2 => pre_sim2 | 3 => pre_sim3 | _ => pre_sim2 297 val bnl = map fst (first_n radius (pre_sim symweight feal fea)) 298 fun ispos (b,n) = b 299 fun isneg (b,n) = not b 300 fun posf bnl = length (filter ispos bnl) 301 fun negf bnl = length (filter isneg bnl) 302 fun proba bnl = 303 let 304 val pos = Real.fromInt (posf bnl) 305 val neg = Real.fromInt (negf bnl) 306 in 307 pos / (neg + pos) 308 end 309 in 310 if null bnl then 0.0 else proba bnl 311 end 312 313end (* struct *) 314