1(* ========================================================================= *)
2(* FILE          : mlReinforce.sml                                           *)
3(* DESCRIPTION   : Environnement for reinforcement learning                  *)
4(* AUTHOR        : (c) Thibault Gauthier, Czech Technical University         *)
5(* DATE          : 2019                                                      *)
6(* ========================================================================= *)
7
8structure mlReinforce :> mlReinforce =
9struct
10
11open HolKernel Abbrev boolLib aiLib psMCTS psBigSteps
12  mlNeuralNetwork mlTreeNeuralNetwork smlParallel
13
14val ERR = mk_HOL_ERR "mlReinforce"
15
16(* -------------------------------------------------------------------------
17   Logs
18   ------------------------------------------------------------------------- *)
19
20val eval_dir = HOLDIR ^ "/examples/AI_tasks/eval"
21fun log_in_eval rlobj s =
22  append_endline (eval_dir ^ "/" ^ (#expname (#rlparam rlobj)) ^ "/log") s
23fun log rlobj s = (log_in_eval rlobj s; print_endline s)
24
25(* -------------------------------------------------------------------------
26   Types
27   ------------------------------------------------------------------------- *)
28
29type 'a targetd = ('a, bool list) Redblackmap.dict
30type 'a gameio =
31  {write_boardl : string -> 'a list -> unit, read_boardl : string -> 'a list}
32type splayer =
33  {unib : bool, tnn : tnn, noiseb : bool, nsim : int}
34type 'a dplayer =
35  {pretob : ('a * tnn) option -> 'a -> term list,
36   schedule : schedule, tnndim : tnndim}
37type 'a es = (splayer, 'a, bool * 'a rlex) smlParallel.extspec
38type rlparam =
39  {expname : string, exwindow : int, ncore : int,
40   ntarget : int, nsim : int, decay : real}
41type ('a,'b) rlobj =
42  {
43  rlparam : rlparam,
44  game : ('a,'b) psMCTS.game,
45  gameio : 'a gameio,
46  dplayer : 'a dplayer,
47  infobs : 'a list -> unit
48  }
49
50(* -------------------------------------------------------------------------
51   Big steps
52   ------------------------------------------------------------------------- *)
53
54fun mk_mctsparam splayer rlobj =
55  {
56  timer = NONE, nsim = SOME (#nsim splayer), stopatwin_flag = false,
57  decay = #decay (#rlparam rlobj), explo_coeff = 2.0,
58  noise_all = false, noise_root = (#noiseb splayer),
59  noise_coeff = 0.25, noise_gen = random_real,
60  noconfl = false, avoidlose = false,
61  evalwin = true (* todo : reflect this parameter in rlparam *)
62  }
63
64fun player_from_tnn tnn tob game board =
65  let
66    val amovel = (#available_movel game) board
67    val (e,p) = pair_of_list (map snd (infer_tnn tnn (tob board)))
68    val d = dnew (#move_compare game) (combine (#movel game,p))
69    fun f x = dfind x d handle NotFound => raise ERR "player_from_tnn" ""
70  in
71    (singleton_of_list e, map_assoc f amovel)
72  end
73
74fun mk_bsobj rlobj (splayer as {unib,tnn,noiseb,nsim}) =
75  let
76    val game = #game rlobj
77    val pretob = #pretob (#dplayer rlobj)
78    fun preplayer target =
79      let val tob = pretob (SOME (target,tnn)) in
80        fn board => player_from_tnn tnn tob game board
81      end
82    fun random_preplayer target board = random_player game board
83  in
84    {
85    verbose = false, temp_flag = false,
86    preplayer = if unib then random_preplayer else preplayer,
87    game = game,
88    mctsparam = mk_mctsparam splayer rlobj
89    }
90  end
91
92(* -------------------------------------------------------------------------
93   I/O for external parallelization
94   ------------------------------------------------------------------------- *)
95
96fun write_rlex gameio file rlex =
97  let val (boardl,rll) = split rlex in
98    (#write_boardl gameio) (file ^ "_boardl") boardl;
99    writel (file ^ "_rll") (map reall_to_string rll)
100  end
101
102fun read_rlex gameio file =
103  let
104    val boardl = (#read_boardl gameio) (file ^ "_boardl")
105    val rll = map string_to_reall (readl (file ^ "_rll"))
106  in
107    combine (boardl,rll)
108  end
109
110fun write_splayer file {unib,tnn,noiseb,nsim} =
111  (
112  write_tnn (file ^ "_tnn") tnn;
113  writel (file ^ "_flags") [String.concatWith " " (map bts [unib,noiseb])];
114  writel (file ^ "_nsim") [its nsim]
115  )
116
117fun read_splayer file =
118  let
119    val tnn = read_tnn (file ^ "_tnn")
120    val (unib,noiseb) =
121      pair_of_list (map string_to_bool
122        (String.tokens Char.isSpace
123           (singleton_of_list (readl (file ^ "_flags")))))
124    val nsim = string_to_int (singleton_of_list (readl (file ^ "_nsim")))
125  in
126    {unib=unib,tnn=tnn,noiseb=noiseb,nsim=nsim}
127  end
128
129fun write_result gameio file (b,rlex) =
130  (writel (file ^ "_bstatus") [bts b];
131   write_rlex gameio (file ^ "_rlex") rlex)
132
133fun read_result gameio file =
134  (string_to_bool (singleton_of_list (readl (file ^ "_bstatus"))),
135   read_rlex gameio (file ^ "_rlex"))
136
137fun write_target gameio file target =
138  (#write_boardl gameio) (file ^ "_target") [target]
139  handle Subscript => raise ERR "write_target" ""
140
141fun read_target gameio file =
142  singleton_of_list ((#read_boardl gameio) (file ^ "_target"))
143
144(* -------------------------------------------------------------------------
145   I/O for storage and restart
146   ------------------------------------------------------------------------- *)
147
148(* Example *)
149fun rlex_file rlobj n =
150  eval_dir ^ "/" ^ (#expname (#rlparam rlobj)) ^ "/rlex" ^ its n
151
152fun store_rlex rlobj n rlex =
153  write_rlex (#gameio rlobj) (rlex_file rlobj n) rlex
154
155fun gather_ex rlobj acc n =
156  let
157    val exwindow = #exwindow (#rlparam rlobj)
158    fun read_ex () = read_rlex (#gameio rlobj) (rlex_file rlobj n)
159  in
160    if n < 0 orelse length acc > exwindow
161    then first_n exwindow (rev acc)
162    else gather_ex rlobj (read_ex () @ acc) (n-1)
163  end
164fun retrieve_rlex rlobj n = gather_ex rlobj [] n
165
166(* TNN *)
167fun tnn_file rlobj n =
168  eval_dir ^ "/" ^ (#expname (#rlparam rlobj)) ^ "/tnn" ^ its n
169fun store_tnn rlobj n tnn = write_tnn (tnn_file rlobj n) tnn
170fun retrieve_tnn rlobj n = read_tnn (tnn_file rlobj n)
171
172(* Target *)
173fun targetd_file rlobj n =
174  eval_dir ^ "/" ^ (#expname (#rlparam rlobj)) ^ "/targetd" ^ its n
175
176fun blts bl = String.concatWith " " (map bts bl)
177fun stbl s = map string_to_bool (String.tokens Char.isSpace s)
178
179fun store_targetd rlobj n targetd =
180  let
181    val file = targetd_file rlobj n
182    val (l1,l2) = split (dlist targetd)
183  in
184    #write_boardl (#gameio rlobj) (file ^ "_boardl") l1;
185    writel (file ^ "_bl") (map blts l2)
186  end
187
188fun retrieve_targetd rlobj n =
189  let
190    val file = targetd_file rlobj n
191    val l1 = #read_boardl (#gameio rlobj) (file ^ "_boardl")
192    val l2 = map stbl (readl (file ^ "_bl"))
193  in
194    dnew (#board_compare (#game rlobj)) (combine (l1,l2))
195  end
196
197(* -------------------------------------------------------------------------
198   External parallelization
199   ------------------------------------------------------------------------- *)
200
201fun extsearch_fun rlobj splayer target =
202  let
203    val bsobj = mk_bsobj rlobj splayer
204    val (b1,rlex,nodel) = run_bigsteps bsobj target
205  in
206    (#infobs rlobj (map #board nodel); (b1,rlex))
207  end
208  handle Subscript => raise ERR "extsearch_fun" "subscript"
209
210fun mk_extsearch self (rlobj as {rlparam,gameio,...}) =
211  {
212  self = self,
213  parallel_dir = default_parallel_dir ^ "_search",
214  reflect_globals = fn () => "()",
215  function = extsearch_fun rlobj,
216  write_param = write_splayer,
217  read_param = read_splayer,
218  write_arg = write_target gameio,
219  read_arg = read_target gameio,
220  write_result = write_result gameio,
221  read_result = read_result gameio
222  }
223
224(* -------------------------------------------------------------------------
225   Training
226   ------------------------------------------------------------------------- *)
227
228fun rl_train ngen rlobj rlex =
229  let
230    val {pretob,schedule,tnndim} = #dplayer rlobj
231    fun tob board = pretob NONE board
232    fun f (a,b) = combine (tob a,[[hd b],tl b])
233    val tnnex = map f rlex
234    val uex = mk_fast_set (list_compare Term.compare) (map (tob o fst) rlex)
235    val _ = log rlobj ("Training examples: " ^ its (length rlex))
236    val _ = log rlobj ("Training unique  : " ^ its (length uex))
237    val randtnn = random_tnn tnndim
238    val (tnn,t) = add_time (train_tnn schedule randtnn) (tnnex,[])
239  in
240    log rlobj ("Training time: " ^ rts t);
241    store_tnn rlobj ngen tnn;
242    tnn
243  end
244
245(* -------------------------------------------------------------------------
246   Exploration
247   ------------------------------------------------------------------------- *)
248
249fun rl_explore_targetl (unib,noiseb) (rlobj,es) tnn targetl =
250  let
251    val {ncore,nsim,...} = #rlparam rlobj
252    val splayer = {unib=unib,tnn=tnn,noiseb=noiseb,nsim=nsim}
253    val (l,t) = add_time (parmap_queue_extern ncore es splayer) targetl
254    val _ =  log rlobj ("Exploration time: " ^ rts t)
255    val resultl = combine (targetl, map fst l)
256    val nwin = length (filter fst l)
257    val _ = log rlobj ("Exploration wins: " ^ its nwin)
258    val rlex = List.concat (map snd l)
259    val _ = log rlobj ("Exploration new examples: " ^ its (length rlex))
260  in
261    (rlex,resultl)
262  end
263
264fun rl_compete_targetl unib (rlobj,es) tnn targetl =
265  rl_explore_targetl (unib,false) (rlobj,es) tnn targetl
266
267(*
268fun row_win l =
269  case l of [] => 0 | a :: m => if a then 1 + row_win m else 0
270fun row_lose l =
271  case l of [] => 0 | a :: m => if not a then 1 + row_lose m else 0
272fun row_either l = Int.max (row_lose l, row_win l)
273fun exists_win l = exists I l
274
275fun stats_select_one rlobj (s,targetl) =
276  let
277    val il = map (row_either o snd o snd) targetl
278    fun f (a,b) = its a ^ "-" ^ its b
279    val l = dlist (count_dict (dempty Int.compare) il)
280  in
281    log rlobj ("  " ^ s ^ " tot-" ^ its (length il) ^ "  " ^
282      (String.concatWith " " (map f l)))
283  end
284
285fun stats_select rlobj nfin nwin (neg,pos,negsel,possel) =
286  let
287    val l = [("neg:",neg),("ns :",negsel),("pos:",pos),("ps :",possel)]
288  in
289    log rlobj ("Exploration: " ^ its nfin ^ " targets ");
290    log rlobj ("Exploration: " ^ its nwin ^ " targets proven at least once");
291    app (stats_select_one rlobj) l
292  end
293*)
294
295(*
296fun select_from_targetd rlobj ntot targetd =
297  let
298    val targetwinl = map (fn (a,(b,c,_)) => (a,(b,c))) (dlist targetd)
299    fun f x = 1.0 / (1.0 + Real.fromInt x)
300    fun g x =
301      let
302        val y = random_real ()
303        val y' = if y < epsilon then epsilon else y
304      in
305        x / y'
306      end
307    fun h (a,(b,winl)) = ((a,(b,winl)), (g o f o row_either) winl)
308    fun test (_,(_,winl)) = null winl orelse not (hd winl)
309    val (neg,pos) = partition test targetwinl
310    val negsel = first_n (ntot div 2) (dict_sort compare_rmax (map h neg))
311    val possel = first_n (ntot div 2) (dict_sort compare_rmax (map h pos))
312    val lfin = map (fst o fst) (rev negsel @ possel)
313    val lwin = filter exists_win (map (snd o snd) targetwinl)
314
315  in
316    stats_select rlobj (length lfin)
317       (length lwin) (neg,pos, map fst negsel, map fst possel);
318    lfin
319  end
320*)
321
322fun select_from_targetd rlobj ntot targetd = dkeys targetd
323  (* map (#modify_board rlobj targetd) *)
324
325fun update_targetd ((board,b),targetd) =
326  let val bl = dfind board targetd handle NotFound => [] in
327    dadd board (b :: bl) targetd
328  end
329
330fun rl_explore_targetd unib (rlobj,es) (tnn,targetd) =
331  let
332    val rlparam = #rlparam rlobj
333    val targetl = select_from_targetd rlobj (#ntarget rlparam) targetd
334    val (rlex,resultl) =
335      rl_explore_targetl (unib,true) (rlobj,es) tnn targetl
336    val newtargetd = foldl update_targetd targetd resultl
337  in
338    (rlex,newtargetd)
339  end
340
341fun rl_explore_init ngen (rlobj,es) targetd =
342  let
343    val _ = log rlobj "Exploration: initialization"
344    val dummy = random_tnn (#tnndim (#dplayer rlobj))
345    val rlparam = #rlparam rlobj
346    val targetl = select_from_targetd rlobj (#ntarget rlparam) targetd
347    val (rlex,resultl) =
348      rl_explore_targetl (true,false) (rlobj,es) dummy targetl
349    val newtargetd = foldl update_targetd targetd resultl
350  in
351    store_rlex rlobj ngen rlex;
352    store_targetd rlobj ngen newtargetd;
353    (rlex,newtargetd)
354  end
355
356fun rl_explore_cont ngen (rlobj,es) (tnn,rlex,targetd) =
357  let
358    val (rlex1,newtargetd) = rl_explore_targetd false (rlobj,es) (tnn,targetd)
359    val rlex2 = first_n (#exwindow (#rlparam rlobj)) (rlex1 @ rlex)
360  in
361    store_rlex rlobj ngen rlex1;
362    store_targetd rlobj ngen newtargetd;
363    (rlex2, newtargetd)
364  end
365
366(* -------------------------------------------------------------------------
367   Reinforcement learning loop
368   ------------------------------------------------------------------------- *)
369
370fun rl_loop ngen (rlobj,es) (rlex,targetd) =
371  let
372    val _ = log rlobj ("\nGeneration " ^ its ngen)
373    val tnn = rl_train ngen rlobj rlex
374    val (newrlex,newtargetd) = rl_explore_cont ngen (rlobj,es) (tnn,rlex,targetd)
375  in
376    rl_loop (ngen + 1) (rlobj,es) (newrlex,newtargetd)
377  end
378
379fun rl_start (rlobj,es) targetd =
380  let
381    val expdir = eval_dir ^ "/" ^ #expname (#rlparam rlobj)
382    val _ = app mkDir_err [eval_dir,expdir]
383    val (rlex,newtargetd) = rl_explore_init 0 (rlobj,es) targetd
384  in
385    rl_loop 1 (rlobj,es) (rlex,newtargetd)
386  end
387
388fun rl_restart ngen (rlobj,es) targetd =
389  let
390    val expdir = eval_dir ^ "/" ^ #expname (#rlparam rlobj)
391    val _ = app mkDir_err [eval_dir,expdir]
392    val rlex = retrieve_rlex rlobj ngen
393  in
394    rl_loop (ngen + 1) (rlobj,es) (rlex,targetd)
395  end
396
397(* -------------------------------------------------------------------------
398   Final MCTS Testing
399   ------------------------------------------------------------------------- *)
400
401(*
402type 'a ftes = (unit, 'a, bool * int * 'a option) smlParallel.extspec
403type 'a fttnnes = (tnn, 'a, bool * int * 'a option) smlParallel.extspec
404
405fun option_to_list vo = if isSome vo then [valOf vo] else []
406fun list_to_option l =
407  case l of [] => NONE | [a] => SOME a | _ => raise ERR "list_to_option" ""
408
409fun ft_write_result gameio file (b,nstep,boardo) =
410  (
411  writel (file ^ "_bstatus") [bts b];
412  writel (file ^ "_nstep") [its nstep];
413  (#write_boardl gameio) (file ^ "_boardo") (option_to_list boardo)
414  )
415
416fun ft_read_result gameio file =
417  let
418    val s1 = singleton_of_list (readl (file ^ "_bstatus"))
419    val s2 = singleton_of_list (readl (file ^ "_nstep"))
420  in
421    (
422    string_to_bool s1,
423    string_to_int s2,
424    list_to_option ((#read_boardl gameio) (file ^ "_boardo"))
425    )
426  end
427
428val ft_mctsparam =
429  {
430  timer = SOME 60.0,
431  nsim = NONE, stopatwin_flag = true,
432  decay = 1.0, explo_coeff = 2.0,
433  noise_all = false, noise_root = false,
434  noise_coeff = 0.25, noise_gen = random_real,
435  noconfl = false, avoidlose = false
436  }
437
438fun mk_ft_mctsparam tim =
439  {
440  timer = SOME tim,
441  nsim = NONE, stopatwin_flag = true,
442  decay = 1.0, explo_coeff = 2.0,
443  noise_all = false, noise_root = false,
444  noise_coeff = 0.25, noise_gen = random_real,
445  noconfl = false, avoidlose = false
446  }
447
448fun ft_extsearch_fun rlobj player (_:unit) target =
449  let
450    val mctsobj =
451      {mctsparam = ft_mctsparam, game = #game rlobj, player = player}
452    val (_,(tree,_)) = mcts mctsobj (starttree_of mctsobj target)
453    val b = is_win (#status (dfind [] tree))
454    val boardo = if not b then NONE else
455      let val nodel = trace_win tree [] in
456        SOME (#board (last nodel))
457      end
458  in
459    (b,dlength tree-1,boardo)
460  end
461
462fun ft_mk_extsearch self (rlobj as {rlparam,gameio,...}) player =
463  {
464  self = self,
465  parallel_dir = default_parallel_dir ^ "_finaltest",
466  reflect_globals = fn () => "()",
467  function = ft_extsearch_fun rlobj player,
468  write_param = fn _ => (fn _ => ()),
469  read_param = fn _ => (),
470  write_arg = write_target gameio,
471  read_arg = read_target gameio,
472  write_result = ft_write_result gameio,
473  read_result = ft_read_result gameio
474  }
475
476
477fun fttnn_extsearch_fun rlobj tnn target =
478  let
479    val pretob = #pretob (#dplayer rlobj)
480    val game = #game rlobj
481    fun preplayer target' =
482      let val tob = pretob (SOME (target',tnn)) in
483        fn board => player_from_tnn tnn tob game board
484      end
485    val mctsobj =
486      {mctsparam = ft_mctsparam, game = #game rlobj,
487       player = preplayer target}
488    val (_,(tree,_)) = mcts mctsobj (starttree_of mctsobj target)
489    val b = #status (dfind [] tree) = Win
490    val boardo = if not b then NONE else
491      let val nodel = trace_win tree [] in
492        SOME (#board (last nodel))
493      end
494  in
495    (b,dlength tree-1,boardo)
496  end
497
498fun fttnn_mk_extsearch self (rlobj as {rlparam,gameio,...}) =
499  {
500  self = self,
501  parallel_dir = default_parallel_dir ^ "_finaltest",
502  reflect_globals = fn () => "()",
503  function = fttnn_extsearch_fun rlobj,
504  write_param = (fn file => (fn tnn => write_tnn (file ^ "_tnn") tnn)),
505  read_param = (fn file => read_tnn (file ^ "_tnn")),
506  write_arg = write_target gameio,
507  read_arg = read_target gameio,
508  write_result = ft_write_result gameio,
509  read_result = ft_read_result gameio
510  }
511
512
513fun mk_dis tree =
514  let
515    val pol = #pol (dfind [] tree)
516    val _ = if null pol then raise ERR "mk_dis" "pol" else ()
517    fun f (_,cid) = #vis (dfind cid tree) handle NotFound => 0.0
518    val dis = map_assoc f pol
519    val tot = sum_real (map snd dis)
520    val _ = if tot < 0.5 then raise ERR "mk_dis" "tot" else ()
521  in
522    (dis,tot)
523  end
524
525fun select_bigstep tree = snd (best_in_distrib (fst (mk_dis tree)))
526
527fun fttnnbs_extsearch_fun rlobj tnn target =
528  let
529    val pretob = #pretob (#dplayer rlobj)
530    val game = #game rlobj
531    fun preplayer target' =
532      let val tob = pretob (SOME (target',tnn)) in
533        fn board => player_from_tnn tnn tob game board
534      end
535    val timerl = List.tabulate (30, fn _ => int_div 60 30)
536    val startmctsobj =
537      {mctsparam = mk_ft_mctsparam (hd timerl), game = #game rlobj,
538       player = preplayer target}
539    val (starttree,startcache) = starttree_of startmctsobj target
540    fun loop timl (tree,cache) =
541      if null timl then (false, 8, NONE) else
542      let
543        val mctsobj =
544          {mctsparam = mk_ft_mctsparam (hd timl), game = #game rlobj,
545           player = preplayer target}
546        val (_,(endtree,_)) = mcts mctsobj (tree,cache)
547        val cid = select_bigstep endtree
548        val status = #status (dfind [] endtree)
549      in
550        if status = Undecided
551          then
552            let
553              val newtree = cut_tree cid endtree
554              val newcache = build_cache game newtree
555            in
556              loop (tl timl) (newtree,newcache)
557            end
558        else if status = Win
559          then
560            let val nodel = trace_win endtree [] in
561              (status = Win, 8 - length timl, SOME (#board (last nodel)))
562            end
563        else (status = Win, 8 - length timl, NONE)
564      end
565  in
566    loop timerl (starttree,startcache)
567  end
568
569fun fttnnbs_mk_extsearch self (rlobj as {rlparam,gameio,...}) =
570  {
571  self = self,
572  parallel_dir = default_parallel_dir ^ "_finaltest",
573  reflect_globals = fn () => "()",
574  function = fttnnbs_extsearch_fun rlobj,
575  write_param = (fn file => (fn tnn => write_tnn (file ^ "_tnn") tnn)),
576  read_param = (fn file => read_tnn (file ^ "_tnn")),
577  write_arg = write_target gameio,
578  read_arg = read_target gameio,
579  write_result = ft_write_result gameio,
580  read_result = ft_read_result gameio
581  }
582*)
583
584
585end (* struct *)
586