1(* Title: HOL/Tools/Old_Datatype/old_datatype_aux.ML 2 Author: Stefan Berghofer, TU Muenchen 3 4Datatype package: auxiliary data structures and functions. 5*) 6 7signature OLD_DATATYPE_COMMON = 8sig 9 type config = {strict : bool, quiet : bool} 10 val default_config : config 11 datatype dtyp = 12 DtTFree of string * sort 13 | DtType of string * dtyp list 14 | DtRec of int 15 type descr = (int * (string * dtyp list * (string * dtyp list) list)) list 16 type info = 17 {index : int, 18 descr : descr, 19 inject : thm list, 20 distinct : thm list, 21 induct : thm, 22 inducts : thm list, 23 exhaust : thm, 24 nchotomy : thm, 25 rec_names : string list, 26 rec_rewrites : thm list, 27 case_name : string, 28 case_rewrites : thm list, 29 case_cong : thm, 30 case_cong_weak : thm, 31 split : thm, 32 split_asm: thm} 33 type spec = (binding * (string * sort) list * mixfix) * (binding * typ list * mixfix) list 34end 35 36signature OLD_DATATYPE_AUX = 37sig 38 include OLD_DATATYPE_COMMON 39 40 val message : config -> string -> unit 41 42 val store_thmss_atts : string -> string list -> attribute list list -> thm list list 43 -> theory -> thm list list * theory 44 val store_thmss : string -> string list -> thm list list -> theory -> thm list list * theory 45 val store_thms_atts : string -> string list -> attribute list list -> thm list 46 -> theory -> thm list * theory 47 val store_thms : string -> string list -> thm list -> theory -> thm list * theory 48 49 val split_conj_thm : thm -> thm list 50 val mk_conj : term list -> term 51 val mk_disj : term list -> term 52 53 val app_bnds : term -> int -> term 54 55 val ind_tac : Proof.context -> thm -> string list -> int -> tactic 56 val exh_tac : Proof.context -> (string -> thm) -> int -> tactic 57 58 exception Datatype 59 exception Datatype_Empty of string 60 val name_of_typ : typ -> string 61 val dtyp_of_typ : (string * (string * sort) list) list -> typ -> dtyp 62 val mk_Free : string -> typ -> int -> term 63 val is_rec_type : dtyp -> bool 64 val typ_of_dtyp : descr -> dtyp -> typ 65 val dest_DtTFree : dtyp -> string * sort 66 val dest_DtRec : dtyp -> int 67 val strip_dtyp : dtyp -> dtyp list * dtyp 68 val body_index : dtyp -> int 69 val mk_fun_dtyp : dtyp list -> dtyp -> dtyp 70 val get_nonrec_types : descr -> typ list 71 val get_branching_types : descr -> typ list 72 val get_arities : descr -> int list 73 val get_rec_types : descr -> typ list 74 val interpret_construction : descr -> (string * sort) list -> 75 {atyp: typ -> 'a, dtyp: typ list -> int * bool -> string * typ list -> 'a} -> 76 ((string * typ list) * (string * 'a list) list) list 77 val unfold_datatypes : Proof.context -> descr -> info Symtab.table -> 78 descr -> int -> descr list * int 79 val find_shortest_path : descr -> int -> (string * int) option 80end; 81 82structure Old_Datatype_Aux : OLD_DATATYPE_AUX = 83struct 84 85(* datatype option flags *) 86 87type config = {strict : bool, quiet : bool}; 88val default_config : config = {strict = true, quiet = false}; 89 90fun message ({quiet = true, ...} : config) s = writeln s 91 | message _ _ = (); 92 93 94(* store theorems in theory *) 95 96fun store_thmss_atts name tnames attss thmss = 97 fold_map (fn ((tname, atts), thms) => 98 Global_Theory.note_thms "" 99 ((Binding.qualify true tname (Binding.name name), atts), [(thms, [])]) 100 #-> (fn (_, res) => pair res)) (tnames ~~ attss ~~ thmss); 101 102fun store_thmss name tnames = store_thmss_atts name tnames (replicate (length tnames) []); 103 104fun store_thms_atts name tnames attss thms = 105 fold_map (fn ((tname, atts), thm) => 106 Global_Theory.note_thms "" 107 ((Binding.qualify true tname (Binding.name name), atts), [([thm], [])]) 108 #-> (fn (_, [res]) => pair res)) (tnames ~~ attss ~~ thms); 109 110fun store_thms name tnames = store_thms_atts name tnames (replicate (length tnames) []); 111 112 113(* split theorem thm_1 & ... & thm_n into n theorems *) 114 115fun split_conj_thm th = 116 ((th RS conjunct1) :: split_conj_thm (th RS conjunct2)) handle THM _ => [th]; 117 118val mk_conj = foldr1 (HOLogic.mk_binop \<^const_name>\<open>HOL.conj\<close>); 119val mk_disj = foldr1 (HOLogic.mk_binop \<^const_name>\<open>HOL.disj\<close>); 120 121fun app_bnds t i = list_comb (t, map Bound (i - 1 downto 0)); 122 123 124(* instantiate induction rule *) 125 126fun ind_tac ctxt indrule indnames = CSUBGOAL (fn (cgoal, i) => 127 let 128 val goal = Thm.term_of cgoal; 129 val ts = HOLogic.dest_conj (HOLogic.dest_Trueprop (Thm.concl_of indrule)); 130 val ts' = HOLogic.dest_conj (HOLogic.dest_Trueprop (Logic.strip_imp_concl goal)); 131 val getP = 132 if can HOLogic.dest_imp (hd ts) 133 then apfst SOME o HOLogic.dest_imp 134 else pair NONE; 135 val flt = 136 if null indnames then I 137 else filter (member (op =) indnames o fst); 138 fun abstr (t1, t2) = 139 (case t1 of 140 NONE => 141 (case flt (Term.add_frees t2 []) of 142 [(s, T)] => SOME (absfree (s, T) t2) 143 | _ => NONE) 144 | SOME (_ $ t') => SOME (Abs ("x", fastype_of t', abstract_over (t', t2)))); 145 val insts = 146 (ts ~~ ts') |> map_filter (fn (t, u) => 147 (case abstr (getP u) of 148 NONE => NONE 149 | SOME u' => SOME (t |> getP |> snd |> head_of |> dest_Var |> #1, Thm.cterm_of ctxt u'))); 150 val indrule' = infer_instantiate ctxt insts indrule; 151 in resolve_tac ctxt [indrule'] i end); 152 153 154(* perform exhaustive case analysis on last parameter of subgoal i *) 155 156fun exh_tac ctxt exh_thm_of = CSUBGOAL (fn (cgoal, i) => 157 let 158 val goal = Thm.term_of cgoal; 159 val params = Logic.strip_params goal; 160 val (_, Type (tname, _)) = hd (rev params); 161 val exhaustion = Thm.lift_rule cgoal (exh_thm_of tname); 162 val prem' = hd (Thm.prems_of exhaustion); 163 val _ $ (_ $ lhs $ _) = hd (rev (Logic.strip_assums_hyp prem')); 164 val exhaustion' = 165 infer_instantiate ctxt 166 [(#1 (dest_Var (head_of lhs)), 167 Thm.cterm_of ctxt (fold_rev (fn (_, T) => fn t => Abs ("z", T, t)) params (Bound 0)))] 168 exhaustion; 169 in compose_tac ctxt (false, exhaustion', Thm.nprems_of exhaustion) i end); 170 171 172(********************** Internal description of datatypes *********************) 173 174datatype dtyp = 175 DtTFree of string * sort 176 | DtType of string * dtyp list 177 | DtRec of int; 178 179(* information about datatypes *) 180 181(* index, datatype name, type arguments, constructor name, types of constructor's arguments *) 182type descr = (int * (string * dtyp list * (string * dtyp list) list)) list; 183 184type info = 185 {index : int, 186 descr : descr, 187 inject : thm list, 188 distinct : thm list, 189 induct : thm, 190 inducts : thm list, 191 exhaust : thm, 192 nchotomy : thm, 193 rec_names : string list, 194 rec_rewrites : thm list, 195 case_name : string, 196 case_rewrites : thm list, 197 case_cong : thm, 198 case_cong_weak : thm, 199 split : thm, 200 split_asm: thm}; 201 202type spec = (binding * (string * sort) list * mixfix) * (binding * typ list * mixfix) list; 203 204fun mk_Free s T i = Free (s ^ string_of_int i, T); 205 206fun subst_DtTFree _ substs (T as DtTFree a) = the_default T (AList.lookup (op =) substs a) 207 | subst_DtTFree i substs (DtType (name, ts)) = DtType (name, map (subst_DtTFree i substs) ts) 208 | subst_DtTFree i _ (DtRec j) = DtRec (i + j); 209 210exception Datatype; 211exception Datatype_Empty of string; 212 213fun dest_DtTFree (DtTFree a) = a 214 | dest_DtTFree _ = raise Datatype; 215 216fun dest_DtRec (DtRec i) = i 217 | dest_DtRec _ = raise Datatype; 218 219fun is_rec_type (DtType (_, dts)) = exists is_rec_type dts 220 | is_rec_type (DtRec _) = true 221 | is_rec_type _ = false; 222 223fun strip_dtyp (DtType ("fun", [T, U])) = apfst (cons T) (strip_dtyp U) 224 | strip_dtyp T = ([], T); 225 226val body_index = dest_DtRec o snd o strip_dtyp; 227 228fun mk_fun_dtyp [] U = U 229 | mk_fun_dtyp (T :: Ts) U = DtType ("fun", [T, mk_fun_dtyp Ts U]); 230 231fun name_of_typ (Type (s, Ts)) = 232 let val s' = Long_Name.base_name s in 233 space_implode "_" 234 (filter_out (equal "") (map name_of_typ Ts) @ 235 [if Symbol_Pos.is_identifier s' then s' else "x"]) 236 end 237 | name_of_typ _ = ""; 238 239fun dtyp_of_typ _ (TFree a) = DtTFree a 240 | dtyp_of_typ _ (TVar _) = error "Illegal schematic type variable(s)" 241 | dtyp_of_typ new_dts (Type (tname, Ts)) = 242 (case AList.lookup (op =) new_dts tname of 243 NONE => DtType (tname, map (dtyp_of_typ new_dts) Ts) 244 | SOME vs => 245 if map (try dest_TFree) Ts = map SOME vs then 246 DtRec (find_index (curry op = tname o fst) new_dts) 247 else error ("Illegal occurrence of recursive type " ^ quote tname)); 248 249fun typ_of_dtyp descr (DtTFree a) = TFree a 250 | typ_of_dtyp descr (DtRec i) = 251 let val (s, ds, _) = the (AList.lookup (op =) descr i) 252 in Type (s, map (typ_of_dtyp descr) ds) end 253 | typ_of_dtyp descr (DtType (s, ds)) = Type (s, map (typ_of_dtyp descr) ds); 254 255(* find all non-recursive types in datatype description *) 256 257fun get_nonrec_types descr = 258 map (typ_of_dtyp descr) (fold (fn (_, (_, _, constrs)) => 259 fold (fn (_, cargs) => union (op =) (filter_out is_rec_type cargs)) constrs) descr []); 260 261(* get all recursive types in datatype description *) 262 263fun get_rec_types descr = map (fn (_ , (s, ds, _)) => 264 Type (s, map (typ_of_dtyp descr) ds)) descr; 265 266(* get all branching types *) 267 268fun get_branching_types descr = 269 map (typ_of_dtyp descr) 270 (fold 271 (fn (_, (_, _, constrs)) => 272 fold (fn (_, cargs) => fold (strip_dtyp #> fst #> fold (insert op =)) cargs) constrs) 273 descr []); 274 275fun get_arities descr = 276 fold 277 (fn (_, (_, _, constrs)) => 278 fold (fn (_, cargs) => 279 fold (insert op =) (map (length o fst o strip_dtyp) (filter is_rec_type cargs))) constrs) 280 descr []; 281 282(* interpret construction of datatype *) 283 284fun interpret_construction descr vs {atyp, dtyp} = 285 let 286 val typ_of = 287 typ_of_dtyp descr #> 288 map_atyps (fn TFree (a, _) => TFree (a, the (AList.lookup (op =) vs a)) | T => T); 289 fun interpT dT = 290 (case strip_dtyp dT of 291 (dTs, DtRec l) => 292 let 293 val (tyco, dTs', _) = the (AList.lookup (op =) descr l); 294 val Ts = map typ_of dTs; 295 val Ts' = map typ_of dTs'; 296 val is_proper = forall (can dest_TFree) Ts'; 297 in dtyp Ts (l, is_proper) (tyco, Ts') end 298 | _ => atyp (typ_of dT)); 299 fun interpC (c, dTs) = (c, map interpT dTs); 300 fun interpD (_, (tyco, dTs, cs)) = ((tyco, map typ_of dTs), map interpC cs); 301 in map interpD descr end; 302 303(* unfold a list of mutually recursive datatype specifications *) 304(* all types of the form DtType (dt_name, [..., DtRec _, ...]) *) 305(* need to be unfolded *) 306 307fun unfold_datatypes ctxt orig_descr (dt_info : info Symtab.table) descr i = 308 let 309 fun typ_error T msg = 310 error ("Non-admissible type expression\n" ^ 311 Syntax.string_of_typ ctxt (typ_of_dtyp (orig_descr @ descr) T) ^ "\n" ^ msg); 312 313 fun get_dt_descr T i tname dts = 314 (case Symtab.lookup dt_info tname of 315 NONE => 316 typ_error T (quote tname ^ " is not registered as an old-style datatype and hence cannot \ 317 \be used in nested recursion") 318 | SOME {index, descr, ...} => 319 let 320 val (_, vars, _) = the (AList.lookup (op =) descr index); 321 val subst = map dest_DtTFree vars ~~ dts 322 handle ListPair.UnequalLengths => 323 typ_error T ("Type constructor " ^ quote tname ^ 324 " used with wrong number of arguments"); 325 in 326 (i + index, 327 map (fn (j, (tn, args, cs)) => 328 (i + j, (tn, map (subst_DtTFree i subst) args, 329 map (apsnd (map (subst_DtTFree i subst))) cs))) descr) 330 end); 331 332 (* unfold a single constructor argument *) 333 334 fun unfold_arg T (i, Ts, descrs) = 335 if is_rec_type T then 336 let val (Us, U) = strip_dtyp T in 337 if exists is_rec_type Us then 338 typ_error T "Non-strictly positive recursive occurrence of type" 339 else 340 (case U of 341 DtType (tname, dts) => 342 let 343 val (index, descr) = get_dt_descr T i tname dts; 344 val (descr', i') = 345 unfold_datatypes ctxt orig_descr dt_info descr (i + length descr); 346 in (i', Ts @ [mk_fun_dtyp Us (DtRec index)], descrs @ descr') end 347 | _ => (i, Ts @ [T], descrs)) 348 end 349 else (i, Ts @ [T], descrs); 350 351 (* unfold a constructor *) 352 353 fun unfold_constr (cname, cargs) (i, constrs, descrs) = 354 let val (i', cargs', descrs') = fold unfold_arg cargs (i, [], descrs) 355 in (i', constrs @ [(cname, cargs')], descrs') end; 356 357 (* unfold a single datatype *) 358 359 fun unfold_datatype (j, (tname, tvars, constrs)) (i, dtypes, descrs) = 360 let val (i', constrs', descrs') = fold unfold_constr constrs (i, [], descrs) 361 in (i', dtypes @ [(j, (tname, tvars, constrs'))], descrs') end; 362 363 val (i', descr', descrs) = fold unfold_datatype descr (i, [], []); 364 365 in (descr' :: descrs, i') end; 366 367(* find shortest path to constructor with no recursive arguments *) 368 369fun find_nonempty descr is i = 370 let 371 fun arg_nonempty (_, DtRec i) = 372 if member (op =) is i 373 then NONE 374 else Option.map (Integer.add 1 o snd) (find_nonempty descr (i :: is) i) 375 | arg_nonempty _ = SOME 0; 376 fun max_inf (SOME i) (SOME j) = SOME (Integer.max i j) 377 | max_inf _ _ = NONE; 378 fun max xs = fold max_inf xs (SOME 0); 379 val (_, _, constrs) = the (AList.lookup (op =) descr i); 380 val xs = 381 sort (int_ord o apply2 snd) 382 (map_filter (fn (s, dts) => Option.map (pair s) 383 (max (map (arg_nonempty o strip_dtyp) dts))) constrs) 384 in if null xs then NONE else SOME (hd xs) end; 385 386fun find_shortest_path descr i = find_nonempty descr [i] i; 387 388end; 389