1(*---------------------------------------------------------------------------*)
2(* Slightly optimised compiler                                               *)
3(*                                                                           *)
4(* We apply a peephole optmisation similar to PRECEDE and FOLLOW during the  *)
5(* compilation. The compiler tries to reduce the number of SEQs, PARs and    *)
6(* and ITEs.                                                                 *)
7(*                                                                           *)
8(* The difference between this version and the original compiler is that we  *)
9(* replace the function calls with their bodies in the function main (inline *)
10(* expansion).                                                               *)
11(*                                                                           *)
12(* This takes advantage of a global view of the structure of the hardware    *)
13(* instead of being restricted to the scope of a single function body.       *)
14(*                                                                           *)
15(* In order to implement the rewriting of the main function,                 *)
16(* alternative definitions for hwDefine, compileExp, compileProg and         *)
17(* compile are created.                                                      *)
18(*---------------------------------------------------------------------------*)
19
20(*---------------------------------------------------------------------------*)
21(* START BOILERPLATE                                                         *)
22(*---------------------------------------------------------------------------*)
23
24(*---------------------------------------------------------------------------*)
25(* Load theories                                                             *)
26(*---------------------------------------------------------------------------*)
27(*
28quietdec := true;
29loadPath :="dff" :: !loadPath;
30map load ["compile"];
31open compile;
32quietdec := false;
33*)
34
35(*---------------------------------------------------------------------------*)
36(* Boilerplate needed for compilation                                        *)
37(*---------------------------------------------------------------------------*)
38open HolKernel (* Parse boolLib bossLib compileTheory compile;*)
39open (*Tactic Tactical*) Parse boolLib bossLib pairSyntax composeTheory compileTheory
40infix THENL
41
42(*---------------------------------------------------------------------------*)
43(* Open theories                                                             *)
44(*---------------------------------------------------------------------------*)
45open compile;
46
47(*---------------------------------------------------------------------------*)
48(* END BOILERPLATE                                                           *)
49(*---------------------------------------------------------------------------*)
50
51
52(*---------------------------------------------------------------------------*)
53(* hwDefine2                                                                 *)
54(* Works just like hwDefine, except that it returns:                         *)
55(*  (|- eqns, |- ind, |- dev, |- def, |- tot )                               *)
56(* where the def is the function defined in terms of Seq, Par, Ite, Rec      *)
57(* and tot is the theorem that proves termination for recursive functions    *)
58(*                                                                           *)
59(* The results of hwDefine2 are stored in an reference hwDefineLib2.         *)
60(*---------------------------------------------------------------------------*)
61val hwDefineLib2 = ref([] : (thm * thm * thm * thm * thm) list);
62val terminationTheorems = ref([]: thm list);
63
64fun addTerminationTheorems th = (terminationTheorems := th::(!terminationTheorems));
65
66val _ = addTerminationTheorems wordsTheory.WORD_PRED_THM;
67
68fun hwDefine2 defq =
69 let val absyn0 = Parse.Absyn defq
70 in
71   case absyn0
72    of Absyn.APP(_,Absyn.APP(_,Absyn.IDENT(loc,"measuring"),def),f) =>
73        let val (deftm,names) = Defn.parse_absyn def
74            val hdeqn = hd (boolSyntax.strip_conj deftm)
75            val (l,r) = boolSyntax.dest_eq hdeqn
76            val domty = pairSyntax.list_mk_prod
77                          (map type_of (snd (boolSyntax.strip_comb l)))
78            val fty = Pretype.fromType (domty --> numSyntax.num)
79            val typedf = Parse.absyn_to_term
80                             (Parse.term_grammar())
81                             (Absyn.TYPED(loc,f,fty))
82            val defn = Defn.mk_defn (hd names) deftm
83            val tac = EXISTS_TAC (numSyntax.mk_cmeasure typedf)
84            THEN CONJ_TAC
85                       THENL [TotalDefn.WF_TAC,
86                              TotalDefn.TC_SIMP_TAC
87                              THEN (PROVE_TAC (!terminationTheorems))]
88            val (defth,ind) = Defn.tprove(defn, tac)
89            val (lt,rt) = boolSyntax.dest_eq(concl defth)
90            val (func,args) = dest_comb lt
91            val (test,t1,t2) = dest_cond rt
92            val fb = mk_pabs(args,test)
93            val f1 = mk_pabs(args,t1)
94            val f2 = mk_pabs(args,rand t2)
95            val totalth = prove
96                    (Term`TOTAL(^fb,^f1,^f2)`,
97                     RW_TAC std_ss [TOTAL_def,pairTheory.FORALL_PROD]
98                      THEN EXISTS_TAC typedf
99                      THEN TotalDefn.TC_SIMP_TAC
100                      THEN (PROVE_TAC (!terminationTheorems)))
101            val devth = PURE_REWRITE_RULE [GSYM DEV_IMP_def]
102                                 (RecCompileConvert defth totalth)
103        in
104         hwDefineLib2 := (defth,ind,devth,RecConvert defth totalth,totalth)
105                        :: !hwDefineLib2;
106         (defth,ind,devth,RecConvert defth totalth,totalth)
107        end
108     | otherwise =>
109        let val defth = SPEC_ALL(Define defq)
110            val _ =
111             if occurs_in
112                 (fst(strip_comb(lhs(concl defth))))
113                 (rhs(concl defth))
114              then (print "definition of ";
115                    print (fst(dest_const(fst(strip_comb(lhs(concl defth))))));
116                    print " is recusive; must have a measure";
117                    raise ERR "hwDefine2" "recursive definition without a measure")
118              else ()
119            val conv = Convert defth
120            val devth = PURE_REWRITE_RULE[GSYM DEV_IMP_def] (Compile conv)
121        in
122         hwDefineLib2 := (defth,boolTheory.TRUTH,devth,conv,boolTheory.TRUTH) ::
123                        !hwDefineLib2;
124         (defth,boolTheory.TRUTH,devth,conv,boolTheory.TRUTH)
125        end
126 end;
127
128
129(*---------------------------------------------------------------------------*)
130(* CompileExp2 exp                                                           *)
131(* -->                                                                       *)
132(* [REC assumption] |- <circuit> ===> DEV exp                                *)
133(*---------------------------------------------------------------------------*)
134fun CompileExp2 tm =
135 let fun is_atm th = term_to_string(rator(rand(rator(snd(dest_thm th)))))="DEV"
136     fun get_function th = rand(rand(rator(snd(dest_thm th))))
137     val _ = if not(fst(dest_type(type_of tm)) = "fun")
138              then (print_term tm; print "\n";
139                    print "Devices can only compute functions.\n";
140                    raise ERR "CompileExp2" "attempt to compile non-function")
141              else ()
142     val (opr,args) = dest_exp tm
143                      handle HOL_ERR _
144                      => raise ERR "CompileExp2" "bad expression"
145 in
146  if null args orelse is_combinational tm
147     then ISPEC ``DEV ^tm`` DEV_IMP_REFL
148  else
149    case fst(dest_const opr) of
150       "Seq" => let val th1 = CompileExp2(hd args)
151                    val th2 = CompileExp2(hd(tl args))
152                in
153                if is_atm th1 andalso is_atm th2 then
154                   ISPEC ``DEV (Seq ^ (get_function th1) ^ (get_function th2))``
155                   DEV_IMP_REFL
156                else if (is_atm th1) andalso not(is_atm th2) then
157                   MATCH_MP (ISPEC (get_function th1) PRECEDE_DEV) th2
158                else if not(is_atm th1) andalso (is_atm th2) then
159                   MATCH_MP (ISPEC (get_function th2) FOLLOW_DEV)  th1
160                else
161                   MATCH_MP SEQ_INTRO (CONJ th1 th2)
162                end
163     | "Par" => let val th1 = CompileExp2(hd args)
164                    val th2 = CompileExp2(hd(tl args))
165                in
166                if is_atm th1 andalso is_atm th2 then
167                   ISPEC ``DEV (Par ^ (get_function th1) ^ (get_function th2))``
168                   DEV_IMP_REFL
169                else
170                   MATCH_MP PAR_INTRO (CONJ th1 th2)
171                end
172     | "Ite" => let val th1 = CompileExp2(hd args)
173                    val th2 = CompileExp2(hd(tl args))
174                    val th3 = CompileExp2(hd(tl(tl args)))
175                in
176                if is_atm th1 andalso is_atm th2 andalso is_atm th3 then
177                   ISPEC ``DEV (Ite ^ (get_function th1) ^ (get_function th2)
178                                    ^ (get_function th3))``
179                   DEV_IMP_REFL
180                else
181                   MATCH_MP ITE_INTRO (LIST_CONJ [th1,th2,th3])
182                end
183     | "Rec" => let val thl = map (CompileExp2) args
184                    val var_list = map (rand o rand o concl o SPEC_ALL) thl
185                   in
186                    MATCH_MP
187                     (UNDISCH(SPEC_ALL(ISPECL var_list REC_INTRO)))
188                     (LIST_CONJ thl)
189                   end
190     | "Let" => let val th1 = REWR_CONV Let tm
191                    val th2 = CompileExp2(rhs(concl th1))
192                in
193                 CONV_RULE (RAND_CONV(RAND_CONV(REWR_CONV(SYM th1)))) th2
194                end
195     | _     => raise ERR "CompileExp2" "this shouldn't happen"
196 end;
197
198
199
200(*---------------------------------------------------------------------------*)
201(* CompileProg2 prog tm --> rewrite tm with prog, then compiles the result   *)
202(*---------------------------------------------------------------------------*)
203fun CompileProg2 prog tm =
204 let val expand_th = REWRITE_CONV prog tm
205     val compile_th = CompileExp2 (rhs(concl expand_th))
206 in
207  CONV_RULE (RAND_CONV(REWRITE_CONV[GSYM expand_th])) compile_th
208 end;
209
210(*---------------------------------------------------------------------------*)
211(* Compile2 (|- f args = bdy) = CompileProg [|- f args = bdy] ``f``          *)
212(*---------------------------------------------------------------------------*)
213fun Compile2 th =
214 let val (func,_) =
215      dest_eq(concl(SPEC_ALL th))
216      handle HOL_ERR _ => raise ERR "Compile2" "not an equation"
217     val _ = if not(is_const func)
218              then raise ERR "Compile" "rator of lhs not a constant"
219              else ()
220 in
221  CompileProg2 [th] func
222 end;
223
224
225
226(*---------------------------------------------------------------------------*)
227(* convertTotal                                                              *)
228(* Rewrites the terms in the proof of totality into ones defined in terms    *)
229(* Seq, Par, Ite, and Rec followed by inline expansions                      *)
230(* See the example at the end of the file.                                   *)
231(*---------------------------------------------------------------------------*)
232fun convertTotal defths totalth =
233    let val (f1,f2f3) = (dest_pair o snd o dest_comb o concl) totalth
234        val (f2,f3) = dest_pair f2f3
235
236        val f1th = Convert_CONV f1
237        val f2th = Convert_CONV f2
238        val f3th = Convert_CONV f3
239
240        fun inline th = let val tm = (snd o dest_eq o snd o dest_thm) th
241                        in REWRITE_CONV defths tm handle _ => REFL tm
242                        end
243
244        val f1' = (snd o dest_eq o concl) (inline f1th)
245        val f2' = (snd o dest_eq o concl) (inline f2th)
246        val f3' = (snd o dest_eq o concl) (inline f3th)
247
248   in prove(``^(concl totalth) = TOTAL( ^f1', ^f2' ,^f3')``,
249            RW_TAC std_ss (defths @ [f1th,f2th,f3th]))
250   end;
251
252
253(*---------------------------------------------------------------------------*)
254(* inlineCompile                                                             *)
255(* Performs an inline expansion before compiling                             *)
256(*---------------------------------------------------------------------------*)
257fun inlineCompile maintm defths totalths =
258       let fun findMain th = ((fst o dest_eq o concl) th) = maintm
259           val maindef = hd(filter findMain defths)
260           val auxdefths = filter (fn th => not(findMain th)) defths
261           val mainth = REFINE (DEPTHR ATM_REFINE)
262              (Compile2 ((CONV_RULE (REWRITE_CONV (Let::auxdefths))) maindef))
263       in prove(concl mainth,
264                METIS_TAC ((DISCH_ALL mainth) ::
265                           (totalths @ (map (convertTotal (Let::defths)) totalths))))
266       end;
267
268
269
270(*---------------------------------------------------------------------------
271  Example:
272
273  load "vsynth";open vsynth;
274
275
276  val (MULT_def,_,_,MULT_c,MULT_t) = hwDefine2
277      `(MULT(n:num,m:num,acc) = if n=0 then acc else MULT(n-1,m,m+acc))
278       measuring FST`;
279
280  val (FACT_def,_,_,FACT_c,FACT_t) = hwDefine2
281      `(FACT(n,acc) = if n=0 then acc else FACT(n-1,MULT(n,acc,0)))
282       measuring FST`;
283
284   (* Example of convertTotal *)
285   convertTotal [MULT_c] FACT_t;
286
287   (* Example of inlineCompile *)
288   val FACT_dev = inlineCompile ``FACT`` [FACT_c,MULT_c] [MULT_t,FACT_t];
289
290   val FACT_dev = inlineCompile ``FACT`` [FACT_c] [FACT_t];
291
292
293  val _ = AddBinop("cSUBT", (``UNCURRY $- : num#num->num`` ,"-"));
294  val _ = AddBinop("cEQ",  (``UNCURRY $= : num#num->bool``,"=="));
295(*  val _ = AddUnop("cMULT", (``MULT :num#num#num->num``,"*")); *)
296  val _ = add_combinational ["MULT"];
297
298   val FACT_cir = NEW_MAKE_CIRCUIT FACT_dev;
299
300---------------------------------------------------------------------------*)
301
302
303
304
305