1(* Title: HOL/Tools/Old_Datatype/old_primrec.ML 2 Author: Norbert Voelker, FernUni Hagen 3 Author: Stefan Berghofer, TU Muenchen 4 Author: Florian Haftmann, TU Muenchen 5 6Primitive recursive functions on datatypes. 7*) 8 9signature OLD_PRIMREC = 10sig 11 val primrec: bool -> (binding * typ option * mixfix) list -> 12 Specification.multi_specs -> local_theory -> 13 {types: string list, result: term list * thm list} * local_theory 14 val primrec_cmd: bool -> (binding * string option * mixfix) list -> 15 Specification.multi_specs_cmd -> local_theory -> 16 {types: string list, result: term list * thm list} * local_theory 17 val primrec_global: bool -> (binding * typ option * mixfix) list -> 18 Specification.multi_specs -> theory -> (term list * thm list) * theory 19 val primrec_overloaded: bool -> (string * (string * typ) * bool) list -> 20 (binding * typ option * mixfix) list -> 21 Specification.multi_specs -> theory -> (term list * thm list) * theory 22 val primrec_simple: bool -> ((binding * typ) * mixfix) list -> term list -> local_theory -> 23 {prefix: string, types: string list, result: term list * thm list} * local_theory 24end; 25 26structure Old_Primrec : OLD_PRIMREC = 27struct 28 29exception PrimrecError of string * term option; 30 31fun primrec_error msg = raise PrimrecError (msg, NONE); 32fun primrec_error_eqn msg eqn = raise PrimrecError (msg, SOME eqn); 33 34 35(* preprocessing of equations *) 36 37fun process_eqn is_fixed spec rec_fns = 38 let 39 val (vs, Ts) = split_list (strip_qnt_vars \<^const_name>\<open>Pure.all\<close> spec); 40 val body = strip_qnt_body \<^const_name>\<open>Pure.all\<close> spec; 41 val (vs', _) = fold_map Name.variant vs (Name.make_context (fold_aterms 42 (fn Free (v, _) => insert (op =) v | _ => I) body [])); 43 val eqn = curry subst_bounds (map2 (curry Free) vs' Ts |> rev) body; 44 val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eqn) 45 handle TERM _ => primrec_error "not a proper equation"; 46 val (recfun, args) = strip_comb lhs; 47 val fname = 48 (case recfun of 49 Free (v, _) => 50 if is_fixed v then v 51 else primrec_error "illegal head of function equation" 52 | _ => primrec_error "illegal head of function equation"); 53 54 val (ls', rest) = chop_prefix is_Free args; 55 val (middle, rs') = chop_suffix is_Free rest; 56 val rpos = length ls'; 57 58 val (constr, cargs') = 59 if null middle then primrec_error "constructor missing" 60 else strip_comb (hd middle); 61 val (cname, T) = dest_Const constr 62 handle TERM _ => primrec_error "ill-formed constructor"; 63 val (tname, _) = dest_Type (body_type T) handle TYPE _ => 64 primrec_error "cannot determine datatype associated with function" 65 66 val (ls, cargs, rs) = 67 (map dest_Free ls', map dest_Free cargs', map dest_Free rs') 68 handle TERM _ => primrec_error "illegal argument in pattern"; 69 val lfrees = ls @ rs @ cargs; 70 71 fun check_vars _ [] = () 72 | check_vars s vars = primrec_error (s ^ commas_quote (map fst vars)) eqn; 73 in 74 if length middle > 1 then 75 primrec_error "more than one non-variable in pattern" 76 else 77 (check_vars "repeated variable names in pattern: " (duplicates (op =) lfrees); 78 check_vars "extra variables on rhs: " 79 (Term.add_frees rhs [] |> subtract (op =) lfrees 80 |> filter_out (is_fixed o fst)); 81 (case AList.lookup (op =) rec_fns fname of 82 NONE => 83 (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eqn))])) :: rec_fns 84 | SOME (_, rpos', eqns) => 85 if AList.defined (op =) eqns cname then 86 primrec_error "constructor already occurred as pattern" 87 else if rpos <> rpos' then 88 primrec_error "position of recursive argument inconsistent" 89 else 90 AList.update (op =) 91 (fname, (tname, rpos, (cname, (ls, cargs, rs, rhs, eqn)) :: eqns)) 92 rec_fns)) 93 end handle PrimrecError (msg, NONE) => primrec_error_eqn msg spec; 94 95fun process_fun descr eqns (i, fname) (fnames, fnss) = 96 let 97 val (_, (tname, _, constrs)) = nth descr i; 98 99 (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *) 100 101 fun subst [] t fs = (t, fs) 102 | subst subs (Abs (a, T, t)) fs = 103 fs 104 |> subst subs t 105 |-> (fn t' => pair (Abs (a, T, t'))) 106 | subst subs (t as (_ $ _)) fs = 107 let 108 val (f, ts) = strip_comb t; 109 in 110 if is_Free f 111 andalso member (fn ((v, _), (w, _)) => v = w) eqns (dest_Free f) then 112 let 113 val (fname', _) = dest_Free f; 114 val (_, rpos, _) = the (AList.lookup (op =) eqns fname'); 115 val (ls, rs) = chop rpos ts 116 val (x', rs') = 117 (case rs of 118 x' :: rs => (x', rs) 119 | [] => primrec_error ("not enough arguments in recursive application\n" ^ 120 "of function " ^ quote fname' ^ " on rhs")); 121 val (x, xs) = strip_comb x'; 122 in 123 (case AList.lookup (op =) subs x of 124 NONE => 125 fs 126 |> fold_map (subst subs) ts 127 |-> (fn ts' => pair (list_comb (f, ts'))) 128 | SOME (i', y) => 129 fs 130 |> fold_map (subst subs) (xs @ ls @ rs') 131 ||> process_fun descr eqns (i', fname') 132 |-> (fn ts' => pair (list_comb (y, ts')))) 133 end 134 else 135 fs 136 |> fold_map (subst subs) (f :: ts) 137 |-> (fn f' :: ts' => pair (list_comb (f', ts'))) 138 end 139 | subst _ t fs = (t, fs); 140 141 (* translate rec equations into function arguments suitable for rec comb *) 142 143 fun trans eqns (cname, cargs) (fnames', fnss', fns) = 144 (case AList.lookup (op =) eqns cname of 145 NONE => (warning ("No equation for constructor " ^ quote cname ^ 146 "\nin definition of function " ^ quote fname); 147 (fnames', fnss', (Const (\<^const_name>\<open>undefined\<close>, dummyT)) :: fns)) 148 | SOME (ls, cargs', rs, rhs, eq) => 149 let 150 val recs = filter (Old_Datatype_Aux.is_rec_type o snd) (cargs' ~~ cargs); 151 val rargs = map fst recs; 152 val subs = map (rpair dummyT o fst) 153 (rev (Term.rename_wrt_term rhs rargs)); 154 val (rhs', (fnames'', fnss'')) = subst (map2 (fn (x, y) => fn z => 155 (Free x, (Old_Datatype_Aux.body_index y, Free z))) recs subs) rhs (fnames', fnss') 156 handle PrimrecError (s, NONE) => primrec_error_eqn s eq 157 in 158 (fnames'', fnss'', fold_rev absfree (cargs' @ subs @ ls @ rs) rhs' :: fns) 159 end) 160 161 in 162 (case AList.lookup (op =) fnames i of 163 NONE => 164 if exists (fn (_, v) => fname = v) fnames then 165 primrec_error ("inconsistent functions for datatype " ^ quote tname) 166 else 167 let 168 val (_, _, eqns) = the (AList.lookup (op =) eqns fname); 169 val (fnames', fnss', fns) = fold_rev (trans eqns) constrs 170 ((i, fname) :: fnames, fnss, []) 171 in 172 (fnames', (i, (fname, #1 (snd (hd eqns)), fns)) :: fnss') 173 end 174 | SOME fname' => 175 if fname = fname' then (fnames, fnss) 176 else primrec_error ("inconsistent functions for datatype " ^ quote tname)) 177 end; 178 179 180(* prepare functions needed for definitions *) 181 182fun get_fns fns ((i : int, (tname, _, constrs)), rec_name) (fs, defs) = 183 (case AList.lookup (op =) fns i of 184 NONE => 185 let 186 val dummy_fns = map (fn (_, cargs) => Const (\<^const_name>\<open>undefined\<close>, 187 replicate (length cargs + length (filter Old_Datatype_Aux.is_rec_type cargs)) 188 dummyT ---> HOLogic.unitT)) constrs; 189 val _ = warning ("No function definition for datatype " ^ quote tname) 190 in 191 (dummy_fns @ fs, defs) 192 end 193 | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name) :: defs)); 194 195 196(* make definition *) 197 198fun make_def ctxt fixes fs (fname, ls, rec_name) = 199 let 200 val SOME (var, varT) = get_first (fn ((b, T), mx: mixfix) => 201 if Binding.name_of b = fname then SOME ((b, mx), T) else NONE) fixes; 202 val def_name = Thm.def_name (Long_Name.base_name fname); 203 val raw_rhs = fold_rev (fn T => fn t => Abs ("", T, t)) (map snd ls @ [dummyT]) 204 (list_comb (Const (rec_name, dummyT), fs @ map Bound (0 :: (length ls downto 1)))) 205 val rhs = singleton (Syntax.check_terms ctxt) (Type.constraint varT raw_rhs); 206 in (var, ((Binding.concealed (Binding.name def_name), []): Attrib.binding, rhs)) end; 207 208 209(* find datatypes which contain all datatypes in tnames' *) 210 211fun find_dts _ _ [] = [] 212 | find_dts dt_info tnames' (tname :: tnames) = 213 (case Symtab.lookup dt_info tname of 214 NONE => primrec_error (quote tname ^ " is not a datatype") 215 | SOME (dt : Old_Datatype_Aux.info) => 216 if subset (op =) (tnames', map (#1 o snd) (#descr dt)) then 217 (tname, dt) :: (find_dts dt_info tnames' tnames) 218 else find_dts dt_info tnames' tnames); 219 220 221(* distill primitive definition(s) from primrec specification *) 222 223fun distill ctxt fixes eqs = 224 let 225 val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed ctxt v 226 orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes)) eqs []; 227 val tnames = distinct (op =) (map (#1 o snd) eqns); 228 val dts = find_dts (Old_Datatype_Data.get_all (Proof_Context.theory_of ctxt)) tnames tnames; 229 val main_fns = map (fn (tname, {index, ...}) => 230 (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts; 231 val {descr, rec_names, rec_rewrites, ...} = 232 if null dts then primrec_error 233 ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive") 234 else snd (hd dts); 235 val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []); 236 val (fs, raw_defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []); 237 val defs = map (make_def ctxt fixes fs) raw_defs; 238 val names = map snd fnames; 239 val names_eqns = map fst eqns; 240 val _ = 241 if eq_set (op =) (names, names_eqns) then () 242 else primrec_error ("functions " ^ commas_quote names_eqns ^ 243 "\nare not mutually recursive"); 244 val rec_rewrites' = map mk_meta_eq rec_rewrites; 245 val prefix = space_implode "_" (map (Long_Name.base_name o #1) raw_defs); 246 fun prove ctxt defs = 247 let 248 val frees = fold (Variable.add_free_names ctxt) eqs []; 249 val rewrites = rec_rewrites' @ map (snd o snd) defs; 250 in 251 map (fn eq => Goal.prove ctxt frees [] eq 252 (fn {context = ctxt', ...} => 253 EVERY [rewrite_goals_tac ctxt' rewrites, resolve_tac ctxt' [refl] 1])) eqs 254 end; 255 in ((prefix, tnames, (fs, defs)), prove) end 256 handle PrimrecError (msg, some_eqn) => 257 error ("Primrec definition error:\n" ^ msg ^ 258 (case some_eqn of 259 SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term ctxt eqn) 260 | NONE => "")); 261 262 263(* primrec definition *) 264 265fun primrec_simple int fixes ts lthy = 266 let 267 val ((prefix, tnames, (_, defs)), prove) = distill lthy fixes ts; 268 in 269 lthy 270 |> fold_map Local_Theory.define defs 271 |> tap (uncurry (BNF_FP_Rec_Sugar_Util.print_def_consts int)) 272 |-> (fn defs => 273 `(fn lthy => {prefix = prefix, types = tnames, result = (map fst defs, prove lthy defs)})) 274 end; 275 276local 277 278fun gen_primrec prep_spec int raw_fixes raw_spec lthy = 279 let 280 val (fixes, spec) = fst (prep_spec raw_fixes raw_spec lthy); 281 val spec_name = Binding.conglomerate (map (#1 o #1) fixes); 282 fun attr_bindings prefix = map (fn ((b, attrs), _) => 283 (Binding.qualify false prefix b, attrs)) spec; 284 fun simp_attr_binding prefix = 285 (Binding.qualify true prefix (Binding.name "simps"), @{attributes [simp, nitpick_simp]}); 286 in 287 lthy 288 |> primrec_simple int fixes (map snd spec) 289 |-> (fn {prefix, types, result = (ts, simps)} => 290 Spec_Rules.add spec_name (Spec_Rules.equational_primrec types) ts simps 291 #> fold_map Local_Theory.note (attr_bindings prefix ~~ map single simps) 292 #-> (fn simps' => Local_Theory.note (simp_attr_binding prefix, maps snd simps') 293 #-> (fn (_, simps'') => 294 Code.declare_default_eqns (map (rpair true) simps'') 295 #> pair {types = types, result = (ts, simps'')}))) 296 end; 297 298in 299 300val primrec = gen_primrec Specification.check_multi_specs; 301val primrec_cmd = gen_primrec Specification.read_multi_specs; 302 303end; 304 305fun primrec_global int fixes specs thy = 306 let 307 val lthy = Named_Target.theory_init thy; 308 val ({result = (ts, simps), ...}, lthy') = primrec int fixes specs lthy; 309 val simps' = Proof_Context.export lthy' lthy simps; 310 in ((ts, simps'), Local_Theory.exit_global lthy') end; 311 312fun primrec_overloaded int ops fixes specs thy = 313 let 314 val lthy = Overloading.overloading ops thy; 315 val ({result = (ts, simps), ...}, lthy') = primrec int fixes specs lthy; 316 val simps' = Proof_Context.export lthy' lthy simps; 317 in ((ts, simps'), Local_Theory.exit_global lthy') end; 318 319end; 320