1(* ========================================================================= *)
2(* FILE          : mleDiophSynt.sml                                          *)
3(* DESCRIPTION   : Specification of term synthesis on Diophantine equations  *)
4(* AUTHOR        : (c) Thibault Gauthier, Czech Technical University         *)
5(* DATE          : 2020                                                      *)
6(* ========================================================================= *)
7
8structure mleDiophSynt :> mleDiophSynt =
9struct
10
11open HolKernel Abbrev boolLib aiLib smlParallel psMCTS psTermGen
12  mlNeuralNetwork mlTreeNeuralNetwork mlTacticData
13  mlReinforce arithmeticTheory numLib numSyntax mleDiophLib
14
15val ERR = mk_HOL_ERR "mleDiophSynt"
16val version = 2
17val selfdir = HOLDIR ^ "/examples/AI_tasks"
18
19(* -------------------------------------------------------------------------
20   Board
21   ------------------------------------------------------------------------- *)
22
23type board = int list list * bool list * int
24fun string_of_board (a,b,c)=
25  poly_to_string a ^ " -- " ^ graph_to_string b ^ " -- " ^ its c
26
27fun board_compare ((a,b,c),(d,e,f)) =
28  cpl_compare poly_compare graph_compare ((a,b),(d,e))
29
30fun fullboard_compare ((a,b,c),(d,e,f)) =
31  triple_compare Int.compare poly_compare graph_compare ((c,a,b),(f,d,e))
32
33fun status_of (poly,graph,n) =
34  if dioph_match poly graph then Win
35  else if n <= 0 then Lose
36  else Undecided
37
38(* -------------------------------------------------------------------------
39   Move
40   ------------------------------------------------------------------------- *)
41
42datatype move = Add of int | Exp of int
43val movel = map Add numberl @ map Exp (List.tabulate (maxexponent + 1, I))
44
45fun string_of_move move = case move of
46    Add i => "A" ^ its i
47  | Exp i => "E" ^ its i
48
49fun move_compare (a,b) = String.compare (string_of_move a, string_of_move b)
50
51fun apply_move_poly move poly =
52  case move of
53    Add c => if length poly >= maxmonomial
54             then raise ERR "apply_move_poly" "plus"
55             else poly @ [[c]]
56  | Exp c => if null poly orelse length (last poly) >= maxvar + 1
57               then raise ERR "apply_move_poly" "mult"
58             else if length poly >= 2 andalso
59               let
60                 val prevexp = tl (last (butlast poly))
61                 val curexp = tl (last poly) @ [c]
62                 val m = Int.min (length curexp,length prevexp)
63               in
64                 list_compare Int.compare (first_n m prevexp, first_n m curexp)
65                 = GREATER
66               end
67             then raise ERR "apply_move_poly" "non-increasing"
68             else butlast poly @ [last poly @ [c]]
69
70fun apply_move (tree,id) move (poly,graph,n) =
71  ((apply_move_poly move poly, graph, n-1), tree)
72
73fun available_movel_poly poly =
74  filter (fn x => can (apply_move_poly x) poly) movel
75
76fun available_movel (poly,_,_) = available_movel_poly poly
77
78(* -------------------------------------------------------------------------
79   Game
80   ------------------------------------------------------------------------- *)
81
82val game : (board,move) game =
83  {
84  status_of = status_of,
85  apply_move = apply_move,
86  available_movel = available_movel,
87  string_of_board = string_of_board,
88  string_of_move = string_of_move,
89  board_compare = board_compare,
90  move_compare = move_compare,
91  movel = movel
92  }
93
94(* -------------------------------------------------------------------------
95   Parallelization
96   ------------------------------------------------------------------------- *)
97
98fun write_boardl file boardl =
99  let val (l1,l2,l3) = split_triple boardl in
100    writel (file ^ "_poly") (map poly_to_string l1);
101    writel (file ^ "_graph") (map graph_to_string l2);
102    writel (file ^ "_timer") (map its l3)
103  end
104
105fun read_boardl file =
106  let
107    val l1 = map string_to_poly (readl_empty (file ^ "_poly"))
108    val l2 = map string_to_graph (readl (file ^ "_graph"))
109    val l3 = map string_to_int (readl (file ^ "_timer"))
110  in
111    combine_triple (l1,l2,l3)
112  end
113
114val gameio = {write_boardl = write_boardl, read_boardl = read_boardl}
115
116(* -------------------------------------------------------------------------
117   Targets
118   ------------------------------------------------------------------------- *)
119
120val targetdir = selfdir ^ "/dioph_target"
121
122fun graph_to_bl graph = map (fn x => mem x graph) numberl
123
124fun create_targetl l =
125  let
126    val (train,test) = part_pct (10.0/11.0) (shuffle l)
127    val _ = export_data (train,test)
128    fun f (graph,poly) = ([], graph_to_bl graph, 2 * poly_size poly)
129  in
130    (dict_sort fullboard_compare (map f train),
131     dict_sort fullboard_compare (map f test))
132  end
133
134fun export_targetl name targetl =
135  let val _ = mkDir_err targetdir in
136    write_boardl (targetdir ^ "/" ^ name) targetl
137  end
138
139fun import_targetl name = read_boardl (targetdir ^ "/" ^ name)
140
141fun mk_targetd l1 =
142  let
143    val l2 = number_snd 0 l1
144    val l3 = map (fn (x,i) => (x,(i,[]))) l2
145  in
146    dnew board_compare l3
147  end
148
149(* -------------------------------------------------------------------------
150   Neural network representation of the board
151   ------------------------------------------------------------------------- *)
152
153fun term_of_graph graph =
154  mk_embedding_var
155  (Vector.fromList (map (fn x => if x then 1.0 else ~1.0) graph), bool)
156
157val head_eval = mk_var ("head_eval", ``:bool -> 'a``)
158val head_poli = mk_var ("head_poli", ``:bool -> 'a``)
159fun tag_heval x = mk_comb (head_eval,x)
160fun tag_hpoli x = mk_comb (head_poli,x)
161val graph_tag = mk_var ("graph_tag", ``:bool -> num``)
162fun tag_graph x = mk_comb (graph_tag,x)
163
164fun tob1 (poly,graph,_) =
165  let
166    val (tm1,tm2) = (term_of_poly poly, tag_graph (term_of_graph graph))
167    val tm = mk_eq (tm1,tm2)
168  in
169    [tag_heval tm, tag_hpoli tm]
170  end
171
172fun tob2 embedv (poly,_,_) =
173  let
174    val (tm1,tm2) = (term_of_poly poly, embedv)
175    val tm = mk_eq (tm1,tm2)
176  in
177    [tag_heval tm, tag_hpoli tm]
178  end
179
180fun pretob boardtnno = case boardtnno of
181    NONE => tob1
182  | SOME ((_,graph,_),tnn) =>
183    tob2 (precomp_embed tnn (tag_graph (term_of_graph graph)))
184
185(* -------------------------------------------------------------------------
186   Player
187   ------------------------------------------------------------------------- *)
188
189val schedule =
190  [{ncore = 4, verbose = true, learning_rate = 0.02,
191    batch_size = 16, nepoch = 10}]
192
193val dioph_operl =
194  [``$= : num -> num -> bool``,
195   graph_tag,``$+``,``$*``,mk_var ("start",``:num``)] @
196  map (fn i => mk_var ("n" ^ its i,``:num``)) numberl @
197  List.concat
198    (List.tabulate (maxvar, fn v =>
199     List.tabulate (maxexponent + 1, fn p =>
200       mk_var ("v" ^ its v ^ "p" ^ its p,``:num``))))
201
202val dim = 16
203fun dim_head_poli n = [dim,n]
204
205val tnndim = map_assoc (dim_std (1,dim)) dioph_operl @
206  [(head_eval,[dim,dim,1]),(head_poli,[dim,dim,length movel])]
207val dplayer = {pretob = pretob, tnndim = tnndim, schedule = schedule}
208
209(* -------------------------------------------------------------------------
210   Interface
211   ------------------------------------------------------------------------- *)
212
213val rlparam =
214  {expname = "mleDiophSynt-" ^ its version, exwindow = 200000,
215   ncore = 30, ntarget = 200, nsim = 32000, decay = 1.0}
216
217val rlobj : (board,move) rlobj =
218  {rlparam = rlparam, game = game, gameio = gameio, dplayer = dplayer,
219   infobs = fn _ => ()}
220
221val extsearch = mk_extsearch "mleDiophSynt.extsearch" rlobj
222
223(* -------------------------------------------------------------------------
224   Final test
225   ------------------------------------------------------------------------- *)
226
227(*
228val ft_extsearch_uniform =
229  ft_mk_extsearch "mleDiophSynt.ft_extsearch_uniform" rlobj
230    (uniform_player game)
231
232fun graph_distance bl1 bl2 =
233  let
234    val bbl1 = combine (bl1,bl2)
235    val bbl2 = filter (fn (a,b) => a = b) bbl1
236  in
237    int_div (length bbl2) (length bl1)
238  end
239
240fun distance_player (board as (poly,set,_)) =
241  let
242    val e = if null poly
243            then 0.0
244            else graph_distance (graph_to_bl (dioph_set poly)) set
245  in
246    (e, map (fn x => (x,1.0)) (available_movel board))
247  end
248
249val ft_extsearch_distance =
250  ft_mk_extsearch "mleDiophSynt.ft_extsearch_distance" rlobj distance_player
251
252val fttnn_extsearch =
253  fttnn_mk_extsearch "mleDiophSynt.fttnn_extsearch" rlobj
254
255val fttnnbs_extsearch =
256  fttnnbs_mk_extsearch "mleDiophSynt.fttnnbs_extsearch" rlobj
257*)
258
259(*
260load "aiLib"; open aiLib;
261load "mlReinforce"; open mlReinforce;
262load "mleDiophLib"; open mleDiophLib;
263load "mleDiophSynt"; open mleDiophSynt;
264
265val (dfull,ntry) = gen_diophset 0 2200 (dempty (list_compare Int.compare));
266val (train,test) = create_targetl (dlist dfull); length train; length test;
267val _ = (export_targetl "train" train; export_targetl "test" test);
268val targetl = import_targetl "train"; length targetl;
269val _ = rl_start (rlobj,extsearch) (mk_targetd targetl);
270
271val targetd = retrieve_targetd rlobj 75;
272val _ = rl_restart 75 (rlobj,extsearch) targetd;
273*)
274
275(* -------------------------------------------------------------------------
276   MCTS test for inspection of the results
277   ------------------------------------------------------------------------- *)
278
279fun solve_target (unib,tim,tnn) target =
280  let
281    val mctsparam =
282    {
283    timer = SOME tim,
284    nsim = (NONE : int option),
285    stopatwin_flag = true,
286    decay = 1.0,
287    explo_coeff = 2.0,
288    noise_all = false,
289    noise_root = false,
290    noise_coeff = 0.25,
291    noise_gen = random_real,
292    noconfl = false,
293    avoidlose = false,
294    evalwin = false
295    }
296  val pretob = (#pretob (#dplayer rlobj));
297  fun preplayer target =
298    let val tob = pretob (SOME (target,tnn)) in
299      fn board => mlReinforce.player_from_tnn tnn tob (#game rlobj) board
300    end;
301  val mctsobj =
302    {mctsparam = mctsparam, game = #game rlobj,
303     player = if unib then uniform_player (#game rlobj) else preplayer target}
304  in
305    (fst o snd) (mcts mctsobj (starttree_of mctsobj target))
306  end
307
308fun solve_diophset (unib,tim,tnn) diophset =
309  let
310    val target = ([]:poly,graph_to_bl diophset,40);
311    val tree = solve_target (unib,tim,tnn) target;
312    val b = #status (dfind [] tree) = Win;
313  in
314    if b then
315      let val nodel = trace_win tree [] in
316        print_endline (its (dlength tree));
317        print_endline (human_of_poly (#1 (#board (last nodel))))
318      end
319    else
320      (print_endline (its (dlength tree)); print_endline "Time out")
321  end
322
323(*
324load "aiLib"; open aiLib;
325load "mleDiophSynt"; open mleDiophSynt;
326val tnn = mlReinforce.retrieve_tnn rlobj 197;
327val diophset = [0,1,2,4,8];
328solve_diophset (false,60.0,tnn) diophset;
329*)
330
331(* -------------------------------------------------------------------------
332   Final testing
333   ------------------------------------------------------------------------- *)
334
335(*
336load "aiLib"; open aiLib;
337load "smlParallel"; open smlParallel;
338load "mlTreeNeuralNetwork"; open mlTreeNeuralNetwork;
339load "mleDiophSynt"; open mleDiophSynt;
340
341val dir1 = HOLDIR ^ "/examples/AI_tasks/dioph_results_nolimit";
342val _ = mkDir_err dir1;
343fun store_result dir (a,i) =
344  #write_result ft_extsearch_uniform (dir ^ "/" ^ its i) a;
345
346(*** Testing set ***)
347val dataset = "test";
348val pretargetl = import_targetl dataset;
349val targetl = map (fn (a,b,_) => (a,b,1000000)) pretargetl;
350length targetl;
351(* uniform *)
352val (l1',t) = add_time (parmap_queue_extern 20 ft_extsearch_uniform ()) targetl;
353val winb = filter I (map #1 l1'); length winb;
354val dir2 = dir1 ^ "/" ^ dataset ^ "_uniform";
355val _ = mkDir_err dir2; app (store_result dir2) (number_snd 0 l1');
356(* distance *)
357val (l2',t) =
358  add_time (parmap_queue_extern 20 ft_extsearch_distance ()) targetl;
359val winb = filter I (map #1 l2'); length winb;
360val dir2 = dir1 ^ "/" ^ dataset ^ "_distance";
361val _ = mkDir_err dir2; app (store_result dir2) (number_snd 0 l2');
362(* tnn *)
363val tnn = mlReinforce.retrieve_tnn rlobj 197;
364val (l3',t) = add_time (parmap_queue_extern 20 fttnn_extsearch tnn) targetl;
365val winb = filter I (map #1 l3'); length winb;
366val dir2 = dir1 ^ "/" ^ dataset ^ "_tnn";
367val _ = mkDir_err dir2; app (store_result dir2) (number_snd 0 l3');
368
369(*** Training set ***)
370val dataset = "train";
371val pretargetl = import_targetl dataset;
372val targetl = map (fn (a,b,_) => (a,b,1000000)) pretargetl;
373length targetl;
374(* uniform *)
375val (l1,t) = add_time (parmap_queue_extern 20 ft_extsearch_uniform ()) targetl;
376val winb = filter I (map #1 l1); length winb;
377val dir2 = dir1 ^ "/" ^ dataset ^ "_uniform";
378val _ = mkDir_err dir2; app (store_result dir2) (number_snd 0 l1);
379(* distance *)
380val (l2,t) =
381  add_time (parmap_queue_extern 20 ft_extsearch_distance ()) targetl;
382val winb = filter I (map #1 l2); length winb;
383val dir2 = dir1 ^ "/" ^ dataset ^ "_distance";
384val _ = mkDir_err dir2; app (store_result dir2) (number_snd 0 l2);
385(* tnn *)
386val tnn = mlReinforce.retrieve_tnn rlobj 197;
387val (l3,t) = add_time (parmap_queue_extern 20 fttnn_extsearch tnn) targetl;
388val winb = filter I (map #1 l3); length winb;
389val dir2 = dir1 ^ "/" ^ dataset ^ "_tnn";
390val _ = mkDir_err dir2; app (store_result dir2) (number_snd 0 l3);
391*)
392
393(* -------------------------------------------------------------------------
394   Final testing statistics
395   ------------------------------------------------------------------------- *)
396
397(*
398load "aiLib"; open aiLib;
399load "mleDiophLib"; open mleDiophLib;
400load "mleDiophSynt"; open mleDiophSynt;
401
402val dir2 = HOLDIR ^ "/examples/AI_tasks/dioph_results/test_tnn_nolimit";
403fun g i = #read_result ft_extsearch_uniform (dir2 ^ "/" ^ its i);
404val l1 = List.tabulate (200,g);
405val dir2 = HOLDIR ^ "/examples/AI_tasks/dioph_results/train_tnn";
406fun g i = #read_result ft_extsearch_uniform (dir2 ^ "/" ^ its i);
407val l2 = List.tabulate (2000,g);
408
409val (l3,l3') = partition #1 (l1 @ l2);
410val nsim_tnn = average_int (map #2 l3');
411val l4 = map (valOf o #3) l3;
412val l5 = map (fn (a,b,c) => veri_of_poly a) l4;
413val l6 = map (fn (a,b,c) => ((graph_to_il b, veri_of_poly a), poly_size a))
414l4;
415val l7 = dict_sort compare_imax l6;
416hd l7;
417val d = dnew (list_compare Int.compare) l6;
418
419val l6 = map (fn (a,b,c) => (graph_to_il b, veri_of_poly a, c)) l4;
420
421val longest =
422  let fun cmp (a,b) = Int.compare (#2 b, #2 a) in
423    dict_sort cmp l3
424  end;
425
426val (a,b,c) = valOf (#3 (hd longest));
427veri_of_poly a;
428graph_to_il b;
429
430
431val monol = List.concat (map numSyntax.strip_plus l5);
432val monofreq = dlist (count_dict (dempty Term.compare) monol);
433val monostats = dict_sort compare_imax monofreq;
434
435mleDiophLib.dioph_set (fst (hd l4));
436
437
438val dir2 = HOLDIR ^ "/examples/AI_tasks/dioph_results/test_uniform";
439fun g i = #read_result ft_extsearch_uniform (dir2 ^ "/" ^ its i);
440val l1 = List.tabulate (200,g);
441val dir2 = HOLDIR ^ "/examples/AI_tasks/dioph_results/train_uniform";
442fun g i = #read_result ft_extsearch_uniform (dir2 ^ "/" ^ its i);
443val l2 = List.tabulate (2000,g);
444
445val (l3,l3') = partition #1 (l1 @ l2);
446val nsim_uniform = average_int (map #2 l3');
447
448val dir2 = HOLDIR ^ "/examples/AI_tasks/dioph_results/test_distance";
449fun g i = #read_result ft_extsearch_uniform (dir2 ^ "/" ^ its i);
450val l1 = List.tabulate (200,g);
451val dir2 = HOLDIR ^ "/examples/AI_tasks/dioph_results/train_distance";
452fun g i = #read_result ft_extsearch_uniform (dir2 ^ "/" ^ its i);
453val l2 = List.tabulate (2000,g);
454
455val (l3,l3') = partition #1 (l1 @ l2);
456val nsim_distance = average_int (map #2 l3');
457
458*)
459
460(* -------------------------------------------------------------------------
461   Training graph
462   ------------------------------------------------------------------------- *)
463
464(*
465load "aiLib"; open aiLib;
466load "mleDiophLib"; open mleDiophLib;
467load "mleDiophSynt"; open mleDiophSynt;
468
469val targetdl = List.tabulate (230,
470  fn x => mlReinforce.retrieve_targetd rlobj (x+1));
471val l1 = map dlist targetdl;
472val l2 = map (map (snd o snd)) l1;
473
474fun btr b = if b then 1.0 else 0.0
475
476fun expectancy_one bl =
477  if null bl then 0.0 else average_real (map btr (first_n 5 bl))
478fun expectancy bll = sum_real (map expectancy_one bll);
479val expectl = map expectancy l2;
480
481fun exists_one bl = btr (exists I bl);
482fun existssol bll = sum_real (map exists_one bll);
483val esoll = map existssol l2;
484
485val graph = number_fst 0 (combine (expectl,esoll));
486fun graph_to_string (i,(r1,r2)) = its i ^ " " ^ rts r1 ^ " " ^ rts r2;
487writel "dioph_graph" ("gen exp sol" :: map graph_to_string graph);
488
489*)
490
491end (* struct *)
492