1(* ========================================================================= *)
2(* FILE          : mleCombinLib.sml                                        *)
3(* DESCRIPTION   : Tools for term synthesis on combinator datatype           *)
4(* AUTHOR        : (c) Thibault Gauthier, Czech Technical University         *)
5(* DATE          : 2020                                                      *)
6(* ========================================================================= *)
7
8structure mleCombinLib :> mleCombinLib =
9struct
10
11open HolKernel Abbrev boolLib aiLib psTermGen hhExportFof
12
13val ERR = mk_HOL_ERR "mleCombinLib"
14val selfdir = HOLDIR ^ "/examples/AI_tasks"
15
16(* -------------------------------------------------------------------------
17   Combinator terms
18   ------------------------------------------------------------------------- *)
19
20(* variables *)
21val cS = mk_var ("s",alpha)
22val cK = mk_var ("k",alpha)
23val cA = mk_var ("a",``:'a -> 'a -> 'a``)
24val vx = mk_var ("X",alpha)
25val vy = mk_var ("Y",alpha)
26val vz = mk_var ("Z",alpha)
27val v1 = mk_var ("V1",alpha)
28val v2 = mk_var ("V2",alpha)
29val v3 = mk_var ("V3",alpha)
30
31(* constructors *)
32infix oo
33fun op oo (a,b) = list_mk_comb (cA,[a,b])
34fun mk_cA (a,b) = list_mk_comb (cA,[a,b])
35fun dest_cA tm =
36  let
37    val (oper,argl) = strip_comb tm
38    val _ = if term_eq oper cA then () else raise ERR "dest_cA" ""
39  in
40    pair_of_list argl
41  end
42fun list_mk_cA tml = case tml of
43    [] => raise ERR "list_mk_cA" ""
44  | [tm] => tm
45  | a :: b :: m => list_mk_cA (mk_cA (a,b) :: m)
46fun strip_cA_aux tm =
47  if is_var tm then [tm] else
48  let val (oper,argl) = strip_comb tm in
49    if term_eq oper cA then
50      let val (a1,a2) = pair_of_list argl in a2 :: strip_cA_aux a1 end
51    else [tm]
52  end
53fun strip_cA tm = rev (strip_cA_aux tm)
54
55(* theorems *)
56fun forall_capital tm =
57  let
58    fun test v = (Char.isUpper o hd_string o fst o dest_var) v
59    val vl = filter test (free_vars_lr tm)
60  in
61    list_mk_forall (vl,tm)
62  end
63
64val s_thm_bare = mk_eq (cS oo vx oo vy oo vz, (vx oo vz) oo (vy oo vz))
65val k_thm_bare = mk_eq (cK oo vx oo vy, vx)
66val eq_axl_bare = [s_thm_bare,k_thm_bare]
67val eq_axl = map forall_capital eq_axl_bare
68
69(* -------------------------------------------------------------------------
70   Generating combinator terms using psTermGen
71   ------------------------------------------------------------------------- *)
72
73val s2 = mk_var ("s2", ``:'a -> 'a -> 'a``)
74val s1 = mk_var ("s1", ``:'a -> 'a``)
75val s0 = mk_var ("s0", alpha)
76val k1 = mk_var ("k1", ``:'a -> 'a``)
77val k0 = mk_var ("k0", alpha)
78
79fun to_apply tm = case strip_comb tm of
80    (oper,[c1,c2]) => (
81    if term_eq oper s2 then list_mk_cA [cS, to_apply c1, to_apply c2]
82    else raise ERR "to_apply" "")
83  | (oper,[c1]) => (
84    if term_eq oper s1 then mk_cA (cS, to_apply c1)
85    else if term_eq oper k1 then mk_cA (cK, to_apply c1)
86    else raise ERR "to_apply" "")
87  | (oper,_) => (
88    if term_eq oper s0 then cS
89    else if term_eq oper k0 then cK
90    else raise ERR "to_apply" "")
91
92fun random_nf size =
93  to_apply (random_term [s2,s1,s0,k1,k0] (size,alpha))
94
95fun gen_nf_exhaustive size =
96  map to_apply (gen_term [s2,s1,s0,k1,k0] (size,alpha))
97
98(* -------------------------------------------------------------------------
99   Position
100   ------------------------------------------------------------------------- *)
101
102datatype pose = Left | Right
103
104fun pose_compare (a,b) = case (a,b) of
105    (Left,Right) => LESS
106  | (Right,Left) => GREATER
107  | _ => EQUAL
108
109fun pose_to_string pose = case pose of
110    Left => "L"
111  | Right => "R"
112
113fun string_to_pose s =
114  if s = "L" then Left else if s = "R" then Right else
115    raise ERR "string_to_pose" ""
116
117fun pos_to_string pos = String.concatWith " " (map pose_to_string pos)
118fun string_to_pos s =
119  map string_to_pose (String.tokens Char.isSpace s)
120
121val pos_compare = list_compare pose_compare
122
123(* -------------------------------------------------------------------------
124   Combinators
125   ------------------------------------------------------------------------- *)
126
127datatype combin = V1 | V2 | V3 | S | K | A of combin * combin
128
129fun combin_size combin = case combin of
130    A (c1,c2) => combin_size c1 + combin_size c2
131  | _ => 1
132
133(* -------------------------------------------------------------------------
134   Printing combinators
135   ------------------------------------------------------------------------- *)
136
137fun strip_A_aux c = case c of
138    A (c1,c2) => c2 :: strip_A_aux c1
139  | _ => [c]
140fun strip_A c = rev (strip_A_aux c)
141
142fun list_mk_A_aux l = case l of
143    [] => raise ERR "list_mk_A" ""
144  | [c] => c
145  | a :: m => A(list_mk_A_aux m,a)
146
147fun list_mk_A l = list_mk_A_aux (rev l)
148
149fun combin_to_string c = case c of
150    S => "S"
151  | K => "K"
152  | V1 => "V1"
153  | V2 => "V2"
154  | V3 => "V3"
155  | A _ => "(" ^ String.concatWith " " (map combin_to_string (strip_A c)) ^ ")"
156
157fun string_to_combin s =
158  let
159    val s' = if mem s ["S","K","V1","V2","V3"] then "(" ^ s ^ ")" else s
160    val sexp = singleton_of_list (lisp_parser s')
161    val assocl = map swap (map_assoc combin_to_string [S,K,V1,V2,V3])
162    fun parse sexp = case sexp of
163      Lterm l => list_mk_A (map parse l)
164    | Lstring s => assoc s assocl
165  in
166    parse sexp
167  end
168
169fun combin_compare (c1,c2) = case (c1,c2) of
170    (A x, A y) => cpl_compare combin_compare combin_compare (x,y)
171  | (_, A _) => LESS
172  | (A _,_) => GREATER
173  | _ => String.compare (combin_to_string c1, combin_to_string c2)
174
175(* -------------------------------------------------------------------------
176   Rewriting combinators
177   ------------------------------------------------------------------------- *)
178
179fun next_pos_aux l = case l of
180    [] => raise ERR "next_pos" ""
181  | Left :: m => Right :: m
182  | Right :: m => next_pos_aux m
183
184fun next_pos l = rev (next_pos_aux (rev l))
185
186exception Break
187
188fun combin_nf c = case c of
189    A(A(A(S,x),y),z) => false
190  | A(A(K,x),y) => false
191  | A(c1,c2) => combin_nf c1 andalso combin_nf c2
192  | _ => true
193
194fun combin_norm n mainc =
195  let
196    val i = ref 0
197    fun incra c = (incr i; if (combin_size c > 100 orelse !i > n)
198                           then raise Break else ())
199    fun combin_norm_aux n c =
200      case c of
201        A(A(A(S,x),y),z) => (incra c; combin_norm_aux n (A(A(x,z),A(y,z))) )
202      | A(A(K,x),y) => (incra c; combin_norm_aux n x)
203      | A(c1,c2) => A(combin_norm_aux n c1, combin_norm_aux n c2)
204      | x => x
205    fun loop c =
206      if combin_nf c then SOME c else loop (combin_norm_aux n c)
207  in
208    loop mainc handle Break => NONE
209  end
210
211(* -------------------------------------------------------------------------
212   Generating combinators
213   ------------------------------------------------------------------------- *)
214
215fun cterm_to_combin cterm =
216  if term_eq cterm cS then S
217  else if term_eq cterm cK then K
218  else if term_eq cterm v1 then V1
219  else if term_eq cterm v2 then V2
220  else if term_eq cterm v3 then V3
221  else let val (a,b) = dest_cA cterm in A (cterm_to_combin a, cterm_to_combin b) end
222
223fun combin_to_cterm c = case c of
224   S => cS | K => cK | V1 => v1 | V2 => v2 | V3 => v3 |
225   A (c1,c2) => mk_cA (combin_to_cterm c1, combin_to_cterm c2)
226
227fun contains_sk c = case c of
228    S => true
229  | K => true
230  | V1 => false
231  | V2 => false
232  | V3 => false
233  | A (c1,c2) => contains_sk c1 orelse contains_sk c2
234
235fun has_bigarity c =
236  let val argl = tl (strip_A c) in
237    length argl > 4 orelse exists has_bigarity argl
238  end
239
240fun compare_csize (a,b) = Int.compare (combin_size a, combin_size b)
241fun smallest_csize l = hd (dict_sort compare_csize l)
242
243fun gen_headnf_aux n nmax d =
244  if dlength d >= nmax then (d,n) else
245  let
246    val c = cterm_to_combin (random_nf (random_int (1,20)))
247    val cnorm = valOf (combin_norm 100 (A(A(A(c,V1),V2),V3)))
248                handle Option => K
249  in
250    if contains_sk cnorm orelse
251       combin_size cnorm > 20 orelse
252       has_bigarity cnorm
253    then gen_headnf_aux (n+1) nmax d
254    else if dmem cnorm d then
255      let val oldc = dfind cnorm d in
256        if compare_csize (c,oldc) = LESS
257        then gen_headnf_aux (n+1) nmax (dadd cnorm c d)
258        else gen_headnf_aux (n+1) nmax d
259      end
260    else
261      (print_endline (its (dlength d + 1));
262       gen_headnf_aux (n+1) nmax (dadd cnorm c d))
263  end
264
265fun gen_headnf nmax d = gen_headnf_aux 0 nmax d
266
267(* -------------------------------------------------------------------------
268   Export
269   ------------------------------------------------------------------------- *)
270
271val targetdir = selfdir ^ "/combin_target"
272
273fun distrib_il il =
274  let
275    val l = dlist (count_dict (dempty Int.compare) il)
276    fun f (i,j) = its i ^ "-" ^ its j
277  in
278    String.concatWith " " (map f l)
279  end
280
281fun export_data (train,test) =
282  let
283    val l = train @ test
284    val _ = mkDir_err targetdir
285    fun f1 (headnf,witness) =
286      "headnf: " ^ combin_to_string headnf ^
287      "\ncombin: " ^ combin_to_string witness
288    val il1 = map (combin_size o fst) l
289    val il2 = map (combin_size o snd) l
290    val train_sorted =
291      dict_sort (cpl_compare combin_compare combin_compare) train
292    val test_sorted =
293      dict_sort (cpl_compare combin_compare combin_compare) test
294  in
295    writel (targetdir ^ "/train_export") (map f1 train_sorted);
296    writel (targetdir ^ "/test_export") (map f1 test_sorted);
297    writel (targetdir ^ "/distrib-headnf") [distrib_il il1];
298    writel (targetdir ^ "/distrib-witness") [distrib_il il2]
299  end
300
301fun import_data file =
302  let
303    val sl = readl (targetdir ^ "/" ^ file)
304    val l = map pair_of_list (mk_batch 2 sl)
305    fun f (a,b) =
306      (
307      string_to_combin (snd (split_string "headnf: " a)),
308      string_to_combin (snd (split_string "combin: " b))
309      )
310  in
311    map f l
312  end
313
314(* -------------------------------------------------------------------------
315   TPTP Export
316   ------------------------------------------------------------------------- *)
317
318fun goal_of_headnf headnf =
319  let
320    val vc = mk_var ("Vc",alpha)
321    val tm =
322    mk_exists (vc, (list_mk_forall ([v1,v2,v3],
323      mk_eq (list_mk_cA [vc,v1,v2,v3],combin_to_cterm headnf))))
324  in
325    (eq_axl,tm)
326  end
327
328fun export_goal dir (goal,n) =
329  let
330    val tptp_dir = HOLDIR ^ "/examples/AI_tasks/TPTP"
331    val _ = mkDir_err tptp_dir
332    val file = tptp_dir ^ "/" ^ dir ^ "/i/" ^ its n ^ ".p"
333    val _ = mkDir_err (tptp_dir ^ "/" ^ dir)
334    val _ = mkDir_err (tptp_dir ^ "/" ^ dir ^ "/i")
335    val _ = mkDir_err (tptp_dir ^ "/" ^ dir ^ "/eprover")
336    val _ = mkDir_err (tptp_dir ^ "/" ^ dir ^ "/vampire")
337    val _ = mkDir_err (tptp_dir ^ "/" ^ dir ^ "/eprover_schedule")
338    val _ = mkDir_err (tptp_dir ^ "/" ^ dir ^ "/vampire_casc")
339  in
340    name_flag := false;
341    type_flag := false;
342    p_flag := false;
343    fof_export_goal file goal
344  end
345
346(*
347load "aiLib"; open aiLib;
348load "mleCombinLib"; open mleCombinLib;
349
350val data = import_data "test_export";
351val gl = map (goal_of_headnf o fst) data;
352app (export_goal "combin_test") (number_snd 0 gl);
353
354val data = import_data "train_export";
355val gl = map (goal_of_headnf o fst) data;
356app (export_goal "combin_train") (number_snd 0 gl);
357
358val data = import_data "test_export" @ import_data "train_export";
359val l1 = map (combin_size o snd) data;
360val l2 = dlist (count_dict (dempty Int.compare) l1);
361
362
363load "psTermGen"; open psTermGen;
364val s2 = mk_var ("s2", ``:'a -> 'a -> 'a``)
365val s1 = mk_var ("s1", ``:'a -> 'a``)
366val s0 = mk_var ("s0", alpha)
367val k1 = mk_var ("k1", ``:'a -> 'a``)
368val k0 = mk_var ("k0", alpha)
369fun f n = nterm [s0,s1,k0,k1,s2] (n,alpha);
370sum_real (List.tabulate (21,f));
371
372*)
373
374
375
376end (* struct *)
377
378