1(*  Title:      HOL/Tools/Function/fun.ML
2    Author:     Alexander Krauss, TU Muenchen
3
4Command "fun": Function definitions with pattern splitting/completion
5and automated termination proofs.
6*)
7
8signature FUNCTION_FUN =
9sig
10  val fun_config : Function_Common.function_config
11  val add_fun : (binding * typ option * mixfix) list ->
12    Specification.multi_specs -> Function_Common.function_config ->
13    local_theory -> Proof.context
14  val add_fun_cmd : (binding * string option * mixfix) list ->
15    Specification.multi_specs_cmd -> Function_Common.function_config ->
16    bool -> local_theory -> Proof.context
17end
18
19structure Function_Fun : FUNCTION_FUN =
20struct
21
22open Function_Lib
23open Function_Common
24
25
26fun check_pats ctxt geq =
27  let
28    fun err str = error (cat_lines ["Malformed definition:",
29      str ^ " not allowed in sequential mode.",
30      Syntax.string_of_term ctxt geq])
31
32    fun check_constr_pattern (Bound _) = ()
33      | check_constr_pattern t =
34      let
35        val (hd, args) = strip_comb t
36      in
37        (case hd of
38          Const (hd_s, hd_T) =>
39          (case body_type hd_T of
40            Type (Tname, _) =>
41            (case Ctr_Sugar.ctr_sugar_of ctxt Tname of
42              SOME {ctrs, ...} => exists (fn Const (s, _) => s = hd_s) ctrs
43            | NONE => false)
44          | _ => false)
45        | _ => false) orelse err "Non-constructor pattern";
46        map check_constr_pattern args;
47        ()
48      end
49
50    val (_, qs, gs, args, _) = split_def ctxt (K true) geq
51
52    val _ = if not (null gs) then err "Conditional equations" else ()
53    val _ = map check_constr_pattern args
54
55    (* just count occurrences to check linearity *)
56    val _ = if fold (fold_aterms (fn Bound _ => Integer.add 1 | _ => I)) args 0 > length qs
57      then err "Nonlinear patterns" else ()
58  in
59    ()
60  end
61
62fun mk_catchall fixes arity_of =
63  let
64    fun mk_eqn ((fname, fT), _) =
65      let
66        val n = arity_of fname
67        val (argTs, rT) = chop n (binder_types fT)
68          |> apsnd (fn Ts => Ts ---> body_type fT)
69
70        val qs = map Free (Name.invent Name.context "a" n ~~ argTs)
71      in
72        HOLogic.mk_eq(list_comb (Free (fname, fT), qs),
73          Const (\<^const_name>\<open>undefined\<close>, rT))
74        |> HOLogic.mk_Trueprop
75        |> fold_rev Logic.all qs
76      end
77  in
78    map mk_eqn fixes
79  end
80
81fun add_catchall ctxt fixes spec =
82  let val fqgars = map (split_def ctxt (K true)) spec
83      val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars
84                     |> AList.lookup (op =) #> the
85  in
86    spec @ mk_catchall fixes arity_of
87  end
88
89fun further_checks ctxt origs tss =
90  let
91    fun fail_redundant t =
92      error (cat_lines ["Equation is redundant (covered by preceding clauses):", Syntax.string_of_term ctxt t])
93    fun warn_missing strs =
94      warning (cat_lines ("Missing patterns in function definition:" :: strs))
95
96    val (tss', added) = chop (length origs) tss
97
98    val _ = case chop 3 (flat added) of
99       ([], []) => ()
100     | (eqs, []) => warn_missing (map (Syntax.string_of_term ctxt) eqs)
101     | (eqs, rest) => warn_missing (map (Syntax.string_of_term ctxt) eqs
102         @ ["(" ^ string_of_int (length rest) ^ " more)"])
103
104    val _ = (origs ~~ tss')
105      |> map (fn (t, ts) => if null ts then fail_redundant t else ())
106  in
107    ()
108  end
109
110fun sequential_preproc (config as FunctionConfig {sequential, ...}) ctxt fixes spec =
111  if sequential then
112    let
113      val (bnds, eqss) = split_list spec
114
115      val eqs = map the_single eqss
116
117      val feqs = eqs
118        |> tap (check_defs ctxt fixes) (* Standard checks *)
119        |> tap (map (check_pats ctxt)) (* More checks for sequential mode *)
120
121      val compleqs = add_catchall ctxt fixes feqs (* Completion *)
122
123      val spliteqs = Function_Split.split_all_equations ctxt compleqs
124        |> tap (further_checks ctxt feqs)
125
126      fun restore_spec thms =
127        bnds ~~ take (length bnds) (unflat spliteqs thms)
128
129      val spliteqs' = flat (take (length bnds) spliteqs)
130      val fnames = map (fst o fst) fixes
131      val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) spliteqs'
132
133      fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) (indices ~~ xs)
134        |> map (map snd)
135
136
137      val bnds' = bnds @ replicate (length spliteqs - length bnds) Binding.empty_atts
138
139      (* using theorem names for case name currently disabled *)
140      val case_names = map_index (fn (i, (_, es)) => mk_case_names i "" (length es)) 
141        (bnds' ~~ spliteqs) |> flat
142    in
143      (flat spliteqs, restore_spec, sort, case_names)
144    end
145  else
146    Function_Common.empty_preproc check_defs config ctxt fixes spec
147
148val _ = Theory.setup (Context.theory_map (Function_Common.set_preproc sequential_preproc))
149
150
151
152val fun_config = FunctionConfig { sequential=true, default=NONE,
153  domintros=false, partials=false }
154
155fun gen_add_fun add lthy =
156  let
157    fun pat_completeness_auto ctxt =
158      Pat_Completeness.pat_completeness_tac ctxt 1
159      THEN auto_tac ctxt
160    fun prove_termination lthy =
161      Function.prove_termination NONE (Function_Common.termination_prover_tac false lthy) lthy
162  in
163    lthy
164    |> add pat_completeness_auto |> snd
165    |> prove_termination |> snd
166  end
167
168fun add_fun a b c = gen_add_fun (Function.add_function a b c)
169fun add_fun_cmd a b c int = gen_add_fun (fn tac => Function.add_function_cmd a b c tac int)
170
171
172
173val _ =
174  Outer_Syntax.local_theory' \<^command_keyword>\<open>fun\<close>
175    "define general recursive functions (short version)"
176    (function_parser fun_config
177      >> (fn (config, (fixes, specs)) => add_fun_cmd fixes specs config))
178
179end
180