1(*  Title:      HOL/HOLCF/Tools/Domain/domain_isomorphism.ML
2    Author:     Brian Huffman
3
4Defines new types satisfying the given domain equations.
5*)
6
7signature DOMAIN_ISOMORPHISM =
8sig
9  val domain_isomorphism :
10      (string list * binding * mixfix * typ
11       * (binding * binding) option) list ->
12      theory ->
13      (Domain_Take_Proofs.iso_info list
14       * Domain_Take_Proofs.take_induct_info) * theory
15
16  val define_map_functions :
17      (binding * Domain_Take_Proofs.iso_info) list ->
18      theory ->
19      {
20        map_consts : term list,
21        map_apply_thms : thm list,
22        map_unfold_thms : thm list,
23        map_cont_thm : thm,
24        deflation_map_thms : thm list
25      }
26      * theory
27
28  val domain_isomorphism_cmd :
29    (string list * binding * mixfix * string * (binding * binding) option) list
30      -> theory -> theory
31end
32
33structure Domain_Isomorphism : DOMAIN_ISOMORPHISM =
34struct
35
36val beta_ss =
37  simpset_of (put_simpset HOL_basic_ss @{context}
38    addsimps @{thms simp_thms} addsimprocs [@{simproc beta_cfun_proc}])
39
40fun is_cpo thy T = Sign.of_sort thy (T, @{sort cpo})
41
42
43(******************************************************************************)
44(************************** building types and terms **************************)
45(******************************************************************************)
46
47open HOLCF_Library
48
49infixr 6 ->>
50infixr -->>
51
52val udomT = @{typ udom}
53val deflT = @{typ "udom defl"}
54val udeflT = @{typ "udom u defl"}
55
56fun mk_DEFL T =
57  Const (@{const_name defl}, Term.itselfT T --> deflT) $ Logic.mk_type T
58
59fun dest_DEFL (Const (@{const_name defl}, _) $ t) = Logic.dest_type t
60  | dest_DEFL t = raise TERM ("dest_DEFL", [t])
61
62fun mk_LIFTDEFL T =
63  Const (@{const_name liftdefl}, Term.itselfT T --> udeflT) $ Logic.mk_type T
64
65fun dest_LIFTDEFL (Const (@{const_name liftdefl}, _) $ t) = Logic.dest_type t
66  | dest_LIFTDEFL t = raise TERM ("dest_LIFTDEFL", [t])
67
68fun mk_u_defl t = mk_capply (@{const "u_defl"}, t)
69
70fun emb_const T = Const (@{const_name emb}, T ->> udomT)
71fun prj_const T = Const (@{const_name prj}, udomT ->> T)
72fun coerce_const (T, U) = mk_cfcomp (prj_const U, emb_const T)
73
74fun isodefl_const T =
75  Const (@{const_name isodefl}, (T ->> T) --> deflT --> HOLogic.boolT)
76
77fun isodefl'_const T =
78  Const (@{const_name isodefl'}, (T ->> T) --> udeflT --> HOLogic.boolT)
79
80fun mk_deflation t =
81  Const (@{const_name deflation}, Term.fastype_of t --> boolT) $ t
82
83(* splits a cterm into the right and lefthand sides of equality *)
84fun dest_eqs t = HOLogic.dest_eq (HOLogic.dest_Trueprop t)
85
86fun mk_eqs (t, u) = HOLogic.mk_Trueprop (HOLogic.mk_eq (t, u))
87
88(******************************************************************************)
89(****************************** isomorphism info ******************************)
90(******************************************************************************)
91
92fun deflation_abs_rep (info : Domain_Take_Proofs.iso_info) : thm =
93  let
94    val abs_iso = #abs_inverse info
95    val rep_iso = #rep_inverse info
96    val thm = @{thm deflation_abs_rep} OF [abs_iso, rep_iso]
97  in
98    Drule.zero_var_indexes thm
99  end
100
101(******************************************************************************)
102(*************** fixed-point definitions and unfolding theorems ***************)
103(******************************************************************************)
104
105fun mk_projs []      _ = []
106  | mk_projs (x::[]) t = [(x, t)]
107  | mk_projs (x::xs) t = (x, mk_fst t) :: mk_projs xs (mk_snd t)
108
109fun add_fixdefs
110    (spec : (binding * term) list)
111    (thy : theory) : (thm list * thm list * thm) * theory =
112  let
113    val binds = map fst spec
114    val (lhss, rhss) = ListPair.unzip (map (dest_eqs o snd) spec)
115    val functional = lambda_tuple lhss (mk_tuple rhss)
116    val fixpoint = mk_fix (mk_cabs functional)
117
118    (* project components of fixpoint *)
119    val projs = mk_projs lhss fixpoint
120
121    (* convert parameters to lambda abstractions *)
122    fun mk_eqn (lhs, rhs) =
123        case lhs of
124          Const (@{const_name Rep_cfun}, _) $ f $ (x as Free _) =>
125            mk_eqn (f, big_lambda x rhs)
126        | f $ Const (@{const_name Pure.type}, T) =>
127            mk_eqn (f, Abs ("t", T, rhs))
128        | Const _ => Logic.mk_equals (lhs, rhs)
129        | _ => raise TERM ("lhs not of correct form", [lhs, rhs])
130    val eqns = map mk_eqn projs
131
132    (* register constant definitions *)
133    val (fixdef_thms, thy) =
134      (Global_Theory.add_defs false o map Thm.no_attributes)
135        (map Thm.def_binding binds ~~ eqns) thy
136
137    (* prove applied version of definitions *)
138    fun prove_proj (lhs, rhs) =
139      let
140        fun tac ctxt = rewrite_goals_tac ctxt fixdef_thms THEN
141          (simp_tac (put_simpset beta_ss ctxt)) 1
142        val goal = Logic.mk_equals (lhs, rhs)
143      in Goal.prove_global thy [] [] goal (tac o #context) end
144    val proj_thms = map prove_proj projs
145
146    (* mk_tuple lhss == fixpoint *)
147    fun pair_equalI (thm1, thm2) = @{thm Pair_equalI} OF [thm1, thm2]
148    val tuple_fixdef_thm = foldr1 pair_equalI proj_thms
149
150    val cont_thm =
151      let
152        val prop = mk_trp (mk_cont functional)
153        val rules = Named_Theorems.get (Proof_Context.init_global thy) @{named_theorems cont2cont}
154        fun tac ctxt = REPEAT_ALL_NEW (match_tac ctxt (rev rules)) 1
155      in
156        Goal.prove_global thy [] [] prop (tac o #context)
157      end
158
159    val tuple_unfold_thm =
160      (@{thm def_cont_fix_eq} OF [tuple_fixdef_thm, cont_thm])
161      |> Local_Defs.unfold (Proof_Context.init_global thy) @{thms split_conv}
162
163    fun mk_unfold_thms [] _ = []
164      | mk_unfold_thms (n::[]) thm = [(n, thm)]
165      | mk_unfold_thms (n::ns) thm = let
166          val thmL = thm RS @{thm Pair_eqD1}
167          val thmR = thm RS @{thm Pair_eqD2}
168        in (n, thmL) :: mk_unfold_thms ns thmR end
169    val unfold_binds = map (Binding.suffix_name "_unfold") binds
170
171    (* register unfold theorems *)
172    val (unfold_thms, thy) =
173      (Global_Theory.add_thms o map (Thm.no_attributes o apsnd Drule.zero_var_indexes))
174        (mk_unfold_thms unfold_binds tuple_unfold_thm) thy
175  in
176    ((proj_thms, unfold_thms, cont_thm), thy)
177  end
178
179
180(******************************************************************************)
181(****************** deflation combinators and map functions *******************)
182(******************************************************************************)
183
184fun defl_of_typ
185    (thy : theory)
186    (tab1 : (typ * term) list)
187    (tab2 : (typ * term) list)
188    (T : typ) : term =
189  let
190    val defl_simps =
191      Named_Theorems.get (Proof_Context.init_global thy) @{named_theorems domain_defl_simps}
192    val rules = map (Thm.concl_of #> HOLogic.dest_Trueprop #> HOLogic.dest_eq) (rev defl_simps)
193    val rules' = map (apfst mk_DEFL) tab1 @ map (apfst mk_LIFTDEFL) tab2
194    fun proc1 t =
195      (case dest_DEFL t of
196        TFree (a, _) => SOME (Free ("d" ^ Library.unprefix "'" a, deflT))
197      | _ => NONE) handle TERM _ => NONE
198    fun proc2 t =
199      (case dest_LIFTDEFL t of
200        TFree (a, _) => SOME (Free ("p" ^ Library.unprefix "'" a, udeflT))
201      | _ => NONE) handle TERM _ => NONE
202  in
203    Pattern.rewrite_term thy (rules @ rules') [proc1, proc2] (mk_DEFL T)
204  end
205
206(******************************************************************************)
207(********************* declaring definitions and theorems *********************)
208(******************************************************************************)
209
210fun define_const
211    (bind : binding, rhs : term)
212    (thy : theory)
213    : (term * thm) * theory =
214  let
215    val typ = Term.fastype_of rhs
216    val (const, thy) = Sign.declare_const_global ((bind, typ), NoSyn) thy
217    val eqn = Logic.mk_equals (const, rhs)
218    val def = Thm.no_attributes (Thm.def_binding bind, eqn)
219    val (def_thm, thy) = yield_singleton (Global_Theory.add_defs false) def thy
220  in
221    ((const, def_thm), thy)
222  end
223
224fun add_qualified_thm name (dbind, thm) =
225    yield_singleton Global_Theory.add_thms
226      ((Binding.qualify_name true dbind name, thm), [])
227
228(******************************************************************************)
229(*************************** defining map functions ***************************)
230(******************************************************************************)
231
232fun define_map_functions
233    (spec : (binding * Domain_Take_Proofs.iso_info) list)
234    (thy : theory) =
235  let
236
237    (* retrieve components of spec *)
238    val dbinds = map fst spec
239    val iso_infos = map snd spec
240    val dom_eqns = map (fn x => (#absT x, #repT x)) iso_infos
241    val rep_abs_consts = map (fn x => (#rep_const x, #abs_const x)) iso_infos
242
243    fun mapT (T as Type (_, Ts)) =
244        (map (fn T => T ->> T) (filter (is_cpo thy) Ts)) -->> (T ->> T)
245      | mapT T = T ->> T
246
247    (* declare map functions *)
248    fun declare_map_const (tbind, (lhsT, _)) thy =
249      let
250        val map_type = mapT lhsT
251        val map_bind = Binding.suffix_name "_map" tbind
252      in
253        Sign.declare_const_global ((map_bind, map_type), NoSyn) thy
254      end
255    val (map_consts, thy) = thy |>
256      fold_map declare_map_const (dbinds ~~ dom_eqns)
257
258    (* defining equations for map functions *)
259    local
260      fun unprime a = Library.unprefix "'" a
261      fun mapvar T = Free (unprime (fst (dest_TFree T)), T ->> T)
262      fun map_lhs (map_const, lhsT) =
263          (lhsT, list_ccomb (map_const, map mapvar (filter (is_cpo thy) (snd (dest_Type lhsT)))))
264      val tab1 = map map_lhs (map_consts ~~ map fst dom_eqns)
265      val Ts = (snd o dest_Type o fst o hd) dom_eqns
266      val tab = (Ts ~~ map mapvar Ts) @ tab1
267      fun mk_map_spec (((rep_const, abs_const), _), (lhsT, rhsT)) =
268        let
269          val lhs = Domain_Take_Proofs.map_of_typ thy tab lhsT
270          val body = Domain_Take_Proofs.map_of_typ thy tab rhsT
271          val rhs = mk_cfcomp (abs_const, mk_cfcomp (body, rep_const))
272        in mk_eqs (lhs, rhs) end
273    in
274      val map_specs =
275          map mk_map_spec (rep_abs_consts ~~ map_consts ~~ dom_eqns)
276    end
277
278    (* register recursive definition of map functions *)
279    val map_binds = map (Binding.suffix_name "_map") dbinds
280    val ((map_apply_thms, map_unfold_thms, map_cont_thm), thy) =
281      add_fixdefs (map_binds ~~ map_specs) thy
282
283    (* prove deflation theorems for map functions *)
284    val deflation_abs_rep_thms = map deflation_abs_rep iso_infos
285    val deflation_map_thm =
286      let
287        fun unprime a = Library.unprefix "'" a
288        fun mk_f T = Free (unprime (fst (dest_TFree T)), T ->> T)
289        fun mk_assm T = mk_trp (mk_deflation (mk_f T))
290        fun mk_goal (map_const, (lhsT, _)) =
291          let
292            val (_, Ts) = dest_Type lhsT
293            val map_term = list_ccomb (map_const, map mk_f (filter (is_cpo thy) Ts))
294          in mk_deflation map_term end
295        val assms = (map mk_assm o filter (is_cpo thy) o snd o dest_Type o fst o hd) dom_eqns
296        val goals = map mk_goal (map_consts ~~ dom_eqns)
297        val goal = mk_trp (foldr1 HOLogic.mk_conj goals)
298        val adm_rules =
299          @{thms adm_conj adm_subst [OF _ adm_deflation]
300                 cont2cont_fst cont2cont_snd cont_id}
301        val bottom_rules =
302          @{thms fst_strict snd_strict deflation_bottom simp_thms}
303        val tuple_rules =
304          @{thms split_def fst_conv snd_conv}
305        val deflation_rules =
306          @{thms conjI deflation_ID}
307          @ deflation_abs_rep_thms
308          @ Domain_Take_Proofs.get_deflation_thms thy
309      in
310        Goal.prove_global thy [] assms goal (fn {prems, context = ctxt} =>
311         EVERY
312          [rewrite_goals_tac ctxt map_apply_thms,
313           resolve_tac ctxt [map_cont_thm RS @{thm cont_fix_ind}] 1,
314           REPEAT (resolve_tac ctxt adm_rules 1),
315           simp_tac (put_simpset HOL_basic_ss ctxt addsimps bottom_rules) 1,
316           simp_tac (put_simpset HOL_basic_ss ctxt addsimps tuple_rules) 1,
317           REPEAT (eresolve_tac ctxt @{thms conjE} 1),
318           REPEAT (resolve_tac ctxt (deflation_rules @ prems) 1 ORELSE assume_tac ctxt 1)])
319      end
320    fun conjuncts [] _ = []
321      | conjuncts (n::[]) thm = [(n, thm)]
322      | conjuncts (n::ns) thm = let
323          val thmL = thm RS @{thm conjunct1}
324          val thmR = thm RS @{thm conjunct2}
325        in (n, thmL):: conjuncts ns thmR end
326    val deflation_map_binds = dbinds |>
327        map (Binding.prefix_name "deflation_" o Binding.suffix_name "_map")
328    val (deflation_map_thms, thy) = thy |>
329      (Global_Theory.add_thms o map (Thm.no_attributes o apsnd Drule.zero_var_indexes))
330        (conjuncts deflation_map_binds deflation_map_thm)
331
332    (* register indirect recursion in theory data *)
333    local
334      fun register_map (dname, args) =
335        Domain_Take_Proofs.add_rec_type (dname, args)
336      val dnames = map (fst o dest_Type o fst) dom_eqns
337      fun args (T, _) = case T of Type (_, Ts) => map (is_cpo thy) Ts | _ => []
338      val argss = map args dom_eqns
339    in
340      val thy =
341          fold register_map (dnames ~~ argss) thy
342    end
343
344    (* register deflation theorems *)
345    val thy = fold Domain_Take_Proofs.add_deflation_thm deflation_map_thms thy
346
347    val result =
348      {
349        map_consts = map_consts,
350        map_apply_thms = map_apply_thms,
351        map_unfold_thms = map_unfold_thms,
352        map_cont_thm = map_cont_thm,
353        deflation_map_thms = deflation_map_thms
354      }
355  in
356    (result, thy)
357  end
358
359(******************************************************************************)
360(******************************* main function ********************************)
361(******************************************************************************)
362
363fun read_typ thy str sorts =
364  let
365    val ctxt = Proof_Context.init_global thy
366      |> fold (Variable.declare_typ o TFree) sorts
367    val T = Syntax.read_typ ctxt str
368  in (T, Term.add_tfreesT T sorts) end
369
370fun cert_typ sign raw_T sorts =
371  let
372    val T = Type.no_tvars (Sign.certify_typ sign raw_T)
373      handle TYPE (msg, _, _) => error msg
374    val sorts' = Term.add_tfreesT T sorts
375    val _ =
376      case duplicates (op =) (map fst sorts') of
377        [] => ()
378      | dups => error ("Inconsistent sort constraints for " ^ commas dups)
379  in (T, sorts') end
380
381fun gen_domain_isomorphism
382    (prep_typ: theory -> 'a -> (string * sort) list -> typ * (string * sort) list)
383    (doms_raw: (string list * binding * mixfix * 'a * (binding * binding) option) list)
384    (thy: theory)
385    : (Domain_Take_Proofs.iso_info list
386       * Domain_Take_Proofs.take_induct_info) * theory =
387  let
388    (* this theory is used just for parsing *)
389    val tmp_thy = thy |>
390      Sign.add_types_global (map (fn (tvs, tbind, mx, _, _) =>
391        (tbind, length tvs, mx)) doms_raw)
392
393    fun prep_dom thy (vs, t, mx, typ_raw, morphs) sorts =
394      let val (typ, sorts') = prep_typ thy typ_raw sorts
395      in ((vs, t, mx, typ, morphs), sorts') end
396
397    val (doms : (string list * binding * mixfix * typ * (binding * binding) option) list,
398         sorts : (string * sort) list) =
399      fold_map (prep_dom tmp_thy) doms_raw []
400
401    (* lookup function for sorts of type variables *)
402    fun the_sort v = the (AList.lookup (op =) sorts v)
403
404    (* declare arities in temporary theory *)
405    val tmp_thy =
406      let
407        fun arity (vs, tbind, _, _, _) =
408          (Sign.full_name thy tbind, map the_sort vs, @{sort "domain"})
409      in
410        fold Axclass.arity_axiomatization (map arity doms) tmp_thy
411      end
412
413    (* check bifiniteness of right-hand sides *)
414    fun check_rhs (_, _, _, rhs, _) =
415      if Sign.of_sort tmp_thy (rhs, @{sort "domain"}) then ()
416      else error ("Type not of sort domain: " ^
417        quote (Syntax.string_of_typ_global tmp_thy rhs))
418    val _ = map check_rhs doms
419
420    (* domain equations *)
421    fun mk_dom_eqn (vs, tbind, _, rhs, _) =
422      let fun arg v = TFree (v, the_sort v)
423      in (Type (Sign.full_name tmp_thy tbind, map arg vs), rhs) end
424    val dom_eqns = map mk_dom_eqn doms
425
426    (* check for valid type parameters *)
427    val (tyvars, _, _, _, _) = hd doms
428    val _ = map (fn (tvs, tname, _, _, _) =>
429      let val full_tname = Sign.full_name tmp_thy tname
430      in
431        (case duplicates (op =) tvs of
432          [] =>
433            if eq_set (op =) (tyvars, tvs) then (full_tname, tvs)
434            else error ("Mutually recursive domains must have same type parameters")
435        | dups => error ("Duplicate parameter(s) for domain " ^ Binding.print tname ^
436            " : " ^ commas dups))
437      end) doms
438    val dbinds = map (fn (_, dbind, _, _, _) => dbind) doms
439    val morphs = map (fn (_, _, _, _, morphs) => morphs) doms
440
441    (* determine deflation combinator arguments *)
442    val lhsTs : typ list = map fst dom_eqns
443    val defl_rec = Free ("t", mk_tupleT (map (K deflT) lhsTs))
444    val defl_recs = mk_projs lhsTs defl_rec
445    val defl_recs' = map (apsnd mk_u_defl) defl_recs
446    fun defl_body (_, _, _, rhsT, _) =
447      defl_of_typ tmp_thy defl_recs defl_recs' rhsT
448    val functional = Term.lambda defl_rec (mk_tuple (map defl_body doms))
449
450    val tfrees = map fst (Term.add_tfrees functional [])
451    val frees = map fst (Term.add_frees functional [])
452    fun get_defl_flags (vs, _, _, _, _) =
453      let
454        fun argT v = TFree (v, the_sort v)
455        fun mk_d v = "d" ^ Library.unprefix "'" v
456        fun mk_p v = "p" ^ Library.unprefix "'" v
457        val args = maps (fn v => [(mk_d v, mk_DEFL (argT v)), (mk_p v, mk_LIFTDEFL (argT v))]) vs
458        val typeTs = map argT (filter (member (op =) tfrees) vs)
459        val defl_args = map snd (filter (member (op =) frees o fst) args)
460      in
461        (typeTs, defl_args)
462      end
463    val defl_flagss = map get_defl_flags doms
464
465    (* declare deflation combinator constants *)
466    fun declare_defl_const ((typeTs, defl_args), (_, tbind, _, _, _)) thy =
467      let
468        val defl_bind = Binding.suffix_name "_defl" tbind
469        val defl_type =
470          map Term.itselfT typeTs ---> map fastype_of defl_args -->> deflT
471      in
472        Sign.declare_const_global ((defl_bind, defl_type), NoSyn) thy
473      end
474    val (defl_consts, thy) =
475      fold_map declare_defl_const (defl_flagss ~~ doms) thy
476
477    (* defining equations for type combinators *)
478    fun mk_defl_term (defl_const, (typeTs, defl_args)) =
479      let
480        val type_args = map Logic.mk_type typeTs
481      in
482        list_ccomb (list_comb (defl_const, type_args), defl_args)
483      end
484    val defl_terms = map mk_defl_term (defl_consts ~~ defl_flagss)
485    val defl_tab = map fst dom_eqns ~~ defl_terms
486    val defl_tab' = map fst dom_eqns ~~ map mk_u_defl defl_terms
487    fun mk_defl_spec (lhsT, rhsT) =
488      mk_eqs (defl_of_typ tmp_thy defl_tab defl_tab' lhsT,
489              defl_of_typ tmp_thy defl_tab defl_tab' rhsT)
490    val defl_specs = map mk_defl_spec dom_eqns
491
492    (* register recursive definition of deflation combinators *)
493    val defl_binds = map (Binding.suffix_name "_defl") dbinds
494    val ((defl_apply_thms, defl_unfold_thms, defl_cont_thm), thy) =
495      add_fixdefs (defl_binds ~~ defl_specs) thy
496
497    (* define types using deflation combinators *)
498    fun make_repdef ((vs, tbind, mx, _, _), defl) thy =
499      let
500        val spec = (tbind, map (rpair dummyS) vs, mx)
501        val ((_, _, _, {DEFL, ...}), thy) =
502          Domaindef.add_domaindef spec defl NONE thy
503        (* declare domain_defl_simps rules *)
504        val thy =
505          Context.theory_map (Named_Theorems.add_thm @{named_theorems domain_defl_simps} DEFL) thy
506      in
507        (DEFL, thy)
508      end
509    val (DEFL_thms, thy) = fold_map make_repdef (doms ~~ defl_terms) thy
510
511    (* prove DEFL equations *)
512    fun mk_DEFL_eq_thm (lhsT, rhsT) =
513      let
514        val goal = mk_eqs (mk_DEFL lhsT, mk_DEFL rhsT)
515        val DEFL_simps =
516          Named_Theorems.get (Proof_Context.init_global thy) @{named_theorems domain_defl_simps}
517        fun tac ctxt =
518          rewrite_goals_tac ctxt (map mk_meta_eq (rev DEFL_simps))
519          THEN TRY (resolve_tac ctxt defl_unfold_thms 1)
520      in
521        Goal.prove_global thy [] [] goal (tac o #context)
522      end
523    val DEFL_eq_thms = map mk_DEFL_eq_thm dom_eqns
524
525    (* register DEFL equations *)
526    val DEFL_eq_binds = map (Binding.prefix_name "DEFL_eq_") dbinds
527    val (_, thy) = thy |>
528      (Global_Theory.add_thms o map Thm.no_attributes)
529        (DEFL_eq_binds ~~ DEFL_eq_thms)
530
531    (* define rep/abs functions *)
532    fun mk_rep_abs ((tbind, _), (lhsT, rhsT)) thy =
533      let
534        val rep_bind = Binding.suffix_name "_rep" tbind
535        val abs_bind = Binding.suffix_name "_abs" tbind
536        val ((rep_const, rep_def), thy) =
537            define_const (rep_bind, coerce_const (lhsT, rhsT)) thy
538        val ((abs_const, abs_def), thy) =
539            define_const (abs_bind, coerce_const (rhsT, lhsT)) thy
540      in
541        (((rep_const, abs_const), (rep_def, abs_def)), thy)
542      end
543    val ((rep_abs_consts, rep_abs_defs), thy) = thy
544      |> fold_map mk_rep_abs (dbinds ~~ morphs ~~ dom_eqns)
545      |>> ListPair.unzip
546
547    (* prove isomorphism and isodefl rules *)
548    fun mk_iso_thms ((tbind, DEFL_eq), (rep_def, abs_def)) thy =
549      let
550        fun make thm =
551            Drule.zero_var_indexes (thm OF [DEFL_eq, abs_def, rep_def])
552        val rep_iso_thm = make @{thm domain_rep_iso}
553        val abs_iso_thm = make @{thm domain_abs_iso}
554        val isodefl_thm = make @{thm isodefl_abs_rep}
555        val thy = thy
556          |> snd o add_qualified_thm "rep_iso" (tbind, rep_iso_thm)
557          |> snd o add_qualified_thm "abs_iso" (tbind, abs_iso_thm)
558          |> snd o add_qualified_thm "isodefl_abs_rep" (tbind, isodefl_thm)
559      in
560        (((rep_iso_thm, abs_iso_thm), isodefl_thm), thy)
561      end
562    val ((iso_thms, isodefl_abs_rep_thms), thy) =
563      thy
564      |> fold_map mk_iso_thms (dbinds ~~ DEFL_eq_thms ~~ rep_abs_defs)
565      |>> ListPair.unzip
566
567    (* collect info about rep/abs *)
568    val iso_infos : Domain_Take_Proofs.iso_info list =
569      let
570        fun mk_info (((lhsT, rhsT), (repC, absC)), (rep_iso, abs_iso)) =
571          {
572            repT = rhsT,
573            absT = lhsT,
574            rep_const = repC,
575            abs_const = absC,
576            rep_inverse = rep_iso,
577            abs_inverse = abs_iso
578          }
579      in
580        map mk_info (dom_eqns ~~ rep_abs_consts ~~ iso_thms)
581      end
582
583    (* definitions and proofs related to map functions *)
584    val (map_info, thy) =
585        define_map_functions (dbinds ~~ iso_infos) thy
586    val { map_consts, map_apply_thms, map_cont_thm, ...} = map_info
587
588    (* prove isodefl rules for map functions *)
589    val isodefl_thm =
590      let
591        fun unprime a = Library.unprefix "'" a
592        fun mk_d T = Free ("d" ^ unprime (fst (dest_TFree T)), deflT)
593        fun mk_p T = Free ("p" ^ unprime (fst (dest_TFree T)), udeflT)
594        fun mk_f T = Free ("f" ^ unprime (fst (dest_TFree T)), T ->> T)
595        fun mk_assm t =
596          case try dest_LIFTDEFL t of
597            SOME T => mk_trp (isodefl'_const T $ mk_f T $ mk_p T)
598          | NONE =>
599            let val T = dest_DEFL t
600            in mk_trp (isodefl_const T $ mk_f T $ mk_d T) end
601        fun mk_goal (map_const, (T, _)) =
602          let
603            val (_, Ts) = dest_Type T
604            val map_term = list_ccomb (map_const, map mk_f (filter (is_cpo thy) Ts))
605            val defl_term = defl_of_typ thy (Ts ~~ map mk_d Ts) (Ts ~~ map mk_p Ts) T
606          in isodefl_const T $ map_term $ defl_term end
607        val assms = (map mk_assm o snd o hd) defl_flagss
608        val goals = map mk_goal (map_consts ~~ dom_eqns)
609        val goal = mk_trp (foldr1 HOLogic.mk_conj goals)
610        val adm_rules =
611          @{thms adm_conj adm_isodefl cont2cont_fst cont2cont_snd cont_id}
612        val bottom_rules =
613          @{thms fst_strict snd_strict isodefl_bottom simp_thms}
614        val tuple_rules =
615          @{thms split_def fst_conv snd_conv}
616        val map_ID_thms = Domain_Take_Proofs.get_map_ID_thms thy
617        val map_ID_simps = map (fn th => th RS sym) map_ID_thms
618        val isodefl_rules =
619          @{thms conjI isodefl_ID_DEFL isodefl_LIFTDEFL}
620          @ isodefl_abs_rep_thms
621          @ rev (Named_Theorems.get (Proof_Context.init_global thy) @{named_theorems domain_isodefl})
622      in
623        Goal.prove_global thy [] assms goal (fn {prems, context = ctxt} =>
624         EVERY
625          [rewrite_goals_tac ctxt (defl_apply_thms @ map_apply_thms),
626           resolve_tac ctxt [@{thm cont_parallel_fix_ind} OF [defl_cont_thm, map_cont_thm]] 1,
627           REPEAT (resolve_tac ctxt adm_rules 1),
628           simp_tac (put_simpset HOL_basic_ss ctxt addsimps bottom_rules) 1,
629           simp_tac (put_simpset HOL_basic_ss ctxt addsimps tuple_rules) 1,
630           simp_tac (put_simpset HOL_basic_ss ctxt addsimps map_ID_simps) 1,
631           REPEAT (eresolve_tac ctxt @{thms conjE} 1),
632           REPEAT (resolve_tac ctxt (isodefl_rules @ prems) 1 ORELSE assume_tac ctxt 1)])
633      end
634    val isodefl_binds = map (Binding.prefix_name "isodefl_") dbinds
635    fun conjuncts [] _ = []
636      | conjuncts (n::[]) thm = [(n, thm)]
637      | conjuncts (n::ns) thm = let
638          val thmL = thm RS @{thm conjunct1}
639          val thmR = thm RS @{thm conjunct2}
640        in (n, thmL):: conjuncts ns thmR end
641    val (isodefl_thms, thy) = thy |>
642      (Global_Theory.add_thms o map (Thm.no_attributes o apsnd Drule.zero_var_indexes))
643        (conjuncts isodefl_binds isodefl_thm)
644    val thy =
645      fold (Context.theory_map o Named_Theorems.add_thm @{named_theorems domain_isodefl})
646        isodefl_thms thy
647
648    (* prove map_ID theorems *)
649    fun prove_map_ID_thm
650        (((map_const, (lhsT, _)), DEFL_thm), isodefl_thm) =
651      let
652        val Ts = snd (dest_Type lhsT)
653        fun is_cpo T = Sign.of_sort thy (T, @{sort cpo})
654        val lhs = list_ccomb (map_const, map mk_ID (filter is_cpo Ts))
655        val goal = mk_eqs (lhs, mk_ID lhsT)
656        fun tac ctxt = EVERY
657          [resolve_tac ctxt @{thms isodefl_DEFL_imp_ID} 1,
658           stac ctxt DEFL_thm 1,
659           resolve_tac ctxt [isodefl_thm] 1,
660           REPEAT (resolve_tac ctxt @{thms isodefl_ID_DEFL isodefl_LIFTDEFL} 1)]
661      in
662        Goal.prove_global thy [] [] goal (tac o #context)
663      end
664    val map_ID_binds = map (Binding.suffix_name "_map_ID") dbinds
665    val map_ID_thms =
666      map prove_map_ID_thm
667        (map_consts ~~ dom_eqns ~~ DEFL_thms ~~ isodefl_thms)
668    val (_, thy) = thy |>
669      (Global_Theory.add_thms o map (rpair [Domain_Take_Proofs.map_ID_add]))
670        (map_ID_binds ~~ map_ID_thms)
671
672    (* definitions and proofs related to take functions *)
673    val (take_info, thy) =
674        Domain_Take_Proofs.define_take_functions
675          (dbinds ~~ iso_infos) thy
676    val { take_consts, chain_take_thms, take_0_thms, take_Suc_thms, ...} =
677        take_info
678
679    (* least-upper-bound lemma for take functions *)
680    val lub_take_lemma =
681      let
682        val lhs = mk_tuple (map mk_lub take_consts)
683        fun is_cpo T = Sign.of_sort thy (T, @{sort cpo})
684        fun mk_map_ID (map_const, (lhsT, _)) =
685          list_ccomb (map_const, map mk_ID (filter is_cpo (snd (dest_Type lhsT))))
686        val rhs = mk_tuple (map mk_map_ID (map_consts ~~ dom_eqns))
687        val goal = mk_trp (mk_eq (lhs, rhs))
688        val map_ID_thms = Domain_Take_Proofs.get_map_ID_thms thy
689        val start_rules =
690            @{thms lub_Pair [symmetric] ch2ch_Pair} @ chain_take_thms
691            @ @{thms prod.collapse split_def}
692            @ map_apply_thms @ map_ID_thms
693        val rules0 =
694            @{thms iterate_0 Pair_strict} @ take_0_thms
695        val rules1 =
696            @{thms iterate_Suc prod_eq_iff fst_conv snd_conv}
697            @ take_Suc_thms
698        fun tac ctxt =
699            EVERY
700            [simp_tac (put_simpset HOL_basic_ss ctxt addsimps start_rules) 1,
701             simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms fix_def2}) 1,
702             resolve_tac ctxt @{thms lub_eq} 1,
703             resolve_tac ctxt @{thms nat.induct} 1,
704             simp_tac (put_simpset HOL_basic_ss ctxt addsimps rules0) 1,
705             asm_full_simp_tac (put_simpset beta_ss ctxt addsimps rules1) 1]
706      in
707        Goal.prove_global thy [] [] goal (tac o #context)
708      end
709
710    (* prove lub of take equals ID *)
711    fun prove_lub_take (((dbind, take_const), map_ID_thm), (lhsT, _)) thy =
712      let
713        val n = Free ("n", natT)
714        val goal = mk_eqs (mk_lub (lambda n (take_const $ n)), mk_ID lhsT)
715        fun tac ctxt =
716            EVERY
717            [resolve_tac ctxt @{thms trans} 1,
718             resolve_tac ctxt [map_ID_thm] 2,
719             cut_tac lub_take_lemma 1,
720             REPEAT (eresolve_tac ctxt @{thms Pair_inject} 1), assume_tac ctxt 1]
721        val lub_take_thm = Goal.prove_global thy [] [] goal (tac o #context)
722      in
723        add_qualified_thm "lub_take" (dbind, lub_take_thm) thy
724      end
725    val (lub_take_thms, thy) =
726        fold_map prove_lub_take
727          (dbinds ~~ take_consts ~~ map_ID_thms ~~ dom_eqns) thy
728
729    (* prove additional take theorems *)
730    val (take_info2, thy) =
731        Domain_Take_Proofs.add_lub_take_theorems
732          (dbinds ~~ iso_infos) take_info lub_take_thms thy
733  in
734    ((iso_infos, take_info2), thy)
735  end
736
737val domain_isomorphism = gen_domain_isomorphism cert_typ
738val domain_isomorphism_cmd = snd oo gen_domain_isomorphism read_typ
739
740(******************************************************************************)
741(******************************** outer syntax ********************************)
742(******************************************************************************)
743
744local
745
746val parse_domain_iso :
747    (string list * binding * mixfix * string * (binding * binding) option)
748      parser =
749  (Parse.type_args -- Parse.binding -- Parse.opt_mixfix -- (@{keyword "="} |-- Parse.typ) --
750    Scan.option (@{keyword "morphisms"} |-- Parse.!!! (Parse.binding -- Parse.binding)))
751    >> (fn ((((vs, t), mx), rhs), morphs) => (vs, t, mx, rhs, morphs))
752
753val parse_domain_isos = Parse.and_list1 parse_domain_iso
754
755in
756
757val _ =
758  Outer_Syntax.command @{command_keyword domain_isomorphism} "define domain isomorphisms (HOLCF)"
759    (parse_domain_isos >> (Toplevel.theory o domain_isomorphism_cmd))
760
761end
762
763end
764