1(*  Title:      HOL/Tools/Function/partial_function.ML
2    Author:     Alexander Krauss, TU Muenchen
3
4Partial function definitions based on least fixed points in ccpos.
5*)
6
7signature PARTIAL_FUNCTION =
8sig
9  val init: string -> term -> term -> thm -> thm -> thm option -> declaration
10  val mono_tac: Proof.context -> int -> tactic
11  val add_partial_function: string -> (binding * typ option * mixfix) list ->
12    Attrib.binding * term -> local_theory -> (term * thm) * local_theory
13  val add_partial_function_cmd: string -> (binding * string option * mixfix) list ->
14    Attrib.binding * string -> local_theory -> (term * thm) * local_theory
15end;
16
17structure Partial_Function: PARTIAL_FUNCTION =
18struct
19
20open Function_Lib
21
22
23(*** Context Data ***)
24
25datatype setup_data = Setup_Data of
26 {fixp: term,
27  mono: term,
28  fixp_eq: thm,
29  fixp_induct: thm,
30  fixp_induct_user: thm option};
31
32fun transform_setup_data phi (Setup_Data {fixp, mono, fixp_eq, fixp_induct, fixp_induct_user}) =
33  let
34    val term = Morphism.term phi;
35    val thm = Morphism.thm phi;
36  in
37    Setup_Data
38     {fixp = term fixp, mono = term mono, fixp_eq = thm fixp_eq,
39      fixp_induct = thm fixp_induct, fixp_induct_user = Option.map thm fixp_induct_user}
40  end;
41
42structure Modes = Generic_Data
43(
44  type T = setup_data Symtab.table;
45  val empty = Symtab.empty;
46  val extend = I;
47  fun merge data = Symtab.merge (K true) data;
48)
49
50fun init mode fixp mono fixp_eq fixp_induct fixp_induct_user phi =
51  let
52    val data' =
53      Setup_Data
54       {fixp = fixp, mono = mono, fixp_eq = fixp_eq,
55        fixp_induct = fixp_induct, fixp_induct_user = fixp_induct_user}
56      |> transform_setup_data (phi $> Morphism.trim_context_morphism);
57  in Modes.map (Symtab.update (mode, data')) end;
58
59val known_modes = Symtab.keys o Modes.get o Context.Proof;
60
61fun lookup_mode ctxt =
62  Symtab.lookup (Modes.get (Context.Proof ctxt))
63  #> Option.map (transform_setup_data (Morphism.transfer_morphism' ctxt));
64
65
66(*** Automated monotonicity proofs ***)
67
68(*rewrite conclusion with k-th assumtion*)
69fun rewrite_with_asm_tac ctxt k =
70  Subgoal.FOCUS (fn {context = ctxt', prems, ...} =>
71    Local_Defs.unfold0_tac ctxt' [nth prems k]) ctxt;
72
73fun dest_case ctxt t =
74  case strip_comb t of
75    (Const (case_comb, _), args) =>
76      (case Ctr_Sugar.ctr_sugar_of_case ctxt case_comb of
77         NONE => NONE
78       | SOME {case_thms, ...} =>
79           let
80             val lhs = Thm.prop_of (hd case_thms)
81               |> HOLogic.dest_Trueprop |> HOLogic.dest_eq |> fst;
82             val arity = length (snd (strip_comb lhs));
83             val conv = funpow (length args - arity) Conv.fun_conv
84               (Conv.rewrs_conv (map mk_meta_eq case_thms));
85           in
86             SOME (nth args (arity - 1), conv)
87           end)
88  | _ => NONE;
89
90(*split on case expressions*)
91val split_cases_tac = Subgoal.FOCUS_PARAMS (fn {context = ctxt, ...} =>
92  SUBGOAL (fn (t, i) => case t of
93    _ $ (_ $ Abs (_, _, body)) =>
94      (case dest_case ctxt body of
95         NONE => no_tac
96       | SOME (arg, conv) =>
97           let open Conv in
98              if Term.is_open arg then no_tac
99              else ((DETERM o Induct.cases_tac ctxt false [[SOME arg]] NONE [])
100                THEN_ALL_NEW (rewrite_with_asm_tac ctxt 0)
101                THEN_ALL_NEW eresolve_tac ctxt @{thms thin_rl}
102                THEN_ALL_NEW (CONVERSION
103                  (params_conv ~1 (fn ctxt' =>
104                    arg_conv (arg_conv (abs_conv (K conv) ctxt'))) ctxt))) i
105           end)
106  | _ => no_tac) 1);
107
108(*monotonicity proof: apply rules + split case expressions*)
109fun mono_tac ctxt =
110  K (Local_Defs.unfold0_tac ctxt [@{thm curry_def}])
111  THEN' (TRY o REPEAT_ALL_NEW
112   (resolve_tac ctxt (rev (Named_Theorems.get ctxt \<^named_theorems>\<open>partial_function_mono\<close>))
113     ORELSE' split_cases_tac ctxt));
114
115
116(*** Auxiliary functions ***)
117
118(*Returns t $ u, but instantiates the type of t to make the
119application type correct*)
120fun apply_inst ctxt t u =
121  let
122    val thy = Proof_Context.theory_of ctxt;
123    val T = domain_type (fastype_of t);
124    val T' = fastype_of u;
125    val subst = Sign.typ_match thy (T, T') Vartab.empty
126      handle Type.TYPE_MATCH => raise TYPE ("apply_inst", [T, T'], [t, u])
127  in
128    map_types (Envir.norm_type subst) t $ u
129  end;
130
131fun head_conv cv ct =
132  if can Thm.dest_comb ct then Conv.fun_conv (head_conv cv) ct else cv ct;
133
134
135(*** currying transformation ***)
136
137fun curry_const (A, B, C) =
138  Const (\<^const_name>\<open>Product_Type.curry\<close>,
139    [HOLogic.mk_prodT (A, B) --> C, A, B] ---> C);
140
141fun mk_curry f =
142  case fastype_of f of
143    Type ("fun", [Type (_, [S, T]), U]) =>
144      curry_const (S, T, U) $ f
145  | T => raise TYPE ("mk_curry", [T], [f]);
146
147(* iterated versions. Nonstandard left-nested tuples arise naturally
148from "split o split o split"*)
149fun curry_n arity = funpow (arity - 1) mk_curry;
150fun uncurry_n arity = funpow (arity - 1) HOLogic.mk_case_prod;
151
152val curry_uncurry_ss =
153  simpset_of (put_simpset HOL_basic_ss @{context}
154    addsimps [@{thm Product_Type.curry_case_prod}, @{thm Product_Type.case_prod_curry}])
155
156val split_conv_ss =
157  simpset_of (put_simpset HOL_basic_ss @{context}
158    addsimps [@{thm Product_Type.split_conv}]);
159
160val curry_K_ss =
161  simpset_of (put_simpset HOL_basic_ss @{context}
162    addsimps [@{thm Product_Type.curry_K}]);
163
164(* instantiate generic fixpoint induction and eliminate the canonical assumptions;
165  curry induction predicate *)
166fun specialize_fixp_induct ctxt args fT fT_uc F curry uncurry mono_thm f_def rule =
167  let
168    val ([P], ctxt') = Variable.variant_fixes ["P"] ctxt
169    val P_inst = Abs ("f", fT_uc, Free (P, fT --> HOLogic.boolT) $ (curry $ Bound 0))
170  in
171    (* FIXME ctxt vs. ctxt' (!?) *)
172    rule
173    |> infer_instantiate' ctxt
174      ((map o Option.map) (Thm.cterm_of ctxt) [SOME uncurry, NONE, SOME curry, NONE, SOME P_inst])
175    |> Tactic.rule_by_tactic ctxt
176      (Simplifier.simp_tac (put_simpset curry_uncurry_ss ctxt) 3 (* discharge U (C f) = f *)
177       THEN Simplifier.simp_tac (put_simpset curry_K_ss ctxt) 4 (* simplify bot case *)
178       THEN Simplifier.full_simp_tac (put_simpset curry_uncurry_ss ctxt) 5) (* simplify induction step *)
179    |> (fn thm => thm OF [mono_thm, f_def])
180    |> Conv.fconv_rule (Conv.concl_conv ~1    (* simplify conclusion *)
181         (Raw_Simplifier.rewrite ctxt false [mk_meta_eq @{thm Product_Type.curry_case_prod}]))
182    |> singleton (Variable.export ctxt' ctxt)
183  end
184
185fun mk_curried_induct args ctxt inst_rule =
186  let
187    val cert = Thm.cterm_of ctxt
188    val ([P], ctxt') = Variable.variant_fixes ["P"] ctxt
189
190    val split_paired_all_conv =
191      Conv.every_conv (replicate (length args - 1) (Conv.rewr_conv @{thm split_paired_all}))
192
193    val split_params_conv =
194      Conv.params_conv ~1 (fn ctxt' =>
195        Conv.implies_conv split_paired_all_conv Conv.all_conv)
196
197    val (P_var, x_var) =
198       Thm.prop_of inst_rule |> Logic.strip_imp_concl |> HOLogic.dest_Trueprop
199      |> strip_comb |> apsnd hd
200      |> apply2 dest_Var
201    val P_rangeT = range_type (snd P_var)
202    val PT = map (snd o dest_Free) args ---> P_rangeT
203    val x_inst = cert (foldl1 HOLogic.mk_prod args)
204    val P_inst = cert (uncurry_n (length args) (Free (P, PT)))
205
206    val inst_rule' = inst_rule
207      |> Tactic.rule_by_tactic ctxt
208        (Simplifier.simp_tac (put_simpset curry_uncurry_ss ctxt) 4
209         THEN Simplifier.simp_tac (put_simpset curry_uncurry_ss ctxt) 3
210         THEN CONVERSION (split_params_conv ctxt
211           then_conv (Conv.forall_conv (K split_paired_all_conv) ctxt)) 3)
212      |> Thm.instantiate ([], [(P_var, P_inst), (x_var, x_inst)])
213      |> Simplifier.full_simplify (put_simpset split_conv_ss ctxt)
214      |> singleton (Variable.export ctxt' ctxt)
215  in
216    inst_rule'
217  end;
218
219
220(*** partial_function definition ***)
221
222fun gen_add_partial_function prep mode fixes_raw eqn_raw lthy =
223  let
224    val setup_data = the (lookup_mode lthy mode)
225      handle Option.Option => error (cat_lines ["Unknown mode " ^ quote mode ^ ".",
226        "Known modes are " ^ commas_quote (known_modes lthy) ^ "."]);
227    val Setup_Data {fixp, mono, fixp_eq, fixp_induct, fixp_induct_user} = setup_data;
228
229    val ((fixes, [(eq_abinding, eqn)]), _) = prep fixes_raw [(eqn_raw, [], [])] lthy;
230    val ((_, plain_eqn), args_ctxt) = Variable.focus NONE eqn lthy;
231
232    val ((f_binding, fT), mixfix) = the_single fixes;
233    val f_bname = Binding.name_of f_binding;
234
235    fun note_qualified (name, thms) =
236      Local_Theory.note ((derived_name f_binding name, []), thms) #> snd
237
238    val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop plain_eqn);
239    val (head, args) = strip_comb lhs;
240    val argnames = map (fst o dest_Free) args;
241    val F = fold_rev lambda (head :: args) rhs;
242
243    val arity = length args;
244    val (aTs, bTs) = chop arity (binder_types fT);
245
246    val tupleT = foldl1 HOLogic.mk_prodT aTs;
247    val fT_uc = tupleT :: bTs ---> body_type fT;
248    val f_uc = Var ((f_bname, 0), fT_uc);
249    val x_uc = Var (("x", 1), tupleT);
250    val uncurry = lambda head (uncurry_n arity head);
251    val curry = lambda f_uc (curry_n arity f_uc);
252
253    val F_uc =
254      lambda f_uc (uncurry_n arity (F $ curry_n arity f_uc));
255
256    val mono_goal = apply_inst lthy mono (lambda f_uc (F_uc $ f_uc $ x_uc))
257      |> HOLogic.mk_Trueprop
258      |> Logic.all x_uc;
259
260    val mono_thm = Goal.prove_internal lthy [] (Thm.cterm_of lthy mono_goal)
261        (K (mono_tac lthy 1))
262    val inst_mono_thm = Thm.forall_elim (Thm.cterm_of lthy x_uc) mono_thm
263
264    val f_def_rhs = curry_n arity (apply_inst lthy fixp F_uc);
265    val f_def_binding =
266      Thm.make_def_binding (Config.get lthy Function_Lib.function_internals) f_binding
267    val ((f, (_, f_def)), lthy') =
268      Local_Theory.define ((f_binding, mixfix), ((f_def_binding, []), f_def_rhs)) lthy;
269
270    val eqn = HOLogic.mk_eq (list_comb (f, args),
271        Term.betapplys (F, f :: args))
272      |> HOLogic.mk_Trueprop;
273
274    val unfold =
275      (infer_instantiate' lthy' (map (SOME o Thm.cterm_of lthy') [uncurry, F, curry]) fixp_eq
276        OF [inst_mono_thm, f_def])
277      |> Tactic.rule_by_tactic lthy' (Simplifier.simp_tac (put_simpset curry_uncurry_ss lthy') 1);
278
279    val specialized_fixp_induct =
280      specialize_fixp_induct lthy' args fT fT_uc F curry uncurry inst_mono_thm f_def fixp_induct
281      |> Drule.rename_bvars' (map SOME (f_bname :: f_bname :: argnames));
282
283    val mk_raw_induct =
284      infer_instantiate' args_ctxt
285        ((map o Option.map) (Thm.cterm_of args_ctxt) [SOME uncurry, NONE, SOME curry])
286      #> mk_curried_induct args args_ctxt
287      #> singleton (Variable.export args_ctxt lthy')
288      #> (fn thm => infer_instantiate' lthy'
289          [SOME (Thm.cterm_of lthy' F)] thm OF [inst_mono_thm, f_def])
290      #> Drule.rename_bvars' (map SOME (f_bname :: argnames @ argnames))
291
292    val raw_induct = Option.map mk_raw_induct fixp_induct_user
293    val rec_rule =
294      let open Conv in
295        Goal.prove lthy' (map (fst o dest_Free) args) [] eqn (fn _ =>
296          CONVERSION ((arg_conv o arg1_conv o head_conv o rewr_conv) (mk_meta_eq unfold)) 1
297          THEN resolve_tac lthy' @{thms refl} 1)
298      end;
299    val ((_, [rec_rule']), lthy'') = lthy' |> Local_Theory.note (eq_abinding, [rec_rule])
300  in
301    lthy''
302    |> Spec_Rules.add Spec_Rules.Equational ([f], [rec_rule'])
303    |> note_qualified ("simps", [rec_rule'])
304    |> note_qualified ("mono", [mono_thm])
305    |> (case raw_induct of NONE => I | SOME thm => note_qualified ("raw_induct", [thm]))
306    |> note_qualified ("fixp_induct", [specialized_fixp_induct])
307    |> pair (f, rec_rule')
308  end;
309
310val add_partial_function = gen_add_partial_function Specification.check_multi_specs;
311val add_partial_function_cmd = gen_add_partial_function Specification.read_multi_specs;
312
313val mode = \<^keyword>\<open>(\<close> |-- Parse.name --| \<^keyword>\<open>)\<close>;
314
315val _ =
316  Outer_Syntax.local_theory \<^command_keyword>\<open>partial_function\<close> "define partial function"
317    ((mode -- (Parse.vars -- (Parse.where_ |-- Parse_Spec.simple_spec)))
318      >> (fn (mode, (vars, spec)) => add_partial_function_cmd mode vars spec #> #2));
319
320end;
321