1structure regAlloc  :> regAlloc  =
2struct
3
4(* app load ["NormalTheory", "Normal", "basic"] *)
5
6open HolKernel Parse boolLib bossLib;
7open pairLib pairSyntax PairRules NormalTheory Normal basic;
8
9val atom_tm = prim_mk_const{Name="atom",Thy="Normal"}
10fun mk_atom tm = mk_comb (inst [alpha |-> type_of tm] atom_tm,tm)
11
12(* --------------------------------------------------------------------*)
13(* --------------------------------------------------------------------*)
14
15
16structure M = Binarymap
17structure S = Binaryset
18val VarType = ref (Type `:word32`) (* numSyntax.num *)
19
20(* --------------------------------------------------------------------*)
21(* Datatypes                                                           *)
22(* --------------------------------------------------------------------*)
23
24datatype alloc_result =
25    Alloc of term (* allocated register *)
26  | Spill of term (* spilled variable *)
27
28(* --------------------------------------------------------------------*)
29(* Configurable setting                                                *)
30(* The "DEBUG" controls whether debugging information should be print  *)
31(*    out or not.                                                      *)
32(* The "numRegs" stores how many registers are available.              *)
33(* By default the register set contains {r0,r1,...}                    *)
34(* Users can customize this set by specifying the "regL" or modify the *)
35(*   "mk_regs()" function.                                             *)
36(* --------------------------------------------------------------------*)
37
38val DEBUG = ref true;
39
40val numRegs = ref 11;
41
42fun mk_regs() =
43  let
44    fun f n = if n = !numRegs then []
45               else mk_var("r" ^ Int.toString(n) , !VarType) :: f(n+1)
46  in f 0
47  end
48
49val regL = ref (mk_regs());
50
51(* --------------------------------------------------------------------*)
52(* Pre-defined variables and functions                                 *)
53(* --------------------------------------------------------------------*)
54
55fun is_reg x = (String.sub (term_to_string x,0) = #"r")
56fun is_mem x = (String.sub (term_to_string x,0) = #"m")
57
58fun tvarOrder (t1:term,t2:term) =
59  let val (s1,s2) = (term_to_string t1, term_to_string t2) in
60    if s1 > s2 then GREATER
61    else if s1 = s2 then EQUAL
62    else LESS
63  end;
64
65(* Is an expression a function application? *)
66
67fun is_fun exp =
68  (* is_var exp andalso *)
69  (#1 (dest_type (type_of exp)) = "fun")
70  handle HOL_ERR _ => false;
71
72(* --------------------------------------------------------------------*)
73(* make variables                                                      *)
74(* mvar -- memory variables                                            *)
75(* tvar -- tempory variables; used for spilling.                       *)
76(* --------------------------------------------------------------------*)
77
78fun num2name stem i = stem ^ Lib.int_to_string i
79val mStrm = ref (Lib.mk_istream (fn x => x+1) 0 (num2name "m"))
80val tStrm = ref (Lib.mk_istream (fn x => x+1) 0 (num2name "t"))
81
82fun reset_mvar () = mStrm := Lib.mk_istream (fn x => x+1) 0 (num2name "m")
83fun next_mvar () = mk_var (state(next (!mStrm)), !VarType)
84fun cur_mvar() = state(!mStrm)
85fun reset_tvar () = tStrm := Lib.mk_istream (fn x => x+1) 0 (num2name "t")
86fun next_tvar () = mk_var (state(next (!tStrm)), !VarType)
87
88(* --------------------------------------------------------------------*)
89(* The variables that have not been allocated                          *)
90(* --------------------------------------------------------------------*)
91
92fun fv exp =
93  let val xs = free_vars exp
94      val xs' = List.filter (fn x => not (is_mem x orelse is_reg x)) xs
95  in  xs'
96  end
97
98(* --------------------------------------------------------------------*)
99(* Attempt to allocate a register                                      *)
100(* "regenv" -- the current allocation scheme;                          *)
101(* "cont" -- the continuation that contains live variables;            *)
102(* "x" -- the variable to be allocated.                                *)
103(* --------------------------------------------------------------------*)
104
105(* allocate a register or spill a variable *)
106
107fun alloc_one cont regenv x =
108 let val allregs = !regL
109 in
110 if is_reg x then Alloc(x) else
111 let val free = fv cont
112     val live = (* set of registers already assigned to live var in cont *)
113       List.foldl (fn (y,live) =>
114         if is_reg y then (* registers that have been used *)
115           S.add (live,y)
116         else (* registers (and memory) already assigned *)
117           S.add (live, M.find (regenv, y)) handle NotFound => live)
118       (S.empty tvarOrder) free
119      val candidate = List.find (fn r => not (S.member(live,r))) allregs
120                     (* the first available register *)
121 in
122  case candidate
123   of SOME r => Alloc(r)
124    | NONE => let
125       fun OK y =
126         let val k = M.find(regenv,y)
127         in not(is_reg y) andalso (* exclude vars in registers        *)
128            not(is_mem k) andalso (* exclude spilled vars (in memory) *)
129            Lib.mem k allregs     (* ensure it's assigned a register  *)
130         end handle NotFound => false
131       val y = valOf(List.find OK (List.rev free))
132       val _ = if !DEBUG then
133                 (print ("register allocation failed for "^
134                         term_to_string x^"; ");
135                  print ("spilling "^term_to_string y^
136                         " from " ^term_to_string (M.find(regenv,y))^"\n"))
137               else ()
138     in
139        Spill(y)
140     end
141 end
142end;
143
144(* --------------------------------------------------------------------*)
145(* For Tuple and Function calls                                        *)
146(* Rules for replacing variables with registers or memory slots.       *)
147(* --------------------------------------------------------------------*)
148
149fun find_pos (regenv,x) =
150 if is_reg x orelse is_word_literal x orelse
151    numSyntax.is_numeral x orelse is_const x then x
152 else
153   let val v = M.find (regenv, x)
154   in if is_reg v orelse is_mem v then v
155      else raise (Fail "find_pos...")
156   end
157  ;
158
159fun tuple_subst_rules xs regenv =
160    List.foldl
161      (fn (x,ys) =>
162        if is_var x then (x |-> find_pos (regenv, x)) :: ys
163        else ys)
164      []
165      xs;
166
167(* --------------------------------------------------------------------*)
168(* Attempt to allocate registers for a tuple.                          *)
169(* If the tuple contains only one variable, we always allocate a       *)
170(* register for it by spilling another variable;                       *)
171(* on the other hand, if the tuple contains several variables, and some*)
172(* of them couldn't be assigned registers, we always spill them into   *)
173(* the memory (this can be optimized a little).                        *)
174(* --------------------------------------------------------------------*)
175
176fun alloc cont regenv x =
177 if is_pair x then
178  let val xs = strip_pair x
179      val (_,env) = List.foldl (fn (t,(cont',env')) =>
180                 case alloc_one cont' env' t
181             of Alloc(r) => (* assign a register *)
182                            (mk_pair(cont', t), M.insert(env', t, r))
183              | Spill(n) => (* assign a memory slot *)
184                   (mk_pair(cont',t), M.insert(env', t, next_mvar()))
185          ) (cont,regenv) xs
186  in
187     (Alloc(subst (tuple_subst_rules xs env) x), env)
188  end
189  else (* single variable *)
190    case alloc_one cont regenv x
191     of Alloc(r) => (Alloc(r), M.insert(regenv, x, r))
192      | Spill(n) => (Spill(n), regenv)
193  ;
194
195(* --------------------------------------------------------------------*)
196(* Register Allocation                                                 *)
197(* Auxiliary data structures and functions                             *)
198(* --------------------------------------------------------------------*)
199
200datatype AllocResult =
201    NoSpill of term * (term, term) M.dict (* new regenv *)
202  | ToSpill of term * term list           (* spilled variables *)
203
204exception regAlloc of string
205exception NoReg of term
206
207fun find_reg (regenv,x) =
208  if is_reg x then x
209  else if is_word_literal x orelse
210          numSyntax.is_numeral x orelse is_const x then x
211  else
212    let val v = M.find (regenv, x)
213    in if is_reg v then v else raise (NoReg x)
214    end
215  handle NotFound => raise (NoReg x);
216
217fun mk_subst_rules xs regenv = (* replace variables with registers *)
218    List.foldl
219      (fn (x,ys) =>
220        if is_var x then (x |-> find_reg (regenv, x)) :: ys
221        else ys)
222      []
223      xs;
224
225(* let x = (let v = M in N) in e2   --> let x = M in let v = N in e2 *)
226
227fun concat e1 x e2 =
228  if is_let e1
229  then let val (v,M,N) = dest_plet e1 in mk_plet (v, M, concat N x e2) end
230  else mk_plet (x, e1, e2);
231
232(* Add a mapping into the environment *)
233
234fun add x r regenv = if is_reg x then regenv else M.insert(regenv, x, r);
235
236
237(* --------------------------------------------------------------------*)
238(* Spill to memory                                                     *)
239(* Restore from memory                                                 *)
240(* --------------------------------------------------------------------*)
241
242fun save x exp regenv =
243 let val v = next_mvar()
244     val regenv1 = M.insert(regenv, x, v)    (* x is spilled to memory *)
245     val _ = if !DEBUG then
246             print ("saving "^term_to_string x^
247                    " to " ^ term_to_string v ^"\n")
248             else ()
249  in    (* let m[.] = r[.] in exp[m[.]] *)
250    (mk_plet (v, M.find(regenv,x), exp), regenv1)
251  end
252
253fun restore x exp regenv =
254  let val v = next_tvar ()
255      val _ = if !DEBUG then
256                print ("restoring "^term_to_string x^
257                       " from "^term_to_string (M.find(regenv,x))^"\n")
258               else ()
259  in mk_plet (v, M.find (regenv,x),
260              subst_exp [x |-> v] exp)  (* let v = m[.] in exp[x <- v] *)
261  end
262
263(* --------------------------------------------------------------------*)
264(* Register Allocation                                                 *)
265(* Main algorithm                                                      *)
266(* --------------------------------------------------------------------*)
267
268(* g' is for allocating registers in expressions *)
269
270fun g' dest cont regenv exp =
271 if is_word_literal exp orelse numSyntax.is_numeral exp orelse
272    is_const exp orelse is_fun exp orelse is_mem exp
273   then NoSpill (exp, regenv) else
274 if is_var exp
275    then NoSpill (subst (mk_subst_rules [exp] regenv) exp, regenv) else
276 if is_cond exp
277    then let val (c,e1,e2) = dest_cond exp
278             val (cmpop,ds) = strip_comb c
279             val (d0,d1) = (hd ds, hd (tl ds))
280             (* val (ds0,ds1) = (#2 (strip_comb d0), #2 (strip_comb d1))
281                val c' = list_mk_comb
282                         (cmpop, [subst (mk_subst_rules ds0 regenv) d0,
283                                  subst (mk_subst_rules ds1 regenv) d1])
284             *)
285             val c' = list_mk_comb (cmpop, [find_reg(regenv,d0),
286                                            find_reg(regenv,d1)])
287             fun f e1 e2 = mk_cond(c', e1, e2)
288         in
289            g'_if dest cont regenv exp f e1 e2
290        end else
291 if is_pair exp
292    then NoSpill (subst (tuple_subst_rules (strip_pair exp) regenv) exp,
293                  regenv) else
294 if is_let exp
295    then g dest cont regenv exp else
296 if is_comb exp
297    then
298      let val (opr,xs) = strip_comb exp
299      in if is_binop opr (* includes is_cmpop opr orelse is_relop opr *)
300            then NoSpill (subst (mk_subst_rules xs regenv) exp, regenv) else
301         if is_fun opr   (* function application *)
302            then
303              let val ys = List.foldl (fn (t,ls) => strip_pair t @ ls) [] xs
304              in NoSpill (subst (tuple_subst_rules ys regenv) exp, regenv)
305              end
306            else raise regAlloc
307                         ("g': this case hasn't been implemented -- "^
308                          term_to_string exp)
309      end
310 else NoSpill (exp, regenv)
311
312and
313
314(*---------------------------------------------------------------------------*)
315(* When g' accesses a spilled variable, a NoReg exception is raised since    *)
316(* we cannot find the variable in regenv. Thus a restore is required, which  *)
317(* restores the value of a variable from the memory to a register.           *)
318(*---------------------------------------------------------------------------*)
319
320g'_and_restore dest cont regenv exp =
321  g' dest cont regenv exp
322  handle NoReg x (* x is stored in the memory, assign it a register *)
323  => g dest cont regenv (restore x exp regenv)
324    (* restore the spilled x from the memory *)
325
326and
327
328(*---------------------------------------------------------------------------*)
329(* g deals with the let v = ... in ... structure.                            *)
330(*---------------------------------------------------------------------------*)
331
332g dest cont regenv exp =
333 if not (is_let exp) then
334     g'_and_restore dest cont regenv exp
335 else (*  exp = LET (\v. N) M  *)
336 let val (x,M,N) = dest_plet exp
337     val cont' = mk_pair(N,cont) (* concat N dest cont *)
338 in
339  if is_mem x orelse is_reg x
340  then case g x cont regenv N
341        of ToSpill(e2,ys) => ToSpill (exp, ys)
342         | NoSpill(e2,regenv2) => NoSpill(mk_plet(x,M,e2),regenv2)
343  else
344    case g'_and_restore x cont' regenv M
345     of ToSpill(exp1, ys) => ToSpill(concat exp1 x N, ys)
346      | NoSpill(exp1, regenv1) =>
347          case (alloc cont' regenv1 x)
348           of (Spill(y), env) =>
349                let val op_vars = free_vars N
350                in (if Lib.mem y op_vars then
351                      let val opt =
352                            List.find (fn (key,value) =>
353                                        is_reg value andalso
354                                        not (Lib.mem key op_vars))
355                                      (M.listItems env)
356                          val (thing,_) = valOf opt
357                      in
358                        ToSpill (exp, [thing])
359                      end
360                    else ToSpill (exp, [y])
361                   ) handle _ => ToSpill (exp, [y])
362                end
363            | (Alloc(r), env) =>
364                 let val (exp2,env') = g_repeat dest cont env N
365                 in NoSpill(concat exp1 r exp2, env')
366                 end
367     end
368and
369g_repeat dest cont regenv exp =   (* early spilling *)
370  case g dest cont regenv exp
371   of NoSpill(exp', regenv') => (exp', regenv')
372    | ToSpill(exp, xs) =>
373        let val (exp,regenv) =
374              List.foldl (fn (x,(exp,env)) => save x exp env)
375               (exp, regenv)
376               xs
377        in
378          g_repeat dest cont regenv exp
379        end
380and
381g'_if dest cont regenv exp constr e1 e2 =
382 let val (e1', regenv1) = g_repeat dest cont regenv e1
383     val (e2', regenv2) = g_repeat dest cont regenv e2
384     val regenv' =
385       List.foldl (fn ((key,value), m) => M.insert(m,key,value)) regenv1
386                  (M.listItems regenv2)
387  in
388    NoSpill(constr e1' e2', regenv')
389  end
390;
391
392(*---------------------------------------------------------------------------
393 The following code is needed only for non-SSA expressions
394    val regenv' =
395      List.foldl
396      (fn (x,regenv') =>
397          if is_reg x then regenv' else
398          let val r1 = M.find(regenv1,x)
399              val r2 = M.find(regenv2,x) in
400            if r1 <> r2 then regenv' else
401            M.insert(regenv',x,r1)
402          end
403        handle NotFound => regenv')
404      (M.mkDict tvarOrder)
405      (fv cont)
406  in
407    case
408      List.filter
409      (fn x => not (is_reg x) andalso
410               (M.peek(regenv',x) = NONE) andalso x <> dest)
411      (fv cont)
412    of
413      [] => NoSpill(constr e1' e2', regenv')
414    | xs => ToSpill(exp, xs)
415  end
416  ---------------------------------------------------------------------------*)
417
418(* --------------------------------------------------------------------*)
419(* Reduce the number of memory slots by reusing memory variables.      *)
420(* The mechanism is similar to that of allocating registers; but we    *)
421(* assume an unlimited number of memory slots.                         *)
422(* --------------------------------------------------------------------*)
423
424(* The first available memory slot that doesn't conflict with live "slots" *)
425
426fun first_avail_slot env cont =
427 let fun indexOfSlot s =
428        valOf(Int.fromString(String.substring(s, 1, String.size s - 1)))
429     val live = List.foldl (fn (t,ts) =>
430                             let val x = M.find(env,t)
431                             in if is_mem x then x::ts else ts
432                             end handle _ => ts) [] (free_vars cont)
433     val max_slot = indexOfSlot(state(!mStrm))
434     fun candidate i =  (* the first available register *)
435        if i > max_slot then next_mvar()
436        else
437          let val cur_slot = mk_var("m" ^ Int.toString i, !VarType)
438          in if Lib.mem cur_slot live then candidate (i+1)
439             else cur_slot
440          end
441 in
442    candidate 1
443 end
444
445(* reuse memory slots that will not be "live" any more *)
446
447fun alloc_mem (args,body) =
448 let val body' = rhs (concl
449                   (QCONV(SIMP_CONV bool_ss [ELIM_USELESS_LET]) body))
450     val (save,loc) = (mk_const("save", (!VarType) --> (!VarType)),
451                       mk_const("loc", (!VarType) --> (!VarType)))
452     fun trav t env cont =
453       if is_let t then
454           let val (v,M,N) = dest_plet t
455           in
456              if is_mem v then
457                let val v' = first_avail_slot env (mk_pair(N,cont))
458                    val M' = mk_comb (save, M)
459                    val env' = M.insert (env, v, v')
460                    val N' = trav N env' cont
461                in
462                    mk_plet (v', M', N')
463                end
464              else if is_var M andalso is_mem M then
465                let val M' = mk_comb (loc, M.find (env, M))
466                    val N' = trav N env cont
467                in
468                    mk_plet (v, M', N')
469                end
470              else if is_pair v then
471                let val cont' = mk_pair(N,cont)
472                    val (v',env') =
473                        List.foldl (fn (x, (w, env')) =>
474                          if not (is_mem x) then (w,env')
475                          else
476                            let val x' = first_avail_slot env' cont'
477                            in (subst [x |-> x'] w, M.insert(env', x, x'))
478                            end) (v,env) (strip_pair v)
479                in
480                    mk_plet (v', trav M env' cont', trav N env' cont)
481                end
482              else mk_plet (v, trav M env cont, trav N env cont)
483            end
484       else if is_comb t then
485            let val (M1,M2) = dest_comb t
486                val M1' = trav M1 env cont
487                val M2' = trav M2 env cont
488            in mk_comb(M1',M2')
489            end
490       else if is_mem t then
491           mk_comb (loc, M.find (env, t) handle _ => t)
492       else t
493
494     val memenv = List.foldl (fn (t, m) =>
495                    if is_mem t then (next_mvar(); M.insert(m,t,t)) else m)
496                  (M.mkDict tvarOrder) (strip_pair args)
497     fun move_mindex i =
498       if i = M.numItems memenv then ()
499       else (next_mvar();move_mindex (i + 1))
500  in
501     (reset_mvar ();
502      move_mindex 0;
503      trav body' memenv ``()``
504     )
505  end;
506
507(* --------------------------------------------------------------------*)
508(* Refinement for Tail Recursion.                                      *)
509(* --------------------------------------------------------------------*)
510
511fun parallel_move dst src exp =
512 let val sane = ref true
513   val (dstL, srcL) = (strip_pair dst, strip_pair src)
514   val (tmpL,_) = List.foldl (fn (_,(l,i)) =>
515                                (l @ [mk_var("m" ^ Int.toString i, !VarType)],
516                                 i+1))
517                    ([],1) (hd srcL :: srcL)
518   val spotL = !regL @ tmpL
519   fun get_avail_spot l =
520        let val l' = l @ free_vars exp
521        in valOf (List.find (fn x => List.all (not o equal x) l') spotL)
522        end
523
524   val transfer_r =
525      let val r = get_avail_spot (srcL @ dstL)
526      in if is_mem r then hd spotL else r
527      end
528
529   fun move (d,s,r) t =
530    if is_mem s andalso is_mem d
531    then let val real_r =
532               if is_mem r then (sane := false; transfer_r) else r
533         in
534            mk_plet(real_r, mk_atom s, mk_plet(d, mk_atom real_r, t))
535         end
536    else mk_plet(d, mk_atom s, t)
537
538   fun moves (t,[],[]) = t
539     | moves (t, d::dL, s::sL) =
540         if d = s then moves(t, dL, sL)
541         else
542          if List.exists (equal d) sL (* d is still live *)
543          then let val spot = get_avail_spot (d :: s :: sL)
544                   val transfer_r = get_avail_spot (spot :: d :: s :: sL)
545                   val tmp = List.map
546                              (fn x => if x = d then spot else
547                                       if x = s then d else x)
548                              sL
549                   val t' = move (d,s,transfer_r) (moves(t, dL, tmp))
550               in (* d is stored in spot in the beginning *)
551                 move (spot,d, transfer_r) t'
552               end
553          else
554             move(d,s,get_avail_spot (d :: s :: sL))
555                 (moves(t, dL, List.map (fn x => if x = s then d else x) sL))
556   val exp' =
557     if !sane then moves(exp, dstL, srcL)
558     else (* store value of transferring register in memory if needed *)
559       let val sdL = srcL @ dstL
560           val opt = List.find (fn x => is_mem x andalso
561                                        List.all (not o equal x) sdL) spotL
562           val tmp_m = valOf opt
563       in (* load the transferring register *)
564          mk_plet(transfer_r, mk_atom tmp_m,
565                  moves(mk_plet(tmp_m, mk_atom transfer_r, exp), dstL, srcL))
566                    (* store the transferring register *)
567       end
568 in
569    exp'
570 end
571
572fun refine_tail_recursion def =
573 let val (fname, fbody) = dest_eq (concl def)
574   val (args,body) = dest_pabs fbody
575   val lem = if is_let body
576              then CONV_RULE (RHS_CONV(SIMP_CONV pure_ss [Once LET_THM])) def
577              else def
578   val body = #2 (dest_pabs (rhs (concl lem)))
579   val is_recursive = Lib.can (find_term (equal fname)) body
580   val lems = ref []
581 in
582   if not is_recursive then def else
583   let fun trav t =
584          if is_let t then
585            let val (v,M,N) = dest_plet t
586            in if is_comb M andalso #1(strip_comb M) = fname then
587                 let val _ = lems := PBETA_RULE
588                                       (SIMP_CONV bool_ss [Once LET_THM] t)
589                                     :: !lems
590                     val src = #2(dest_comb M)
591                     val exp = subst [src |-> args] (subst [v |-> M] N)
592                 in
593                   parallel_move args src exp
594                 end
595               else
596                 mk_plet (v, M, trav N)
597            end
598          else if is_cond t then
599            let val (v,M,N) = dest_cond t
600            in mk_cond(v, trav M, trav N)
601            end
602          else if is_comb t then
603            let val (M,N) = dest_comb t
604                val M' = trav M
605                val N' = trav N
606            in mk_comb(M',N')
607            end
608          else if is_pabs t then
609            let val (v,M) = dest_pabs t
610            in mk_pabs(v,trav M)
611            end
612          else
613            t
614
615        val body' = trav body
616        val th = prove(mk_eq(fname, mk_pabs(args,body')),
617                   REWRITE_TAC [Once lem] THEN
618                   SIMP_TAC bool_ss ([FUN_EQ_THM, LET_ATOM, ATOM_ID] @ (!lems))
619                 ) handle _ =>
620           (print "failed to convert a tail recursion into expected format!";
621            def)
622   in
623      REWRITE_RULE [ATOM_ID] th
624   end
625 end
626
627(* --------------------------------------------------------------------------*)
628(* Register Allocation                                                       *)
629(* --------------------------------------------------------------------------*)
630
631fun reset () = (regL := mk_regs(); reset_mvar(); reset_tvar() );
632
633(*---------------------------------------------------------------------------*)
634(* Assign registers to inputs; memory slots will be used when there are too  *)
635(* many paramenters                                                          *)
636(*---------------------------------------------------------------------------*)
637
638fun args_env args =
639 let val argL = strip_pair args
640     fun assgn_v (v,(i,regenv)) =
641       (i+1,
642        M.insert (regenv, v,
643                  if i < !numRegs then mk_var ("r" ^ Int.toString i, !VarType)
644                  else next_mvar()))
645   in
646       #2 (List.foldl assgn_v (0, M.mkDict tvarOrder) argL)
647   end
648
649(*---------------------------------------------------------------------------*)
650(* step1: configure the environment;                                         *)
651(* step2: obtain an allocation scheme by term rewriting;                     *)
652(* step3: prove the correctness of the scheme by showing the result is       *)
653(* alpha-equivalent to the source program.                                   *)
654(*---------------------------------------------------------------------------*)
655
656(*
657fun preset_regs tm =
658 let fun find tm acc =
659      if is_let tm
660*)
661
662fun reg_alloc def =
663 let
664   val (fname, fbody) = dest_eq (concl def)
665   val (args,body) = dest_pabs fbody handle _ => (Term`()`, fbody)  (* no argument *)
666   val (sane,var_type) = pre_check(args,body)
667 in
668  if sane then
669   let (* set the variable type according to the given program *)
670      val _ = (VarType := var_type; reset())
671      val regenv = args_env args
672      val args1 = subst (tuple_subst_rules (strip_pair args) regenv) args
673      val dest = hd (!regL)
674      val cont = dest
675      val body1 = #1 (g_repeat dest cont regenv body)
676      val body2 = alloc_mem(args1,body1)
677      val tha = QCONV(SIMP_CONV pure_ss [LET_SAVE, LET_LOC, loc_def]) body2
678      val th1 = SYM (PBETA_RULE tha)
679      val body3 = lhs (concl th1)
680      val fbody' = if args1 = Term`()` then body3
681                   else mk_pabs (args1,body3)
682      val th2 = ALPHA fbody fbody'
683                handle _ => prove (mk_eq(fbody, fbody'), SIMP_TAC std_ss [LET_THM])
684                handle e => (print "the allocation is incomplete or incorrect";
685                             Raise e)
686      val th3 = CONV_RULE (RHS_CONV (ONCE_REWRITE_CONV [th1])) th2
687      val th4 = TRANS def th3
688      val th5 = (PBETA_RULE o REWRITE_RULE [save_def, loc_def]) th4
689      val th6 = refine_tail_recursion th5 handle _ => th5
690   in
691     th6
692   end
693  else
694   ( print("The source program is invalid!\
695         \ (e.g. all variables are not of the same type)");
696     def
697   )
698  end;
699
700
701end
702