1(* ========================================================================= *) 2(* FILE : tttBigSteps.sml *) 3(* DESCRIPTION : Successions of non-backtrackable moves chosen after one *) 4(* MCTS call for each step *) 5(* AUTHOR : (c) Thibault Gauthier, Czech Technical University *) 6(* DATE : 2020 *) 7(* ========================================================================= *) 8 9structure tttBigSteps :> tttBigSteps = 10struct 11 12open HolKernel Abbrev boolLib aiLib smlRedirect psMCTS mlTreeNeuralNetwork 13 tttToken tttTrain tttSearch tttRecord tacticToe 14 15val ERR = mk_HOL_ERR "tttBigSteps" 16 17(* ------------------------------------------------------------------------- 18 Global example datasets 19 ------------------------------------------------------------------------- *) 20 21type ex = (term * real) list 22 23val exval = ref [] 24val expol = ref [] 25val exarg = ref [] 26 27datatype bstatus = BigStepsProved | BigStepsSaturated | BigStepsTimeout 28 29fun string_of_bstatus x = case x of 30 BigStepsProved => "BigStepsProved" 31 | BigStepsSaturated => "BigStepsSaturated" 32 | BigStepsTimeout => "BigStepsTimeout" 33 34(* ------------------------------------------------------------------------- 35 Extracting examples 36 ------------------------------------------------------------------------- *) 37 38fun example_val goalrec = 39 if #gvis goalrec < 0.5 40 then [] 41 else if #gstatus goalrec = GoalProved 42 then [(nntm_of_stateval (#goal goalrec), 1.0)] 43 else if #gstatus goalrec = GoalSaturated 44 then [(nntm_of_stateval (#goal goalrec), 0.0)] 45 else [(nntm_of_stateval (#goal goalrec), #gsum goalrec / #gvis goalrec)] 46 47fun example_pol goal argl1 = 48 let 49 fun test x = #sstatus x <> StacSaturated 50 val argl2p = filter test argl1 51 val argl2 = first_n 8 (map_assoc #svis argl2p) 52 val tot = sum_real (map snd argl2) 53 in 54 if tot <= 0.5 orelse length argl2 <= 1 then [] else 55 let 56 fun f (x,b) = if #sstatus x = StacProved then (x,1.0) else (x,b / tot) 57 val rl = first_n 3 (map snd (dict_sort compare_rmax argl2)) 58 val _ = print_endline ("tot: " ^ rts tot ^ ", " ^ 59 String.concatWith " " (map rts rl)) 60 val argl3 = map f argl2 61 val argl4 = map_fst 62 (fn x => hidef nntm_of_statepol (goal,dest_stac (#token x))) argl3 63 in 64 argl4 65 end 66 end 67 68fun example_arg (goal,stac) argl1 = 69 let 70 fun test x = #sstatus x <> StacSaturated 71 val argl2p = filter test argl1 72 val argl2 = first_n 8 (map_assoc #svis argl2p) 73 val tot = sum_real (map snd argl2) 74 in 75 if tot <= 0.5 orelse length argl2 <= 1 then [] else 76 let 77 fun f (x,b) = if #sstatus x = StacProved then (x,1.0) else (x,b / tot) 78 val rl = first_n 3 (map snd (dict_sort compare_rmax argl2)) 79 val _ = print_endline ("tot: " ^ rts tot ^ ", " ^ 80 String.concatWith " " (map rts rl)) 81 val argl3 = map f argl2 82 fun fnn (x,r) = (hidef nntm_of_statearg ((goal,stac),#token x), r) 83 val argl4 = map fnn argl3 84 in 85 argl4 86 end 87 end 88 89(* ------------------------------------------------------------------------- 90 Selecting bigsteps 91 ------------------------------------------------------------------------- *) 92 93fun children_argtree_aux argtree anl i = 94 if dmem (i :: anl) argtree then 95 let val child = dfind (i :: anl) argtree in 96 if #sstatus child = StacFresh then [] else 97 (i,child) :: children_argtree_aux argtree anl (i+1) 98 end 99 else [] 100 101fun children_argtree argtree anl = children_argtree_aux argtree anl 0 102 103fun root_argtree stacv i = dfind [] (Vector.sub (stacv,i)) 104 105fun children_stacv_aux stacv i = 106 if can (root_argtree stacv) i then 107 let val child = root_argtree stacv i in 108 if #sstatus child = StacFresh then [] else 109 (i,child) :: children_stacv_aux stacv (i+1) 110 end 111 else [] 112 113fun children_stacv stacv = children_stacv_aux stacv 0 114 115fun best_arg tree (sn,anl) = 116 let 117 val node = dfind [] tree 118 val goalrec = Vector.sub (#goalv node,0) 119 val argtree = Vector.sub (#stacv goalrec,sn) 120 val stac = dest_stac (#token (dfind [] argtree)) 121 val argl1 = children_argtree argtree anl 122 val _ = if null argl1 then raise ERR "best_arg" "empty" else () 123 val _ = exarg := example_arg (#goal goalrec, stac) (map snd argl1) @ !exarg 124 val argl2 = filter (fn (_,x) => #sstatus x <> StacSaturated) argl1 125 in 126 if exists (fn (_,x) => #sstatus x = StacProved) argl2 127 then fst (valOf (List.find (fn (_,x) => #sstatus x = StacProved) argl2)) 128 else if null argl2 then 0 else (* unsafe if no predictions *) 129 let 130 val argl3 = dict_sort compare_rmax (map_assoc (#svis o snd) argl2) 131 val bestan = fst (fst (hd argl3)) 132 in 133 bestan 134 end 135 end 136 137fun best_stac tree = 138 let 139 val node = dfind [] tree 140 val goalrec = Vector.sub (#goalv node,0) 141 val _ = exval := example_val goalrec @ !exval 142 val argl1 = children_stacv (#stacv goalrec) 143 val _ = if null argl1 then raise ERR "best_stac" "empty" else () 144 val _ = expol := example_pol (#goal goalrec) (map snd argl1) @ !expol 145 val argl2 = filter (fn (_,x) => #sstatus x <> StacSaturated) argl1 146 in 147 (* to do change to a softer version of not always selecting 148 the winning path *) 149 if exists (fn (_,x) => #sstatus x = StacProved) argl2 150 then fst (valOf (List.find (fn (_,x) => #sstatus x = StacProved) argl2)) 151 else if null argl2 then 0 else (* unsafe if no predictions *) 152 let 153 val argl3 = dict_sort compare_rmax (map_assoc (#svis o snd) argl2) 154 val bestsn = fst (fst (hd argl3)) 155 in 156 bestsn 157 end 158 end 159 160(* ------------------------------------------------------------------------- 161 MCTS big steps. Ending the search when there is no move available. 162 ------------------------------------------------------------------------- *) 163 164val max_bigsteps = 20 165 166fun stacrec_of tree (sn,anl) = 167 let 168 val node = dfind [] tree 169 val goalrec = Vector.sub (#goalv node,0) 170 val argtree = Vector.sub (#stacv goalrec,sn) 171 in 172 dfind anl argtree 173 end 174 175fun path_to_anl tree (sn,anl) = 176 let 177 val node = dfind [] tree 178 val goalrec = Vector.sub (#goalv node,0) 179 val (goal : goal) = #goal goalrec 180 val argtree = Vector.sub (#stacv goalrec,sn) 181 val stac = dest_stac (#token (dfind [] argtree)) 182 fun loop acc anl_aux = 183 if null anl_aux then acc else 184 loop (#token (dfind anl_aux argtree) :: acc) (tl anl_aux) 185 val tokenl = loop [] anl 186 in 187 (goal,stac,tokenl) 188 end 189 190fun goallist_of_node node = map #goal (vector_to_list (#goalv node)) 191 192fun loop_arg searchobj prevtree (sn,anl) = 193 let 194 val (goal,stac,tokenl) = path_to_anl prevtree (sn,anl) 195 val _ = 196 if null tokenl 197 then print_endline ("Tactic: " ^ stac) 198 else print_endline (string_of_token (last tokenl)) 199 val stacrec = stacrec_of prevtree (sn,anl) 200 val argstatus = #sstatus stacrec 201 val _ = if argstatus <> StacUndecided 202 then print_endline (string_of_sstatus argstatus) else () 203 in 204 if null (#atyl stacrec) then 205 if argstatus = StacProved 206 then SOME [] 207 else if dmem [(0,sn,anl)] prevtree 208 then SOME (goallist_of_node (dfind [(0,sn,anl)] prevtree)) 209 else NONE 210 else 211 let 212 val starttree = hidef (starttree_of_gstacarg searchobj) 213 (goal,stac,tokenl) 214 val (_,tree) = hidef (search_loop searchobj (SOME 1600)) starttree 215 val newanl = List.tabulate (length anl, fn _ => 0) 216 val an = best_arg tree (0, newanl) 217 val _ = print_endline ("best arg: " ^ its an) 218 in 219 loop_arg searchobj tree (0, an :: newanl) 220 end 221 end 222 223fun loop_stac searchobj g = 224 let 225 val _ = print_endline ("Goal: " ^ string_of_goal g) 226 val starttree = hidef (starttree_of_goal searchobj) g 227 in 228 if #nstatus (dfind [] starttree) = NodeSaturated then NONE else 229 let 230 val (_,tree) = hidef (search_loop searchobj (SOME 1600)) starttree 231 val sn = best_stac tree 232 val _ = print_endline ("best tac: " ^ its sn) 233 in 234 loop_arg searchobj tree (sn,[]) 235 end 236 end 237 238fun string_of_gl gl = String.concatWith "\n" (map string_of_goal gl) 239 240fun loop_node searchobj n gl = 241 ( 242 if length gl > 1 then print_endline ("\nOpen goals: " ^ string_of_gl gl) 243 else () 244 ; 245 if null gl 246 then BigStepsProved 247 else if n + length gl > max_bigsteps 248 then BigStepsTimeout 249 else 250 let val newglol = map (loop_stac searchobj) gl in 251 if exists (fn x => not (isSome x)) newglol 252 then BigStepsSaturated 253 else 254 let val newgl = List.concat (map valOf newglol) in 255 loop_node searchobj (n + length gl) newgl 256 end 257 end 258 ) 259 260fun run_bigsteps searchobj g = 261 let 262 val _ = print_endline "init bigsteps" 263 val _ = (exval := []; expol := []; exarg := []) 264 val bstatus = loop_node searchobj 0 [g] 265 val r = (bstatus,(!exval,!expol,!exarg)) 266 val _ = (exval := []; expol := []; exarg := []) 267 in 268 r 269 end 270 271fun run_bigsteps_eval (expdir,ngen) (thmdata,tacdata) (vnno,pnno,anno) g = 272 let 273 val mem = !hide_flag 274 val _ = hide_flag := false 275 val pb = current_theory () ^ "_" ^ its (!savestate_level) 276 val gendir = expdir ^ "/" ^ its ngen 277 val valdir = gendir ^ "/val" 278 val poldir = gendir ^ "/pol" 279 val argdir = gendir ^ "/arg" 280 val resdir = gendir ^ "/res" 281 val errdir = gendir ^ "/err" 282 val _ = app mkDir_err [expdir,gendir,valdir,poldir,argdir,resdir,errdir]; 283 val _ = print_endline "searchobj" 284 val searchobj = build_searchobj (thmdata,tacdata) (vnno,pnno,anno) g 285 handle Interrupt => raise Interrupt 286 | e => (append_endline (errdir ^ "/" ^ pb) "searchobj"; raise e) 287 val _ = print_endline "run_bigsteps" 288 val (bstatus,(exv,exp,exa)) = run_bigsteps searchobj g 289 handle Interrupt => raise Interrupt 290 | e => (append_endline (errdir ^ "/" ^ pb) "run_bigsteps"; raise e) 291 in 292 write_tnnex (valdir ^ "/" ^ pb) (basicex_to_tnnex exv); 293 write_tnnex (poldir ^ "/" ^ pb) (basicex_to_tnnex exp); 294 write_tnnex (argdir ^ "/" ^ pb) (basicex_to_tnnex exa); 295 writel (resdir ^ "/" ^ pb) [string_of_bstatus bstatus]; 296 hide_flag := mem 297 end 298 299 300 301(* ------------------------------------------------------------------------- 302 Toy example (* todo: follow proven *) 303 ------------------------------------------------------------------------- *) 304 305(* 306load "aiLib"; open aiLib; 307load "tttBigSteps"; open tttBigSteps; 308load "tacticToe"; open tacticToe; 309 310val goal:goal = ([],``?x. x*x=4``); 311val thmdata = mlThmData.create_thmdata (); 312val tacdata = mlTacticData.create_tacdata (); 313val searchobj = build_searchobj (thmdata,tacdata) (NONE,NONE) goal; 314 315val (b,exl) = run_bigsteps searchobj goal; 316val (a,b,c) = exl; 317 318val (_,t) = add_time run_bigsteps goal; 319val (winb,ex) = run_bigsteps goal; 320*) 321 322(* take care of the edge case when all tactic fails *) 323 324end (* struct *) 325