1structure monomorphisation (* :> monomorphisation *) = 2struct 3 4 5(* 6app load ["basic"]; 7*) 8 9open HolKernel Parse boolLib pairLib PairRules bossLib pairSyntax ParseDatatype TypeBase; 10 11(*-----------------------------------------------------------------------------------------*) 12(* This transformation eliminates polymorphism and produces a simply-typed intermediate *) 13(* form that enables good data representations. *) 14(* The basic idea is to duplicate a datatype declaration at each type used and a function *) 15(* declaration at each type used, resulting in multiple monomorphic clones of this datatype*) 16(* and function. *) 17(*-----------------------------------------------------------------------------------------*) 18 19(*-----------------------------------------------------------------------------------------*) 20(* Map and set operation functions. *) 21(*-----------------------------------------------------------------------------------------*) 22 23structure M = Binarymap 24structure S = Binaryset 25 26(*-----------------------------------------------------------------------------------------*) 27(* Auxiliary functions. *) 28(*-----------------------------------------------------------------------------------------*) 29 30fun strOrder (s1:string,s2:string) = (* order of strings *) 31 if s1 > s2 then GREATER 32 else if s1 = s2 then EQUAL 33 else LESS 34 ; 35 36fun tvarOrder (t1:term,t2:term) = (* order of terms *) 37 strOrder (term_to_string t1, term_to_string t2) 38 39fun tvarWithTypeOrder (t1:term,t2:term) = (* order of typed terms *) 40 strOrder (term_to_string t1 ^ (type_to_string o type_of) t1, term_to_string t2 ^ (type_to_string o type_of) t2) 41 42fun typeOrder (t1:hol_type,t2:hol_type) = (* order of types *) 43 strOrder(type_to_string t1, type_to_string t2) 44 ; 45 46fun is_fun t = (* the term is a function? *) 47 #1 (Type.dest_type (type_of t)) = "fun" 48 handle e => false 49 50fun get_fname f = 51 #1 (strip_comb (#1 (dest_eq f))) 52 53(*-----------------------------------------------------------------------------------------*) 54(* Data structures. *) 55(*-----------------------------------------------------------------------------------------*) 56 57(* 58val Imap = ref (M.mkDict tvarOrder) (* the instantiation map *) 59 (* Format: [function's name |-> [type |-> instantiation set] ] *) 60 61val MonoFunc = ref (M.mkDict tvarOrder) (* monomorphistic functions *) 62 (* Format: [function's name |-> a set of new defitions] *) 63*) 64 65fun smap m = List.map (fn (tp, s) => (tp, S.listItems s)) (M.listItems m) 66 67fun Smap imap = List.map (fn (f,m) => (f, smap m)) (M.listItems imap) 68 69(* 70val map1 = M.insert(M.mkDict typeOrder, ``:'c``, S.addList(S.empty typeOrder, [``:'num``, ``:'bool``])); 71val map2 = M.insert(M.mkDict typeOrder, ``:'b``, S.addList(S.empty typeOrder, [``:'c``, ``:'d``])); 72*) 73 74(*-----------------------------------------------------------------------------------------*) 75(* Union and composition of instantiation maps. *) 76(*-----------------------------------------------------------------------------------------*) 77 78fun mk_map inst_rules = 79 List.foldl (fn (rule : {redex : hol_type, residue : hol_type}, m) => 80 M.insert(m, #redex rule, 81 case M.peek(m, #redex rule) of 82 NONE => S.add(S.empty typeOrder, #residue rule) 83 | SOME s => S.add(s, #residue rule) 84 ) 85 ) 86 (M.mkDict typeOrder) 87 inst_rules 88 89fun union_map map1 map2 = 90 List.foldl (fn ((tp, insts), m) => 91 case M.peek(m, tp) of 92 NONE => M.insert(m, tp, insts) 93 | SOME old_insts => M.insert(m, tp, S.union(old_insts, insts)) 94 ) 95 map1 96 (M.listItems map2) 97 98fun compose_map map1 map2 = 99 let 100 fun compose type_set = 101 List.foldl (fn (tp, s) => 102 case M.peek(map2, tp) of 103 NONE => S.add(S.empty typeOrder, tp) 104 | SOME s' => S.union(s, s') 105 ) 106 (S.empty typeOrder) 107 (S.listItems type_set) 108 in 109 List.foldl (fn ((tp, type_set), m) => 110 M.insert(m, tp, compose type_set) 111 ) 112 (M.mkDict typeOrder) 113 (M.listItems map1) 114 end 115 116fun union_imap imap1 imap2 = 117 List.foldl (fn ((f, m), imap) => 118 case M.peek(imap, f) of 119 NONE => M.insert(imap, f, m) 120 | SOME old_m => M.insert(imap, f, union_map old_m m) 121 ) 122 imap1 123 (M.listItems imap2) 124 125fun compose_imap imap map = 126 List.foldl (fn ((f, m), imap') => 127 M.insert(imap', f, compose_map m map) 128 ) 129 (M.mkDict strOrder) 130 (M.listItems imap) 131 132(*-----------------------------------------------------------------------------------------*) 133(* Examine the type and build an instantiation map. *) 134(*-----------------------------------------------------------------------------------------*) 135 136fun strip_type tp = 137 let val (t1, t2) = dest_prod tp 138 in (strip_type t1) @ (strip_type t2) 139 end 140 handle _ => 141 let val (t1, t2) = dom_rng tp 142 in (strip_type t1) @ (strip_type t2) 143 end 144 handle _ => [tp] 145 146fun examine_type tp = 147 List.foldl (fn (t,imap) => 148 let val original_t = (TypeBasePure.ty_of o valOf o TypeBase.fetch) t 149 val pstr = #1 (dest_type t) 150 val inst_rules = match_type original_t t 151 in 152 if null inst_rules then imap 153 else 154 case M.peek(imap, pstr) of 155 NONE => M.insert(imap, pstr, mk_map inst_rules) 156 | SOME m => M.insert(imap, pstr, union_map (mk_map inst_rules) m) 157 end 158 handle _ => imap) 159 (M.mkDict strOrder) 160 (strip_type tp) 161 162(*-----------------------------------------------------------------------------------------*) 163(* Build the instantiation map. *) 164(*-----------------------------------------------------------------------------------------*) 165 166(* find the constant by its name (a string) *) 167 168fun peek_fname f_str env = 169 case M.peek(env, f_str) of 170 SOME x => SOME x 171 | NONE => SOME (hd (Term.decls f_str)) 172(* SOME (#1 ((strip_comb o lhs o concl o SPEC_ALL o DB.definition) (f_str ^ "_def"))) (* be a predefined function *) *) 173 handle _ => NONE 174 175(* traverse an expression and build the instantiation map *) 176 177fun trav_exp t env = 178 if basic.is_atomic t then 179 examine_type (type_of t) 180 else if is_let t then 181 let val (v,M,N) = dest_plet t in 182 if is_pabs M then (* an embedded function *) 183 let 184 val (arg, body) = dest_pabs M 185 val f_str = #1 (dest_var v) 186 val body_imap = trav_exp body env 187 val env' = M.insert(M.mkDict strOrder, f_str, v) 188 val N_imap = trav_exp N env' 189 val body_imap' = compose_imap body_imap (M.find(N_imap, f_str)) 190 handle _ => body_imap 191 in 192 union_imap body_imap' N_imap 193 end 194 else 195 union_imap (trav_exp M env) (trav_exp N env) 196 end 197 else if is_cond t then 198 let val (J,M,N) = dest_cond t in 199 union_imap (trav_exp J env) 200 (union_imap (trav_exp M env) (trav_exp N env)) 201 end 202 else if is_pair t then 203 let val (M,N) = dest_pair t in 204 union_imap (trav_exp M env) (trav_exp N env) 205 end 206 else if is_pabs t then 207 let val (M,N) = dest_pabs t in 208 trav_exp N env 209 end 210 else if is_comb t then 211 let val (M,N) = dest_comb t 212 in 213 if is_constructor M then 214 union_imap (examine_type (type_of M)) (trav_exp N env) 215 else if is_fun M then (* function application *) 216 let val fstr = #1 (dest_const M) handle _ => #1 (dest_var M) 217 val fname = valOf (peek_fname fstr env) 218 val inst_rules = match_type (type_of fname) (type_of M) 219 val imap = trav_exp N env 220 val imap' = if null inst_rules then imap 221 else union_imap imap (M.insert(M.mkDict strOrder, fstr, mk_map inst_rules)) 222 in union_imap imap' (examine_type (type_of M)) 223 end 224 else 225 union_imap (trav_exp M env) (trav_exp N env) 226 end 227 (* handle _ => M.mkDict strOrder (* not function application *) *) 228 else if is_fun t then 229 M.mkDict strOrder 230 else M.mkDict strOrder 231 232(* val imap = M.mkDict strOrder; *) 233 234fun build_imap defs = 235 let 236 fun compose (f_def,imap) = 237 let val env = M.mkDict strOrder 238 val (f_lhs, f_body) = (dest_eq o concl o SPEC_ALL) f_def 239 val f_str = #1 (dest_const (#1 (strip_comb f_lhs))) 240 val body_imap = trav_exp f_body env 241 val imap' = compose_imap body_imap (M.find(imap, f_str)) 242 handle _ => body_imap 243 in union_imap imap imap' 244 end 245 in 246 List.foldr compose (M.mkDict strOrder) defs 247 end 248 249(*-----------------------------------------------------------------------------------------*) 250(* Eliminate polymorphism by duplicating functions definitions. *) 251(*-----------------------------------------------------------------------------------------*) 252 253(* 254val Duplicated = ref (M.mkDict tvarOrder) (* definitions of the monomorphic functions *) 255 (* format: function name |-> new definition *) 256*) 257 258(*-----------------------------------------------------------------------------------------*) 259(* Redefine functions in HOL and prove the correctness of the translation. *) 260(*-----------------------------------------------------------------------------------------*) 261 262fun change_f_name f name = 263 let val (f_lhs, f_rhs) = dest_eq f 264 val (fname, argL) = (strip_comb f_lhs) 265 val (_, f_type) = dest_const fname 266 val new_fname = mk_var (name, f_type) 267 val new_f_lhs = list_mk_comb(new_fname, argL) 268 in 269 mk_eq (new_f_lhs, f_rhs) 270 end 271 272val MonoFunc = ref (M.mkDict tvarWithTypeOrder) (* a map from polymorphic function name to the names of its clones *) 273val judgements = ref [] (* a list of judgements specifying the monomorphic functions are equivalent to their polymorphic functions *) 274 275(* Create the clones of a function according to the instantiation information in the instantiation map *) 276 277fun duplicate_func imap def = 278 let 279 fun one_type tp [] rules = [] 280 | one_type tp (x::xs) rules = 281 (List.map (fn y => (tp |-> x) :: y) rules) @ one_type tp xs rules 282 283 (* compute all the combinations of type instantiation rules *) 284 fun mk_type_combination [(tp,type_set)] = List.map (fn x => [tp |-> x]) (S.listItems type_set) 285 | mk_type_combination ((tp,type_set)::xs) = 286 one_type tp (S.listItems type_set) (mk_type_combination(xs)) 287 288 val f = (concl o SPEC_ALL) def 289 val (f_lhs, f_rhs) = dest_eq f 290 val fname = #1 (strip_comb (f_lhs)) 291 val (f_str, f_type) = dest_const fname 292 val mono_rules = List.map (fn (old_name, new_name) => old_name |-> new_name) (M.listItems (!MonoFunc)) 293 294 val index = ref 0 295 val insts = M.listItems(M.find(imap, f_str)) handle _ => [] 296 val new_fs = 297 if null insts then (* the function is already monomorphistic, no instantiations are needed *) 298 (* However, we still need to rewrite its body if other monomorphic functions are called in this body *) 299 let val f' = subst mono_rules f 300 val new_f_str = f_str ^ Int.toString (!index) 301 val new_fname = mk_var(new_f_str, f_type) 302 val new_f = subst [fname |-> new_fname] f' 303 val new_f_def = Define `^new_f` 304 val _ = MonoFunc := M.insert(!MonoFunc, fname, mk_const(new_f_str, f_type)) 305 val _ = judgements := (mk_eq(mk_const(f_str, f_type), mk_const(new_f_str, f_type))) 306 :: (!judgements) 307 in 308 [new_f_def] 309 end 310 else (* instantiate types and replace all polymorphic function calls with corresponding monomorphic calls *) 311 let val rules = mk_type_combination insts 312 in 313 List.map (fn rule => 314 let val f' = inst rule f 315 val new_f = subst mono_rules f' 316 val _ = index := !index + 1 317 val new_f_str = f_str ^ Int.toString (!index) 318 val old_fname = get_fname new_f 319 val new_f_type = #2 (dest_const old_fname) 320 val new_fname = mk_var(new_f_str, new_f_type) 321 val f'' = subst [old_fname |-> new_fname] new_f 322 val new_f_def = Define `^f''` 323 val _ = MonoFunc := M.insert(!MonoFunc, old_fname, mk_const(new_f_str, new_f_type)) 324 val _ = judgements := (mk_eq(mk_const(f_str, new_f_type), mk_const(new_f_str, new_f_type))) 325 :: (!judgements) 326 in new_f_def 327 end) 328 rules 329 end 330 in 331 new_fs 332 end 333 334fun build_clone defs = 335 let 336 val imap = build_imap defs 337 val _ = MonoFunc := M.mkDict tvarWithTypeOrder 338 val _ = judgements := [] 339 val new_defs = List.foldl (fn (def,fs) => 340 fs @ (duplicate_func imap def)) 341 [] defs 342 in 343 (new_defs, list_mk_conj (!judgements)) 344 end 345 346(*-----------------------------------------------------------------------------------------*) 347(* Mechanical proof. *) 348(*-----------------------------------------------------------------------------------------*) 349 350fun elim_poly defs = 351 let 352 val (newdefs, judgement) = build_clone defs 353 val thm = prove (judgement, 354 SIMP_TAC std_ss [FUN_EQ_THM, pairTheory.FORALL_PROD] THEN 355 SIMP_TAC std_ss (defs @ newdefs) 356 ) 357 in 358 (newdefs, thm) 359 end 360 361(*-----------------------------------------------------------------------------------------*) 362(* Example 1. *) 363(*-----------------------------------------------------------------------------------------*) 364 365(* 366val _ = Hol_datatype ` 367 p = P of 'a # 'a`; 368 369val f_def = Define ` 370 f = \x:'b. (P : 'b # 'b -> 'b p) (x, x)`; 371 372val g_def = Define ` 373 g = \(y:'c,z:'d). 374 let h = \w : 'c p. f z in 375 let v = f y in 376 h v 377 `; 378 379val a_def = Define ` 380 a = g (3, T)`; 381 382val b_def = Define ` 383 b = g (F, 1)`; 384 385val defs = [f_def, g_def, a_def, b_def]; 386 387val newdefs = elim_poly defs; 388 389val (newdefs, thm) = elim_poly defs; 390 391*) 392 393(*-----------------------------------------------------------------------------------------*) 394(* Example 2. *) 395(*-----------------------------------------------------------------------------------------*) 396 397(* 398 399Hol_datatype `dt1 = C of 'a # 'b`; 400 401val f_def = Define `f (x:'a) = x`; 402val g_def = Define `g (x : 'c, y : 'd) = 403 let h = \z. C (f x, f z) in 404 h y`; 405val j_def = Define `j = (g(1, F), g(F, T))`; 406 407val defs = [f_def, g_def, j_def]; 408 409val (newdefs, thm) = elim_poly defs; 410 411*) 412 413(*-----------------------------------------------------------------------------------------*) 414 415end (* struct *)