1(* 2 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230) 3 * 4 * SPDX-License-Identifier: BSD-2-Clause 5 *) 6theory Datatype_Schematic 7 8imports 9 MLUtils 10 TermPatternAntiquote 11begin 12 13text \<open> 14 Introduces a method for improving unification outcomes for schematics with 15 datatype expressions as parameters. 16 17 There are two variants: 18 1. In cases where a schematic is applied to a constant like @{term True}, 19 we wrap the constant to avoid some undesirable unification candidates. 20 21 2. In cases where a schematic is applied to a constructor expression like 22 @{term "Some x"} or @{term "(x, y)"}, we supply selector expressions 23 like @{term "the"} or @{term "fst"} to provide more unification 24 candidates. This is only done if parameter that would be selected (e.g. 25 @{term x} in @{term "Some x"}) contains bound variables which the 26 schematic does not have as parameters. 27 28 In the "constructor expression" case, we let users supply additional 29 constructor handlers via the `datatype_schematic` attribute. The method uses 30 rules of the following form: 31 32 @{term "\<And>x1 x2 x3. getter (constructor x1 x2 x3) = x2"} 33 34 These are essentially simp rules for simple "accessor" primrec functions, 35 which are used to turn schematics like 36 37 @{text "?P (constructor x1 x2 x3)"} 38 39 into 40 41 @{text "?P' x2 (constructor x1 x2 x3)"}. 42\<close> 43 44ML \<open> 45 \<comment> \<open> 46 Anchor used to link error messages back to the documentation above. 47 \<close> 48 val usage_pos = @{here}; 49\<close> 50 51definition 52 ds_id :: "'a \<Rightarrow> 'a" 53where 54 "ds_id = (\<lambda>x. x)" 55 56lemma wrap_ds_id: 57 "x = ds_id x" 58 by (simp add: ds_id_def) 59 60ML \<open> 61structure Datatype_Schematic = struct 62 63fun eq ((idx1, name1, thm1), (idx2, name2, thm2)) = 64 idx1 = idx2 andalso 65 name1 = name2 andalso 66 (Thm.full_prop_of thm1) aconv (Thm.full_prop_of thm2); 67 68structure Datatype_Schematic_Data = Generic_Data 69( 70 \<comment> \<open> 71 Keys are names of datatype constructors (like @{const Cons}), values are 72 `(index, function_name, thm)`. 73 74 - `function_name` is the name of an "accessor" function that accesses part 75 of the constructor specified by the key (so the accessor @{const hd} is 76 related to the constructor/key @{const Cons}). 77 78 - `thm` is a theorem showing that the function accesses one of the 79 arguments to the constructor (like @{thm list.sel(1)}). 80 81 - `idx` is the index of the constructor argument that the accessor 82 accesses. (eg. since `hd` accesses the first argument, `idx = 0`; since 83 `tl` accesses the second argument, `idx = 1`). 84 \<close> 85 type T = ((int * string * thm) list) Symtab.table; 86 val empty = Symtab.empty; 87 val extend = I; 88 val merge = Symtab.merge_list eq; 89); 90 91fun gen_att m = 92 Thm.declaration_attribute (fn thm => fn context => 93 Datatype_Schematic_Data.map (m (Context.proof_of context) thm) context); 94 95(* gathers schematic applications from the goal. no effort is made 96 to normalise bound variables here, since we'll always be comparing 97 elements within a compound application which will be at the same 98 level as regards lambdas. *) 99fun gather_schem_apps (f $ x) insts = let 100 val (f, xs) = strip_comb (f $ x) 101 val insts = fold (gather_schem_apps) (f :: xs) insts 102 in if is_Var f then (f, xs) :: insts else insts end 103 | gather_schem_apps (Abs (_, _, t)) insts 104 = gather_schem_apps t insts 105 | gather_schem_apps _ insts = insts 106 107fun sfirst xs f = get_first f xs 108 109fun get_action ctxt prop = let 110 val schem_insts = gather_schem_apps prop []; 111 val actions = Datatype_Schematic_Data.get (Context.Proof ctxt); 112 fun mk_sel selname T i = let 113 val (argTs, resT) = strip_type T 114 in Const (selname, resT --> nth argTs i) end 115 in 116 sfirst schem_insts 117 (fn (var, xs) => sfirst (Library.tag_list 0 xs) 118 (try (fn (idx, x) => let 119 val (c, ys) = strip_comb x 120 val (fname, T) = dest_Const c 121 val acts = Symtab.lookup_list actions fname 122 fun interesting arg = not (member Term.aconv_untyped xs arg) 123 andalso exists (fn i => not (member (=) xs (Bound i))) 124 (Term.loose_bnos arg) 125 in the (sfirst acts (fn (i, selname, thms) => if interesting (nth ys i) 126 then SOME (var, idx, mk_sel selname T i, thms) else NONE)) 127 end))) 128 end 129 130fun get_bound_tac ctxt = SUBGOAL (fn (t, i) => case get_action ctxt t of 131 SOME (Var ((nm, ix), T), idx, sel, thm) => (fn t => let 132 val (argTs, _) = strip_type T 133 val ix2 = Thm.maxidx_of t + 1 134 val xs = map (fn (i, T) => Free ("x" ^ string_of_int i, T)) 135 (Library.tag_list 1 argTs) 136 val nx = sel $ nth xs idx 137 val v' = Var ((nm, ix2), fastype_of nx --> T) 138 val inst_v = fold lambda (rev xs) (betapplys (v' $ nx, xs)) 139 val t' = Drule.infer_instantiate ctxt 140 [((nm, ix), Thm.cterm_of ctxt inst_v)] t 141 val t'' = Conv.fconv_rule (Thm.beta_conversion true) t' 142 in safe_full_simp_tac (clear_simpset ctxt addsimps [thm]) i t'' end) 143 | _ => no_tac) 144 145fun id_applicable (f $ x) = let 146 val (f, xs) = strip_comb (f $ x) 147 val here = is_Var f andalso exists is_Const xs 148 in here orelse exists id_applicable (f :: xs) end 149 | id_applicable (Abs (_, _, t)) = id_applicable t 150 | id_applicable _ = false 151 152fun combination_conv cv1 cv2 ct = 153 let 154 val (ct1, ct2) = Thm.dest_comb ct 155 val r1 = SOME (cv1 ct1) handle Option => NONE 156 val r2 = SOME (cv2 ct2) handle Option => NONE 157 fun mk _ (SOME res) = res 158 | mk ct NONE = Thm.reflexive ct 159 in case (r1, r2) of 160 (NONE, NONE) => raise Option 161 | _ => Thm.combination (mk ct1 r1) (mk ct2 r2) 162 end 163 164val wrap = mk_meta_eq @{thm wrap_ds_id} 165 166fun wrap_const_conv _ ct = if is_Const (Thm.term_of ct) 167 andalso fastype_of (Thm.term_of ct) <> @{typ unit} 168 then Conv.rewr_conv wrap ct 169 else raise Option 170 171fun combs_conv conv ctxt ct = case Thm.term_of ct of 172 _ $ _ => combination_conv (combs_conv conv ctxt) (conv ctxt) ct 173 | _ => conv ctxt ct 174 175fun wrap_conv ctxt ct = case Thm.term_of ct of 176 Abs _ => Conv.sub_conv wrap_conv ctxt ct 177 | f $ x => if is_Var (head_of f) then combs_conv wrap_const_conv ctxt ct 178 else if not (id_applicable (f $ x)) then raise Option 179 else combs_conv wrap_conv ctxt ct 180 | _ => raise Option 181 182fun CONVERSION_opt conv i t = CONVERSION conv i t 183 handle Option => no_tac t 184 185exception Datatype_Schematic_Error of Pretty.T; 186 187fun apply_pos_markup pos text = 188 let 189 val props = Position.def_properties_of pos; 190 val markup = Markup.properties props (Markup.entity "" ""); 191 in Pretty.mark_str (markup, text) end; 192 193fun invalid_accessor ctxt thm : exn = 194 Datatype_Schematic_Error ([ 195 Pretty.str "Bad input theorem '", 196 Syntax.pretty_term ctxt (Thm.full_prop_of thm), 197 Pretty.str "'. Click ", 198 apply_pos_markup usage_pos "*here*", 199 Pretty.str " for info on the required rule format." ] |> Pretty.paragraph); 200 201local 202 fun dest_accessor' thm = 203 case (thm |> Thm.full_prop_of |> HOLogic.dest_Trueprop) of 204 @{term_pat "?fun_name ?data_pat = ?rhs"} => 205 let 206 val fun_name = Term.dest_Const fun_name |> fst; 207 val (data_const, data_args) = Term.strip_comb data_pat; 208 val data_vars = data_args |> map (Term.dest_Var #> fst); 209 val rhs_var = rhs |> Term.dest_Var |> fst; 210 val data_name = Term.dest_Const data_const |> fst; 211 val rhs_idx = ListExtras.find_index (curry op = rhs_var) data_vars |> the; 212 in (fun_name, data_name, rhs_idx) end; 213in 214 fun dest_accessor ctxt thm = 215 case try dest_accessor' thm of 216 SOME x => x 217 | NONE => raise invalid_accessor ctxt thm; 218end 219 220fun add_rule ctxt thm data = 221 let 222 val (fun_name, data_name, idx) = dest_accessor ctxt thm; 223 val entry = (data_name, (idx, fun_name, thm)); 224 in Symtab.insert_list eq entry data end; 225 226fun del_rule ctxt thm data = 227 let 228 val (fun_name, data_name, idx) = dest_accessor ctxt thm; 229 val entry = (data_name, (idx, fun_name, thm)); 230 in Symtab.remove_list eq entry data end; 231 232val add = gen_att add_rule; 233val del = gen_att del_rule; 234 235fun wrap_tac ctxt = CONVERSION_opt (wrap_conv ctxt) 236 237fun tac1 ctxt = REPEAT_ALL_NEW (get_bound_tac ctxt) THEN' (TRY o wrap_tac ctxt) 238 239fun tac ctxt = tac1 ctxt ORELSE' wrap_tac ctxt 240 241val add_section = 242 Args.add -- Args.colon >> K (Method.modifier add @{here}); 243 244val method = 245 Method.sections [add_section] >> (fn _ => fn ctxt => Method.SIMPLE_METHOD' (tac ctxt)); 246 247end 248\<close> 249 250setup \<open> 251 Attrib.setup 252 @{binding "datatype_schematic"} 253 (Attrib.add_del Datatype_Schematic.add Datatype_Schematic.del) 254 "Accessor rules to fix datatypes in schematics" 255\<close> 256 257method_setup datatype_schem = \<open> 258 Datatype_Schematic.method 259\<close> 260 261declare prod.sel[datatype_schematic] 262declare option.sel[datatype_schematic] 263declare list.sel(1,3)[datatype_schematic] 264 265locale datatype_schem_demo begin 266 267lemma handles_nested_constructors: 268 "\<exists>f. \<forall>y. f True (Some [x, (y, z)]) = y" 269 apply (rule exI, rule allI) 270 apply datatype_schem 271 apply (rule refl) 272 done 273 274datatype foo = 275 basic nat int 276 | another nat 277 278primrec get_basic_0 where 279 "get_basic_0 (basic x0 x1) = x0" 280 281primrec get_nat where 282 "get_nat (basic x _) = x" 283 | "get_nat (another z) = z" 284 285lemma selectively_exposing_datatype_arugments: 286 notes get_basic_0.simps[datatype_schematic] 287 shows "\<exists>x. \<forall>a b. x (basic a b) = a" 288 apply (rule exI, (rule allI)+) 289 apply datatype_schem \<comment> \<open>Only exposes `a` to the schematic.\<close> 290 by (rule refl) 291 292lemma method_handles_primrecs_with_two_constructors: 293 shows "\<exists>x. \<forall>a b. x (basic a b) = a" 294 apply (rule exI, (rule allI)+) 295 apply (datatype_schem add: get_nat.simps) 296 by (rule refl) 297 298end 299 300end 301