1structure Normal :> Normal =
2struct
3(*
4quietdec := true;
5app load ["basic","NormalTheory","pairLib"];
6open HolKernel Parse boolLib bossLib;
7open pairLib pairSyntax pairTheory PairRules NormalTheory basic;
8quietdec := false;
9*)
10
11open HolKernel Parse boolLib bossLib;
12open pairLib pairSyntax pairTheory PairRules NormalTheory basic;
13
14val ERR = mk_HOL_ERR "Normal";
15
16val atom_tm = prim_mk_const{Name="atom",Thy="Normal"};
17
18fun mk_atom tm =
19  with_exn mk_comb (inst [alpha |-> type_of tm] atom_tm, tm)
20          (ERR "mk_atom" "Non-boolean argument");
21
22val branch_join = ref false
23
24(*---------------------------------------------------------------------------*)
25(* Pre-processing. (KLS: not sure where this is used.)                       *)
26(* Apply the rewrite rules in bool_ss to simplify boolean connectives and    *)
27(* conditional expressions.                                                  *)
28(* It contains all of the de Morgan theorems for moving negations in over    *)
29(* the connectives (conjunction, disjunction, implication and conditional    *)
30(* expressions). It also contains the rules specifying the behaviour of the  *)
31(* connectives when the constants T and F appear as their arguments. The     *)
32(* arith_ss simpset extends std_ss by adding the ability to decide formulas  *)
33(* of Presburger arithmetic, and to normalise arithmetic expressions         *)
34(*---------------------------------------------------------------------------*)
35
36val PRE_PROCESS_RULE = SIMP_RULE arith_ss [AND_COND, OR_COND, BRANCH_NORM];
37
38val pre_process = PBETA_RULE o REWRITE_RULE [AND_COND, OR_COND] o
39                   SIMP_RULE arith_ss [ELIM_USELESS_LET];
40
41(*---------------------------------------------------------------------------*)
42(* Normalization                                                             *)
43(* This intermediate language is a combination of K-normal forms             *)
44(* and A-normal forms                                                        *)
45(*---------------------------------------------------------------------------*)
46
47val C_tm = prim_mk_const{Name="C",Thy="Normal"};
48fun mk_C tm = mk_comb (inst [alpha |-> type_of tm] C_tm, tm);
49val dest_C = dest_monop C_tm (ERR "dest_C" "");
50
51(*---------------------------------------------------------------------------*)
52(* Convert an expression to its continuation form                            *)
53(* Algorithm of the conversion                                               *)
54(*---------------------------------------------------------------------------*)
55
56(* Rewrite the first two components with the obtained theorems *)
57
58fun SUBST_2_RULE (lem1,lem2) =
59  CONV_RULE (RHS_CONV (PABS_CONV (
60          RATOR_CONV (REWRITE_CONV [Once lem1]) THENC
61                      RAND_CONV (PABS_CONV (RATOR_CONV
62                                     (REWRITE_CONV [Once lem2]))))));
63
64(* CONV_RULE (RHS_CONV (PABS_CONV (
65          RAND_CONV (PABS_CONV (RATOR_CONV
66                       (REWRITE_CONV [Once lem2])))))) th3
67*)
68
69(* Normalize multiplication expressions *)
70
71fun norm_mult (th0,th1) (d0,d1) exp =
72  if is_8bit_literal d0 then
73    (if is_8bit_literal d1 then       (* const op const *)
74       SUBST_2_RULE (REWRITE_RULE [Once ATOM_ID] th0, REWRITE_RULE [Once ATOM_ID] th1)
75         (ONCE_REWRITE_CONV [C_BINOP,C_WORDS_BINOP] exp)
76     else                        (* const op var *)
77         SUBST_2_RULE (REWRITE_RULE [Once ATOM_ID] th0,th1) (ONCE_REWRITE_CONV [C_BINOP,C_WORDS_BINOP] exp)
78    )
79  else
80    (if is_8bit_literal d1 then       (* var op const *)
81       SUBST_2_RULE (th0, REWRITE_RULE [Once ATOM_ID] th1) (ONCE_REWRITE_CONV [C_BINOP,C_WORDS_BINOP] exp)
82     else                        (* var op var *)
83       SUBST_2_RULE (th0,th1) (ONCE_REWRITE_CONV [C_BINOP,C_WORDS_BINOP] exp)
84    )
85
86(* Rewrite the four components of atomic branch with the obtained theorems *)
87
88fun Normalize_Atom_Cond (lem0,lem1,lem2,lem3) exp =
89   let
90     val th1 =  ONCE_REWRITE_CONV [C_ATOM_COND] exp
91     val th2 =  CONV_RULE (RHS_CONV (PABS_CONV (
92                  RATOR_CONV (REWRITE_CONV [Once lem0]) THENC
93                  RAND_CONV (PABS_CONV (
94                    RATOR_CONV (REWRITE_CONV [Once lem1]) THENC
95                    RAND_CONV (PABS_CONV (RATOR_CONV (RAND_CONV (
96                      RATOR_CONV (RAND_CONV (RATOR_CONV
97                                (REWRITE_CONV [Once lem2]))) THENC
98                      RAND_CONV (RATOR_CONV (REWRITE_CONV [Once lem3]))))))
99                    ))
100                 ))) th1
101     val th3 = (PBETA_RULE o REWRITE_RULE [C_ATOM]) th2
102   in
103      th3
104   end;
105
106fun Normalize_Atom_Cond_Ex (lem0,lem1,lem2,lem3) exp =
107   let
108     val th1 =  ONCE_REWRITE_CONV [C_ATOM_COND_EX] exp
109     val th2 =  CONV_RULE (RHS_CONV (PABS_CONV (
110                  RATOR_CONV (REWRITE_CONV [Once lem0]) THENC
111                  RAND_CONV (PABS_CONV (
112                    RATOR_CONV (REWRITE_CONV [Once lem1]) THENC
113                    RAND_CONV (PABS_CONV (RAND_CONV (
114                      RATOR_CONV (RAND_CONV (RATOR_CONV
115                                (REWRITE_CONV [Once lem2]))) THENC
116                      RAND_CONV (RATOR_CONV (REWRITE_CONV [Once lem3])))))
117                    ))
118                 ))) th1
119     val th3 = (PBETA_RULE o REWRITE_RULE [C_ATOM]) th2
120   in
121      th3
122   end;
123
124
125(*  for debugging
126fun mk_let_norm (c_thm, th0, th1) =
127  let val t0 = rhs (concl (SPEC_ALL c_thm))
128      val (k, t1) = dest_pabs t0
129      val (c1, ts1) = strip_comb t1
130      val (x, t2) = dest_pabs (hd (tl ts1))
131      val (c2, ts2) = strip_comb t2
132      val cont = hd (tl ts2)
133      val (e1, e2) = (rhs (concl th0), rhs (concl th1))
134
135      val cfk = mk_pabs(x, mk_comb(e2, cont))       (* \x. C (f x) (\y. k y) *)
136      val theta = match_type (hd (#2 (dest_type (type_of e1)))) (type_of cfk)
137      val cexp = mk_pabs(k, mk_comb(inst theta e1, cfk))       (* \k. C e (\x. C (f x) (\y. k y)) *)
138  in cexp
139  end
140*)
141
142fun K_Normalize exp =
143 let val t = dest_C exp               (* eliminate the C *)
144 in
145  if is_word_literal t andalso not (is_8bit_literal t) then REFL exp (* over 8-bit word constants *)
146  else if is_atomic t then ISPEC t C_ATOM_INTRO
147  else if is_let t then                  (*  exp = LET (\v. N) M  *)
148   let val (v,M,N) = dest_plet t
149       val branch_flag = !branch_join    (* save the branch_join flag *)
150       val _ = branch_join := (if is_cond M then true else false)     (* let v = if ... then ... else ... in ... *)
151       val (th0, th1) = (K_Normalize (mk_C M), K_Normalize (mk_C N))
152       val _ = branch_join := branch_flag   (* restore the branch_join flag *)
153       val th3 = SIMP_CONV bool_ss [Once C_LET] exp
154                 handle Conv.UNCHANGED (* case let (v1,v2,..) = ... in ... *)
155                 =>
156                     let val t1 = mk_pabs(v,N)
157                         val th = INST_TYPE [alpha |-> type_of N,
158                                            (* beta  |-> type_of N, *)
159                                            gamma |-> type_of v]
160                                 (GEN_ALL C_LET)
161                         val t2 = rhs (concl (SPECL [t1, M] th))
162                         (* val theta = match_type (type_of exp) (type_of t2) *)
163                         val th2 = prove (mk_eq(exp,t2), (* set_goal([],mk_eq(exp,t2)) *)
164                                          SIMP_TAC bool_ss [LET_THM, C_def])
165                     in
166                         PBETA_RULE th2
167                     end
168       val th4 = (SUBST_2_RULE (th0, th1)) th3
169       val th5 = (PBETA_RULE o REWRITE_RULE [C_ATOM]) th4
170    in
171       th5
172    end
173
174  else if is_pabs t then                        (*  exp = \(x...).M *)
175    let
176       val (v,M) = dest_pabs t
177       val (v_type, m_type) = (type_of v, type_of M);
178       val t1 = mk_C(M)
179       val id = let val x = mk_var("x",alpha) in mk_abs(x,x) end
180       val t2 = mk_comb (inst [beta |-> m_type] t1,
181                         inst [alpha |-> m_type] id)
182       val t3 = mk_C (mk_pabs(v, t2))
183       val th0 = K_Normalize t1
184       val th1 = prove (mk_eq(exp, t3), SIMP_TAC std_ss [C_def]);
185       val th2 = CONV_RULE (RHS_CONV (RAND_CONV (PABS_CONV (
186                    RATOR_CONV (ONCE_REWRITE_CONV [th0]))))) th1
187       val th3 = (PBETA_RULE o REWRITE_RULE [C_ATOM]) th2
188    in
189       th3
190    end
191
192    else if is_pair t then                        (*  exp = (M,N) *)
193        let
194          val (M,N) = dest_pair t
195          val th0 = K_Normalize (mk_C M)
196          val th1 = K_Normalize (mk_C N)
197          val th2 = ONCE_REWRITE_CONV [C_PAIR] exp
198          val th3 = SUBST_2_RULE (th0,th1) th2
199          val th4 = (PBETA_RULE o REWRITE_RULE [C_ATOM]) th3
200        in
201          th4
202        end
203
204    else if is_cond t then                  (*  exp = if P then M else N *)
205      let
206         val (J,M,N) = dest_cond t
207      in
208         if is_atom_cond J then
209           let
210              val (op0, [P,Q]) = strip_comb J
211              val (lem0, lem1, lem2, lem3) =
212                 (K_Normalize (mk_C P),
213                  K_Normalize (mk_C Q),
214                  K_Normalize (mk_C M),
215                  K_Normalize (mk_C N))
216
217              val th4 = if !branch_join then Normalize_Atom_Cond (lem0,lem1,lem2,lem3) exp
218                        else Normalize_Atom_Cond_Ex (lem0,lem1,lem2,lem3) exp
219              val th5 = if is_8bit_literal P then CONV_RULE (RHS_CONV (PABS_CONV (RAND_CONV (RATOR_CONV (RATOR_CONV (
220                              REWRITE_CONV [Once COND_SWAP])))))) th4
221                        else th4
222           in
223              th5
224           end
225
226        else
227          REFL exp
228      end
229
230   else if is_comb t then
231      let fun norm_app (M,N) =
232            let val th0 = K_Normalize (mk_C M)
233                val th1 = K_Normalize (mk_C N)
234                val th2 = ONCE_REWRITE_CONV [C_APP] exp
235                val th3 =  SUBST_2_RULE (th0,th1) th2
236                val th4 = (PBETA_RULE o REWRITE_RULE [C_ATOM]) th3
237            in
238              th4
239            end
240      in
241       case strip_comb t of
242         (operator,[d0,d1]) =>
243           if is_binop operator then
244              let val th0 = K_Normalize (mk_C d0)
245                  val th1 = K_Normalize (mk_C d1)
246
247                  val th2 =
248                        if is_mult_op operator then
249                          norm_mult (th0,th1) (d0,d1) exp
250                        else if is_8bit_literal d0 then
251                          (if is_8bit_literal d1 then       (* const op const *)
252                             SUBST_2_RULE (REWRITE_RULE [Once ATOM_ID] th0, th1)
253                               (ONCE_REWRITE_CONV [C_BINOP,C_WORDS_BINOP] exp)
254                           else                        (* const op var *)
255                             SUBST_2_RULE (th0,th1) (ONCE_REWRITE_CONV [C_BINOP_SYM,C_WORDS_BINOP_SYM] exp)
256                          )
257                        else (* var op const || var op var *)
258                          SUBST_2_RULE (th0, th1) (ONCE_REWRITE_CONV [C_BINOP,C_WORDS_BINOP] exp)
259
260                  val th3 = (PBETA_RULE o REWRITE_RULE [C_ATOM]) th2
261              in
262                th3
263              end
264           else norm_app (dest_comb t)
265         | _ => norm_app (dest_comb t)
266      end
267   else
268    REFL exp
269 end;
270
271(*---------------------------------------------------------------------------*)
272(*   Convert a function to its equivalent KNF                                *)
273(*---------------------------------------------------------------------------*)
274
275fun normalize def =
276 let val _ = branch_join := false     (* need not to introduce new binding for top level conditionals *)
277     (* Break compound condition jumps *)
278     val thm1 = def (* (PBETA_RULE o REWRITE_RULE [AND_COND, OR_COND]) thm0 *)
279     val thm2 = CONV_RULE (RHS_CONV (ONCE_REWRITE_CONV [C_INTRO]))
280                          (SPEC_ALL thm1)
281
282     val exp = mk_C (rhs (concl (SPEC_ALL thm1)))
283     val lem3 = K_Normalize exp          (* Continuation Form *)
284
285     val thm4 = PURE_ONCE_REWRITE_RULE [lem3] thm2
286     val thm5 = (SIMP_RULE bool_ss [C_2_LET]) thm4 (* "Let" Form *)
287     val thm6 = REWRITE_RULE [BRANCH_NORM] thm5
288     (* val thm6 = CONV_RULE (DEPTH_CONV (SIMP_CONV bool_ss [Once LET_THM])) thm5 *)
289 in
290   thm6
291 end
292
293(*---------------------------------------------------------------------------*)
294(* Beta-Reduction on Let expressions (rule "atom_let")                       *)
295(* Reduce expressions such as let x = y in x + y to y + y, expanding         *)
296(* the aliasing of variables                                                 *)
297(*---------------------------------------------------------------------------*)
298
299fun identify_atom tm =
300  let
301     fun trav t =
302       if is_let t then
303           let val (v,M,N) = dest_plet t
304               val M' = if is_atomic M then
305                           mk_atom M
306                        else if is_pair v andalso is_pair M then
307                           trav M        (* to be revised *)
308                        else trav M
309           in mk_plet (v, M', trav N)
310           end
311       else if is_comb t then
312            let val (M1,M2) = dest_comb t
313                val M1' = trav M1
314                val M2' = trav M2
315            in mk_comb(M1',M2')
316            end
317       else if is_pabs t then
318           let val (v,M) = dest_pabs t
319           in mk_pabs(v,trav M)
320           end
321       else t
322  in
323    trav tm
324  end;
325
326fun beta_reduction def =
327  let
328    val t0 = rhs (concl (SPEC_ALL def))
329    val t1 = identify_atom t0
330    val lem0 = REFL t1
331    val lem1 = CONV_RULE (LHS_CONV (REWRITE_CONV [ATOM_ID])) lem0
332    val thm1 = ONCE_REWRITE_RULE [lem1] def
333    val thm2 = SIMP_RULE bool_ss [BETA_REDUCTION] thm1
334  in
335    thm2
336  end
337
338(*---------------------------------------------------------------------------*)
339(* Elimination of Unnecessary Definitions                                    *)
340(* If e1 has no side effect and x does not appear free in                    *)
341(* e2, we can replace let x = e1 in e2 just with e2.                         *)
342(*---------------------------------------------------------------------------*)
343
344val ELIM_LET_RULE = SIMP_RULE bool_ss [ELIM_USELESS_LET]
345
346(*---------------------------------------------------------------------------*)
347(* Elimination of Unnecessary Definitions                                    *)
348(* If e1 has no side effect and x does not appear free in                    *)
349(* e2, we can replace let x = e1 in e2 just with e2.                         *)
350(*---------------------------------------------------------------------------*)
351
352val FLATTEN_LET_RULE = SIMP_RULE std_ss [FLATTEN_LET]
353
354(*---------------------------------------------------------------------------*)
355(* Convert the normal form into SSA form. Ensures that each let-bound        *)
356(* variable name in a term is different than the others.                     *)
357(*---------------------------------------------------------------------------*)
358
359fun to_ssa stem tm =
360 let open Lib
361     fun num2name i = stem^Lib.int_to_string i
362     val nameStrm = Lib.mk_istream (fn x => x+1) 0 num2name
363     fun next_name () = state(next nameStrm)
364     fun trav M =
365       if is_comb M then
366            let val (M1,M2) = dest_comb M
367                val M1' = trav M1
368                val M2' = trav M2
369            in mk_comb(M1',M2')
370            end else
371       if is_pabs M then
372           let val (v,N) = dest_pabs(rename_bvar (next_name()) M)
373           in mk_pabs(v,trav N)
374           end
375       else M
376 in
377   trav tm
378 end;
379
380fun SSA_CONV tm =
381  let val tm' = to_ssa "v" tm
382  in Thm.ALPHA tm tm'
383  end;
384
385fun SSA_RULE def =
386 let val (t1,t2) = dest_eq (concl (SPEC_ALL def))
387     val (fname,args) =
388         if is_comb t1 then ((I##single) o dest_comb) t1 else
389         if is_pabs t1 then (t1, [#1 (dest_pabs t2)]) else (t1,[])
390     val flag = is_pabs t2
391     val body = if flag then #2 (dest_pabs t2) else t2
392     val lem1 = if flag then def else
393                prove (mk_eq (fname, list_mk_pabs(args,body)),
394                       SIMP_TAC std_ss [FUN_EQ_THM, FORALL_PROD, Once def])
395     val t3 = if flag then t2 else list_mk_pabs (args,body)
396     val lem2 = SIMP_CONV std_ss [LAMBDA_PROD] t3 handle UNCHANGED => REFL t3
397     val lem3 = TRANS lem1 lem2
398     val lem4 = SSA_CONV (rhs (concl lem3))
399     val lem5 = ONCE_REWRITE_RULE [lem4] lem3
400 in
401    lem5
402 end;
403
404(* Original
405fun SSA_RULE def =
406  let
407      val t0 = concl (SPEC_ALL def)
408      val (t1,t2) = (lhs t0, rhs t0)
409      val flag = is_pabs t2
410      val (fname, args) =
411          if is_comb t1 then dest_comb t1
412          else (t1, #1 (dest_pabs t2))
413      val body = if flag then #2 (dest_pabs t2) else t2
414      val lem1 = if flag then def
415                 else prove (mk_eq (fname, mk_pabs(args,body)),
416                        SIMP_TAC std_ss [FUN_EQ_THM, FORALL_PROD, Once def])
417      val t3 = if flag then t2 else mk_pabs (args,body)
418      val lem2 = QCONV SIMP_CONV std_ss [LAMBDA_PROD] t3
419                  (* handle HOL_ERR _ => REFL t3 *)
420      val lem3 = TRANS lem1 lem2
421      val t4 = rhs (concl (GEN_ALL lem3))  (* SPEC_ALL? *)
422      val lem4 = SSA_CONV t4
423      val lem5 = ONCE_REWRITE_RULE [lem4] lem3
424  in
425    lem5
426  end;
427*)
428
429(*---------------------------------------------------------------------------*)
430(* Normalized forms with after a series of optimizations                     *)
431(*---------------------------------------------------------------------------*)
432
433end (* Normal *)
434