1(* ========================================================================= *) 2(* FILE : mleDiophSynt.sml *) 3(* DESCRIPTION : Specification of term synthesis on Diophantine equations *) 4(* AUTHOR : (c) Thibault Gauthier, Czech Technical University *) 5(* DATE : 2020 *) 6(* ========================================================================= *) 7 8structure mleDiophSynt :> mleDiophSynt = 9struct 10 11open HolKernel Abbrev boolLib aiLib smlParallel psMCTS psTermGen 12 mlNeuralNetwork mlTreeNeuralNetwork mlTacticData 13 mlReinforce arithmeticTheory numLib numSyntax mleDiophLib 14 15val ERR = mk_HOL_ERR "mleDiophSynt" 16val version = 2 17val selfdir = HOLDIR ^ "/examples/AI_tasks" 18 19(* ------------------------------------------------------------------------- 20 Board 21 ------------------------------------------------------------------------- *) 22 23type board = int list list * bool list * int 24fun string_of_board (a,b,c)= 25 poly_to_string a ^ " -- " ^ graph_to_string b ^ " -- " ^ its c 26 27fun board_compare ((a,b,c),(d,e,f)) = 28 cpl_compare poly_compare graph_compare ((a,b),(d,e)) 29 30fun fullboard_compare ((a,b,c),(d,e,f)) = 31 triple_compare Int.compare poly_compare graph_compare ((c,a,b),(f,d,e)) 32 33fun status_of (poly,graph,n) = 34 if dioph_match poly graph then Win 35 else if n <= 0 then Lose 36 else Undecided 37 38(* ------------------------------------------------------------------------- 39 Move 40 ------------------------------------------------------------------------- *) 41 42datatype move = Add of int | Exp of int 43val movel = map Add numberl @ map Exp (List.tabulate (maxexponent + 1, I)) 44 45fun string_of_move move = case move of 46 Add i => "A" ^ its i 47 | Exp i => "E" ^ its i 48 49fun move_compare (a,b) = String.compare (string_of_move a, string_of_move b) 50 51fun apply_move_poly move poly = 52 case move of 53 Add c => if length poly >= maxmonomial 54 then raise ERR "apply_move_poly" "plus" 55 else poly @ [[c]] 56 | Exp c => if null poly orelse length (last poly) >= maxvar + 1 57 then raise ERR "apply_move_poly" "mult" 58 else if length poly >= 2 andalso 59 let 60 val prevexp = tl (last (butlast poly)) 61 val curexp = tl (last poly) @ [c] 62 val m = Int.min (length curexp,length prevexp) 63 in 64 list_compare Int.compare (first_n m prevexp, first_n m curexp) 65 = GREATER 66 end 67 then raise ERR "apply_move_poly" "non-increasing" 68 else butlast poly @ [last poly @ [c]] 69 70fun apply_move (tree,id) move (poly,graph,n) = 71 ((apply_move_poly move poly, graph, n-1), tree) 72 73fun available_movel_poly poly = 74 filter (fn x => can (apply_move_poly x) poly) movel 75 76fun available_movel (poly,_,_) = available_movel_poly poly 77 78(* ------------------------------------------------------------------------- 79 Game 80 ------------------------------------------------------------------------- *) 81 82val game : (board,move) game = 83 { 84 status_of = status_of, 85 apply_move = apply_move, 86 available_movel = available_movel, 87 string_of_board = string_of_board, 88 string_of_move = string_of_move, 89 board_compare = board_compare, 90 move_compare = move_compare, 91 movel = movel 92 } 93 94(* ------------------------------------------------------------------------- 95 Parallelization 96 ------------------------------------------------------------------------- *) 97 98fun write_boardl file boardl = 99 let val (l1,l2,l3) = split_triple boardl in 100 writel (file ^ "_poly") (map poly_to_string l1); 101 writel (file ^ "_graph") (map graph_to_string l2); 102 writel (file ^ "_timer") (map its l3) 103 end 104 105fun read_boardl file = 106 let 107 val l1 = map string_to_poly (readl_empty (file ^ "_poly")) 108 val l2 = map string_to_graph (readl (file ^ "_graph")) 109 val l3 = map string_to_int (readl (file ^ "_timer")) 110 in 111 combine_triple (l1,l2,l3) 112 end 113 114val gameio = {write_boardl = write_boardl, read_boardl = read_boardl} 115 116(* ------------------------------------------------------------------------- 117 Targets 118 ------------------------------------------------------------------------- *) 119 120val targetdir = selfdir ^ "/dioph_target" 121 122fun graph_to_bl graph = map (fn x => mem x graph) numberl 123 124fun create_targetl l = 125 let 126 val (train,test) = part_pct (10.0/11.0) (shuffle l) 127 val _ = export_data (train,test) 128 fun f (graph,poly) = ([], graph_to_bl graph, 2 * poly_size poly) 129 in 130 (dict_sort fullboard_compare (map f train), 131 dict_sort fullboard_compare (map f test)) 132 end 133 134fun export_targetl name targetl = 135 let val _ = mkDir_err targetdir in 136 write_boardl (targetdir ^ "/" ^ name) targetl 137 end 138 139fun import_targetl name = read_boardl (targetdir ^ "/" ^ name) 140 141fun mk_targetd l1 = 142 let 143 val l2 = number_snd 0 l1 144 val l3 = map (fn (x,i) => (x,(i,[]))) l2 145 in 146 dnew board_compare l3 147 end 148 149(* ------------------------------------------------------------------------- 150 Neural network representation of the board 151 ------------------------------------------------------------------------- *) 152 153fun term_of_graph graph = 154 mk_embedding_var 155 (Vector.fromList (map (fn x => if x then 1.0 else ~1.0) graph), bool) 156 157val head_eval = mk_var ("head_eval", ``:bool -> 'a``) 158val head_poli = mk_var ("head_poli", ``:bool -> 'a``) 159fun tag_heval x = mk_comb (head_eval,x) 160fun tag_hpoli x = mk_comb (head_poli,x) 161val graph_tag = mk_var ("graph_tag", ``:bool -> num``) 162fun tag_graph x = mk_comb (graph_tag,x) 163 164fun tob1 (poly,graph,_) = 165 let 166 val (tm1,tm2) = (term_of_poly poly, tag_graph (term_of_graph graph)) 167 val tm = mk_eq (tm1,tm2) 168 in 169 [tag_heval tm, tag_hpoli tm] 170 end 171 172fun tob2 embedv (poly,_,_) = 173 let 174 val (tm1,tm2) = (term_of_poly poly, embedv) 175 val tm = mk_eq (tm1,tm2) 176 in 177 [tag_heval tm, tag_hpoli tm] 178 end 179 180fun pretob boardtnno = case boardtnno of 181 NONE => tob1 182 | SOME ((_,graph,_),tnn) => 183 tob2 (precomp_embed tnn (tag_graph (term_of_graph graph))) 184 185(* ------------------------------------------------------------------------- 186 Player 187 ------------------------------------------------------------------------- *) 188 189val schedule = 190 [{ncore = 4, verbose = true, learning_rate = 0.02, 191 batch_size = 16, nepoch = 10}] 192 193val dioph_operl = 194 [``$= : num -> num -> bool``, 195 graph_tag,``$+``,``$*``,mk_var ("start",``:num``)] @ 196 map (fn i => mk_var ("n" ^ its i,``:num``)) numberl @ 197 List.concat 198 (List.tabulate (maxvar, fn v => 199 List.tabulate (maxexponent + 1, fn p => 200 mk_var ("v" ^ its v ^ "p" ^ its p,``:num``)))) 201 202val dim = 16 203fun dim_head_poli n = [dim,n] 204 205val tnndim = map_assoc (dim_std (1,dim)) dioph_operl @ 206 [(head_eval,[dim,dim,1]),(head_poli,[dim,dim,length movel])] 207val dplayer = {pretob = pretob, tnndim = tnndim, schedule = schedule} 208 209(* ------------------------------------------------------------------------- 210 Interface 211 ------------------------------------------------------------------------- *) 212 213val rlparam = 214 {expname = "mleDiophSynt-" ^ its version, exwindow = 200000, 215 ncore = 30, ntarget = 200, nsim = 32000, decay = 1.0} 216 217val rlobj : (board,move) rlobj = 218 {rlparam = rlparam, game = game, gameio = gameio, dplayer = dplayer, 219 infobs = fn _ => ()} 220 221val extsearch = mk_extsearch "mleDiophSynt.extsearch" rlobj 222 223(* ------------------------------------------------------------------------- 224 Final test 225 ------------------------------------------------------------------------- *) 226 227(* 228val ft_extsearch_uniform = 229 ft_mk_extsearch "mleDiophSynt.ft_extsearch_uniform" rlobj 230 (uniform_player game) 231 232fun graph_distance bl1 bl2 = 233 let 234 val bbl1 = combine (bl1,bl2) 235 val bbl2 = filter (fn (a,b) => a = b) bbl1 236 in 237 int_div (length bbl2) (length bl1) 238 end 239 240fun distance_player (board as (poly,set,_)) = 241 let 242 val e = if null poly 243 then 0.0 244 else graph_distance (graph_to_bl (dioph_set poly)) set 245 in 246 (e, map (fn x => (x,1.0)) (available_movel board)) 247 end 248 249val ft_extsearch_distance = 250 ft_mk_extsearch "mleDiophSynt.ft_extsearch_distance" rlobj distance_player 251 252val fttnn_extsearch = 253 fttnn_mk_extsearch "mleDiophSynt.fttnn_extsearch" rlobj 254 255val fttnnbs_extsearch = 256 fttnnbs_mk_extsearch "mleDiophSynt.fttnnbs_extsearch" rlobj 257*) 258 259(* 260load "aiLib"; open aiLib; 261load "mlReinforce"; open mlReinforce; 262load "mleDiophLib"; open mleDiophLib; 263load "mleDiophSynt"; open mleDiophSynt; 264 265val (dfull,ntry) = gen_diophset 0 2200 (dempty (list_compare Int.compare)); 266val (train,test) = create_targetl (dlist dfull); length train; length test; 267val _ = (export_targetl "train" train; export_targetl "test" test); 268val targetl = import_targetl "train"; length targetl; 269val _ = rl_start (rlobj,extsearch) (mk_targetd targetl); 270 271val targetd = retrieve_targetd rlobj 75; 272val _ = rl_restart 75 (rlobj,extsearch) targetd; 273*) 274 275(* ------------------------------------------------------------------------- 276 MCTS test for inspection of the results 277 ------------------------------------------------------------------------- *) 278 279fun solve_target (unib,tim,tnn) target = 280 let 281 val mctsparam = 282 { 283 timer = SOME tim, 284 nsim = (NONE : int option), 285 stopatwin_flag = true, 286 decay = 1.0, 287 explo_coeff = 2.0, 288 noise_all = false, 289 noise_root = false, 290 noise_coeff = 0.25, 291 noise_gen = random_real, 292 noconfl = false, 293 avoidlose = false, 294 evalwin = false 295 } 296 val pretob = (#pretob (#dplayer rlobj)); 297 fun preplayer target = 298 let val tob = pretob (SOME (target,tnn)) in 299 fn board => mlReinforce.player_from_tnn tnn tob (#game rlobj) board 300 end; 301 val mctsobj = 302 {mctsparam = mctsparam, game = #game rlobj, 303 player = if unib then uniform_player (#game rlobj) else preplayer target} 304 in 305 (fst o snd) (mcts mctsobj (starttree_of mctsobj target)) 306 end 307 308fun solve_diophset (unib,tim,tnn) diophset = 309 let 310 val target = ([]:poly,graph_to_bl diophset,40); 311 val tree = solve_target (unib,tim,tnn) target; 312 val b = #status (dfind [] tree) = Win; 313 in 314 if b then 315 let val nodel = trace_win tree [] in 316 print_endline (its (dlength tree)); 317 print_endline (human_of_poly (#1 (#board (last nodel)))) 318 end 319 else 320 (print_endline (its (dlength tree)); print_endline "Time out") 321 end 322 323(* 324load "aiLib"; open aiLib; 325load "mleDiophSynt"; open mleDiophSynt; 326val tnn = mlReinforce.retrieve_tnn rlobj 197; 327val diophset = [0,1,2,4,8]; 328solve_diophset (false,60.0,tnn) diophset; 329*) 330 331(* ------------------------------------------------------------------------- 332 Final testing 333 ------------------------------------------------------------------------- *) 334 335(* 336load "aiLib"; open aiLib; 337load "smlParallel"; open smlParallel; 338load "mlTreeNeuralNetwork"; open mlTreeNeuralNetwork; 339load "mleDiophSynt"; open mleDiophSynt; 340 341val dir1 = HOLDIR ^ "/examples/AI_tasks/dioph_results_nolimit"; 342val _ = mkDir_err dir1; 343fun store_result dir (a,i) = 344 #write_result ft_extsearch_uniform (dir ^ "/" ^ its i) a; 345 346(*** Testing set ***) 347val dataset = "test"; 348val pretargetl = import_targetl dataset; 349val targetl = map (fn (a,b,_) => (a,b,1000000)) pretargetl; 350length targetl; 351(* uniform *) 352val (l1',t) = add_time (parmap_queue_extern 20 ft_extsearch_uniform ()) targetl; 353val winb = filter I (map #1 l1'); length winb; 354val dir2 = dir1 ^ "/" ^ dataset ^ "_uniform"; 355val _ = mkDir_err dir2; app (store_result dir2) (number_snd 0 l1'); 356(* distance *) 357val (l2',t) = 358 add_time (parmap_queue_extern 20 ft_extsearch_distance ()) targetl; 359val winb = filter I (map #1 l2'); length winb; 360val dir2 = dir1 ^ "/" ^ dataset ^ "_distance"; 361val _ = mkDir_err dir2; app (store_result dir2) (number_snd 0 l2'); 362(* tnn *) 363val tnn = mlReinforce.retrieve_tnn rlobj 197; 364val (l3',t) = add_time (parmap_queue_extern 20 fttnn_extsearch tnn) targetl; 365val winb = filter I (map #1 l3'); length winb; 366val dir2 = dir1 ^ "/" ^ dataset ^ "_tnn"; 367val _ = mkDir_err dir2; app (store_result dir2) (number_snd 0 l3'); 368 369(*** Training set ***) 370val dataset = "train"; 371val pretargetl = import_targetl dataset; 372val targetl = map (fn (a,b,_) => (a,b,1000000)) pretargetl; 373length targetl; 374(* uniform *) 375val (l1,t) = add_time (parmap_queue_extern 20 ft_extsearch_uniform ()) targetl; 376val winb = filter I (map #1 l1); length winb; 377val dir2 = dir1 ^ "/" ^ dataset ^ "_uniform"; 378val _ = mkDir_err dir2; app (store_result dir2) (number_snd 0 l1); 379(* distance *) 380val (l2,t) = 381 add_time (parmap_queue_extern 20 ft_extsearch_distance ()) targetl; 382val winb = filter I (map #1 l2); length winb; 383val dir2 = dir1 ^ "/" ^ dataset ^ "_distance"; 384val _ = mkDir_err dir2; app (store_result dir2) (number_snd 0 l2); 385(* tnn *) 386val tnn = mlReinforce.retrieve_tnn rlobj 197; 387val (l3,t) = add_time (parmap_queue_extern 20 fttnn_extsearch tnn) targetl; 388val winb = filter I (map #1 l3); length winb; 389val dir2 = dir1 ^ "/" ^ dataset ^ "_tnn"; 390val _ = mkDir_err dir2; app (store_result dir2) (number_snd 0 l3); 391*) 392 393(* ------------------------------------------------------------------------- 394 Final testing statistics 395 ------------------------------------------------------------------------- *) 396 397(* 398load "aiLib"; open aiLib; 399load "mleDiophLib"; open mleDiophLib; 400load "mleDiophSynt"; open mleDiophSynt; 401 402val dir2 = HOLDIR ^ "/examples/AI_tasks/dioph_results/test_tnn_nolimit"; 403fun g i = #read_result ft_extsearch_uniform (dir2 ^ "/" ^ its i); 404val l1 = List.tabulate (200,g); 405val dir2 = HOLDIR ^ "/examples/AI_tasks/dioph_results/train_tnn"; 406fun g i = #read_result ft_extsearch_uniform (dir2 ^ "/" ^ its i); 407val l2 = List.tabulate (2000,g); 408 409val (l3,l3') = partition #1 (l1 @ l2); 410val nsim_tnn = average_int (map #2 l3'); 411val l4 = map (valOf o #3) l3; 412val l5 = map (fn (a,b,c) => veri_of_poly a) l4; 413val l6 = map (fn (a,b,c) => ((graph_to_il b, veri_of_poly a), poly_size a)) 414l4; 415val l7 = dict_sort compare_imax l6; 416hd l7; 417val d = dnew (list_compare Int.compare) l6; 418 419val l6 = map (fn (a,b,c) => (graph_to_il b, veri_of_poly a, c)) l4; 420 421val longest = 422 let fun cmp (a,b) = Int.compare (#2 b, #2 a) in 423 dict_sort cmp l3 424 end; 425 426val (a,b,c) = valOf (#3 (hd longest)); 427veri_of_poly a; 428graph_to_il b; 429 430 431val monol = List.concat (map numSyntax.strip_plus l5); 432val monofreq = dlist (count_dict (dempty Term.compare) monol); 433val monostats = dict_sort compare_imax monofreq; 434 435mleDiophLib.dioph_set (fst (hd l4)); 436 437 438val dir2 = HOLDIR ^ "/examples/AI_tasks/dioph_results/test_uniform"; 439fun g i = #read_result ft_extsearch_uniform (dir2 ^ "/" ^ its i); 440val l1 = List.tabulate (200,g); 441val dir2 = HOLDIR ^ "/examples/AI_tasks/dioph_results/train_uniform"; 442fun g i = #read_result ft_extsearch_uniform (dir2 ^ "/" ^ its i); 443val l2 = List.tabulate (2000,g); 444 445val (l3,l3') = partition #1 (l1 @ l2); 446val nsim_uniform = average_int (map #2 l3'); 447 448val dir2 = HOLDIR ^ "/examples/AI_tasks/dioph_results/test_distance"; 449fun g i = #read_result ft_extsearch_uniform (dir2 ^ "/" ^ its i); 450val l1 = List.tabulate (200,g); 451val dir2 = HOLDIR ^ "/examples/AI_tasks/dioph_results/train_distance"; 452fun g i = #read_result ft_extsearch_uniform (dir2 ^ "/" ^ its i); 453val l2 = List.tabulate (2000,g); 454 455val (l3,l3') = partition #1 (l1 @ l2); 456val nsim_distance = average_int (map #2 l3'); 457 458*) 459 460(* ------------------------------------------------------------------------- 461 Training graph 462 ------------------------------------------------------------------------- *) 463 464(* 465load "aiLib"; open aiLib; 466load "mleDiophLib"; open mleDiophLib; 467load "mleDiophSynt"; open mleDiophSynt; 468 469val targetdl = List.tabulate (230, 470 fn x => mlReinforce.retrieve_targetd rlobj (x+1)); 471val l1 = map dlist targetdl; 472val l2 = map (map (snd o snd)) l1; 473 474fun btr b = if b then 1.0 else 0.0 475 476fun expectancy_one bl = 477 if null bl then 0.0 else average_real (map btr (first_n 5 bl)) 478fun expectancy bll = sum_real (map expectancy_one bll); 479val expectl = map expectancy l2; 480 481fun exists_one bl = btr (exists I bl); 482fun existssol bll = sum_real (map exists_one bll); 483val esoll = map existssol l2; 484 485val graph = number_fst 0 (combine (expectl,esoll)); 486fun graph_to_string (i,(r1,r2)) = its i ^ " " ^ rts r1 ^ " " ^ rts r2; 487writel "dioph_graph" ("gen exp sol" :: map graph_to_string graph); 488 489*) 490 491end (* struct *) 492