1structure defunctionalize (* :> defunctionalize *) =
2struct
3
4open HolKernel Parse boolLib pairLib PairRules bossLib pairSyntax ParseDatatype TypeBase;
5
6(*-----------------------------------------------------------------------------------------*)
7(* We convert higher-order functions into equivalent first-order functions and hoist nested*)
8(* functions to the top level through a type based closure conversion. After this conver-  *)
9(* sion, no nested functions exist; and function call is made by dispatching on the closure*)
10(* tag followed by a top-level call.                                                       *)
11(* Function closures are represented as algebraic data types in a way that,for each func-  *)
12(* tion definition, a constructor taking the free variables of this function is created.   *)
13(* For each arrow type we create a dispatch function, which converts the definition of a   *)
14(* function of this arrow type into a closure constructor application.                     *)
15(* A nested function is hoisted to the top level with its free variables to be passed as   *)
16(* extra arguments. After that, the calling to the original function is replaced by a      *)
17(* calling to the relevant dispatch function passing a closure containing the values of    *)
18(* this function's free variables. The dispatch function examines the closure tag and      *)
19(* passes control to the appropriate hoisted function. Thus, higher order operations on    *)
20(* functions are replaced by equivalent operations on first order closure values.          *)
21(*-----------------------------------------------------------------------------------------*)
22
23(*-----------------------------------------------------------------------------------------*)
24(* Map and set operation functions.                                                        *)
25(*-----------------------------------------------------------------------------------------*)
26
27structure M = Binarymap
28structure S = Binaryset
29
30(*-----------------------------------------------------------------------------------------*)
31(* Auxiliary functions.                                                                    *)
32(*-----------------------------------------------------------------------------------------*)
33
34fun strOrder (s1:string,s2:string) =   (* order of strings *)
35  if s1 > s2 then GREATER
36    else if s1 = s2 then EQUAL
37    else LESS
38  ;
39
40fun tvarOrder (t1:term,t2:term) =      (* order of terms *)
41  strOrder (term_to_string t1, term_to_string t2)
42
43fun typeOrder (t1:hol_type,t2:hol_type) =  (* order of types *)
44  strOrder(type_to_string t1, type_to_string t2)
45  ;
46
47fun is_fun t =   (* the term is a function? *)
48  #1 (Type.dest_type (type_of t)) = "fun"
49  handle e => false
50
51fun FunName f =
52  #1 (strip_comb (#1 (dest_eq f)))
53
54(*-----------------------------------------------------------------------------------------*)
55(* Data structures.                                                                        *)
56(*-----------------------------------------------------------------------------------------*)
57
58val Lifted = ref (M.mkDict tvarOrder)       (* the definitions of those embedded functions that should be lifted *)
59                                            (* Format: [function's name |-> function's body] *)
60val ClosFunc = ref (M.mkDict typeOrder)     (* the types and the higher order functions associating with them *)
61                                            (* Format: [function's type |-> a set of function names] *)
62
63val ClosInfo = ref (M.mkDict typeOrder)     (* A mapping from a type to the information of its datatype representing a closure *)
64                                            (* Format: [type |-> datatype's info]  *)
65val ClosName = ref (M.mkDict typeOrder)     (* A mapping from a type to the name of its datatype representing a closure *)
66                                            (* Format: [type |-> datatype's name (a string)]  *)
67
68val HOFunc = ref (M.mkDict tvarOrder)       (* higher order functions *)
69                                            (* Format: [function's name |-> (new function's lhs, constructor)] *)
70
71fun cF() =
72  (M.listItems (!Lifted),
73   List.map (fn (tp, s) => (tp, S.listItems s)) (M.listItems (!ClosFunc)));
74
75(*-----------------------------------------------------------------------------------------*)
76(* Identify higher order functions (those functions used in arguments and returns;         *)
77(* then build datatype for them.                                                           *)
78(*-----------------------------------------------------------------------------------------*)
79
80fun record_f fname =          (* store the name of a higher order function *)
81  let val tp = type_of fname
82  in
83    case M.peek(!ClosFunc, tp) of
84         NONE =>
85             (* val _ = closure_index := !closure_index + 1; *)
86             ClosFunc := M.insert(!ClosFunc, tp, S.add(S.empty tvarOrder, fname))
87     |   SOME s =>
88             ClosFunc := M.insert(!ClosFunc, tp, S.add(s, fname))
89  end;
90
91fun identify_f e =        (* Identify higher order functions in an expression and store them into the ClosFunc *)
92 let
93   fun trav t =
94       if is_let t then
95           let val (v,M,N) = dest_plet t
96               val _ = (trav M; trav N)
97           in  if is_pabs M then        (* an embedded function, should be lifted *)
98                 Lifted := M.insert(!Lifted, N, M)
99               else
100                 ()
101           end
102       else if is_pair t then
103           let val (M,N) = dest_pair t
104           in  (trav M; trav N)
105           end
106       else if is_cond t then
107           let val (J,M,N) = dest_cond t
108           in  (trav M; trav N)
109           end
110       else if is_comb t then
111            let val (M,N) = dest_comb t
112            in  if is_fun N then
113                  (record_f t;
114                   if is_comb M then trav M else ()
115                  )
116                else
117                  if is_comb M then trav M else ()
118            end
119       else if is_pabs t then
120            let val (M,N) = dest_pabs t
121            in  trav N
122            end
123       else if is_fun t then
124            record_f t
125       else
126            ()
127  in
128     trav e
129  end
130
131fun identify_closure defs =     (* Identify higher order functions in a list of function definitions *)
132  let
133    fun mk_clos_data f =
134      let val (fdecl, fbody) = dest_eq f
135          val (fname, args) = dest_comb fdecl
136          val _ = Lifted := M.insert(!Lifted, fname, mk_pabs(args,fbody))
137      in
138         identify_f fbody
139      end
140  in
141    (ClosFunc := M.mkDict typeOrder;
142     Lifted := M.mkDict tvarOrder;
143     List.map (mk_clos_data o concl o SPEC_ALL) defs
144    )
145  end
146
147(*-----------------------------------------------------------------------------------------*)
148(* Build datatypes for closures.                                                           *)
149(*-----------------------------------------------------------------------------------------*)
150
151val closure_index = ref 0;
152val constructor_index = ref 0;
153
154fun register_type tyinfos_etc =            (* register the new datatype in HOL *)
155  let
156    val (tyinfos, etc) = unzip tyinfos_etc
157    val tyinfos = TypeBase.write tyinfos
158    val () = app computeLib.write_datatype_info tyinfos
159  in
160    Datatype.write_tyinfos tyinfos_etc
161  end
162
163fun build_type tp funcs =                 (* build a new datatype for a type *)
164  let
165
166    (* the arguments of a constructor, these arguments are the free variables of a function body *)
167    fun build_type_args fv =
168      if null fv then []
169      else if length fv = 1 then
170        [dTyop{Args = [], Thy = NONE,
171         Tyop = let val t = type_of (hd fv) in
172                  M.find(!ClosName, t)
173                  handle e => #1 (Type.dest_type t)
174                end}
175        ]
176      else
177        [dTyop{Args =
178                 List.map (fn arg =>
179                   dTyop{Args = [], Thy = NONE,
180                        Tyop = M.find(!ClosName, type_of arg)
181                               handle e => #1 (Type.dest_type (type_of arg))})
182                 fv,
183               Thy = NONE, Tyop = "prod"}
184        ]
185
186    val clos_name = (* the name of the datatype representing a closure for the inputting type *)
187        let val _ = closure_index := !closure_index + 1
188            val x = "clos" ^ Int.toString (!closure_index)
189            val _ = ClosName := M.insert(!ClosName, tp, x)
190        in  x
191        end
192
193    val clos_tp_info = (* the type information of the datatype *)
194         [(clos_name,
195           Constructors (
196             List.map
197              (fn lifted_f =>
198                let
199                  val _ = constructor_index := !constructor_index + 1
200                  val fv = free_vars (M.find (!Lifted, lifted_f))
201                  val args = build_type_args fv
202                in
203                  ("cons" ^ Int.toString(!constructor_index),
204                   build_type_args fv
205                  )
206                end
207               ) (S.listItems funcs)
208            )
209          )]
210
211    val new_clos_type = Datatype.primHol_datatype_from_astl (TypeBase.theTypeBase()) clos_tp_info;
212    val _ = register_type (#2 new_clos_type)
213    val _ = ClosInfo := M.insert(!ClosInfo, tp, #1 (hd (#2 new_clos_type)))
214  in
215    new_clos_type
216  end
217  ;
218
219fun build_types defs =           (* build datatypes for all higher order functions *)
220  (closure_index := 0;
221   constructor_index := 0;
222   identify_closure defs;
223   ClosName := M.mkDict typeOrder;
224   List.map (fn (tp, fs) => build_type tp fs) (M.listItems (!ClosFunc))
225  )
226
227(*-----------------------------------------------------------------------------------------*)
228(*  Conversions from original HOL types to closure types.                                  *)
229(*-----------------------------------------------------------------------------------------*)
230
231fun type2closure tp =    (* from an original type to its closure type *)
232  TypeBasePure.ty_of(M.find(!ClosInfo, tp))
233  handle _ => tp
234
235fun term2closure t =     (* get the closure type for a term *)
236  let val (name, tp) = dest_var t
237  in  mk_var(name, type2closure tp)
238  end
239  handle _ => t
240
241fun type2dispatch tp =   (* from an original type to its dispatch function *)
242  let val tinfo = M.find(!ClosInfo, tp)
243      val clos_type = TypeBasePure.ty_of tinfo
244      val f_index = String.extract (#1 (Type.dest_type clos_type), 4, NONE) (* take the value of n from "closn" *)
245      val (arg_type, return_type) = dom_rng tp
246      val df_var = mk_const("dispatch" ^ f_index,     (* the dispatch function has been defined *)
247                         mk_prod(clos_type, arg_type) --> return_type)
248          handle e => mk_var("dispatch" ^ f_index,    (* the dispatch function has not been defined *)
249                         mk_prod(clos_type, arg_type) --> return_type)
250  in
251     df_var
252  end
253
254(*-----------------------------------------------------------------------------------------*)
255(* Build dispatch functions.                                                               *)
256(* A dispatch function is in pattern-matching format.                                      *)
257(*-----------------------------------------------------------------------------------------*)
258
259fun mk_dispatch tp =
260  let
261    val tinfo = M.find(!ClosInfo, tp)
262    val clos_type = TypeBasePure.ty_of tinfo
263    (* val clos_case = TypeBasePure.case_const_of tinfo *)
264    val clos_consL = TypeBasePure.constructors_of tinfo
265    val f_index = String.extract (#1 (Type.dest_type clos_type), 4, NONE) (* take the value of n from "closn" *)
266
267    val funL = S.listItems (M.find(!ClosFunc, tp))
268    val (arg_type, return_type) = dom_rng tp
269
270    val df_name = "dispatch" ^ f_index
271    val df_type = mk_prod(clos_type, arg_type) --> return_type
272
273    val df_var = mk_var(df_name, df_type)
274(*
275    val _ = new_constant(df_name, df_type)
276    val df_const = mk_const(df_name, df_type)
277*)
278
279    fun mk_clause (fname, constructor) =        (* construct a dispatching clause for the pattern matching pattern *)
280        let val f_body = M.find(!Lifted, fname)
281            val (f_arg, body) = dest_pabs f_body
282            val fv = free_vars f_body
283            val fv' = List.map term2closure fv
284            val clos_arg = if null fv then constructor
285                           else mk_comb(constructor, list_mk_pair fv')
286            val arg = mk_pair(clos_arg, f_arg)
287            val lt = mk_comb(df_var, arg)
288            val rt = let val (old_name, ftype) = dest_const fname handle _ => dest_var fname
289                         val new_arg = if null fv then f_arg else mk_pair(list_mk_pair fv', f_arg)
290                         val new_name = old_name ^ "'"
291                         val new_f_type = (type_of new_arg) --> return_type
292                         (*
293                         val _ = new_constant(new_name , new_f_type)
294                         val new_fname = mk_const(new_name, new_f_type)
295                         *)
296                         val new_fname = mk_var(new_name, new_f_type)
297                         val new_f = mk_comb (new_fname, new_arg)
298                         val _ = HOFunc := M.insert(!HOFunc, fname, (new_f, clos_arg))
299                     in
300                         new_f
301                     end
302        in
303            mk_eq(lt,rt)
304        end
305
306    val clauses = list_mk_conj (List.map mk_clause (zip funL clos_consL))
307  in
308    clauses
309  end
310
311
312val Dispatched = ref (M.mkDict typeOrder)      (* definitions of dispatch functions *)
313                                               (* format: type |-> list of definitions *)
314
315fun build_dispatch () =           (* build dispatch functions for all introduced datatypes *)
316  (HOFunc := M.mkDict tvarOrder;
317   List.map (fn tp => Dispatched := M.insert(!Dispatched, tp, mk_dispatch tp))
318     (List.map fst (M.listItems (!ClosFunc)))
319  )
320
321(*-----------------------------------------------------------------------------------------*)
322(* convert_exp translates expressions;                                                     *)
323(* convert_fun translates functions;                                                       *)
324(* TS translates top-level declarations;                                                   *)
325(*-----------------------------------------------------------------------------------------*)
326
327val Redefined  = ref (M.mkDict tvarOrder)       (* definitions of the functions after closure conversion *)
328                                                (* format: function name |-> new definition              *)
329fun convert_exp t =
330  if is_let t then
331    let val (v,M,N) = dest_plet t in
332      if is_pabs M then       (* an embedded function *)
333        let
334          val (arg, body) = dest_pabs M
335          val _ = convert_fun (mk_eq(mk_comb(v, arg), body))
336(*          val M' =  #2 (M.find(!HOFunc, v))
337          val v' = mk_var (#1 (dest_var v), type_of M')
338        in
339            mk_plet(v', M', convert_exp N)
340        end
341*)
342        in
343          convert_exp N
344        end
345      else
346          mk_plet (v, convert_exp M, convert_exp N)
347    end
348  else if is_cond t then
349    let val (J,M,N) = dest_cond t in
350        mk_cond (J, convert_exp M, convert_exp N)
351    end
352  else if is_pair t then
353    let val (M,N) = dest_pair t in
354        mk_pair (convert_exp M, convert_exp N)
355    end
356  else if is_pabs t then
357    let val (M,N) = dest_pabs t in
358        mk_pabs (convert_exp M, convert_exp N)
359    end
360  else if is_comb t then
361    let val (M,N) = dest_comb t
362    in
363       if length (#2 (strip_comb t)) > 1 then t    (* binary expressions *)
364       else if is_fun M then    (* function application *)
365         if not (M.peek(!Redefined, M) = NONE) then     (* a pre-defined function *)
366            let val fname_var = #1 (M.find(!Redefined, M))
367                val (fname_str, f_tp) = dest_var fname_var
368                val fname_const = mk_const (fname_str, f_tp)
369                                  handle _ => fname_var    (* recursive function *)
370            in  mk_comb(fname_const, convert_exp N)
371            end
372         else
373            let
374              val tp = type_of M
375              val clos_var =
376                  mk_var(#1 (dest_const M) handle _ => #1 (dest_var M),
377                                    type2closure tp)
378              val closure = mk_pair(clos_var, convert_exp N)
379            in
380              mk_comb (type2dispatch(tp), closure)
381            end
382       else
383         mk_comb(convert_exp M, convert_exp N)
384    end
385    handle _ => t      (* not function application *)
386  else if is_fun t then
387    case M.peek(!HOFunc, t) of     (* Higher order function *)
388         NONE => mk_var(#1 (dest_const t) handle _ => #1 (dest_var t),
389                        type2closure (type_of t)) |
390         SOME (f_sig, constr) => constr
391  else t
392
393and
394
395convert_fun f =
396  let
397    val (fdecl, fbody) = dest_eq f
398    val (fname, args) = dest_comb fdecl
399    val (fname_str, f_tp) = dest_const fname handle _ => dest_var fname
400    val new_fname_str = fname_str ^ "'"
401  in
402    if M.peek(!HOFunc, fname) = NONE then  (* not higher order function *)
403      let val args1 = convert_exp args
404          val new_f_tp = type_of args1 --> type2closure (type_of fdecl)
405
406          val new_fname = mk_var(new_fname_str, new_f_tp)
407          val _ = Redefined := M.insert(!Redefined, fname, (new_fname, ``T``))
408
409          val fbody1 = convert_exp fbody
410          val new_f = mk_eq(mk_comb (new_fname, args1), fbody1)
411          val _ = Redefined := M.insert(!Redefined, fname, (new_fname, new_f))
412      in
413          new_f
414      end
415    else                                   (* a higher order function *)
416      let val lt = #1 (M.find (!HOFunc, fname))
417          val (new_fname, new_args) = dest_comb lt
418          val _ = Redefined := M.insert(!Redefined, fname, (new_fname, ``T``))
419          val fbody1 = convert_exp fbody
420          val new_f = mk_eq(lt, fbody1)
421          val _ = Redefined := M.insert(!Redefined, fname, (new_fname, new_f))
422      in
423          new_f
424      end
425  end
426  handle _ => f
427
428fun defunctionalize defs =
429  let
430    val _ = build_types defs
431    val _ = build_dispatch ()
432
433    fun process_type tp =
434      let val fs = S.listItems(M.find (!ClosFunc, tp))
435          val fs' = List.map (fn fname =>
436                       let val fbody = M.find(!Lifted, fname)
437                           val (args,body) = dest_pabs fbody
438                           val fdecl = mk_comb(fname, args)
439                       in  convert_fun (mk_eq(fdecl, body))
440                       end) fs
441          val spec = list_mk_conj (strip_conj (M.find(!Dispatched, tp)) @ fs')
442          val def = Defn.eqns_of (Defn.Hol_defn "x" `^spec`)
443       in
444          def
445       end
446
447    val _ = Redefined := M.mkDict tvarOrder
448    val dispatch_spec = List.map process_type (List.map fst (M.listItems (!ClosFunc)))
449
450    val remaining_funcs =
451          List.filter (fn f => M.peek(!Redefined, #1 (dest_comb (lhs f))) = NONE)
452          (List.map (concl o SPEC_ALL) defs)
453    val new_spec = List.map (fn x => let val f = convert_fun x in Define `^f` end) remaining_funcs
454
455  in
456    (hd dispatch_spec) @ new_spec
457  end
458
459(*-----------------------------------------------------------------------------------------*)
460(* Redefine functions in HOL and prove the correctness of the translation.                 *)
461(*-----------------------------------------------------------------------------------------*)
462
463(* Convert function arguments to closure arguments                                         *)
464
465fun process_args args =
466  if is_pair args then
467    let val (arg1, arg2) = dest_pair args
468        val (assms1, arg1') = process_args arg1
469        val (assms2, arg2') = process_args arg2
470    in
471        (assms1 @ assms2, mk_pair(arg1', arg2'))
472    end
473  else
474    let
475      val (arg_str, arg_tp) = dest_var args
476      val new_arg_str = arg_str ^ "'"
477
478      val new_args = if is_fun args then mk_var (new_arg_str, type2closure arg_tp)
479                    else args
480      val assms = if is_fun args then
481                    let val input = mk_var("i", #1 (dom_rng arg_tp)) in
482                        [mk_eq(mk_comb(args, input), mk_comb (type2dispatch arg_tp, mk_pair(new_args, input)))]
483                    end
484                else []
485    in
486      (assms, new_args)
487    end
488
489(* Build the equivalence statement for a function.                                       *)
490
491fun var2const t =
492  if is_comb t then
493    let val (M,N) = dest_comb t
494    in mk_comb(var2const M, N)
495    end
496  else
497    let val (v, tp) = dest_var t
498    in mk_const(v, tp)
499    end
500
501fun build_judgement f =
502  let
503    val (fdecl, fbody) = dest_eq f
504    val (fname, args) = dest_comb fdecl
505    val (assums, new_args) = process_args args
506    val new_fname = var2const (#1 (M.find (!Redefined, fname))) handle _ => fname
507    val new_fdecl = mk_comb (new_fname, new_args)
508    val x = if not (is_fun fdecl) then mk_eq(fdecl, new_fdecl)
509            else let val ftp = type_of fdecl
510                     val input = mk_var("m", #1 (dom_rng ftp))
511                     val new_fdecl' = mk_comb (type2dispatch ftp, mk_pair(new_fdecl, input))
512                 in
513                     mk_eq(mk_comb(fdecl, input), new_fdecl')
514                 end
515    val x' = gen_all x
516    val judgement = if null assums then x'
517                    else mk_imp(list_mk_conj assums, x')
518  in
519    judgement
520  end
521
522(*
523  (build_judgement o concl o SPEC_ALL) (List.nth(defs,2))
524  val def = List.nth(defs,2)
525  val f = concl (SPEC_ALL def)
526*)
527
528fun elim_hof defs =
529  let
530    val newdefs = defunctionalize defs
531    val judgements = List.map (build_judgement o concl o SPEC_ALL) defs
532  in
533    (newdefs, judgements)
534  end
535
536(*-----------------------------------------------------------------------------------------*)
537(* Example 1.                                                                              *)
538(*-----------------------------------------------------------------------------------------*)
539
540val empty_def = Define `
541  empty (x : num) = F`;
542
543val member_def = Define `
544  member (s : num -> bool, x : num) = s x`;
545
546val insert_def = Define `
547  insert(s : num -> bool, x : num) =
548    let s1 y = if x = y then T else s x
549    in s1
550  `;
551
552val upto_def = Define `
553  upto(n : num) =
554    if n = 0 then empty else insert(upto(n-1),n)
555  `;
556
557val main_def = Define `
558  main (n : num) = (upto n, 100)`;
559
560val defs = [empty_def, member_def, insert_def, upto_def];
561
562(*
563val (newdefs, judgements) = elim_hof defs;
564
565val defs1 = List.take(defs, 3);
566val defs2 = List.drop(defs, 3);
567val newdefs1 = List.take(newdefs, 6);
568val newdefs2 = List.drop(newdefs, 6);
569
570set_goal ([], List.nth(judgements, 3))      (* set_goal ([], ``!m. dispatch1(upto' n, j) = upto n m``)  *)
571
572Induct_on `n` THENL [
573  ONCE_REWRITE_TAC (defs2 @ newdefs2) THEN
574    expandf (RW_TAC arith_ss (defs1 @ newdefs1)),
575
576  ONCE_REWRITE_TAC (defs2 @ newdefs2) THEN
577    RW_TAC arith_ss [LET_THM] THEN
578    expandf (RW_TAC arith_ss (defs1 @ newdefs1)) THEN
579    Q.UNABBREV_TAC `s1` THEN
580    RW_TAC std_ss []
581]
582
583*)
584
585(*-----------------------------------------------------------------------------------------*)
586(* Example 2.                                                                              *)
587(*-----------------------------------------------------------------------------------------*)
588
589val f_def = Define `
590  f x = x * 2 < x + 10`;
591
592val g_def = Define `
593  g (s, x) =
594    let h1 = \y. y + x in
595        if s x then h1 else let h2 = \y. h1 y * x in h2`;
596
597val k_def = Define `
598  k x = if x = 0 then 1 else (g (f,x)) (k(x-1))`
599
600val defs = [f_def, g_def, k_def];
601
602(*-----------------------------------------------------------------------------------------*)
603
604end (* struct *)
605
606