1structure defunctionalize (* :> defunctionalize *) = 2struct 3 4open HolKernel Parse boolLib pairLib PairRules bossLib pairSyntax ParseDatatype TypeBase; 5 6(*-----------------------------------------------------------------------------------------*) 7(* We convert higher-order functions into equivalent first-order functions and hoist nested*) 8(* functions to the top level through a type based closure conversion. After this conver- *) 9(* sion, no nested functions exist; and function call is made by dispatching on the closure*) 10(* tag followed by a top-level call. *) 11(* Function closures are represented as algebraic data types in a way that,for each func- *) 12(* tion definition, a constructor taking the free variables of this function is created. *) 13(* For each arrow type we create a dispatch function, which converts the definition of a *) 14(* function of this arrow type into a closure constructor application. *) 15(* A nested function is hoisted to the top level with its free variables to be passed as *) 16(* extra arguments. After that, the calling to the original function is replaced by a *) 17(* calling to the relevant dispatch function passing a closure containing the values of *) 18(* this function's free variables. The dispatch function examines the closure tag and *) 19(* passes control to the appropriate hoisted function. Thus, higher order operations on *) 20(* functions are replaced by equivalent operations on first order closure values. *) 21(*-----------------------------------------------------------------------------------------*) 22 23(*-----------------------------------------------------------------------------------------*) 24(* Map and set operation functions. *) 25(*-----------------------------------------------------------------------------------------*) 26 27structure M = Binarymap 28structure S = Binaryset 29 30(*-----------------------------------------------------------------------------------------*) 31(* Auxiliary functions. *) 32(*-----------------------------------------------------------------------------------------*) 33 34fun strOrder (s1:string,s2:string) = (* order of strings *) 35 if s1 > s2 then GREATER 36 else if s1 = s2 then EQUAL 37 else LESS 38 ; 39 40fun tvarOrder (t1:term,t2:term) = (* order of terms *) 41 strOrder (term_to_string t1, term_to_string t2) 42 43fun typeOrder (t1:hol_type,t2:hol_type) = (* order of types *) 44 strOrder(type_to_string t1, type_to_string t2) 45 ; 46 47fun is_fun t = (* the term is a function? *) 48 #1 (Type.dest_type (type_of t)) = "fun" 49 handle e => false 50 51fun FunName f = 52 #1 (strip_comb (#1 (dest_eq f))) 53 54(*-----------------------------------------------------------------------------------------*) 55(* Data structures. *) 56(*-----------------------------------------------------------------------------------------*) 57 58val Lifted = ref (M.mkDict tvarOrder) (* the definitions of those embedded functions that should be lifted *) 59 (* Format: [function's name |-> function's body] *) 60val ClosFunc = ref (M.mkDict typeOrder) (* the types and the higher order functions associating with them *) 61 (* Format: [function's type |-> a set of function names] *) 62 63val ClosInfo = ref (M.mkDict typeOrder) (* A mapping from a type to the information of its datatype representing a closure *) 64 (* Format: [type |-> datatype's info] *) 65val ClosName = ref (M.mkDict typeOrder) (* A mapping from a type to the name of its datatype representing a closure *) 66 (* Format: [type |-> datatype's name (a string)] *) 67 68val HOFunc = ref (M.mkDict tvarOrder) (* higher order functions *) 69 (* Format: [function's name |-> (new function's lhs, constructor)] *) 70 71fun cF() = 72 (M.listItems (!Lifted), 73 List.map (fn (tp, s) => (tp, S.listItems s)) (M.listItems (!ClosFunc))); 74 75(*-----------------------------------------------------------------------------------------*) 76(* Identify higher order functions (those functions used in arguments and returns; *) 77(* then build datatype for them. *) 78(*-----------------------------------------------------------------------------------------*) 79 80fun record_f fname = (* store the name of a higher order function *) 81 let val tp = type_of fname 82 in 83 case M.peek(!ClosFunc, tp) of 84 NONE => 85 (* val _ = closure_index := !closure_index + 1; *) 86 ClosFunc := M.insert(!ClosFunc, tp, S.add(S.empty tvarOrder, fname)) 87 | SOME s => 88 ClosFunc := M.insert(!ClosFunc, tp, S.add(s, fname)) 89 end; 90 91fun identify_f e = (* Identify higher order functions in an expression and store them into the ClosFunc *) 92 let 93 fun trav t = 94 if is_let t then 95 let val (v,M,N) = dest_plet t 96 val _ = (trav M; trav N) 97 in if is_pabs M then (* an embedded function, should be lifted *) 98 Lifted := M.insert(!Lifted, N, M) 99 else 100 () 101 end 102 else if is_pair t then 103 let val (M,N) = dest_pair t 104 in (trav M; trav N) 105 end 106 else if is_cond t then 107 let val (J,M,N) = dest_cond t 108 in (trav M; trav N) 109 end 110 else if is_comb t then 111 let val (M,N) = dest_comb t 112 in if is_fun N then 113 (record_f t; 114 if is_comb M then trav M else () 115 ) 116 else 117 if is_comb M then trav M else () 118 end 119 else if is_pabs t then 120 let val (M,N) = dest_pabs t 121 in trav N 122 end 123 else if is_fun t then 124 record_f t 125 else 126 () 127 in 128 trav e 129 end 130 131fun identify_closure defs = (* Identify higher order functions in a list of function definitions *) 132 let 133 fun mk_clos_data f = 134 let val (fdecl, fbody) = dest_eq f 135 val (fname, args) = dest_comb fdecl 136 val _ = Lifted := M.insert(!Lifted, fname, mk_pabs(args,fbody)) 137 in 138 identify_f fbody 139 end 140 in 141 (ClosFunc := M.mkDict typeOrder; 142 Lifted := M.mkDict tvarOrder; 143 List.map (mk_clos_data o concl o SPEC_ALL) defs 144 ) 145 end 146 147(*-----------------------------------------------------------------------------------------*) 148(* Build datatypes for closures. *) 149(*-----------------------------------------------------------------------------------------*) 150 151val closure_index = ref 0; 152val constructor_index = ref 0; 153 154fun register_type tyinfos_etc = (* register the new datatype in HOL *) 155 let 156 val (tyinfos, etc) = unzip tyinfos_etc 157 val tyinfos = TypeBase.write tyinfos 158 val () = app computeLib.write_datatype_info tyinfos 159 in 160 Datatype.write_tyinfos tyinfos_etc 161 end 162 163fun build_type tp funcs = (* build a new datatype for a type *) 164 let 165 166 (* the arguments of a constructor, these arguments are the free variables of a function body *) 167 fun build_type_args fv = 168 if null fv then [] 169 else if length fv = 1 then 170 [dTyop{Args = [], Thy = NONE, 171 Tyop = let val t = type_of (hd fv) in 172 M.find(!ClosName, t) 173 handle e => #1 (Type.dest_type t) 174 end} 175 ] 176 else 177 [dTyop{Args = 178 List.map (fn arg => 179 dTyop{Args = [], Thy = NONE, 180 Tyop = M.find(!ClosName, type_of arg) 181 handle e => #1 (Type.dest_type (type_of arg))}) 182 fv, 183 Thy = NONE, Tyop = "prod"} 184 ] 185 186 val clos_name = (* the name of the datatype representing a closure for the inputting type *) 187 let val _ = closure_index := !closure_index + 1 188 val x = "clos" ^ Int.toString (!closure_index) 189 val _ = ClosName := M.insert(!ClosName, tp, x) 190 in x 191 end 192 193 val clos_tp_info = (* the type information of the datatype *) 194 [(clos_name, 195 Constructors ( 196 List.map 197 (fn lifted_f => 198 let 199 val _ = constructor_index := !constructor_index + 1 200 val fv = free_vars (M.find (!Lifted, lifted_f)) 201 val args = build_type_args fv 202 in 203 ("cons" ^ Int.toString(!constructor_index), 204 build_type_args fv 205 ) 206 end 207 ) (S.listItems funcs) 208 ) 209 )] 210 211 val new_clos_type = Datatype.primHol_datatype_from_astl (TypeBase.theTypeBase()) clos_tp_info; 212 val _ = register_type (#2 new_clos_type) 213 val _ = ClosInfo := M.insert(!ClosInfo, tp, #1 (hd (#2 new_clos_type))) 214 in 215 new_clos_type 216 end 217 ; 218 219fun build_types defs = (* build datatypes for all higher order functions *) 220 (closure_index := 0; 221 constructor_index := 0; 222 identify_closure defs; 223 ClosName := M.mkDict typeOrder; 224 List.map (fn (tp, fs) => build_type tp fs) (M.listItems (!ClosFunc)) 225 ) 226 227(*-----------------------------------------------------------------------------------------*) 228(* Conversions from original HOL types to closure types. *) 229(*-----------------------------------------------------------------------------------------*) 230 231fun type2closure tp = (* from an original type to its closure type *) 232 TypeBasePure.ty_of(M.find(!ClosInfo, tp)) 233 handle _ => tp 234 235fun term2closure t = (* get the closure type for a term *) 236 let val (name, tp) = dest_var t 237 in mk_var(name, type2closure tp) 238 end 239 handle _ => t 240 241fun type2dispatch tp = (* from an original type to its dispatch function *) 242 let val tinfo = M.find(!ClosInfo, tp) 243 val clos_type = TypeBasePure.ty_of tinfo 244 val f_index = String.extract (#1 (Type.dest_type clos_type), 4, NONE) (* take the value of n from "closn" *) 245 val (arg_type, return_type) = dom_rng tp 246 val df_var = mk_const("dispatch" ^ f_index, (* the dispatch function has been defined *) 247 mk_prod(clos_type, arg_type) --> return_type) 248 handle e => mk_var("dispatch" ^ f_index, (* the dispatch function has not been defined *) 249 mk_prod(clos_type, arg_type) --> return_type) 250 in 251 df_var 252 end 253 254(*-----------------------------------------------------------------------------------------*) 255(* Build dispatch functions. *) 256(* A dispatch function is in pattern-matching format. *) 257(*-----------------------------------------------------------------------------------------*) 258 259fun mk_dispatch tp = 260 let 261 val tinfo = M.find(!ClosInfo, tp) 262 val clos_type = TypeBasePure.ty_of tinfo 263 (* val clos_case = TypeBasePure.case_const_of tinfo *) 264 val clos_consL = TypeBasePure.constructors_of tinfo 265 val f_index = String.extract (#1 (Type.dest_type clos_type), 4, NONE) (* take the value of n from "closn" *) 266 267 val funL = S.listItems (M.find(!ClosFunc, tp)) 268 val (arg_type, return_type) = dom_rng tp 269 270 val df_name = "dispatch" ^ f_index 271 val df_type = mk_prod(clos_type, arg_type) --> return_type 272 273 val df_var = mk_var(df_name, df_type) 274(* 275 val _ = new_constant(df_name, df_type) 276 val df_const = mk_const(df_name, df_type) 277*) 278 279 fun mk_clause (fname, constructor) = (* construct a dispatching clause for the pattern matching pattern *) 280 let val f_body = M.find(!Lifted, fname) 281 val (f_arg, body) = dest_pabs f_body 282 val fv = free_vars f_body 283 val fv' = List.map term2closure fv 284 val clos_arg = if null fv then constructor 285 else mk_comb(constructor, list_mk_pair fv') 286 val arg = mk_pair(clos_arg, f_arg) 287 val lt = mk_comb(df_var, arg) 288 val rt = let val (old_name, ftype) = dest_const fname handle _ => dest_var fname 289 val new_arg = if null fv then f_arg else mk_pair(list_mk_pair fv', f_arg) 290 val new_name = old_name ^ "'" 291 val new_f_type = (type_of new_arg) --> return_type 292 (* 293 val _ = new_constant(new_name , new_f_type) 294 val new_fname = mk_const(new_name, new_f_type) 295 *) 296 val new_fname = mk_var(new_name, new_f_type) 297 val new_f = mk_comb (new_fname, new_arg) 298 val _ = HOFunc := M.insert(!HOFunc, fname, (new_f, clos_arg)) 299 in 300 new_f 301 end 302 in 303 mk_eq(lt,rt) 304 end 305 306 val clauses = list_mk_conj (List.map mk_clause (zip funL clos_consL)) 307 in 308 clauses 309 end 310 311 312val Dispatched = ref (M.mkDict typeOrder) (* definitions of dispatch functions *) 313 (* format: type |-> list of definitions *) 314 315fun build_dispatch () = (* build dispatch functions for all introduced datatypes *) 316 (HOFunc := M.mkDict tvarOrder; 317 List.map (fn tp => Dispatched := M.insert(!Dispatched, tp, mk_dispatch tp)) 318 (List.map fst (M.listItems (!ClosFunc))) 319 ) 320 321(*-----------------------------------------------------------------------------------------*) 322(* convert_exp translates expressions; *) 323(* convert_fun translates functions; *) 324(* TS translates top-level declarations; *) 325(*-----------------------------------------------------------------------------------------*) 326 327val Redefined = ref (M.mkDict tvarOrder) (* definitions of the functions after closure conversion *) 328 (* format: function name |-> new definition *) 329fun convert_exp t = 330 if is_let t then 331 let val (v,M,N) = dest_plet t in 332 if is_pabs M then (* an embedded function *) 333 let 334 val (arg, body) = dest_pabs M 335 val _ = convert_fun (mk_eq(mk_comb(v, arg), body)) 336(* val M' = #2 (M.find(!HOFunc, v)) 337 val v' = mk_var (#1 (dest_var v), type_of M') 338 in 339 mk_plet(v', M', convert_exp N) 340 end 341*) 342 in 343 convert_exp N 344 end 345 else 346 mk_plet (v, convert_exp M, convert_exp N) 347 end 348 else if is_cond t then 349 let val (J,M,N) = dest_cond t in 350 mk_cond (J, convert_exp M, convert_exp N) 351 end 352 else if is_pair t then 353 let val (M,N) = dest_pair t in 354 mk_pair (convert_exp M, convert_exp N) 355 end 356 else if is_pabs t then 357 let val (M,N) = dest_pabs t in 358 mk_pabs (convert_exp M, convert_exp N) 359 end 360 else if is_comb t then 361 let val (M,N) = dest_comb t 362 in 363 if length (#2 (strip_comb t)) > 1 then t (* binary expressions *) 364 else if is_fun M then (* function application *) 365 if not (M.peek(!Redefined, M) = NONE) then (* a pre-defined function *) 366 let val fname_var = #1 (M.find(!Redefined, M)) 367 val (fname_str, f_tp) = dest_var fname_var 368 val fname_const = mk_const (fname_str, f_tp) 369 handle _ => fname_var (* recursive function *) 370 in mk_comb(fname_const, convert_exp N) 371 end 372 else 373 let 374 val tp = type_of M 375 val clos_var = 376 mk_var(#1 (dest_const M) handle _ => #1 (dest_var M), 377 type2closure tp) 378 val closure = mk_pair(clos_var, convert_exp N) 379 in 380 mk_comb (type2dispatch(tp), closure) 381 end 382 else 383 mk_comb(convert_exp M, convert_exp N) 384 end 385 handle _ => t (* not function application *) 386 else if is_fun t then 387 case M.peek(!HOFunc, t) of (* Higher order function *) 388 NONE => mk_var(#1 (dest_const t) handle _ => #1 (dest_var t), 389 type2closure (type_of t)) | 390 SOME (f_sig, constr) => constr 391 else t 392 393and 394 395convert_fun f = 396 let 397 val (fdecl, fbody) = dest_eq f 398 val (fname, args) = dest_comb fdecl 399 val (fname_str, f_tp) = dest_const fname handle _ => dest_var fname 400 val new_fname_str = fname_str ^ "'" 401 in 402 if M.peek(!HOFunc, fname) = NONE then (* not higher order function *) 403 let val args1 = convert_exp args 404 val new_f_tp = type_of args1 --> type2closure (type_of fdecl) 405 406 val new_fname = mk_var(new_fname_str, new_f_tp) 407 val _ = Redefined := M.insert(!Redefined, fname, (new_fname, ``T``)) 408 409 val fbody1 = convert_exp fbody 410 val new_f = mk_eq(mk_comb (new_fname, args1), fbody1) 411 val _ = Redefined := M.insert(!Redefined, fname, (new_fname, new_f)) 412 in 413 new_f 414 end 415 else (* a higher order function *) 416 let val lt = #1 (M.find (!HOFunc, fname)) 417 val (new_fname, new_args) = dest_comb lt 418 val _ = Redefined := M.insert(!Redefined, fname, (new_fname, ``T``)) 419 val fbody1 = convert_exp fbody 420 val new_f = mk_eq(lt, fbody1) 421 val _ = Redefined := M.insert(!Redefined, fname, (new_fname, new_f)) 422 in 423 new_f 424 end 425 end 426 handle _ => f 427 428fun defunctionalize defs = 429 let 430 val _ = build_types defs 431 val _ = build_dispatch () 432 433 fun process_type tp = 434 let val fs = S.listItems(M.find (!ClosFunc, tp)) 435 val fs' = List.map (fn fname => 436 let val fbody = M.find(!Lifted, fname) 437 val (args,body) = dest_pabs fbody 438 val fdecl = mk_comb(fname, args) 439 in convert_fun (mk_eq(fdecl, body)) 440 end) fs 441 val spec = list_mk_conj (strip_conj (M.find(!Dispatched, tp)) @ fs') 442 val def = Defn.eqns_of (Defn.Hol_defn "x" `^spec`) 443 in 444 def 445 end 446 447 val _ = Redefined := M.mkDict tvarOrder 448 val dispatch_spec = List.map process_type (List.map fst (M.listItems (!ClosFunc))) 449 450 val remaining_funcs = 451 List.filter (fn f => M.peek(!Redefined, #1 (dest_comb (lhs f))) = NONE) 452 (List.map (concl o SPEC_ALL) defs) 453 val new_spec = List.map (fn x => let val f = convert_fun x in Define `^f` end) remaining_funcs 454 455 in 456 (hd dispatch_spec) @ new_spec 457 end 458 459(*-----------------------------------------------------------------------------------------*) 460(* Redefine functions in HOL and prove the correctness of the translation. *) 461(*-----------------------------------------------------------------------------------------*) 462 463(* Convert function arguments to closure arguments *) 464 465fun process_args args = 466 if is_pair args then 467 let val (arg1, arg2) = dest_pair args 468 val (assms1, arg1') = process_args arg1 469 val (assms2, arg2') = process_args arg2 470 in 471 (assms1 @ assms2, mk_pair(arg1', arg2')) 472 end 473 else 474 let 475 val (arg_str, arg_tp) = dest_var args 476 val new_arg_str = arg_str ^ "'" 477 478 val new_args = if is_fun args then mk_var (new_arg_str, type2closure arg_tp) 479 else args 480 val assms = if is_fun args then 481 let val input = mk_var("i", #1 (dom_rng arg_tp)) in 482 [mk_eq(mk_comb(args, input), mk_comb (type2dispatch arg_tp, mk_pair(new_args, input)))] 483 end 484 else [] 485 in 486 (assms, new_args) 487 end 488 489(* Build the equivalence statement for a function. *) 490 491fun var2const t = 492 if is_comb t then 493 let val (M,N) = dest_comb t 494 in mk_comb(var2const M, N) 495 end 496 else 497 let val (v, tp) = dest_var t 498 in mk_const(v, tp) 499 end 500 501fun build_judgement f = 502 let 503 val (fdecl, fbody) = dest_eq f 504 val (fname, args) = dest_comb fdecl 505 val (assums, new_args) = process_args args 506 val new_fname = var2const (#1 (M.find (!Redefined, fname))) handle _ => fname 507 val new_fdecl = mk_comb (new_fname, new_args) 508 val x = if not (is_fun fdecl) then mk_eq(fdecl, new_fdecl) 509 else let val ftp = type_of fdecl 510 val input = mk_var("m", #1 (dom_rng ftp)) 511 val new_fdecl' = mk_comb (type2dispatch ftp, mk_pair(new_fdecl, input)) 512 in 513 mk_eq(mk_comb(fdecl, input), new_fdecl') 514 end 515 val x' = gen_all x 516 val judgement = if null assums then x' 517 else mk_imp(list_mk_conj assums, x') 518 in 519 judgement 520 end 521 522(* 523 (build_judgement o concl o SPEC_ALL) (List.nth(defs,2)) 524 val def = List.nth(defs,2) 525 val f = concl (SPEC_ALL def) 526*) 527 528fun elim_hof defs = 529 let 530 val newdefs = defunctionalize defs 531 val judgements = List.map (build_judgement o concl o SPEC_ALL) defs 532 in 533 (newdefs, judgements) 534 end 535 536(*-----------------------------------------------------------------------------------------*) 537(* Example 1. *) 538(*-----------------------------------------------------------------------------------------*) 539 540val empty_def = Define ` 541 empty (x : num) = F`; 542 543val member_def = Define ` 544 member (s : num -> bool, x : num) = s x`; 545 546val insert_def = Define ` 547 insert(s : num -> bool, x : num) = 548 let s1 y = if x = y then T else s x 549 in s1 550 `; 551 552val upto_def = Define ` 553 upto(n : num) = 554 if n = 0 then empty else insert(upto(n-1),n) 555 `; 556 557val main_def = Define ` 558 main (n : num) = (upto n, 100)`; 559 560val defs = [empty_def, member_def, insert_def, upto_def]; 561 562(* 563val (newdefs, judgements) = elim_hof defs; 564 565val defs1 = List.take(defs, 3); 566val defs2 = List.drop(defs, 3); 567val newdefs1 = List.take(newdefs, 6); 568val newdefs2 = List.drop(newdefs, 6); 569 570set_goal ([], List.nth(judgements, 3)) (* set_goal ([], ``!m. dispatch1(upto' n, j) = upto n m``) *) 571 572Induct_on `n` THENL [ 573 ONCE_REWRITE_TAC (defs2 @ newdefs2) THEN 574 expandf (RW_TAC arith_ss (defs1 @ newdefs1)), 575 576 ONCE_REWRITE_TAC (defs2 @ newdefs2) THEN 577 RW_TAC arith_ss [LET_THM] THEN 578 expandf (RW_TAC arith_ss (defs1 @ newdefs1)) THEN 579 Q.UNABBREV_TAC `s1` THEN 580 RW_TAC std_ss [] 581] 582 583*) 584 585(*-----------------------------------------------------------------------------------------*) 586(* Example 2. *) 587(*-----------------------------------------------------------------------------------------*) 588 589val f_def = Define ` 590 f x = x * 2 < x + 10`; 591 592val g_def = Define ` 593 g (s, x) = 594 let h1 = \y. y + x in 595 if s x then h1 else let h2 = \y. h1 y * x in h2`; 596 597val k_def = Define ` 598 k x = if x = 0 then 1 else (g (f,x)) (k(x-1))` 599 600val defs = [f_def, g_def, k_def]; 601 602(*-----------------------------------------------------------------------------------------*) 603 604end (* struct *) 605 606