1(* ========================================================================  *)
2(* FILE          : mlNearestNeighbor.sml                                     *)
3(* DESCRIPTION   : Predictions of tactics, theorems, terms                   *)
4(* AUTHOR        : (c) Thibault Gauthier, University of Innsbruck            *)
5(* DATE          : 2018                                                      *)
6(* ========================================================================= *)
7
8structure mlNearestNeighbor :> mlNearestNeighbor =
9struct
10
11open HolKernel Abbrev aiLib mlFeature mlThmData mlTacticData
12
13val ERR = mk_HOL_ERR "mlNearestNeighbor"
14
15type symweight = (int, real) Redblackmap.dict
16type 'a afea = ('a * fea) list
17val inter_time = ref 0.0
18val dfind_time = ref 0.0
19val sum_time = ref 0.0
20
21(* ------------------------------------------------------------------------
22   Distance
23   ------------------------------------------------------------------------ *)
24
25fun knn_dist symweight feao feap =
26  let
27    val feai = total_time inter_time inter_increasing feao feap
28    fun wf n = dfind n symweight handle NotFound => raise ERR "knn_dist" ""
29    val weightl = total_time dfind_time (map wf) feai
30  in
31    total_time sum_time sum_real weightl
32  end
33
34(* ------------------------------------------------------------------------
35   Sorting feature vectors according to the distance
36   ------------------------------------------------------------------------ *)
37
38fun knn_sortu cmp n (symweight,feav) feao =
39  let
40    fun g x = SOME (x, dfind x symweight) handle NotFound => NONE
41    val feaosymweight = dnew Int.compare (List.mapPartial g feao)
42    fun f (x,feap) = (x, knn_dist feaosymweight feao feap)
43  in
44    best_n_rmaxu cmp n (map f feav)
45  end
46
47(* ------------------------------------------------------------------------
48   Term predictions
49   ------------------------------------------------------------------------ *)
50
51fun termknn (symweight,termfea) n fea =
52  knn_sortu Term.compare n (symweight,termfea) fea
53
54(* ------------------------------------------------------------------------
55   Theorem predictions
56   ------------------------------------------------------------------------ *)
57
58fun thmknn (symweight,thmfea) n fea =
59  knn_sortu String.compare n (symweight,thmfea) fea
60
61(* ----------------------------------------------------------------------
62   Adding theorem dependencies
63   ---------------------------------------------------------------------- *)
64
65fun add_thmdep n predl =
66  let
67    fun f pred = pred :: validdep_of_thmid pred
68    val predl0 = List.concat (map f predl)
69    val predl1 = mk_sameorder_set String.compare predl0
70  in
71    first_n n predl1
72  end
73
74fun thmknn_wdep (symweight,feavdict) n fea =
75  add_thmdep n (thmknn (symweight,feavdict) n fea)
76
77(* ------------------------------------------------------------------------
78   Tactic predictions
79   ------------------------------------------------------------------------ *)
80
81fun tacknn (symweight,tacfea) n fea =
82  knn_sortu String.compare n (symweight,tacfea) fea
83
84fun callknn (symweight,callfea) n fea =
85  knn_sortu (snd_compare call_compare) n (symweight,callfea) fea
86
87(* ----------------------------------------------------------------------
88   Adding tactic dependencies
89   --------------------------------------------------------------------- *)
90
91fun dep_call_g rl lookup loopd gn =
92  if HOLset.member (loopd,gn) orelse not (can lookup gn) then () else
93  let
94    val newloopd = HOLset.add (loopd,gn)
95    val (loc,call) = lookup gn
96    val _ = rl := (loc,call) :: (!rl)
97    val gnl = #ogl call
98  in
99    app (dep_call_g rl lookup newloopd) gnl
100  end
101
102fun dep_call calld ((thy,thmn,gn),{stac,ogl,fea}) =
103  let
104    val rl = ref []
105    fun lookup x = ((thy,thmn,x), dfind (thy,thmn,x) calld)
106    val loopd = HOLset.fromList Int.compare [gn]
107  in
108    app (dep_call_g rl lookup loopd) ogl;
109    mk_sameorder_set (snd_compare call_compare) (rev (!rl))
110  end
111
112fun add_calldep calld n calls =
113  let
114    val l1 = List.concat (map (fn x => x :: dep_call calld x) calls)
115    val l2 = mk_sameorder_set (snd_compare call_compare) l1
116  in
117    first_n n l2
118  end
119
120(* ----------------------------------------------------------------------
121   Training from a dataset of term-value pairs for comparison with
122   tree neural networks.
123   --------------------------------------------------------------------- *)
124
125type 'a knnpred = (symweight * term afea) * (term, 'a) Redblackmap.dict
126
127fun train_knn trainset =
128  let
129    val termfea = map_assoc (fea_of_term true) (map fst trainset);
130    val symweight = learn_tfidf termfea;
131  in
132    ((symweight,termfea), dnew Term.compare trainset)
133  end
134
135fun infer_knn ((symweight,termfea),d) tm =
136  let val neartm = hd (termknn (symweight,termfea) 1 (fea_of_term true tm)) in
137    dfind neartm d
138  end
139
140fun is_accurate_knn knnpred (tm,rlo) =
141  let
142    val rl1 = infer_knn knnpred tm
143    val rl2 = combine (rl1, rlo)
144    fun test (x,y) = Real.abs (x - y) < 0.5
145  in
146    if all test rl2 then true else false
147  end
148
149fun knn_accuracy knnpred exset =
150  let val correct = filter I (map (is_accurate_knn knnpred) exset) in
151    Real.fromInt (length correct) / Real.fromInt (length exset)
152  end
153
154
155end (* struct *)
156