1(* ========================================================================= *) 2(* FILE : mleArithData.sml *) 3(* DESCRIPTION : Data for elementary arithmetic problems *) 4(* AUTHOR : (c) Thibault Gauthier, Czech Technical University *) 5(* DATE : 2019 *) 6(* ========================================================================= *) 7 8structure mleArithData :> mleArithData = 9struct 10 11open HolKernel Abbrev boolLib aiLib psTermGen mlTacticData 12mlTreeNeuralNetwork mlFeature numSyntax 13 14val ERR = mk_HOL_ERR "mleArithData" 15val arithdir = HOLDIR ^ "/examples/AI_TNN/data_arith" 16 17(* ------------------------------------------------------------------------- 18 Arithmetic 19 ------------------------------------------------------------------------- *) 20 21fun mk_sucn n = funpow n mk_suc zero_tm 22 23fun eval_numtm tm = 24 (string_to_int o term_to_string o rhs o concl o computeLib.EVAL_CONV) tm 25 26(* ------------------------------------------------------------------------- 27 Arithmetical term generation 28 ------------------------------------------------------------------------- *) 29 30fun num_operl n = 31 [``SUC``,``$+``,``$*``] @ map mk_sucn (List.tabulate (n+1,I)) 32fun random_numtm (nsuc,nsize) = 33 random_term (num_operl nsuc) (nsize,``:num``) 34 35(* ------------------------------------------------------------------------- 36 Set of arithmetical examples 37 ------------------------------------------------------------------------- *) 38 39fun create_exset_class notset nex (nsuc,nsize) = 40 let 41 val d = ref (dempty Term.compare) 42 fun random_exl n = 43 if n <= 0 orelse dlength (!d) >= nex then () else 44 let val tm = random_numtm (nsuc,nsize) in 45 if dmem tm (!notset) orelse dmem tm (!d) then () 46 else (d := dadd tm () (!d); notset := dadd tm () (!notset)); 47 random_exl (n - 1) 48 end 49 in 50 random_exl (nex * 10); dkeys (!d) 51 end 52 53fun create_exset_table notset nex (nsuc,nsize) = 54 let 55 fun f x = x + 1 56 val l = cartesian_product (List.tabulate (nsuc,f)) 57 (List.tabulate (nsize,f)) 58 in 59 map_assoc (create_exset_class notset nex) l 60 end 61 62fun create_exset tml nex (nsuc,nsize) = 63 let val notset = ref (dset Term.compare tml) in 64 List.concat (map snd (create_exset_table notset nex (nsuc,nsize))) 65 end 66 67(* ------------------------------------------------------------------------- 68 Creation of training set, validation set and test set and export 69 ------------------------------------------------------------------------- *) 70 71fun create_train nex = create_exset [] nex (10,10) 72fun create_valid tml nex = create_exset tml nex (10,10) 73fun create_test tml nex = create_exset tml nex (10,10) 74 75(* ------------------------------------------------------------------------- 76 Export/Import functions 77 ------------------------------------------------------------------------- *) 78 79fun create_export_arithdata () = 80 let 81 val _ = mkDir_err arithdir 82 val tmltrain = create_train 200 83 val tmlvalid = create_valid tmltrain 200 84 val tmltest = create_test (tmltrain @ tmlvalid) 200 85 fun f tm = tts tm ^ "," ^ its (eval_numtm tm) 86 in 87 writel (arithdir ^ "/train") (map f tmltrain); 88 writel (arithdir ^ "/valid") (map f tmlvalid); 89 writel (arithdir ^ "/test") (map f tmltest) 90 end 91 92(* 93load "mleArithData"; open mleArithData; 94val _ = create_export_arithdata (); 95*) 96 97fun lisp_to_arith lisp = 98 case lisp of 99 Lstring "0" => zero_tm 100 | Lterm [Lstring "S",a] => mk_suc (lisp_to_arith a) 101 | Lterm [Lstring "+",a,b] => mk_plus (lisp_to_arith a, lisp_to_arith b) 102 | Lterm [Lstring "*",a,b] => mk_mult (lisp_to_arith a, lisp_to_arith b) 103 | _ => raise ERR "lisp_to_arith" "" 104 105fun import_arithdata dataname = 106 let 107 val l1 = readl (arithdir ^ "/" ^ dataname) 108 val l2 = map pair_of_list (mk_batch_full 2 l1) 109 val l3 = map_snd string_to_int l2 110 fun f x = if x = "0" then zero_tm else 111 (lisp_to_arith o singleton_of_list o lisp_parser) x 112 in 113 map_fst f l3 114 end 115 116(* 117load "mleArithData"; open mleArithData; 118load "aiLib"; open aiLib; 119val train = import_arithdata "train"; 120*) 121 122(* ------------------------------------------------------------------------- 123 Write subterm features for feature-based predictors 124 ------------------------------------------------------------------------- *) 125 126fun export_arithfea dataname = 127 let 128 val tml' = map fst (List.concat (map import_arithdata ["test","train"])) 129 fun all_features x = 130 let val l = List.concat (map (fea_of_term_mod 1299827 true) x) in 131 dnew Int.compare (number_snd 0 (mk_fast_set Int.compare l)) 132 end 133 val tml = import_arithdata dataname 134 val d = all_features tml' 135 fun f (tm,i) = 136 let 137 val il1 = dict_sort Int.compare 138 (map (fn x => dfind x d) (fea_of_term_mod 1299827 true tm)) 139 val il2 = map (fn x => (its x ^ ":1")) il1 140 in 141 ("+" ^ its (i mod 16) ^ " " ^ String.concatWith " " il2) 142 end 143 in 144 writel (arithdir ^ "/" ^ dataname ^ "_fea") (map f tml) 145 end 146 147(* 148load "mleArithData"; open mleArithData; 149app export_computefea ["train","test"]; 150*) 151 152(* ------------------------------------------------------------------------- 153 Statistics 154 ------------------------------------------------------------------------- *) 155 156fun regroup_by_metric f tml = 157 let val d = dregroup Int.compare (map swap (map_assoc f tml)) in 158 map_snd length (dlist d) 159 end 160 161 162end (* struct *) 163