1structure ACF  =
2struct
3
4local
5open HolKernel Parse boolLib bossLib pairLib pairSyntax pairTheory PairRules ACFTheory;
6
7in
8(*---------------------------------------------------------------------------*)
9(* Convert HOL programs to combinator-based pseudo-ASTs                      *)
10(* Term programs is translated to equivalent A-Combinator Forms (ACF)        *)
11(*---------------------------------------------------------------------------*)
12
13(*---------------------------------------------------------------------------*)
14(* Variable order defined for binary sets                                    *)
15(*---------------------------------------------------------------------------*)
16
17fun varOrder (t1:term, t2:term) =
18  let val s1 = #1 (dest_var t1)
19      val s2 = #1 (dest_var t2)
20  in
21  if s1 > s2 then GREATER
22  else if s1 = s2 then EQUAL
23  else LESS
24  end;
25
26
27(*---------------------------------------------------------------------------*)
28(* Ensure that each let-bound variable name in a term is different than the  *)
29(* others.                                                                   *)
30(*---------------------------------------------------------------------------*)
31
32fun std_bvars stem tm =
33 let open Lib
34     fun num2name i = stem^Lib.int_to_string i
35     val nameStrm = Lib.mk_istream (fn x => x+1) 0 num2name
36     fun next_name () = state(next nameStrm)
37     fun trav M =
38       if is_comb M then
39            let val (M1,M2) = dest_comb M
40                val M1' = trav M1
41                val M2' = trav M2
42            in mk_comb(M1',M2')
43            end else
44       if is_abs M then
45           let val (v,N) = dest_abs(rename_bvar (next_name()) M)
46           in mk_abs(v,trav N)
47           end
48       else M
49 in
50   trav tm
51 end;
52
53fun STD_BVARS_CONV stem tm =
54 let val tm' = std_bvars stem tm
55 in Thm.ALPHA tm tm'
56 end;
57
58val STD_BVARS = CONV_RULE o STD_BVARS_CONV;
59val STD_BVARS_TAC = CONV_TAC o STD_BVARS_CONV;
60
61(*---------------------------------------------------------------------------*)
62(* Part of the Compiler frontend ... largely taken from ANF.sml              *)
63(*---------------------------------------------------------------------------*)
64
65(*****************************************************************************)
66(* Error reporting function                                                  *)
67(*****************************************************************************)
68
69val ERR = mk_HOL_ERR "COMBINATOR";
70
71(*****************************************************************************)
72(* List of definitions (useful for rewriting)                                *)
73(*****************************************************************************)
74
75val SimpThms = [sc_def, cj_def, tr_def];
76
77(*****************************************************************************)
78(* An expression is just a HOL term built using expressions defined earlier  *)
79(* in a program (see description of programs below) and sc, cj and tr:       *)
80(*                                                                           *)
81(*  expr := sc expr expr                                                     *)
82(*        | cj expr expr expr                                                *)
83(*        | tr expr expr expr                                                *)
84(*                                                                           *)
85(*****************************************************************************)
86
87(*****************************************************************************)
88(* to_combinator ``\(x1,...,xn). tm[x1,...,xn]`` returns a theorem           *)
89(*                                                                           *)
90(*  |- (\(x1,...,xn). tm[x1,...,xn]) = p                                     *)
91(*                                                                           *)
92(* where p is a combinatory expression built from the combinators            *)
93(* sc and cj.                                                                *)
94(*****************************************************************************)
95
96fun is_word_literal tm =
97  ((is_comb tm) andalso
98  let val (c,args) = strip_comb tm
99      val {Name,Thy,Ty} = dest_thy_const c
100  in Name = "n2w" andalso numSyntax.is_numeral (hd args)
101  end)
102  handle HOL_ERR _ => raise ERR "is_word_literal" "";
103
104
105fun to_combinator f =
106 let val (args,t) = dest_pabs f
107 in
108   if is_var t orelse is_word_literal t orelse numSyntax.is_numeral t orelse is_const t orelse is_pair t then
109     REFL f
110
111   else if is_cond t then
112     let val (b,t1,t2) = dest_cond t
113            val fb = mk_pabs(args,b)
114            val f1 = mk_pabs(args,t1)
115            val f2 = mk_pabs(args,t2)
116            val thb = PBETA_CONV (mk_comb(fb,args))
117            val th1 = PBETA_CONV (mk_comb(f1,args))
118            val th2 = PBETA_CONV (mk_comb(f2,args))
119            val th3 = ISPECL [fb,f1,f2] cj_def
120            val ptm = mk_pabs
121                       (args,
122                        mk_cond(mk_comb(fb,args),mk_comb(f1,args),mk_comb(f2,args)))
123            val th4 = TRANS th3 (PALPHA  (rhs(concl th3)) ptm)
124            val th5 = CONV_RULE
125                       (RHS_CONV
126                         (PABS_CONV
127                           (RAND_CONV(REWR_CONV th2)
128                             THENC RATOR_CONV(RAND_CONV(REWR_CONV th1))
129                             THENC RATOR_CONV(RATOR_CONV(RAND_CONV(REWR_CONV thb))))))
130                       th4
131            val th6 = GSYM th5
132        in
133            CONV_RULE
134                (RHS_CONV
135                  ((RAND_CONV to_combinator)
136                   THENC (RATOR_CONV(RAND_CONV to_combinator))
137                   THENC (RATOR_CONV(RATOR_CONV(RAND_CONV to_combinator)))))
138                th6
139        end
140
141  else if is_let t then  (*  t = LET (\v. N) M  *)
142      let val (v,M,N) = dest_plet t
143            val liveVarS = Binaryset.addList (Binaryset.empty varOrder, free_vars N)
144            val extraVarS = Binaryset.delete (liveVarS, v)
145            val th1 =
146              if Binaryset.isEmpty extraVarS then
147                let val f1 = mk_pabs(args, M)
148                    val f2 = mk_pabs(v, N)
149                in
150                    ISPECL [f1,f2] sc_def
151                end
152              else
153                let
154                  val extraVars = list_mk_pair (Binaryset.listItems extraVarS)
155                  val args1 = mk_pair (extraVars,v)
156                  val f1 = mk_pabs(args, mk_pair (extraVars, M))
157                  val f2 = mk_pabs(args1,N)
158                in
159                  ISPECL [f1,f2] sc_def
160                end
161              val th2 = CONV_RULE(RHS_CONV(SIMP_CONV std_ss [LAMBDA_PROD])) th1
162              val th3 = TRANS th2 (SYM (QCONV (fn t => SIMP_CONV std_ss [LAMBDA_PROD, Once LET_THM] t) f))
163              val th4 = SYM th3
164        in
165              CONV_RULE
166                (RHS_CONV
167                  ((RAND_CONV to_combinator)
168                   THENC (RATOR_CONV (RAND_CONV (RAND_CONV to_combinator)))))
169                th4
170        end
171
172  else if is_comb t then
173       REFL f
174  else
175       REFL f
176(*
177  else (print_term f; print "\n";
178        print "shouldn't get this case (not first-order)!\n";
179        raise ERR "to_combinator" "shouldn't happen")
180*)
181 end;
182
183(*****************************************************************************)
184(* Predicate to test whether a term occurs in another term                   *)
185(*****************************************************************************)
186
187fun occurs_in t1 t2 = can (find_term (aconv t1)) t2;
188
189(*****************************************************************************)
190(* convert_to_combinator (|- f x = e) returns a theorem                      *)
191(*                                                                           *)
192(*  |- f = p                                                                 *)
193(*                                                                           *)
194(* where p is a combinatory expression built from the combinators Seq and Ite*)
195(*****************************************************************************)
196
197fun convert defth =
198 let val (lt,rt) =
199         dest_eq(concl(SPEC_ALL defth))
200         handle HOL_ERR _
201         => (print "not an equation\n"; raise ERR "Convert" "not an equation")
202     val (func,args) =
203         dest_comb lt
204         handle HOL_ERR _ =>
205         (print "lhs not a comb\n"; raise ERR "Convert" "lhs not a comb")
206     val _ = if not(is_const func)
207              then (print_term func; print " is not a constant\n";
208                    raise ERR "Convert" "rator of lhs not a constant")
209              else ()
210     val _ = if not(subtract (free_vars rt) (free_vars lt) = [])
211              then (print "definition rhs has unbound variables: ";
212                    map (fn t => (print_term t; print " "))
213                        (rev(subtract (free_vars rt) (free_vars lt)));
214                    print "\n";
215                    raise ERR "Convert" "definition rhs has unbound variables")
216              else ()
217 in
218  let val f = mk_pabs(args,rt)
219      val th1 = to_combinator f
220      val th2 = PABS args (SPEC_ALL defth)
221      val th3 = TRANS th2 th1
222  in
223   CONV_RULE (LHS_CONV PETA_CONV) th3
224  end
225 end;
226
227
228(*****************************************************************************)
229(* tr_convert (|- f x = if f1 x then f2 x else f(f3 x))                      *)
230(*            (|- TOTAL(f1,f2,f3))                                           *)
231(*                                                                           *)
232(* returns a theorem                                                         *)
233(*                                                                           *)
234(*  |- f = tr (p1,p2,p3)                                                     *)
235(*                                                                           *)
236(* where p1, p2 and p3 are combinatory expressions built from the            *)
237(* combinators sc and cj                                                     *)
238(*                                                                           *)
239(*****************************************************************************)
240
241
242val I_DEF_ALT = Q.prove(`I = \x.x`,SIMP_TAC std_ss [FUN_EQ_THM]);
243val rec_INTRO_ALT = REWRITE_RULE[I_DEF_ALT] rec_INTRO;
244val sc_o = Q.prove(`!f g. f o g = sc g f`,
245         SIMP_TAC std_ss [combinTheory.o_DEF,sc_def]);
246
247(* Define `f1 (x,a) = if x = 0 then x + a else f1(x - 1, a * x)` *)
248
249fun rec_convert defth =
250 let val (lt,rt) = dest_eq(concl(SPEC_ALL defth))
251     val (func,args) = dest_comb lt
252     val (b,t1,t2) = dest_cond rt
253     val _ = if not(subtract (free_vars rt) (free_vars lt) = [])
254              then (print "definition rhs has unbound variables: ";
255                    map (fn t => (print_term t; print " "))
256                        (rev(subtract (free_vars rt) (free_vars lt)));
257                    print "\n";
258                    raise ERR "RecConvert" "definition rhs has unbound variables")
259              else()
260 in
261  if (is_comb t2
262       andalso aconv (rator t2) func
263       andalso not(occurs_in func b)
264       andalso not(occurs_in func t1)
265       andalso not(occurs_in func (rand t2)))
266  then
267   let val fb = mk_pabs(args,b)
268       val f1 = mk_pabs(args,t1)
269       val f2 = mk_pabs(args,rand t2)
270       val thm = ISPECL[func,fb,f1,f2] rec_INTRO
271
272       val M = fst(dest_imp(concl thm))
273       val (v,body) = dest_forall M
274       val M' = rhs(concl(DEPTH_CONV PBETA_CONV
275                 (mk_pforall(args,subst [v |-> args]body))))
276       val MeqM' = prove(mk_eq(M,M'),
277                    Ho_Rewrite.REWRITE_TAC[LAMBDA_PROD]
278                     THEN PBETA_TAC THEN REFL_TAC)
279       val thm1 = PURE_REWRITE_RULE[MeqM'] thm
280       val thm2 = PGEN args defth
281       val thm3 = MP thm1 thm2
282
283       val N = fst(dest_imp(concl thm3))
284       val (R,tm) = dest_exists N
285       val (tm1,tm2) = dest_conj tm
286       val (v,body) = dest_forall tm2
287       val tm2' = rhs(concl(DEPTH_CONV PBETA_CONV
288                 (mk_pforall(args,subst [v |-> args]body))))
289       val N' = mk_exists(R,mk_conj(tm1,tm2'))
290       val NeqN' = prove(mk_eq(N,N'),
291                    Ho_Rewrite.REWRITE_TAC[LAMBDA_PROD]
292                     THEN PBETA_TAC THEN REFL_TAC)
293       val thm4 = PURE_REWRITE_RULE[NeqN'] thm3
294       val thm5 = UNDISCH thm4
295       val thm6 = CONV_RULE (RHS_CONV
296                   (RAND_CONV to_combinator THENC
297                    RATOR_CONV(RAND_CONV (RAND_CONV to_combinator)) THENC
298                    RATOR_CONV(RAND_CONV (RATOR_CONV (RAND_CONV to_combinator)))))
299                  thm5
300    in thm6
301    end
302  else if occurs_in func rt
303   then (print "definition of: "; print_term func;
304         print " is not tail recursive"; print "\n";
305         raise ERR "RecConvert" "not tail recursive")
306   else raise ERR "RecConvert" "this shouldn't happen"
307 end;
308
309(*---------------------------------------------------------------------------*)
310(* Convert a possibly (tail) recursive equation to combinator form, either   *)
311(* by calling convert or tr_convert.                                         *)
312(*---------------------------------------------------------------------------*)
313
314fun toComb def =
315 let val (l,r) = dest_eq(snd(strip_forall(concl def)))
316     val (func,args) = dest_comb l
317     val is_recursive = Lib.can (find_term (aconv func)) r
318     val comb_exp_thm = if is_recursive then rec_convert def else convert def
319 in
320   (is_recursive,lhs(concl comb_exp_thm), args, comb_exp_thm)
321 end;
322
323(*---------------------------------------------------------------------------*)
324(* Given an environment and a possibly (tail) recursive definition, convert  *)
325(* to combinator form, then add the result to the environment.               *)
326(*---------------------------------------------------------------------------*)
327
328fun toACF env def =
329 let val (is_recursive,func,args,const_eq_comb) = toComb def
330 in
331   (func,(is_recursive,def,const_eq_comb))::env
332 end;
333
334
335end
336end
337