1(* ========================================================================= *)
2(* FILE          : mleArithData.sml                                          *)
3(* DESCRIPTION   : Data for elementary arithmetic problems                   *)
4(* AUTHOR        : (c) Thibault Gauthier, Czech Technical University         *)
5(* DATE          : 2019                                                      *)
6(* ========================================================================= *)
7
8structure mleArithData :> mleArithData =
9struct
10
11open HolKernel Abbrev boolLib aiLib psTermGen mlTacticData
12mlTreeNeuralNetwork mlFeature numSyntax
13
14val ERR = mk_HOL_ERR "mleArithData"
15val arithdir = HOLDIR ^ "/examples/AI_TNN/data_arith"
16
17(* -------------------------------------------------------------------------
18   Arithmetic
19   ------------------------------------------------------------------------- *)
20
21fun mk_sucn n = funpow n mk_suc zero_tm
22
23fun eval_numtm tm =
24  (string_to_int o term_to_string o rhs o concl o computeLib.EVAL_CONV) tm
25
26(* -------------------------------------------------------------------------
27   Arithmetical term generation
28   ------------------------------------------------------------------------- *)
29
30fun num_operl n =
31  [``SUC``,``$+``,``$*``] @ map mk_sucn (List.tabulate (n+1,I))
32fun random_numtm (nsuc,nsize) =
33  random_term (num_operl nsuc) (nsize,``:num``)
34
35(* -------------------------------------------------------------------------
36   Set of arithmetical examples
37   ------------------------------------------------------------------------- *)
38
39fun create_exset_class notset nex (nsuc,nsize) =
40  let
41    val d = ref (dempty Term.compare)
42    fun random_exl n =
43      if n <= 0 orelse dlength (!d) >= nex then () else
44      let val tm = random_numtm (nsuc,nsize) in
45        if dmem tm (!notset) orelse dmem tm (!d) then ()
46        else (d := dadd tm () (!d); notset := dadd tm () (!notset));
47        random_exl (n - 1)
48      end
49  in
50    random_exl (nex * 10); dkeys (!d)
51  end
52
53fun create_exset_table notset nex (nsuc,nsize) =
54  let
55    fun f x = x + 1
56    val l = cartesian_product (List.tabulate (nsuc,f))
57                              (List.tabulate (nsize,f))
58  in
59    map_assoc (create_exset_class notset nex) l
60  end
61
62fun create_exset tml nex (nsuc,nsize) =
63  let val notset = ref (dset Term.compare tml) in
64    List.concat (map snd (create_exset_table notset nex (nsuc,nsize)))
65  end
66
67(* -------------------------------------------------------------------------
68   Creation of training set, validation set and test set and export
69   ------------------------------------------------------------------------- *)
70
71fun create_train nex = create_exset [] nex (10,10)
72fun create_valid tml nex = create_exset tml nex (10,10)
73fun create_test tml nex = create_exset tml nex (10,10)
74
75(* -------------------------------------------------------------------------
76   Export/Import functions
77   ------------------------------------------------------------------------- *)
78
79fun create_export_arithdata () =
80  let
81    val _ = mkDir_err arithdir
82    val tmltrain = create_train 200
83    val tmlvalid = create_valid tmltrain 200
84    val tmltest = create_test (tmltrain @ tmlvalid) 200
85    fun f tm = tts tm ^ "," ^ its (eval_numtm tm)
86  in
87    writel (arithdir ^ "/train") (map f tmltrain);
88    writel (arithdir ^ "/valid") (map f tmlvalid);
89    writel (arithdir ^ "/test") (map f tmltest)
90  end
91
92(*
93load "mleArithData"; open mleArithData;
94val _ = create_export_arithdata ();
95*)
96
97fun lisp_to_arith lisp =
98  case lisp of
99    Lstring "0" => zero_tm
100  | Lterm [Lstring "S",a] => mk_suc (lisp_to_arith a)
101  | Lterm [Lstring "+",a,b] => mk_plus (lisp_to_arith a, lisp_to_arith b)
102  | Lterm [Lstring "*",a,b] => mk_mult (lisp_to_arith a, lisp_to_arith b)
103  | _ => raise ERR "lisp_to_arith" ""
104
105fun import_arithdata dataname =
106  let
107    val l1 = readl (arithdir ^ "/" ^ dataname)
108    val l2 = map pair_of_list (mk_batch_full 2 l1)
109    val l3 = map_snd string_to_int l2
110    fun f x = if x = "0" then zero_tm else
111      (lisp_to_arith o singleton_of_list o lisp_parser) x
112  in
113    map_fst f l3
114  end
115
116(*
117load "mleArithData"; open mleArithData;
118load "aiLib"; open aiLib;
119val train = import_arithdata "train";
120*)
121
122(* -------------------------------------------------------------------------
123   Write subterm features for feature-based predictors
124   ------------------------------------------------------------------------- *)
125
126fun export_arithfea dataname =
127  let
128    val tml' = map fst (List.concat (map import_arithdata ["test","train"]))
129    fun all_features x =
130      let val l = List.concat (map (fea_of_term_mod 1299827 true) x) in
131        dnew Int.compare (number_snd 0 (mk_fast_set Int.compare l))
132      end
133    val tml = import_arithdata dataname
134    val d = all_features tml'
135    fun f (tm,i) =
136      let
137        val il1 = dict_sort Int.compare
138          (map (fn x => dfind x d) (fea_of_term_mod 1299827 true tm))
139        val il2 = map (fn x => (its x ^ ":1")) il1
140      in
141        ("+" ^ its (i mod 16) ^ " " ^ String.concatWith " " il2)
142      end
143  in
144    writel (arithdir ^ "/" ^ dataname ^ "_fea") (map f tml)
145  end
146
147(*
148load "mleArithData"; open mleArithData;
149app export_computefea ["train","test"];
150*)
151
152(* -------------------------------------------------------------------------
153   Statistics
154   ------------------------------------------------------------------------- *)
155
156fun regroup_by_metric f tml =
157  let val d = dregroup Int.compare (map swap (map_assoc f tml)) in
158    map_snd length (dlist d)
159  end
160
161
162end (* struct *)
163