1(* Title: HOL/Nominal/nominal_primrec.ML 2 Author: Norbert Voelker, FernUni Hagen 3 Author: Stefan Berghofer, TU Muenchen 4 5Package for defining functions on nominal datatypes by primitive recursion. 6Taken from HOL/Tools/primrec.ML 7*) 8 9signature NOMINAL_PRIMREC = 10sig 11 val primrec: term list option -> term option -> 12 (binding * typ option * mixfix) list -> 13 (binding * typ option * mixfix) list -> 14 Specification.multi_specs -> local_theory -> Proof.state 15 val primrec_cmd: string list option -> string option -> 16 (binding * string option * mixfix) list -> 17 (binding * string option * mixfix) list -> 18 Specification.multi_specs_cmd -> local_theory -> Proof.state 19end; 20 21structure NominalPrimrec : NOMINAL_PRIMREC = 22struct 23 24exception RecError of string; 25 26fun primrec_err s = error ("Nominal primrec definition error:\n" ^ s); 27fun primrec_eq_err lthy s eq = 28 primrec_err (s ^ "\nin\n" ^ quote (Syntax.string_of_term lthy eq)); 29 30 31(* preprocessing of equations *) 32 33fun unquantify t = 34 let 35 val (vs, Ts) = split_list (strip_qnt_vars \<^const_name>\<open>Pure.all\<close> t); 36 val body = strip_qnt_body \<^const_name>\<open>Pure.all\<close> t; 37 val (vs', _) = fold_map Name.variant vs (Name.make_context (fold_aterms 38 (fn Free (v, _) => insert (op =) v | _ => I) body [])) 39 in (curry subst_bounds (map2 (curry Free) vs' Ts |> rev) body) end; 40 41fun process_eqn lthy is_fixed spec rec_fns = 42 let 43 val eq = unquantify spec; 44 val (lhs, rhs) = 45 HOLogic.dest_eq (HOLogic.dest_Trueprop (Logic.strip_imp_concl eq)) 46 handle TERM _ => raise RecError "not a proper equation"; 47 48 val (recfun, args) = strip_comb lhs; 49 val fname = case recfun of Free (v, _) => if is_fixed v then v 50 else raise RecError "illegal head of function equation" 51 | _ => raise RecError "illegal head of function equation"; 52 53 val (ls', rest) = chop_prefix is_Free args; 54 val (middle, rs') = chop_suffix is_Free rest; 55 val rpos = length ls'; 56 57 val (constr, cargs') = if null middle then raise RecError "constructor missing" 58 else strip_comb (hd middle); 59 val (cname, T) = dest_Const constr 60 handle TERM _ => raise RecError "ill-formed constructor"; 61 val (tname, _) = dest_Type (body_type T) handle TYPE _ => 62 raise RecError "cannot determine datatype associated with function" 63 64 val (ls, cargs, rs) = 65 (map dest_Free ls', map dest_Free cargs', map dest_Free rs') 66 handle TERM _ => raise RecError "illegal argument in pattern"; 67 val lfrees = ls @ rs @ cargs; 68 69 fun check_vars _ [] = () 70 | check_vars s vars = raise RecError (s ^ commas_quote (map fst vars)) 71 in 72 if length middle > 1 then 73 raise RecError "more than one non-variable in pattern" 74 else 75 (check_vars "repeated variable names in pattern: " (duplicates (op =) lfrees); 76 check_vars "extra variables on rhs: " 77 (map dest_Free (Misc_Legacy.term_frees rhs) |> subtract (op =) lfrees 78 |> filter_out (is_fixed o fst)); 79 case AList.lookup (op =) rec_fns fname of 80 NONE => 81 (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eq))]))::rec_fns 82 | SOME (_, rpos', eqns) => 83 if AList.defined (op =) eqns cname then 84 raise RecError "constructor already occurred as pattern" 85 else if rpos <> rpos' then 86 raise RecError "position of recursive argument inconsistent" 87 else 88 AList.update (op =) 89 (fname, (tname, rpos, (cname, (ls, cargs, rs, rhs, eq))::eqns)) 90 rec_fns) 91 end 92 handle RecError s => primrec_eq_err lthy s spec; 93 94val param_err = "Parameters must be the same for all recursive functions"; 95 96fun process_fun lthy descr eqns (i, fname) (fnames, fnss) = 97 let 98 val (_, (tname, _, constrs)) = nth descr i; 99 100 (* substitute "fname ls x rs" by "y" for (x, (_, y)) in subs *) 101 102 fun subst [] t fs = (t, fs) 103 | subst subs (Abs (a, T, t)) fs = 104 fs 105 |> subst subs t 106 |-> (fn t' => pair (Abs (a, T, t'))) 107 | subst subs (t as (_ $ _)) fs = 108 let 109 val (f, ts) = strip_comb t; 110 in 111 if is_Free f 112 andalso member (fn ((v, _), (w, _)) => v = w) eqns (dest_Free f) then 113 let 114 val (fname', _) = dest_Free f; 115 val (_, rpos, eqns') = the (AList.lookup (op =) eqns fname'); 116 val (ls, rs'') = chop rpos ts 117 val (x', rs) = case rs'' of 118 x' :: rs => (x', rs) 119 | [] => raise RecError ("not enough arguments in recursive application\n" 120 ^ "of function " ^ quote fname' ^ " on rhs"); 121 val rs' = (case eqns' of 122 (_, (ls', _, rs', _, _)) :: _ => 123 let val (rs1, rs2) = chop (length rs') rs 124 in 125 if ls = map Free ls' andalso rs1 = map Free rs' then rs2 126 else raise RecError param_err 127 end 128 | _ => raise RecError ("no equations for " ^ quote fname')); 129 val (x, xs) = strip_comb x' 130 in case AList.lookup (op =) subs x 131 of NONE => 132 fs 133 |> fold_map (subst subs) ts 134 |-> (fn ts' => pair (list_comb (f, ts'))) 135 | SOME (i', y) => 136 fs 137 |> fold_map (subst subs) (xs @ rs') 138 ||> process_fun lthy descr eqns (i', fname') 139 |-> (fn ts' => pair (list_comb (y, ts'))) 140 end 141 else 142 fs 143 |> fold_map (subst subs) (f :: ts) 144 |-> (fn (f'::ts') => pair (list_comb (f', ts'))) 145 end 146 | subst _ t fs = (t, fs); 147 148 (* translate rec equations into function arguments suitable for rec comb *) 149 150 fun trans eqns (cname, cargs) (fnames', fnss', fns) = 151 (case AList.lookup (op =) eqns cname of 152 NONE => (warning ("No equation for constructor " ^ quote cname ^ 153 "\nin definition of function " ^ quote fname); 154 (fnames', fnss', (Const (\<^const_name>\<open>undefined\<close>, dummyT))::fns)) 155 | SOME (ls, cargs', rs, rhs, eq) => 156 let 157 val recs = filter (Old_Datatype_Aux.is_rec_type o snd) (cargs' ~~ cargs); 158 val rargs = map fst recs; 159 val subs = map (rpair dummyT o fst) 160 (rev (Term.rename_wrt_term rhs rargs)); 161 val (rhs', (fnames'', fnss'')) = subst (map2 (fn (x, y) => fn z => 162 (Free x, (Old_Datatype_Aux.body_index y, Free z))) recs subs) rhs (fnames', fnss') 163 handle RecError s => primrec_eq_err lthy s eq 164 in (fnames'', fnss'', fold_rev absfree (cargs' @ subs) rhs' :: fns) 165 end) 166 167 in (case AList.lookup (op =) fnames i of 168 NONE => 169 if exists (fn (_, v) => fname = v) fnames then 170 raise RecError ("inconsistent functions for datatype " ^ quote tname) 171 else 172 let 173 val SOME (_, _, eqns' as (_, (ls, _, rs, _, _)) :: _) = 174 AList.lookup (op =) eqns fname; 175 val (fnames', fnss', fns) = fold_rev (trans eqns') constrs 176 ((i, fname)::fnames, fnss, []) 177 in 178 (fnames', (i, (fname, ls, rs, fns))::fnss') 179 end 180 | SOME fname' => 181 if fname = fname' then (fnames, fnss) 182 else raise RecError ("inconsistent functions for datatype " ^ quote tname)) 183 end; 184 185 186(* prepare functions needed for definitions *) 187 188fun get_fns fns ((i : int, (tname, _, constrs)), rec_name) (fs, defs) = 189 case AList.lookup (op =) fns i of 190 NONE => 191 let 192 val dummy_fns = map (fn (_, cargs) => Const (\<^const_name>\<open>undefined\<close>, 193 replicate (length cargs + length (filter Old_Datatype_Aux.is_rec_type cargs)) 194 dummyT ---> HOLogic.unitT)) constrs; 195 val _ = warning ("No function definition for datatype " ^ quote tname) 196 in 197 (dummy_fns @ fs, defs) 198 end 199 | SOME (fname, ls, rs, fs') => (fs' @ fs, (fname, ls, rs, rec_name, tname) :: defs); 200 201 202(* make definition *) 203 204fun make_def ctxt fixes fs (fname, ls, rs, rec_name, tname) = 205 let 206 val used = map fst (fold Term.add_frees fs []); 207 val x = (singleton (Name.variant_list used) "x", dummyT); 208 val frees = ls @ x :: rs; 209 val raw_rhs = fold_rev absfree frees 210 (list_comb (Const (rec_name, dummyT), fs @ [Free x])) 211 val def_name = Thm.def_name (Long_Name.base_name fname); 212 val rhs = singleton (Syntax.check_terms ctxt) raw_rhs; 213 val SOME var = get_first (fn ((b, _), mx) => 214 if Binding.name_of b = fname then SOME (b, mx) else NONE) fixes; 215 in 216 ((var, ((Binding.name def_name, []), rhs)), 217 subst_bounds (rev (map Free frees), strip_abs_body rhs)) 218 end; 219 220 221(* find datatypes which contain all datatypes in tnames' *) 222 223fun find_dts (dt_info : NominalDatatype.nominal_datatype_info Symtab.table) _ [] = [] 224 | find_dts dt_info tnames' (tname::tnames) = 225 (case Symtab.lookup dt_info tname of 226 NONE => primrec_err (quote tname ^ " is not a nominal datatype") 227 | SOME dt => 228 if subset (op =) (tnames', map (#1 o snd) (#descr dt)) then 229 (tname, dt)::(find_dts dt_info tnames' tnames) 230 else find_dts dt_info tnames' tnames); 231 232fun common_prefix eq ([], _) = [] 233 | common_prefix eq (_, []) = [] 234 | common_prefix eq (x :: xs, y :: ys) = 235 if eq (x, y) then x :: common_prefix eq (xs, ys) else []; 236 237local 238 239fun gen_primrec prep_spec prep_term invs fctxt raw_fixes raw_params raw_spec lthy = 240 let 241 val (fixes', spec) = fst (prep_spec (raw_fixes @ raw_params) raw_spec lthy); 242 val fixes = List.take (fixes', length raw_fixes); 243 val (names_atts, spec') = split_list spec; 244 val eqns' = map unquantify spec' 245 val eqns = fold_rev (process_eqn lthy (fn v => Variable.is_fixed lthy v 246 orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes)) spec' []; 247 val dt_info = NominalDatatype.get_nominal_datatypes (Proof_Context.theory_of lthy); 248 val lsrs :: lsrss = maps (fn (_, (_, _, eqns)) => 249 map (fn (_, (ls, _, rs, _, _)) => ls @ rs) eqns) eqns 250 val _ = 251 (if forall (curry (eq_set (op =)) lsrs) lsrss andalso forall 252 (fn (_, (_, _, (_, (ls, _, rs, _, _)) :: eqns)) => 253 forall (fn (_, (ls', _, rs', _, _)) => 254 ls = ls' andalso rs = rs') eqns 255 | _ => true) eqns 256 then () else primrec_err param_err); 257 val tnames = distinct (op =) (map (#1 o snd) eqns); 258 val dts = find_dts dt_info tnames tnames; 259 val main_fns = 260 map (fn (tname, {index, ...}) => 261 (index, 262 (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) 263 dts; 264 val {descr, rec_names, rec_rewrites, ...} = 265 if null dts then 266 primrec_err ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive") 267 else snd (hd dts); 268 val descr = map (fn (i, (tname, args, constrs)) => (i, (tname, args, 269 map (fn (cname, cargs) => (cname, fold (fn (dTs, dT) => fn dTs' => 270 dTs' @ dTs @ [dT]) cargs [])) constrs))) descr; 271 val (fnames, fnss) = fold_rev (process_fun lthy descr eqns) main_fns ([], []); 272 val (fs, defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []); 273 val defs' = map (make_def lthy fixes fs) defs; 274 val names1 = map snd fnames; 275 val names2 = map fst eqns; 276 val _ = if eq_set (op =) (names1, names2) then () 277 else primrec_err ("functions " ^ commas_quote names2 ^ 278 "\nare not mutually recursive"); 279 val (defs_thms, lthy') = lthy |> 280 fold_map (apfst (snd o snd) oo Local_Theory.define o fst) defs'; 281 val qualify = Binding.qualify false 282 (space_implode "_" (map (Long_Name.base_name o #1) defs)); 283 val names_atts' = map (apfst qualify) names_atts; 284 285 fun mk_idx eq = 286 let 287 val Free (name, _) = head_of (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop 288 (Logic.strip_imp_concl eq)))); 289 val SOME i = AList.lookup op = (map swap fnames) name; 290 val SOME (_, _, constrs) = AList.lookup op = descr i; 291 val SOME (_, _, eqns'') = AList.lookup op = eqns name; 292 val SOME (cname, (_, cargs, _, _, _)) = find_first 293 (fn (_, (_, _, _, _, eq')) => eq = eq') eqns'' 294 in (i, find_index (fn (cname', _) => cname = cname') constrs, cargs) end; 295 296 val rec_rewritess = 297 unflat (map (fn (_, (_, _, constrs)) => constrs) descr) rec_rewrites; 298 val fvars = rec_rewrites |> hd |> Thm.concl_of |> HOLogic.dest_Trueprop |> 299 HOLogic.dest_eq |> fst |> strip_comb |> snd |> take_prefix is_Var; 300 val (pvars, ctxtvars) = List.partition 301 (equal HOLogic.boolT o body_type o snd) 302 (subtract (op =) 303 (Term.add_vars (Thm.concl_of (hd rec_rewrites)) []) 304 (fold_rev (Term.add_vars o Logic.strip_assums_concl) 305 (Thm.prems_of (hd rec_rewrites)) [])); 306 val cfs = defs' |> hd |> snd |> strip_comb |> snd |> 307 curry (List.take o swap) (length fvars) |> map (Thm.cterm_of lthy'); 308 val invs' = (case invs of 309 NONE => map (fn (i, _) => 310 Abs ("x", fastype_of (snd (nth defs' i)), \<^term>\<open>True\<close>)) descr 311 | SOME invs' => map (prep_term lthy') invs'); 312 val inst = (map (#1 o dest_Var) fvars ~~ cfs) @ 313 (map #1 pvars ~~ map (Thm.cterm_of lthy') invs') @ 314 (case ctxtvars of 315 [ctxtvar] => [(#1 ctxtvar, 316 Thm.cterm_of lthy' (the_default HOLogic.unit (Option.map (prep_term lthy') fctxt)))] 317 | _ => []); 318 val rec_rewrites' = map (fn eq => 319 let 320 val (i, j, cargs) = mk_idx eq 321 val th = nth (nth rec_rewritess i) j; 322 val cargs' = th |> Thm.concl_of |> HOLogic.dest_Trueprop |> 323 HOLogic.dest_eq |> fst |> strip_comb |> snd |> List.last |> 324 strip_comb |> snd 325 in (cargs, Logic.strip_imp_prems eq, 326 infer_instantiate lthy' (inst @ 327 (map (#1 o dest_Var) cargs' ~~ map (Thm.cterm_of lthy' o Free) cargs)) th) 328 end) eqns'; 329 330 val prems = foldr1 (common_prefix op aconv) (map (Thm.prems_of o #3) rec_rewrites'); 331 val cprems = map (Thm.cterm_of lthy') prems; 332 val asms = map Thm.assume cprems; 333 val premss = map (fn (cargs, eprems, eqn) => 334 map (fn t => fold_rev (Logic.all o Free) cargs (Logic.list_implies (eprems, t))) 335 (List.drop (Thm.prems_of eqn, length prems))) rec_rewrites'; 336 val cpremss = map (map (Thm.cterm_of lthy')) premss; 337 val asmss = map (map Thm.assume) cpremss; 338 339 fun mk_eqn ((cargs, eprems, eqn), asms') = 340 let 341 val ceprems = map (Thm.cterm_of lthy') eprems; 342 val asms'' = map Thm.assume ceprems; 343 val ccargs = map (Thm.cterm_of lthy' o Free) cargs; 344 val asms''' = map (fn th => implies_elim_list 345 (forall_elim_list ccargs th) asms'') asms' 346 in 347 implies_elim_list eqn (asms @ asms''') |> 348 implies_intr_list ceprems |> 349 forall_intr_list ccargs 350 end; 351 352 val rule_prems = cprems @ flat cpremss; 353 val rule = implies_intr_list rule_prems 354 (Conjunction.intr_balanced (map mk_eqn (rec_rewrites' ~~ asmss))); 355 356 val goals = map (fn ((cargs, _, _), eqn) => 357 (fold_rev (Logic.all o Free) cargs eqn, [])) (rec_rewrites' ~~ eqns'); 358 359 in 360 lthy' |> 361 Variable.add_fixes (map fst lsrs) |> snd |> 362 Proof.theorem NONE 363 (fn thss => fn goal_ctxt => 364 let 365 val simps = Proof_Context.export goal_ctxt lthy' (flat thss); 366 val (simps', lthy'') = 367 fold_map Local_Theory.note (names_atts' ~~ map single simps) lthy'; 368 in 369 lthy'' 370 |> Local_Theory.note 371 ((qualify (Binding.name "simps"), @{attributes [simp, nitpick_simp]}), maps snd simps') 372 |> snd 373 end) 374 [goals] |> 375 Proof.refine_singleton (Method.Basic (fn ctxt => fn _ => 376 CONTEXT_TACTIC 377 (rewrite_goals_tac ctxt defs_thms THEN 378 compose_tac ctxt (false, rule, length rule_prems) 1))) 379 end; 380 381in 382 383val primrec = gen_primrec Specification.check_multi_specs (K I); 384val primrec_cmd = gen_primrec Specification.read_multi_specs Syntax.read_term; 385 386end; 387 388 389(* outer syntax *) 390 391val freshness_context = Parse.reserved "freshness_context"; 392val invariant = Parse.reserved "invariant"; 393 394fun unless_flag scan = Scan.unless ((freshness_context || invariant) -- \<^keyword>\<open>:\<close>) scan; 395 396val parser1 = (freshness_context -- \<^keyword>\<open>:\<close>) |-- unless_flag Parse.term >> SOME; 397val parser2 = (invariant -- \<^keyword>\<open>:\<close>) |-- 398 (Scan.repeat1 (unless_flag Parse.term) >> SOME) -- Scan.optional parser1 NONE || 399 (parser1 >> pair NONE); 400val options = 401 Scan.optional (\<^keyword>\<open>(\<close> |-- Parse.!!! (parser2 --| \<^keyword>\<open>)\<close>)) (NONE, NONE); 402 403val _ = 404 Outer_Syntax.local_theory_to_proof \<^command_keyword>\<open>nominal_primrec\<close> 405 "define primitive recursive functions on nominal datatypes" 406 (options -- Parse.vars -- Parse.for_fixes -- Parse_Spec.where_multi_specs 407 >> (fn ((((invs, fctxt), vars), params), specs) => 408 primrec_cmd invs fctxt vars params specs)); 409 410end; 411