1structure refine =
2struct
3
4(*
5app load ["wordsLib", "Normal"];
6*)
7local
8open HolKernel Parse boolLib bossLib
9     wordsSyntax numSyntax pairSyntax NormalTheory
10in
11(*---------------------------------------------------------------------------*)
12
13val C_tm = prim_mk_const{Name="C",Thy="Normal"};
14fun mk_C tm =
15  mk_comb (inst [alpha |-> type_of tm, beta |-> type_of tm] C_tm, tm);
16
17val dest_C = dest_monop C_tm (ERR "dest_C" "");
18
19(*---------------------------------------------------------------------------*)
20(* To eventually go to wordsSyntax.                                          *)
21(*---------------------------------------------------------------------------*)
22
23val is_word_literal = is_n2w;
24
25fun strip_word_or tm =
26 let fun f tm =
27      case total dest_word_or tm
28       of SOME (l,r) => f(l) @ f(r)
29        | NONE => [tm]
30 in Lib.front_last(f tm)
31 end;
32
33val list_mk_word_or = end_itlist (curry mk_word_or);
34
35fun pflat [] = []
36  | pflat ((x,y)::t) = x::y::pflat t;
37
38(*---------------------------------------------------------------------------*)
39(* Translating constants into compound expressions.                          *)
40(*---------------------------------------------------------------------------*)
41
42fun type_to_name t =
43 String.extract(Hol_pp.type_to_string t, 1, NONE);
44
45fun numeric_type_to_num t =
46 Lib.with_flag(type_pp.pp_num_types,true)
47  Arbnum.fromString (type_to_name t);
48
49fun pad s i =
50 let fun loop n acc = if n <= 0 then acc else loop(n-1) (#"0"::acc)
51 in String.implode (loop i (String.explode s))
52 end;
53
54fun bit_pattern w =
55 let open wordsSyntax
56     val (ntm,ty) = dest_n2w w
57     val n = numSyntax.dest_numeral ntm
58     val ty_width = Arbnum.toInt(numeric_type_to_num ty)
59     val str = Arbnum.toBinString n
60 in
61   pad str (ty_width - String.size str)
62 end;
63
64fun word_of_string(s,width) =
65    mk_n2w(mk_numeral(Arbnum.fromBinString s),width);
66
67val index32 = fcpLib.index_type (Arbnum.fromInt 32);
68fun word32_of_string s = word_of_string(s,index32);
69
70fun word_of_int (i,width) =
71  mk_n2w(numSyntax.term_of_int i,
72         fcpLib.index_type(Arbnum.fromInt width));
73
74fun chunk s =
75 let open String numSyntax
76     val s1 = substring(s,0,8)
77     val s2 = substring(s,8,8)
78     val s3 = substring(s,16,8)
79     val s4 = substring(s,24,8)
80 in
81   (word32_of_string s1,word32_of_string s2,
82    word32_of_string s3,word32_of_string s4)
83 end;
84
85val zero32 = ``0w:word32``;
86val n8 = term_of_int 8;
87val n16 = term_of_int 16;
88val n24 = term_of_int 24;
89val n256 = Arbnum.fromInt 256;
90
91fun bytes_to_let (b1,b2,b3,b4) =
92 let val plist = List.mapPartial
93                   (fn p as (b,s) => if b = zero32 then NONE else SOME p)
94                    [(b1,n24), (b2,n16), (b3,n8)]
95     val plist' = enumerate 0 plist
96     val plist'' = map (fn (i,p as (b,_)) =>
97                          (mk_var("v"^Int.toString i,type_of b),p)) plist'
98     val vlist = list_mk_word_or (map fst plist'' @ [b4])
99     fun foo (v,(c,s)) = ((v,c),(v,mk_word_lsl(v,s)))
100     val vb4 = mk_var("v"^Int.toString(length plist''), type_of b4)
101     val plist3 = pflat(map foo plist'') @ [(vb4,vlist)]
102 in list_mk_anylet (map single plist3,vb4)
103 end;
104
105fun IMMEDIATE_CONST_CONV c =
106 let val n = numSyntax.dest_numeral(fst(dest_n2w c))
107 in if Arbnum.<(n,n256) then failwith "CONST_CONV" else
108    let val bstr = bit_pattern c
109        val res = bytes_to_let (chunk bstr)
110    in EQT_ELIM (wordsLib.WORD_CONV (mk_eq(c,res)))
111    end
112 end;
113
114(*---------------------------------------------------------------------------*)
115(* Want to remove returned constants, in favour of returned registers.       *)
116(* So "if P then c else d" where c or d are constants, needs to become       *)
117(*                                                                           *)
118(*   if P then (let x = c in x) else (let x = d in x)                        *)
119(*---------------------------------------------------------------------------*)
120
121fun MK_COND tm th1 th2 =
122 let val core = rator(rator tm)
123     val thm1 = MK_COMB (REFL core, th1)
124     val thm2 = MK_COMB (thm1,th2)
125 in thm2
126 end;
127
128fun letify c =
129 let val v = genvar (type_of c)
130     val tm = mk_let (mk_abs(v,v),c)
131 in SYM(BETA_RULE (REWR_CONV LET_THM tm))
132 end;
133
134fun COND_CONST_ELIM_CONV tm =
135 let val (t,a1,a2) = dest_cond tm
136 in case (is_const a1 orelse is_word_literal a1,
137          is_const a2 orelse is_word_literal a2)
138     of (true,true) => MK_COND tm (letify a1) (letify a2)
139      | (true,false) => MK_COND tm (letify a1) (REFL a2)
140      | (false,true) => MK_COND tm (REFL a1) (letify a2)
141      | (false,false) => failwith ""
142 end
143  handle HOL_ERR _ => raise ERR "COND_CONST_ELIM_CONV" "";
144
145(*---------------------------------------------------------------------------*)
146(* Expand constants, and eliminate returned constants (return values must    *)
147(* be in registers).                                                         *)
148(*---------------------------------------------------------------------------*)
149
150val refine0 = CONV_RULE (DEPTH_CONV IMMEDIATE_CONST_CONV);
151val refine0a = CONV_RULE (DEPTH_CONV COND_CONST_ELIM_CONV);
152val refine0b = Ho_Rewrite.PURE_REWRITE_RULE [NormalTheory.FLATTEN_LET];
153
154val refine_const = refine0b o refine0a o refine0;
155
156(*---------------------------------------------------------------------------*)
157(* Some refinements after compilation                                        *)
158(*---------------------------------------------------------------------------*)
159
160val LIFT_COND_ABOVE_LET = Q.prove
161(`!f v1 v2 v3.
162  (let x = (if v1 then v2 else v3) in f x) =
163    if v1 then (let x = v2 in f x) else (let x = v3 in f x)`,
164 RW_TAC std_ss [LET_DEF]);
165
166val LIFT_COND_ABOVE_LET1 = Q.prove
167(`(let x = val in (if v1 then v2 x else v3 x)) =
168  if v1 then (let x = val in v2 x) else (let x = val in v3 x)`,
169 RW_TAC std_ss [LET_DEF]);
170
171val LIFT_COND_ABOVE_TRIVLET = Q.prove
172(`(let x = (if v1 then v2 else v3) in x) =
173  if v1 then v2 else v3`,
174 SIMP_TAC std_ss [LET_DEF]);
175
176val ID_LET = Q.prove
177(`LET (\x.x) y = y`,
178 SIMP_TAC std_ss [LET_DEF]);
179
180fun lift_cond def =
181  let
182    fun mk_flat_let (x, e1, e2) =
183        if is_pair x andalso is_pair e1 then
184           let val (x1,x2) = dest_pair x
185               val (e1',e1'') = dest_pair e1
186           in  mk_flat_let(x1, e1', mk_flat_let(x2, e1'', e2))
187           end
188        else if basic.is_atomic e1 andalso basic.is_atomic x then
189           subst [x |-> e1] e2
190        else
191           mk_plet(x, e1, e2)
192
193    fun trav t =
194      if is_let t then
195        let val (v,M,N) = dest_plet t
196            val N' = trav N
197        in if is_cond M then
198             let val (J, M1, M2) = dest_cond M
199                 val M1' = mk_flat_let (v, trav M1, N')
200                 val M2' = mk_flat_let (v, trav M2, N')
201             in mk_cond (J, M1', M2') end
202           else
203             mk_flat_let(v, trav M, trav N)
204        end
205       else if is_comb t then
206            let val (M1,M2) = dest_comb t
207                val M1' = trav M1
208                val M2' = trav M2
209            in mk_comb(M1',M2')
210            end
211       else if is_pabs t then
212           let val (v,M) = dest_pabs t
213           in mk_pabs(v,trav M)
214           end
215       else t
216
217      val (fname, fbody) = dest_eq (concl (SPEC_ALL def))
218      val fbody' = trav fbody
219      val th1 = prove (mk_eq(fbody, fbody'),
220                       SIMP_TAC bool_ss [LET_THM])
221                handle _ => prove (mk_eq(fbody, fbody'),
222                                   RW_TAC std_ss [LET_THM])
223                handle _ => def
224      val th2 = CONV_RULE (RHS_CONV (REWRITE_CONV [Once th1])) (SPEC_ALL def)
225  in
226     th2
227  end
228
229(*---------------------------------------------------------------------------*)
230(*   Convert a definition to its equivalent refined format                   *)
231(*---------------------------------------------------------------------------*)
232
233(*
234val LIFT_COND_ABOVE_C = Q.prove (
235  `!f v1 v2 v3. C (let x = (if v1 then v2 else v3) in f x) =
236     if v1 then C v2 (\x. C (f x)) else C v3 (\y. C (f y))`,
237  SIMP_TAC std_ss [C_def, LET_THM, FUN_EQ_THM] THEN
238  Cases_on `v1` THEN
239  RW_TAC std_ss []
240 );
241
242fun lift_cond exp =
243 let val t = dest_C exp               (* eliminate the C *)
244 in
245  if is_let t then                    (*  exp = LET (\v. N) M  *)
246    let val (v,M,N) = dest_plet t
247        val (th0, th1) = (lift_cond (mk_C M), lift_cond (mk_C N))
248    in  if is_cond M then
249          let val (J, M1, M2) = dest_cond M
250              val M1' = mk_plet (v, lift_cond M1, N')
251              val M2' = mk_plet (v, lift_cond M2, N')
252
253              val th3 =
254                     let val f = mk_pabs(v,N)
255                         val th = INST_TYPE [alpha |-> type_of N,
256                                            beta  |-> type_of N,
257                                            gamma |-> type_of v]
258                                 (LIFT_COND_ABOVE_C)
259                         val t2 = rhs (concl (SPECL [f, J, M1, M2] th))
260                         val th2 = prove (mk_eq(exp,t2),
261                                          RW_TAC std_ss [LET_THM, C_def])
262                     in
263                         PairRules.PBETA_RULE (SIMP_RULE std_ss [pairTheory.LAMBDA_PROD] th2)
264                     end
265    in
266       th5
267    end
268
269   else
270    REFL exp
271 end;
272*)
273
274(*---------------------------------------------------------------------------*)
275end (* local open in .... end *)
276end
277