1(* ========================================================================= *)
2(* FILE          : tttBigSteps.sml                                           *)
3(* DESCRIPTION   : Successions of non-backtrackable moves chosen after one   *)
4(*                 MCTS call for each step                                   *)
5(* AUTHOR        : (c) Thibault Gauthier, Czech Technical University         *)
6(* DATE          : 2020                                                      *)
7(* ========================================================================= *)
8
9structure tttBigSteps :> tttBigSteps =
10struct
11
12open HolKernel Abbrev boolLib aiLib smlRedirect psMCTS mlTreeNeuralNetwork
13  tttToken tttTrain tttSearch tttRecord tacticToe
14
15val ERR = mk_HOL_ERR "tttBigSteps"
16
17(* -------------------------------------------------------------------------
18   Global example datasets
19   ------------------------------------------------------------------------- *)
20
21type ex = (term * real) list
22
23val exval = ref []
24val expol = ref []
25val exarg = ref []
26
27datatype bstatus = BigStepsProved | BigStepsSaturated | BigStepsTimeout
28
29fun string_of_bstatus x = case x of
30    BigStepsProved => "BigStepsProved"
31  | BigStepsSaturated => "BigStepsSaturated"
32  | BigStepsTimeout => "BigStepsTimeout"
33
34(* -------------------------------------------------------------------------
35   Extracting examples
36   ------------------------------------------------------------------------- *)
37
38fun example_val goalrec =
39  if #gvis goalrec < 0.5
40    then []
41  else if #gstatus goalrec = GoalProved
42    then [(nntm_of_stateval (#goal goalrec), 1.0)]
43  else if #gstatus goalrec = GoalSaturated
44    then [(nntm_of_stateval (#goal goalrec), 0.0)]
45  else [(nntm_of_stateval (#goal goalrec), #gsum goalrec / #gvis goalrec)]
46
47fun example_pol goal argl1 =
48  let
49    fun test x = #sstatus x <> StacSaturated
50    val argl2p = filter test argl1
51    val argl2 = first_n 8 (map_assoc #svis argl2p)
52    val tot = sum_real (map snd argl2)
53  in
54    if tot <= 0.5 orelse length argl2 <= 1 then [] else
55    let
56      fun f (x,b) = if #sstatus x = StacProved then (x,1.0) else (x,b / tot)
57      val rl = first_n 3 (map snd (dict_sort compare_rmax argl2))
58      val _ = print_endline ("tot: " ^ rts tot ^ ", " ^
59        String.concatWith " " (map rts rl))
60      val argl3 = map f argl2
61      val argl4 = map_fst
62        (fn x => hidef nntm_of_statepol (goal,dest_stac (#token x))) argl3
63    in
64      argl4
65    end
66  end
67
68fun example_arg (goal,stac) argl1 =
69  let
70    fun test x = #sstatus x <> StacSaturated
71    val argl2p = filter test argl1
72    val argl2 = first_n 8 (map_assoc #svis argl2p)
73    val tot = sum_real (map snd argl2)
74  in
75    if tot <= 0.5 orelse length argl2 <= 1 then [] else
76    let
77      fun f (x,b) = if #sstatus x = StacProved then (x,1.0) else (x,b / tot)
78      val rl = first_n 3 (map snd (dict_sort compare_rmax argl2))
79      val _ = print_endline ("tot: " ^ rts tot ^ ", " ^
80        String.concatWith " " (map rts rl))
81      val argl3 = map f argl2
82      fun fnn (x,r) = (hidef nntm_of_statearg ((goal,stac),#token x), r)
83      val argl4 = map fnn argl3
84     in
85       argl4
86     end
87  end
88
89(* -------------------------------------------------------------------------
90   Selecting bigsteps
91   ------------------------------------------------------------------------- *)
92
93fun children_argtree_aux argtree anl i =
94  if dmem (i :: anl) argtree then
95    let val child = dfind (i :: anl) argtree in
96      if #sstatus child = StacFresh then [] else
97      (i,child) :: children_argtree_aux argtree anl (i+1)
98    end
99  else []
100
101fun children_argtree argtree anl = children_argtree_aux argtree anl 0
102
103fun root_argtree stacv i = dfind [] (Vector.sub (stacv,i))
104
105fun children_stacv_aux stacv i =
106  if can (root_argtree stacv) i then
107    let val child = root_argtree stacv i in
108      if #sstatus child = StacFresh then [] else
109      (i,child) :: children_stacv_aux stacv (i+1)
110    end
111  else []
112
113fun children_stacv stacv = children_stacv_aux stacv 0
114
115fun best_arg tree (sn,anl) =
116  let
117    val node = dfind [] tree
118    val goalrec = Vector.sub (#goalv node,0)
119    val argtree = Vector.sub (#stacv goalrec,sn)
120    val stac = dest_stac (#token (dfind [] argtree))
121    val argl1 = children_argtree argtree anl
122    val _ = if null argl1 then raise ERR "best_arg" "empty" else ()
123    val _ = exarg := example_arg (#goal goalrec, stac) (map snd argl1) @ !exarg
124    val argl2 = filter (fn (_,x) => #sstatus x <> StacSaturated) argl1
125  in
126    if exists (fn (_,x) => #sstatus x = StacProved) argl2
127    then fst (valOf (List.find (fn (_,x) => #sstatus x = StacProved) argl2))
128    else if null argl2 then 0 else (* unsafe if no predictions *)
129    let
130      val argl3 = dict_sort compare_rmax (map_assoc (#svis o snd) argl2)
131      val bestan = fst (fst (hd argl3))
132    in
133      bestan
134    end
135  end
136
137fun best_stac tree =
138  let
139    val node = dfind [] tree
140    val goalrec = Vector.sub (#goalv node,0)
141    val _ = exval := example_val goalrec @ !exval
142    val argl1 = children_stacv (#stacv goalrec)
143    val _ = if null argl1 then raise ERR "best_stac" "empty" else ()
144    val _ = expol := example_pol (#goal goalrec) (map snd argl1) @ !expol
145    val argl2 = filter (fn (_,x) => #sstatus x <> StacSaturated) argl1
146  in
147    (* to do change to a softer version of not always selecting
148       the winning path *)
149    if exists (fn (_,x) => #sstatus x = StacProved) argl2
150    then fst (valOf (List.find (fn (_,x) => #sstatus x = StacProved) argl2))
151    else if null argl2 then 0 else (* unsafe if no predictions *)
152    let
153      val argl3 = dict_sort compare_rmax (map_assoc (#svis o snd) argl2)
154      val bestsn = fst (fst (hd argl3))
155    in
156      bestsn
157    end
158  end
159
160(* -------------------------------------------------------------------------
161   MCTS big steps. Ending the search when there is no move available.
162   ------------------------------------------------------------------------- *)
163
164val max_bigsteps = 20
165
166fun stacrec_of tree (sn,anl) =
167  let
168    val node = dfind [] tree
169    val goalrec = Vector.sub (#goalv node,0)
170    val argtree = Vector.sub (#stacv goalrec,sn)
171  in
172    dfind anl argtree
173  end
174
175fun path_to_anl tree (sn,anl) =
176  let
177    val node = dfind [] tree
178    val goalrec = Vector.sub (#goalv node,0)
179    val (goal : goal) = #goal goalrec
180    val argtree = Vector.sub (#stacv goalrec,sn)
181    val stac = dest_stac (#token (dfind [] argtree))
182    fun loop acc anl_aux =
183      if null anl_aux then acc else
184      loop (#token (dfind anl_aux argtree) :: acc) (tl anl_aux)
185    val tokenl = loop [] anl
186  in
187    (goal,stac,tokenl)
188  end
189
190fun goallist_of_node node = map #goal (vector_to_list (#goalv node))
191
192fun loop_arg searchobj prevtree (sn,anl) =
193  let
194    val (goal,stac,tokenl) = path_to_anl prevtree (sn,anl)
195    val _ =
196      if null tokenl
197      then print_endline ("Tactic: " ^ stac)
198      else print_endline (string_of_token (last tokenl))
199    val stacrec = stacrec_of prevtree (sn,anl)
200    val argstatus = #sstatus stacrec
201    val _ = if argstatus <> StacUndecided
202      then print_endline (string_of_sstatus argstatus) else ()
203  in
204    if null (#atyl stacrec) then
205      if argstatus = StacProved
206        then SOME []
207      else if dmem [(0,sn,anl)] prevtree
208        then SOME (goallist_of_node (dfind [(0,sn,anl)] prevtree))
209        else NONE
210    else
211    let
212      val starttree = hidef (starttree_of_gstacarg searchobj)
213        (goal,stac,tokenl)
214      val (_,tree) = hidef (search_loop searchobj (SOME 1600)) starttree
215      val newanl = List.tabulate (length anl, fn _ => 0)
216      val an = best_arg tree (0, newanl)
217      val _ = print_endline ("best arg: " ^ its an)
218    in
219      loop_arg searchobj tree (0, an :: newanl)
220    end
221  end
222
223fun loop_stac searchobj g =
224  let
225    val _ = print_endline ("Goal: " ^ string_of_goal g)
226    val starttree = hidef (starttree_of_goal searchobj) g
227  in
228    if #nstatus (dfind [] starttree) = NodeSaturated then NONE else
229    let
230      val (_,tree) = hidef (search_loop searchobj (SOME 1600)) starttree
231      val sn = best_stac tree
232      val _ = print_endline ("best tac: " ^ its sn)
233    in
234      loop_arg searchobj tree (sn,[])
235    end
236  end
237
238fun string_of_gl gl = String.concatWith "\n" (map string_of_goal gl)
239
240fun loop_node searchobj n gl =
241  (
242  if length gl > 1 then print_endline ("\nOpen goals: " ^ string_of_gl gl)
243  else ()
244  ;
245  if null gl
246    then BigStepsProved
247  else if n + length gl > max_bigsteps
248    then BigStepsTimeout
249  else
250    let val newglol = map (loop_stac searchobj) gl in
251      if exists (fn x => not (isSome x)) newglol
252      then BigStepsSaturated
253      else
254        let val newgl = List.concat (map valOf newglol) in
255          loop_node searchobj (n + length gl) newgl
256        end
257     end
258  )
259
260fun run_bigsteps searchobj g =
261  let
262    val _ = print_endline "init bigsteps"
263    val _ = (exval := []; expol := []; exarg := [])
264    val bstatus = loop_node searchobj 0 [g]
265    val r = (bstatus,(!exval,!expol,!exarg))
266    val _ = (exval := []; expol := []; exarg := [])
267  in
268    r
269  end
270
271fun run_bigsteps_eval (expdir,ngen) (thmdata,tacdata) (vnno,pnno,anno) g =
272  let
273    val mem = !hide_flag
274    val _ = hide_flag := false
275    val pb = current_theory () ^ "_" ^ its (!savestate_level)
276    val gendir = expdir ^ "/" ^ its ngen
277    val valdir = gendir ^ "/val"
278    val poldir = gendir ^ "/pol"
279    val argdir = gendir ^ "/arg"
280    val resdir = gendir ^ "/res"
281    val errdir = gendir ^ "/err"
282    val _ = app mkDir_err [expdir,gendir,valdir,poldir,argdir,resdir,errdir];
283    val _ = print_endline "searchobj"
284    val searchobj = build_searchobj (thmdata,tacdata) (vnno,pnno,anno) g
285      handle Interrupt => raise Interrupt
286        | e => (append_endline (errdir ^ "/" ^ pb) "searchobj"; raise e)
287    val _ = print_endline "run_bigsteps"
288    val (bstatus,(exv,exp,exa)) = run_bigsteps searchobj g
289      handle Interrupt => raise Interrupt
290        | e => (append_endline (errdir ^ "/" ^ pb) "run_bigsteps"; raise e)
291  in
292    write_tnnex (valdir ^ "/" ^ pb) (basicex_to_tnnex exv);
293    write_tnnex (poldir ^ "/" ^ pb) (basicex_to_tnnex exp);
294    write_tnnex (argdir ^ "/" ^ pb) (basicex_to_tnnex exa);
295    writel (resdir ^ "/" ^ pb) [string_of_bstatus bstatus];
296    hide_flag := mem
297  end
298
299
300
301(* -------------------------------------------------------------------------
302   Toy example (* todo: follow proven *)
303   ------------------------------------------------------------------------- *)
304
305(*
306load "aiLib"; open aiLib;
307load "tttBigSteps"; open tttBigSteps;
308load "tacticToe"; open tacticToe;
309
310val goal:goal = ([],``?x. x*x=4``);
311val thmdata = mlThmData.create_thmdata ();
312val tacdata = mlTacticData.create_tacdata ();
313val searchobj = build_searchobj (thmdata,tacdata) (NONE,NONE) goal;
314
315val (b,exl) = run_bigsteps searchobj goal;
316val (a,b,c) = exl;
317
318val (_,t) = add_time run_bigsteps goal;
319val (winb,ex) = run_bigsteps goal;
320*)
321
322(* take care of the edge case when all tactic fails *)
323
324end (* struct *)
325