1(* ======================================================================== *) 2(* FILE : mlNearestNeighbor.sml *) 3(* DESCRIPTION : Predictions of tactics, theorems, terms *) 4(* AUTHOR : (c) Thibault Gauthier, University of Innsbruck *) 5(* DATE : 2018 *) 6(* ========================================================================= *) 7 8structure mlNearestNeighbor :> mlNearestNeighbor = 9struct 10 11open HolKernel Abbrev aiLib mlFeature mlThmData mlTacticData 12 13val ERR = mk_HOL_ERR "mlNearestNeighbor" 14 15type symweight = (int, real) Redblackmap.dict 16type 'a afea = ('a * fea) list 17val inter_time = ref 0.0 18val dfind_time = ref 0.0 19val sum_time = ref 0.0 20 21(* ------------------------------------------------------------------------ 22 Distance 23 ------------------------------------------------------------------------ *) 24 25fun knn_dist symweight feao feap = 26 let 27 val feai = total_time inter_time inter_increasing feao feap 28 fun wf n = dfind n symweight handle NotFound => raise ERR "knn_dist" "" 29 val weightl = total_time dfind_time (map wf) feai 30 in 31 total_time sum_time sum_real weightl 32 end 33 34(* ------------------------------------------------------------------------ 35 Sorting feature vectors according to the distance 36 ------------------------------------------------------------------------ *) 37 38fun knn_sortu cmp n (symweight,feav) feao = 39 let 40 fun g x = SOME (x, dfind x symweight) handle NotFound => NONE 41 val feaosymweight = dnew Int.compare (List.mapPartial g feao) 42 fun f (x,feap) = (x, knn_dist feaosymweight feao feap) 43 in 44 best_n_rmaxu cmp n (map f feav) 45 end 46 47(* ------------------------------------------------------------------------ 48 Term predictions 49 ------------------------------------------------------------------------ *) 50 51fun termknn (symweight,termfea) n fea = 52 knn_sortu Term.compare n (symweight,termfea) fea 53 54(* ------------------------------------------------------------------------ 55 Theorem predictions 56 ------------------------------------------------------------------------ *) 57 58fun thmknn (symweight,thmfea) n fea = 59 knn_sortu String.compare n (symweight,thmfea) fea 60 61(* ---------------------------------------------------------------------- 62 Adding theorem dependencies 63 ---------------------------------------------------------------------- *) 64 65fun add_thmdep n predl = 66 let 67 fun f pred = pred :: validdep_of_thmid pred 68 val predl0 = List.concat (map f predl) 69 val predl1 = mk_sameorder_set String.compare predl0 70 in 71 first_n n predl1 72 end 73 74fun thmknn_wdep (symweight,feavdict) n fea = 75 add_thmdep n (thmknn (symweight,feavdict) n fea) 76 77(* ------------------------------------------------------------------------ 78 Tactic predictions 79 ------------------------------------------------------------------------ *) 80 81fun tacknn (symweight,tacfea) n fea = 82 knn_sortu String.compare n (symweight,tacfea) fea 83 84fun callknn (symweight,callfea) n fea = 85 knn_sortu (snd_compare call_compare) n (symweight,callfea) fea 86 87(* ---------------------------------------------------------------------- 88 Adding tactic dependencies 89 --------------------------------------------------------------------- *) 90 91fun dep_call_g rl lookup loopd gn = 92 if HOLset.member (loopd,gn) orelse not (can lookup gn) then () else 93 let 94 val newloopd = HOLset.add (loopd,gn) 95 val (loc,call) = lookup gn 96 val _ = rl := (loc,call) :: (!rl) 97 val gnl = #ogl call 98 in 99 app (dep_call_g rl lookup newloopd) gnl 100 end 101 102fun dep_call calld ((thy,thmn,gn),{stac,ogl,fea}) = 103 let 104 val rl = ref [] 105 fun lookup x = ((thy,thmn,x), dfind (thy,thmn,x) calld) 106 val loopd = HOLset.fromList Int.compare [gn] 107 in 108 app (dep_call_g rl lookup loopd) ogl; 109 mk_sameorder_set (snd_compare call_compare) (rev (!rl)) 110 end 111 112fun add_calldep calld n calls = 113 let 114 val l1 = List.concat (map (fn x => x :: dep_call calld x) calls) 115 val l2 = mk_sameorder_set (snd_compare call_compare) l1 116 in 117 first_n n l2 118 end 119 120(* ---------------------------------------------------------------------- 121 Training from a dataset of term-value pairs for comparison with 122 tree neural networks. 123 --------------------------------------------------------------------- *) 124 125type 'a knnpred = (symweight * term afea) * (term, 'a) Redblackmap.dict 126 127fun train_knn trainset = 128 let 129 val termfea = map_assoc (fea_of_term true) (map fst trainset); 130 val symweight = learn_tfidf termfea; 131 in 132 ((symweight,termfea), dnew Term.compare trainset) 133 end 134 135fun infer_knn ((symweight,termfea),d) tm = 136 let val neartm = hd (termknn (symweight,termfea) 1 (fea_of_term true tm)) in 137 dfind neartm d 138 end 139 140fun is_accurate_knn knnpred (tm,rlo) = 141 let 142 val rl1 = infer_knn knnpred tm 143 val rl2 = combine (rl1, rlo) 144 fun test (x,y) = Real.abs (x - y) < 0.5 145 in 146 if all test rl2 then true else false 147 end 148 149fun knn_accuracy knnpred exset = 150 let val correct = filter I (map (is_accurate_knn knnpred) exset) in 151 Real.fromInt (length correct) / Real.fromInt (length exset) 152 end 153 154 155end (* struct *) 156