1(* ========================================================================= *) 2(* FILE : mlReinforce.sml *) 3(* DESCRIPTION : Environnement for reinforcement learning *) 4(* AUTHOR : (c) Thibault Gauthier, Czech Technical University *) 5(* DATE : 2019 *) 6(* ========================================================================= *) 7 8structure mlReinforce :> mlReinforce = 9struct 10 11open HolKernel Abbrev boolLib aiLib psMCTS psBigSteps 12 mlNeuralNetwork mlTreeNeuralNetwork smlParallel 13 14val ERR = mk_HOL_ERR "mlReinforce" 15 16(* ------------------------------------------------------------------------- 17 Logs 18 ------------------------------------------------------------------------- *) 19 20val eval_dir = HOLDIR ^ "/examples/AI_tasks/eval" 21fun log_in_eval rlobj s = 22 append_endline (eval_dir ^ "/" ^ (#expname (#rlparam rlobj)) ^ "/log") s 23fun log rlobj s = (log_in_eval rlobj s; print_endline s) 24 25(* ------------------------------------------------------------------------- 26 Types 27 ------------------------------------------------------------------------- *) 28 29type 'a targetd = ('a, bool list) Redblackmap.dict 30type 'a gameio = 31 {write_boardl : string -> 'a list -> unit, read_boardl : string -> 'a list} 32type splayer = 33 {unib : bool, tnn : tnn, noiseb : bool, nsim : int} 34type 'a dplayer = 35 {pretob : ('a * tnn) option -> 'a -> term list, 36 schedule : schedule, tnndim : tnndim} 37type 'a es = (splayer, 'a, bool * 'a rlex) smlParallel.extspec 38type rlparam = 39 {expname : string, exwindow : int, ncore : int, 40 ntarget : int, nsim : int, decay : real} 41type ('a,'b) rlobj = 42 { 43 rlparam : rlparam, 44 game : ('a,'b) psMCTS.game, 45 gameio : 'a gameio, 46 dplayer : 'a dplayer, 47 infobs : 'a list -> unit 48 } 49 50(* ------------------------------------------------------------------------- 51 Big steps 52 ------------------------------------------------------------------------- *) 53 54fun mk_mctsparam splayer rlobj = 55 { 56 timer = NONE, nsim = SOME (#nsim splayer), stopatwin_flag = false, 57 decay = #decay (#rlparam rlobj), explo_coeff = 2.0, 58 noise_all = false, noise_root = (#noiseb splayer), 59 noise_coeff = 0.25, noise_gen = random_real, 60 noconfl = false, avoidlose = false, 61 evalwin = true (* todo : reflect this parameter in rlparam *) 62 } 63 64fun player_from_tnn tnn tob game board = 65 let 66 val amovel = (#available_movel game) board 67 val (e,p) = pair_of_list (map snd (infer_tnn tnn (tob board))) 68 val d = dnew (#move_compare game) (combine (#movel game,p)) 69 fun f x = dfind x d handle NotFound => raise ERR "player_from_tnn" "" 70 in 71 (singleton_of_list e, map_assoc f amovel) 72 end 73 74fun mk_bsobj rlobj (splayer as {unib,tnn,noiseb,nsim}) = 75 let 76 val game = #game rlobj 77 val pretob = #pretob (#dplayer rlobj) 78 fun preplayer target = 79 let val tob = pretob (SOME (target,tnn)) in 80 fn board => player_from_tnn tnn tob game board 81 end 82 fun random_preplayer target board = random_player game board 83 in 84 { 85 verbose = false, temp_flag = false, 86 preplayer = if unib then random_preplayer else preplayer, 87 game = game, 88 mctsparam = mk_mctsparam splayer rlobj 89 } 90 end 91 92(* ------------------------------------------------------------------------- 93 I/O for external parallelization 94 ------------------------------------------------------------------------- *) 95 96fun write_rlex gameio file rlex = 97 let val (boardl,rll) = split rlex in 98 (#write_boardl gameio) (file ^ "_boardl") boardl; 99 writel (file ^ "_rll") (map reall_to_string rll) 100 end 101 102fun read_rlex gameio file = 103 let 104 val boardl = (#read_boardl gameio) (file ^ "_boardl") 105 val rll = map string_to_reall (readl (file ^ "_rll")) 106 in 107 combine (boardl,rll) 108 end 109 110fun write_splayer file {unib,tnn,noiseb,nsim} = 111 ( 112 write_tnn (file ^ "_tnn") tnn; 113 writel (file ^ "_flags") [String.concatWith " " (map bts [unib,noiseb])]; 114 writel (file ^ "_nsim") [its nsim] 115 ) 116 117fun read_splayer file = 118 let 119 val tnn = read_tnn (file ^ "_tnn") 120 val (unib,noiseb) = 121 pair_of_list (map string_to_bool 122 (String.tokens Char.isSpace 123 (singleton_of_list (readl (file ^ "_flags"))))) 124 val nsim = string_to_int (singleton_of_list (readl (file ^ "_nsim"))) 125 in 126 {unib=unib,tnn=tnn,noiseb=noiseb,nsim=nsim} 127 end 128 129fun write_result gameio file (b,rlex) = 130 (writel (file ^ "_bstatus") [bts b]; 131 write_rlex gameio (file ^ "_rlex") rlex) 132 133fun read_result gameio file = 134 (string_to_bool (singleton_of_list (readl (file ^ "_bstatus"))), 135 read_rlex gameio (file ^ "_rlex")) 136 137fun write_target gameio file target = 138 (#write_boardl gameio) (file ^ "_target") [target] 139 handle Subscript => raise ERR "write_target" "" 140 141fun read_target gameio file = 142 singleton_of_list ((#read_boardl gameio) (file ^ "_target")) 143 144(* ------------------------------------------------------------------------- 145 I/O for storage and restart 146 ------------------------------------------------------------------------- *) 147 148(* Example *) 149fun rlex_file rlobj n = 150 eval_dir ^ "/" ^ (#expname (#rlparam rlobj)) ^ "/rlex" ^ its n 151 152fun store_rlex rlobj n rlex = 153 write_rlex (#gameio rlobj) (rlex_file rlobj n) rlex 154 155fun gather_ex rlobj acc n = 156 let 157 val exwindow = #exwindow (#rlparam rlobj) 158 fun read_ex () = read_rlex (#gameio rlobj) (rlex_file rlobj n) 159 in 160 if n < 0 orelse length acc > exwindow 161 then first_n exwindow (rev acc) 162 else gather_ex rlobj (read_ex () @ acc) (n-1) 163 end 164fun retrieve_rlex rlobj n = gather_ex rlobj [] n 165 166(* TNN *) 167fun tnn_file rlobj n = 168 eval_dir ^ "/" ^ (#expname (#rlparam rlobj)) ^ "/tnn" ^ its n 169fun store_tnn rlobj n tnn = write_tnn (tnn_file rlobj n) tnn 170fun retrieve_tnn rlobj n = read_tnn (tnn_file rlobj n) 171 172(* Target *) 173fun targetd_file rlobj n = 174 eval_dir ^ "/" ^ (#expname (#rlparam rlobj)) ^ "/targetd" ^ its n 175 176fun blts bl = String.concatWith " " (map bts bl) 177fun stbl s = map string_to_bool (String.tokens Char.isSpace s) 178 179fun store_targetd rlobj n targetd = 180 let 181 val file = targetd_file rlobj n 182 val (l1,l2) = split (dlist targetd) 183 in 184 #write_boardl (#gameio rlobj) (file ^ "_boardl") l1; 185 writel (file ^ "_bl") (map blts l2) 186 end 187 188fun retrieve_targetd rlobj n = 189 let 190 val file = targetd_file rlobj n 191 val l1 = #read_boardl (#gameio rlobj) (file ^ "_boardl") 192 val l2 = map stbl (readl (file ^ "_bl")) 193 in 194 dnew (#board_compare (#game rlobj)) (combine (l1,l2)) 195 end 196 197(* ------------------------------------------------------------------------- 198 External parallelization 199 ------------------------------------------------------------------------- *) 200 201fun extsearch_fun rlobj splayer target = 202 let 203 val bsobj = mk_bsobj rlobj splayer 204 val (b1,rlex,nodel) = run_bigsteps bsobj target 205 in 206 (#infobs rlobj (map #board nodel); (b1,rlex)) 207 end 208 handle Subscript => raise ERR "extsearch_fun" "subscript" 209 210fun mk_extsearch self (rlobj as {rlparam,gameio,...}) = 211 { 212 self = self, 213 parallel_dir = default_parallel_dir ^ "_search", 214 reflect_globals = fn () => "()", 215 function = extsearch_fun rlobj, 216 write_param = write_splayer, 217 read_param = read_splayer, 218 write_arg = write_target gameio, 219 read_arg = read_target gameio, 220 write_result = write_result gameio, 221 read_result = read_result gameio 222 } 223 224(* ------------------------------------------------------------------------- 225 Training 226 ------------------------------------------------------------------------- *) 227 228fun rl_train ngen rlobj rlex = 229 let 230 val {pretob,schedule,tnndim} = #dplayer rlobj 231 fun tob board = pretob NONE board 232 fun f (a,b) = combine (tob a,[[hd b],tl b]) 233 val tnnex = map f rlex 234 val uex = mk_fast_set (list_compare Term.compare) (map (tob o fst) rlex) 235 val _ = log rlobj ("Training examples: " ^ its (length rlex)) 236 val _ = log rlobj ("Training unique : " ^ its (length uex)) 237 val randtnn = random_tnn tnndim 238 val (tnn,t) = add_time (train_tnn schedule randtnn) (tnnex,[]) 239 in 240 log rlobj ("Training time: " ^ rts t); 241 store_tnn rlobj ngen tnn; 242 tnn 243 end 244 245(* ------------------------------------------------------------------------- 246 Exploration 247 ------------------------------------------------------------------------- *) 248 249fun rl_explore_targetl (unib,noiseb) (rlobj,es) tnn targetl = 250 let 251 val {ncore,nsim,...} = #rlparam rlobj 252 val splayer = {unib=unib,tnn=tnn,noiseb=noiseb,nsim=nsim} 253 val (l,t) = add_time (parmap_queue_extern ncore es splayer) targetl 254 val _ = log rlobj ("Exploration time: " ^ rts t) 255 val resultl = combine (targetl, map fst l) 256 val nwin = length (filter fst l) 257 val _ = log rlobj ("Exploration wins: " ^ its nwin) 258 val rlex = List.concat (map snd l) 259 val _ = log rlobj ("Exploration new examples: " ^ its (length rlex)) 260 in 261 (rlex,resultl) 262 end 263 264fun rl_compete_targetl unib (rlobj,es) tnn targetl = 265 rl_explore_targetl (unib,false) (rlobj,es) tnn targetl 266 267(* 268fun row_win l = 269 case l of [] => 0 | a :: m => if a then 1 + row_win m else 0 270fun row_lose l = 271 case l of [] => 0 | a :: m => if not a then 1 + row_lose m else 0 272fun row_either l = Int.max (row_lose l, row_win l) 273fun exists_win l = exists I l 274 275fun stats_select_one rlobj (s,targetl) = 276 let 277 val il = map (row_either o snd o snd) targetl 278 fun f (a,b) = its a ^ "-" ^ its b 279 val l = dlist (count_dict (dempty Int.compare) il) 280 in 281 log rlobj (" " ^ s ^ " tot-" ^ its (length il) ^ " " ^ 282 (String.concatWith " " (map f l))) 283 end 284 285fun stats_select rlobj nfin nwin (neg,pos,negsel,possel) = 286 let 287 val l = [("neg:",neg),("ns :",negsel),("pos:",pos),("ps :",possel)] 288 in 289 log rlobj ("Exploration: " ^ its nfin ^ " targets "); 290 log rlobj ("Exploration: " ^ its nwin ^ " targets proven at least once"); 291 app (stats_select_one rlobj) l 292 end 293*) 294 295(* 296fun select_from_targetd rlobj ntot targetd = 297 let 298 val targetwinl = map (fn (a,(b,c,_)) => (a,(b,c))) (dlist targetd) 299 fun f x = 1.0 / (1.0 + Real.fromInt x) 300 fun g x = 301 let 302 val y = random_real () 303 val y' = if y < epsilon then epsilon else y 304 in 305 x / y' 306 end 307 fun h (a,(b,winl)) = ((a,(b,winl)), (g o f o row_either) winl) 308 fun test (_,(_,winl)) = null winl orelse not (hd winl) 309 val (neg,pos) = partition test targetwinl 310 val negsel = first_n (ntot div 2) (dict_sort compare_rmax (map h neg)) 311 val possel = first_n (ntot div 2) (dict_sort compare_rmax (map h pos)) 312 val lfin = map (fst o fst) (rev negsel @ possel) 313 val lwin = filter exists_win (map (snd o snd) targetwinl) 314 315 in 316 stats_select rlobj (length lfin) 317 (length lwin) (neg,pos, map fst negsel, map fst possel); 318 lfin 319 end 320*) 321 322fun select_from_targetd rlobj ntot targetd = dkeys targetd 323 (* map (#modify_board rlobj targetd) *) 324 325fun update_targetd ((board,b),targetd) = 326 let val bl = dfind board targetd handle NotFound => [] in 327 dadd board (b :: bl) targetd 328 end 329 330fun rl_explore_targetd unib (rlobj,es) (tnn,targetd) = 331 let 332 val rlparam = #rlparam rlobj 333 val targetl = select_from_targetd rlobj (#ntarget rlparam) targetd 334 val (rlex,resultl) = 335 rl_explore_targetl (unib,true) (rlobj,es) tnn targetl 336 val newtargetd = foldl update_targetd targetd resultl 337 in 338 (rlex,newtargetd) 339 end 340 341fun rl_explore_init ngen (rlobj,es) targetd = 342 let 343 val _ = log rlobj "Exploration: initialization" 344 val dummy = random_tnn (#tnndim (#dplayer rlobj)) 345 val rlparam = #rlparam rlobj 346 val targetl = select_from_targetd rlobj (#ntarget rlparam) targetd 347 val (rlex,resultl) = 348 rl_explore_targetl (true,false) (rlobj,es) dummy targetl 349 val newtargetd = foldl update_targetd targetd resultl 350 in 351 store_rlex rlobj ngen rlex; 352 store_targetd rlobj ngen newtargetd; 353 (rlex,newtargetd) 354 end 355 356fun rl_explore_cont ngen (rlobj,es) (tnn,rlex,targetd) = 357 let 358 val (rlex1,newtargetd) = rl_explore_targetd false (rlobj,es) (tnn,targetd) 359 val rlex2 = first_n (#exwindow (#rlparam rlobj)) (rlex1 @ rlex) 360 in 361 store_rlex rlobj ngen rlex1; 362 store_targetd rlobj ngen newtargetd; 363 (rlex2, newtargetd) 364 end 365 366(* ------------------------------------------------------------------------- 367 Reinforcement learning loop 368 ------------------------------------------------------------------------- *) 369 370fun rl_loop ngen (rlobj,es) (rlex,targetd) = 371 let 372 val _ = log rlobj ("\nGeneration " ^ its ngen) 373 val tnn = rl_train ngen rlobj rlex 374 val (newrlex,newtargetd) = rl_explore_cont ngen (rlobj,es) (tnn,rlex,targetd) 375 in 376 rl_loop (ngen + 1) (rlobj,es) (newrlex,newtargetd) 377 end 378 379fun rl_start (rlobj,es) targetd = 380 let 381 val expdir = eval_dir ^ "/" ^ #expname (#rlparam rlobj) 382 val _ = app mkDir_err [eval_dir,expdir] 383 val (rlex,newtargetd) = rl_explore_init 0 (rlobj,es) targetd 384 in 385 rl_loop 1 (rlobj,es) (rlex,newtargetd) 386 end 387 388fun rl_restart ngen (rlobj,es) targetd = 389 let 390 val expdir = eval_dir ^ "/" ^ #expname (#rlparam rlobj) 391 val _ = app mkDir_err [eval_dir,expdir] 392 val rlex = retrieve_rlex rlobj ngen 393 in 394 rl_loop (ngen + 1) (rlobj,es) (rlex,targetd) 395 end 396 397(* ------------------------------------------------------------------------- 398 Final MCTS Testing 399 ------------------------------------------------------------------------- *) 400 401(* 402type 'a ftes = (unit, 'a, bool * int * 'a option) smlParallel.extspec 403type 'a fttnnes = (tnn, 'a, bool * int * 'a option) smlParallel.extspec 404 405fun option_to_list vo = if isSome vo then [valOf vo] else [] 406fun list_to_option l = 407 case l of [] => NONE | [a] => SOME a | _ => raise ERR "list_to_option" "" 408 409fun ft_write_result gameio file (b,nstep,boardo) = 410 ( 411 writel (file ^ "_bstatus") [bts b]; 412 writel (file ^ "_nstep") [its nstep]; 413 (#write_boardl gameio) (file ^ "_boardo") (option_to_list boardo) 414 ) 415 416fun ft_read_result gameio file = 417 let 418 val s1 = singleton_of_list (readl (file ^ "_bstatus")) 419 val s2 = singleton_of_list (readl (file ^ "_nstep")) 420 in 421 ( 422 string_to_bool s1, 423 string_to_int s2, 424 list_to_option ((#read_boardl gameio) (file ^ "_boardo")) 425 ) 426 end 427 428val ft_mctsparam = 429 { 430 timer = SOME 60.0, 431 nsim = NONE, stopatwin_flag = true, 432 decay = 1.0, explo_coeff = 2.0, 433 noise_all = false, noise_root = false, 434 noise_coeff = 0.25, noise_gen = random_real, 435 noconfl = false, avoidlose = false 436 } 437 438fun mk_ft_mctsparam tim = 439 { 440 timer = SOME tim, 441 nsim = NONE, stopatwin_flag = true, 442 decay = 1.0, explo_coeff = 2.0, 443 noise_all = false, noise_root = false, 444 noise_coeff = 0.25, noise_gen = random_real, 445 noconfl = false, avoidlose = false 446 } 447 448fun ft_extsearch_fun rlobj player (_:unit) target = 449 let 450 val mctsobj = 451 {mctsparam = ft_mctsparam, game = #game rlobj, player = player} 452 val (_,(tree,_)) = mcts mctsobj (starttree_of mctsobj target) 453 val b = is_win (#status (dfind [] tree)) 454 val boardo = if not b then NONE else 455 let val nodel = trace_win tree [] in 456 SOME (#board (last nodel)) 457 end 458 in 459 (b,dlength tree-1,boardo) 460 end 461 462fun ft_mk_extsearch self (rlobj as {rlparam,gameio,...}) player = 463 { 464 self = self, 465 parallel_dir = default_parallel_dir ^ "_finaltest", 466 reflect_globals = fn () => "()", 467 function = ft_extsearch_fun rlobj player, 468 write_param = fn _ => (fn _ => ()), 469 read_param = fn _ => (), 470 write_arg = write_target gameio, 471 read_arg = read_target gameio, 472 write_result = ft_write_result gameio, 473 read_result = ft_read_result gameio 474 } 475 476 477fun fttnn_extsearch_fun rlobj tnn target = 478 let 479 val pretob = #pretob (#dplayer rlobj) 480 val game = #game rlobj 481 fun preplayer target' = 482 let val tob = pretob (SOME (target',tnn)) in 483 fn board => player_from_tnn tnn tob game board 484 end 485 val mctsobj = 486 {mctsparam = ft_mctsparam, game = #game rlobj, 487 player = preplayer target} 488 val (_,(tree,_)) = mcts mctsobj (starttree_of mctsobj target) 489 val b = #status (dfind [] tree) = Win 490 val boardo = if not b then NONE else 491 let val nodel = trace_win tree [] in 492 SOME (#board (last nodel)) 493 end 494 in 495 (b,dlength tree-1,boardo) 496 end 497 498fun fttnn_mk_extsearch self (rlobj as {rlparam,gameio,...}) = 499 { 500 self = self, 501 parallel_dir = default_parallel_dir ^ "_finaltest", 502 reflect_globals = fn () => "()", 503 function = fttnn_extsearch_fun rlobj, 504 write_param = (fn file => (fn tnn => write_tnn (file ^ "_tnn") tnn)), 505 read_param = (fn file => read_tnn (file ^ "_tnn")), 506 write_arg = write_target gameio, 507 read_arg = read_target gameio, 508 write_result = ft_write_result gameio, 509 read_result = ft_read_result gameio 510 } 511 512 513fun mk_dis tree = 514 let 515 val pol = #pol (dfind [] tree) 516 val _ = if null pol then raise ERR "mk_dis" "pol" else () 517 fun f (_,cid) = #vis (dfind cid tree) handle NotFound => 0.0 518 val dis = map_assoc f pol 519 val tot = sum_real (map snd dis) 520 val _ = if tot < 0.5 then raise ERR "mk_dis" "tot" else () 521 in 522 (dis,tot) 523 end 524 525fun select_bigstep tree = snd (best_in_distrib (fst (mk_dis tree))) 526 527fun fttnnbs_extsearch_fun rlobj tnn target = 528 let 529 val pretob = #pretob (#dplayer rlobj) 530 val game = #game rlobj 531 fun preplayer target' = 532 let val tob = pretob (SOME (target',tnn)) in 533 fn board => player_from_tnn tnn tob game board 534 end 535 val timerl = List.tabulate (30, fn _ => int_div 60 30) 536 val startmctsobj = 537 {mctsparam = mk_ft_mctsparam (hd timerl), game = #game rlobj, 538 player = preplayer target} 539 val (starttree,startcache) = starttree_of startmctsobj target 540 fun loop timl (tree,cache) = 541 if null timl then (false, 8, NONE) else 542 let 543 val mctsobj = 544 {mctsparam = mk_ft_mctsparam (hd timl), game = #game rlobj, 545 player = preplayer target} 546 val (_,(endtree,_)) = mcts mctsobj (tree,cache) 547 val cid = select_bigstep endtree 548 val status = #status (dfind [] endtree) 549 in 550 if status = Undecided 551 then 552 let 553 val newtree = cut_tree cid endtree 554 val newcache = build_cache game newtree 555 in 556 loop (tl timl) (newtree,newcache) 557 end 558 else if status = Win 559 then 560 let val nodel = trace_win endtree [] in 561 (status = Win, 8 - length timl, SOME (#board (last nodel))) 562 end 563 else (status = Win, 8 - length timl, NONE) 564 end 565 in 566 loop timerl (starttree,startcache) 567 end 568 569fun fttnnbs_mk_extsearch self (rlobj as {rlparam,gameio,...}) = 570 { 571 self = self, 572 parallel_dir = default_parallel_dir ^ "_finaltest", 573 reflect_globals = fn () => "()", 574 function = fttnnbs_extsearch_fun rlobj, 575 write_param = (fn file => (fn tnn => write_tnn (file ^ "_tnn") tnn)), 576 read_param = (fn file => read_tnn (file ^ "_tnn")), 577 write_arg = write_target gameio, 578 read_arg = read_target gameio, 579 write_result = ft_write_result gameio, 580 read_result = ft_read_result gameio 581 } 582*) 583 584 585end (* struct *) 586