1structure monomorphisation (* :> monomorphisation *) =
2struct
3
4
5(*
6app load ["basic"];
7*)
8
9open HolKernel Parse boolLib pairLib PairRules bossLib pairSyntax ParseDatatype TypeBase;
10
11(*-----------------------------------------------------------------------------------------*)
12(* This transformation eliminates polymorphism and produces a simply-typed intermediate    *)
13(* form that enables good data representations.                                            *)
14(* The basic idea is to duplicate a datatype declaration at each type used and a function  *)
15(* declaration at each type used, resulting in multiple monomorphic clones of this datatype*)
16(* and function.                                                                           *)
17(*-----------------------------------------------------------------------------------------*)
18
19(*-----------------------------------------------------------------------------------------*)
20(* Map and set operation functions.                                                        *)
21(*-----------------------------------------------------------------------------------------*)
22
23structure M = Binarymap
24structure S = Binaryset
25
26(*-----------------------------------------------------------------------------------------*)
27(* Auxiliary functions.                                                                    *)
28(*-----------------------------------------------------------------------------------------*)
29
30fun strOrder (s1:string,s2:string) =   (* order of strings *)
31  if s1 > s2 then GREATER
32    else if s1 = s2 then EQUAL
33    else LESS
34  ;
35
36fun tvarOrder (t1:term,t2:term) =      (* order of terms *)
37  strOrder (term_to_string t1, term_to_string t2)
38
39fun tvarWithTypeOrder (t1:term,t2:term) =      (* order of typed terms *)
40  strOrder (term_to_string t1 ^ (type_to_string o type_of) t1, term_to_string t2 ^ (type_to_string o type_of) t2)
41
42fun typeOrder (t1:hol_type,t2:hol_type) =  (* order of types *)
43  strOrder(type_to_string t1, type_to_string t2)
44  ;
45
46fun is_fun t =   (* the term is a function? *)
47  #1 (Type.dest_type (type_of t)) = "fun"
48  handle e => false
49
50fun get_fname f =
51  #1 (strip_comb (#1 (dest_eq f)))
52
53(*-----------------------------------------------------------------------------------------*)
54(* Data structures.                                                                        *)
55(*-----------------------------------------------------------------------------------------*)
56
57(*
58val Imap = ref (M.mkDict tvarOrder)         (* the instantiation map *)
59                                            (* Format: [function's name |-> [type |-> instantiation set] ] *)
60
61val MonoFunc = ref (M.mkDict tvarOrder)     (* monomorphistic functions *)
62                                            (* Format: [function's name |-> a set of new defitions] *)
63*)
64
65fun smap m = List.map (fn (tp, s) => (tp, S.listItems s)) (M.listItems m)
66
67fun Smap imap = List.map (fn (f,m) => (f, smap m)) (M.listItems imap)
68
69(*
70val map1 = M.insert(M.mkDict typeOrder, ``:'c``, S.addList(S.empty typeOrder, [``:'num``, ``:'bool``]));
71val map2 = M.insert(M.mkDict typeOrder, ``:'b``, S.addList(S.empty typeOrder, [``:'c``, ``:'d``]));
72*)
73
74(*-----------------------------------------------------------------------------------------*)
75(* Union and composition of instantiation maps.                                            *)
76(*-----------------------------------------------------------------------------------------*)
77
78fun mk_map inst_rules =
79  List.foldl (fn (rule : {redex : hol_type, residue : hol_type}, m) =>
80                M.insert(m, #redex rule,
81                  case M.peek(m, #redex rule) of
82                      NONE => S.add(S.empty typeOrder, #residue rule)
83                    | SOME s => S.add(s, #residue rule)
84                )
85             )
86             (M.mkDict typeOrder)
87             inst_rules
88
89fun union_map map1 map2 =
90   List.foldl (fn ((tp, insts), m) =>
91                 case M.peek(m, tp) of
92                      NONE => M.insert(m, tp, insts)
93                    | SOME old_insts => M.insert(m, tp, S.union(old_insts, insts))
94              )
95              map1
96              (M.listItems map2)
97
98fun compose_map map1 map2 =
99  let
100    fun compose type_set =
101       List.foldl (fn (tp, s) =>
102                     case M.peek(map2, tp) of
103                        NONE => S.add(S.empty typeOrder, tp)
104                      | SOME s' => S.union(s, s')
105                  )
106                  (S.empty typeOrder)
107                  (S.listItems type_set)
108  in
109   List.foldl (fn ((tp, type_set), m) =>
110                M.insert(m, tp, compose type_set)
111              )
112              (M.mkDict typeOrder)
113              (M.listItems map1)
114  end
115
116fun union_imap imap1 imap2 =
117   List.foldl (fn ((f, m), imap) =>
118                 case M.peek(imap, f) of
119                      NONE => M.insert(imap, f, m)
120                    | SOME old_m => M.insert(imap, f, union_map old_m  m)
121              )
122              imap1
123              (M.listItems imap2)
124
125fun compose_imap imap map =
126   List.foldl (fn ((f, m), imap') =>
127                M.insert(imap', f, compose_map m map)
128              )
129              (M.mkDict strOrder)
130              (M.listItems imap)
131
132(*-----------------------------------------------------------------------------------------*)
133(* Examine the type and build an instantiation map.                                        *)
134(*-----------------------------------------------------------------------------------------*)
135
136fun strip_type tp =
137  let val (t1, t2) = dest_prod tp
138  in (strip_type t1) @ (strip_type t2)
139  end
140  handle _ =>
141    let val (t1, t2) = dom_rng tp
142    in (strip_type t1) @ (strip_type t2)
143    end
144    handle _ => [tp]
145
146fun examine_type tp =
147  List.foldl (fn (t,imap) =>
148      let val original_t = (TypeBasePure.ty_of o valOf o TypeBase.fetch) t
149          val pstr = #1 (dest_type t)
150          val inst_rules = match_type original_t t
151      in
152        if null inst_rules then imap
153        else
154          case M.peek(imap, pstr) of
155               NONE => M.insert(imap, pstr, mk_map inst_rules)
156            |  SOME m => M.insert(imap, pstr, union_map (mk_map inst_rules) m)
157      end
158      handle _ => imap)
159   (M.mkDict strOrder)
160   (strip_type tp)
161
162(*-----------------------------------------------------------------------------------------*)
163(* Build the instantiation map.                                                            *)
164(*-----------------------------------------------------------------------------------------*)
165
166(* find the constant by its name (a string) *)
167
168fun peek_fname f_str env =
169  case M.peek(env, f_str) of
170      SOME x => SOME x
171   |  NONE => SOME (hd (Term.decls f_str))
172(* SOME (#1 ((strip_comb o lhs o concl o SPEC_ALL o DB.definition) (f_str ^ "_def")))  (* be a predefined function *) *)
173      handle _ => NONE
174
175(* traverse an expression and build the instantiation map *)
176
177fun trav_exp t env =
178  if basic.is_atomic t then
179     examine_type (type_of t)
180  else if is_let t then
181    let val (v,M,N) = dest_plet t in
182      if is_pabs M then       (* an embedded function *)
183        let
184          val (arg, body) = dest_pabs M
185          val f_str = #1 (dest_var v)
186          val body_imap = trav_exp body env
187          val env' = M.insert(M.mkDict strOrder, f_str, v)
188          val N_imap = trav_exp N env'
189          val body_imap' = compose_imap body_imap (M.find(N_imap, f_str))
190                           handle _ => body_imap
191        in
192          union_imap body_imap' N_imap
193        end
194      else
195         union_imap (trav_exp M env) (trav_exp N env)
196    end
197  else if is_cond t then
198    let val (J,M,N) = dest_cond t in
199        union_imap (trav_exp J env)
200          (union_imap (trav_exp M env) (trav_exp N env))
201    end
202  else if is_pair t then
203    let val (M,N) = dest_pair t in
204        union_imap (trav_exp M env) (trav_exp N env)
205    end
206  else if is_pabs t then
207    let val (M,N) = dest_pabs t in
208        trav_exp N env
209    end
210  else if is_comb t then
211    let val (M,N) = dest_comb t
212    in
213       if is_constructor M then
214         union_imap (examine_type (type_of M)) (trav_exp N env)
215       else if is_fun M then    (* function application *)
216         let val fstr = #1 (dest_const M) handle _ => #1 (dest_var M)
217             val fname = valOf (peek_fname fstr env)
218             val inst_rules = match_type (type_of fname) (type_of M)
219             val imap = trav_exp N env
220             val imap' = if null inst_rules then imap
221                         else union_imap imap (M.insert(M.mkDict strOrder, fstr, mk_map inst_rules))
222         in  union_imap imap' (examine_type (type_of M))
223         end
224       else
225         union_imap (trav_exp M env) (trav_exp N env)
226    end
227    (* handle _ => M.mkDict strOrder      (* not function application *) *)
228  else if is_fun t then
229    M.mkDict strOrder
230  else M.mkDict strOrder
231
232(* val imap = M.mkDict strOrder; *)
233
234fun build_imap defs =
235  let
236    fun compose (f_def,imap) =
237      let val env = M.mkDict strOrder
238          val (f_lhs, f_body) = (dest_eq o concl o SPEC_ALL) f_def
239          val f_str = #1 (dest_const (#1 (strip_comb f_lhs)))
240          val body_imap = trav_exp f_body env
241          val imap' = compose_imap body_imap (M.find(imap, f_str))
242              handle _ => body_imap
243      in  union_imap imap imap'
244      end
245  in
246    List.foldr compose (M.mkDict strOrder) defs
247  end
248
249(*-----------------------------------------------------------------------------------------*)
250(* Eliminate polymorphism by duplicating functions definitions.                            *)
251(*-----------------------------------------------------------------------------------------*)
252
253(*
254val Duplicated  = ref (M.mkDict tvarOrder)      (* definitions of the monomorphic functions  *)
255                                                (* format: function name |-> new definition  *)
256*)
257
258(*-----------------------------------------------------------------------------------------*)
259(* Redefine functions in HOL and prove the correctness of the translation.                 *)
260(*-----------------------------------------------------------------------------------------*)
261
262fun change_f_name f name =
263  let val (f_lhs, f_rhs) = dest_eq f
264      val (fname, argL) = (strip_comb f_lhs)
265      val (_, f_type) = dest_const fname
266      val new_fname = mk_var (name, f_type)
267      val new_f_lhs = list_mk_comb(new_fname, argL)
268  in
269      mk_eq (new_f_lhs, f_rhs)
270  end
271
272val MonoFunc = ref (M.mkDict tvarWithTypeOrder)    (* a map from polymorphic function name to the names of its clones *)
273val judgements = ref []             (* a list of judgements specifying the monomorphic functions are equivalent to their polymorphic functions *)
274
275(* Create the clones of a function according to the instantiation information in the instantiation map *)
276
277fun duplicate_func imap def =
278  let
279    fun one_type tp [] rules = []
280     |  one_type tp (x::xs) rules =
281          (List.map (fn y => (tp |-> x) :: y) rules) @ one_type tp xs rules
282
283    (* compute all the combinations of type instantiation rules *)
284    fun mk_type_combination [(tp,type_set)] = List.map (fn x => [tp |-> x]) (S.listItems type_set)
285     |  mk_type_combination ((tp,type_set)::xs) =
286          one_type tp (S.listItems type_set) (mk_type_combination(xs))
287
288    val f = (concl o SPEC_ALL) def
289    val (f_lhs, f_rhs) = dest_eq f
290    val fname = #1 (strip_comb (f_lhs))
291    val (f_str, f_type) = dest_const fname
292    val mono_rules = List.map (fn (old_name, new_name) => old_name |-> new_name) (M.listItems (!MonoFunc))
293
294    val index = ref 0
295    val insts = M.listItems(M.find(imap, f_str)) handle _ => []
296    val new_fs =
297          if null insts then (* the function is already monomorphistic, no instantiations are needed *)
298            (* However, we still need to rewrite its body if other monomorphic functions are called in this body *)
299            let val f' = subst mono_rules f
300                val new_f_str = f_str ^ Int.toString (!index)
301                val new_fname = mk_var(new_f_str, f_type)
302                val new_f = subst [fname |-> new_fname] f'
303                val new_f_def = Define `^new_f`
304                val _ = MonoFunc := M.insert(!MonoFunc, fname, mk_const(new_f_str, f_type))
305                val _ = judgements := (mk_eq(mk_const(f_str, f_type), mk_const(new_f_str, f_type)))
306                                      :: (!judgements)
307            in
308                [new_f_def]
309            end
310          else  (* instantiate types and replace all polymorphic function calls with corresponding monomorphic calls *)
311            let val rules = mk_type_combination insts
312            in
313              List.map (fn rule =>
314                let val f' = inst rule f
315                    val new_f = subst mono_rules f'
316                    val _ = index := !index + 1
317                    val new_f_str = f_str ^ Int.toString (!index)
318                    val old_fname = get_fname new_f
319                    val new_f_type = #2 (dest_const old_fname)
320                    val new_fname = mk_var(new_f_str, new_f_type)
321                    val f'' = subst [old_fname |-> new_fname] new_f
322                    val new_f_def = Define `^f''`
323                    val _ = MonoFunc := M.insert(!MonoFunc, old_fname, mk_const(new_f_str, new_f_type))
324                    val _ = judgements := (mk_eq(mk_const(f_str, new_f_type), mk_const(new_f_str, new_f_type)))
325                                          :: (!judgements)
326                in new_f_def
327                end)
328                rules
329            end
330  in
331    new_fs
332  end
333
334fun build_clone defs =
335  let
336    val imap = build_imap defs
337    val _ = MonoFunc := M.mkDict tvarWithTypeOrder
338    val _ = judgements := []
339    val new_defs = List.foldl (fn (def,fs) =>
340                     fs @ (duplicate_func imap def))
341                    [] defs
342  in
343    (new_defs, list_mk_conj (!judgements))
344  end
345
346(*-----------------------------------------------------------------------------------------*)
347(* Mechanical proof.                                                                       *)
348(*-----------------------------------------------------------------------------------------*)
349
350fun elim_poly defs =
351  let
352    val (newdefs, judgement) = build_clone defs
353    val thm = prove (judgement,
354                SIMP_TAC std_ss [FUN_EQ_THM, pairTheory.FORALL_PROD] THEN
355                SIMP_TAC std_ss (defs @ newdefs)
356              )
357  in
358    (newdefs, thm)
359  end
360
361(*-----------------------------------------------------------------------------------------*)
362(* Example 1.                                                                              *)
363(*-----------------------------------------------------------------------------------------*)
364
365(*
366val _ = Hol_datatype `
367  p = P of 'a # 'a`;
368
369val f_def = Define `
370  f = \x:'b. (P : 'b # 'b -> 'b p) (x, x)`;
371
372val g_def = Define `
373  g = \(y:'c,z:'d).
374      let h = \w : 'c p. f z in
375      let v = f y in
376      h v
377  `;
378
379val a_def = Define `
380  a = g (3, T)`;
381
382val b_def = Define `
383  b = g (F, 1)`;
384
385val defs = [f_def, g_def, a_def, b_def];
386
387val newdefs = elim_poly defs;
388
389val (newdefs, thm) = elim_poly defs;
390
391*)
392
393(*-----------------------------------------------------------------------------------------*)
394(* Example 2.                                                                              *)
395(*-----------------------------------------------------------------------------------------*)
396
397(*
398
399Hol_datatype `dt1 = C of 'a # 'b`;
400
401val f_def = Define `f (x:'a) = x`;
402val g_def = Define `g (x : 'c, y : 'd) =
403      let h = \z. C (f x, f z) in
404      h y`;
405val j_def = Define `j = (g(1, F), g(F, T))`;
406
407val defs = [f_def, g_def, j_def];
408
409val (newdefs, thm) = elim_poly defs;
410
411*)
412
413(*-----------------------------------------------------------------------------------------*)
414
415end (* struct *)