1(*
2 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3 *
4 * SPDX-License-Identifier: BSD-2-Clause
5 *)
6theory Datatype_Schematic
7
8imports
9  MLUtils
10  TermPatternAntiquote
11begin
12
13text \<open>
14  Introduces a method for improving unification outcomes for schematics with
15  datatype expressions as parameters.
16
17  There are two variants:
18    1. In cases where a schematic is applied to a constant like @{term True},
19       we wrap the constant to avoid some undesirable unification candidates.
20
21    2. In cases where a schematic is applied to a constructor expression like
22       @{term "Some x"} or @{term "(x, y)"}, we supply selector expressions
23       like @{term "the"} or @{term "fst"} to provide more unification
24       candidates.  This is only done if parameter that would be selected (e.g.
25       @{term x} in @{term "Some x"}) contains bound variables which the
26       schematic does not have as parameters.
27
28  In the "constructor expression" case, we let users supply additional
29  constructor handlers via the `datatype_schematic` attribute. The method uses
30  rules of the following form:
31
32    @{term "\<And>x1 x2 x3. getter (constructor x1 x2 x3) = x2"}
33
34  These are essentially simp rules for simple "accessor" primrec functions,
35  which are used to turn schematics like
36
37    @{text "?P (constructor x1 x2 x3)"}
38
39  into
40
41    @{text "?P' x2 (constructor x1 x2 x3)"}.
42\<close>
43
44ML \<open>
45  \<comment> \<open>
46    Anchor used to link error messages back to the documentation above.
47  \<close>
48  val usage_pos = @{here};
49\<close>
50
51definition
52  ds_id :: "'a \<Rightarrow> 'a"
53where
54  "ds_id = (\<lambda>x. x)"
55
56lemma wrap_ds_id:
57  "x = ds_id x"
58  by (simp add: ds_id_def)
59
60ML \<open>
61structure Datatype_Schematic = struct
62
63fun eq ((idx1, name1, thm1), (idx2, name2, thm2)) =
64  idx1 = idx2 andalso
65  name1 = name2 andalso
66  (Thm.full_prop_of thm1) aconv (Thm.full_prop_of thm2);
67
68structure Datatype_Schematic_Data = Generic_Data
69(
70  \<comment> \<open>
71    Keys are names of datatype constructors (like @{const Cons}), values are
72    `(index, function_name, thm)`.
73
74    - `function_name` is the name of an "accessor" function that accesses part
75      of the constructor specified by the key (so the accessor @{const hd} is
76      related to the constructor/key @{const Cons}).
77
78    - `thm` is a theorem showing that the function accesses one of the
79      arguments to the constructor (like @{thm list.sel(1)}).
80
81    - `idx` is the index of the constructor argument that the accessor
82      accesses.  (eg. since `hd` accesses the first argument, `idx = 0`; since
83      `tl` accesses the second argument, `idx = 1`).
84  \<close>
85  type T = ((int * string * thm) list) Symtab.table;
86  val empty = Symtab.empty;
87  val extend = I;
88  val merge = Symtab.merge_list eq;
89);
90
91fun gen_att m =
92  Thm.declaration_attribute (fn thm => fn context =>
93    Datatype_Schematic_Data.map (m (Context.proof_of context) thm) context);
94
95(* gathers schematic applications from the goal. no effort is made
96   to normalise bound variables here, since we'll always be comparing
97   elements within a compound application which will be at the same
98   level as regards lambdas. *)
99fun gather_schem_apps (f $ x) insts = let
100    val (f, xs) = strip_comb (f $ x)
101    val insts = fold (gather_schem_apps) (f :: xs) insts
102  in if is_Var f then (f, xs) :: insts else insts end
103  | gather_schem_apps (Abs (_, _, t)) insts
104    = gather_schem_apps t insts
105  | gather_schem_apps _ insts = insts
106
107fun sfirst xs f = get_first f xs
108
109fun get_action ctxt prop = let
110    val schem_insts = gather_schem_apps prop [];
111    val actions = Datatype_Schematic_Data.get (Context.Proof ctxt);
112    fun mk_sel selname T i = let
113        val (argTs, resT) = strip_type T
114      in Const (selname, resT --> nth argTs i) end
115  in
116    sfirst schem_insts
117    (fn (var, xs) => sfirst (Library.tag_list 0 xs)
118        (try (fn (idx, x) => let
119            val (c, ys) = strip_comb x
120            val (fname, T) = dest_Const c
121            val acts = Symtab.lookup_list actions fname
122            fun interesting arg = not (member Term.aconv_untyped xs arg)
123                andalso exists (fn i => not (member (=) xs (Bound i)))
124                    (Term.loose_bnos arg)
125          in the (sfirst acts (fn (i, selname, thms) => if interesting (nth ys i)
126            then SOME (var, idx, mk_sel selname T i, thms) else NONE))
127          end)))
128  end
129
130fun get_bound_tac ctxt = SUBGOAL (fn (t, i) => case get_action ctxt t of
131  SOME (Var ((nm, ix), T), idx, sel, thm) => (fn t => let
132    val (argTs, _) = strip_type T
133    val ix2 = Thm.maxidx_of t + 1
134    val xs = map (fn (i, T) => Free ("x" ^ string_of_int i, T))
135        (Library.tag_list 1 argTs)
136    val nx = sel $ nth xs idx
137    val v' = Var ((nm, ix2), fastype_of nx --> T)
138    val inst_v = fold lambda (rev xs) (betapplys (v' $ nx, xs))
139    val t' = Drule.infer_instantiate ctxt
140        [((nm, ix), Thm.cterm_of ctxt inst_v)] t
141    val t'' = Conv.fconv_rule (Thm.beta_conversion true) t'
142  in safe_full_simp_tac (clear_simpset ctxt addsimps [thm]) i t'' end)
143  | _ => no_tac)
144
145fun id_applicable (f $ x) = let
146    val (f, xs) = strip_comb (f $ x)
147    val here = is_Var f andalso exists is_Const xs
148  in here orelse exists id_applicable (f :: xs) end
149  | id_applicable (Abs (_, _, t)) = id_applicable t
150  | id_applicable _ = false
151
152fun combination_conv cv1 cv2 ct =
153  let
154    val (ct1, ct2) = Thm.dest_comb ct
155    val r1 = SOME (cv1 ct1) handle Option => NONE
156    val r2 = SOME (cv2 ct2) handle Option => NONE
157    fun mk _ (SOME res) = res
158      | mk ct NONE = Thm.reflexive ct
159  in case (r1, r2) of
160      (NONE, NONE) => raise Option
161    | _ => Thm.combination (mk ct1 r1) (mk ct2 r2)
162  end
163
164val wrap = mk_meta_eq @{thm wrap_ds_id}
165
166fun wrap_const_conv _ ct = if is_Const (Thm.term_of ct)
167        andalso fastype_of (Thm.term_of ct) <> @{typ unit}
168    then Conv.rewr_conv wrap ct
169    else raise Option
170
171fun combs_conv conv ctxt ct = case Thm.term_of ct of
172    _ $ _ => combination_conv (combs_conv conv ctxt) (conv ctxt) ct
173  | _ => conv ctxt ct
174
175fun wrap_conv ctxt ct = case Thm.term_of ct of
176    Abs _ => Conv.sub_conv wrap_conv ctxt ct
177  | f $ x => if is_Var (head_of f) then combs_conv wrap_const_conv ctxt ct
178    else if not (id_applicable (f $ x)) then raise Option
179    else combs_conv wrap_conv ctxt ct
180  | _ => raise Option
181
182fun CONVERSION_opt conv i t = CONVERSION conv i t
183    handle Option => no_tac t
184
185exception Datatype_Schematic_Error of Pretty.T;
186
187fun apply_pos_markup pos text =
188  let
189    val props = Position.def_properties_of pos;
190    val markup = Markup.properties props (Markup.entity "" "");
191  in Pretty.mark_str (markup, text) end;
192
193fun invalid_accessor ctxt thm : exn =
194  Datatype_Schematic_Error ([
195    Pretty.str "Bad input theorem '",
196    Syntax.pretty_term ctxt (Thm.full_prop_of thm),
197    Pretty.str "'. Click ",
198    apply_pos_markup usage_pos "*here*",
199    Pretty.str " for info on the required rule format." ] |> Pretty.paragraph);
200
201local
202  fun dest_accessor' thm =
203    case (thm |> Thm.full_prop_of |> HOLogic.dest_Trueprop) of
204      @{term_pat "?fun_name ?data_pat = ?rhs"} =>
205        let
206          val fun_name = Term.dest_Const fun_name |> fst;
207          val (data_const, data_args) = Term.strip_comb data_pat;
208          val data_vars = data_args |> map (Term.dest_Var #> fst);
209          val rhs_var = rhs |> Term.dest_Var |> fst;
210          val data_name = Term.dest_Const data_const |> fst;
211          val rhs_idx = ListExtras.find_index (curry op = rhs_var) data_vars |> the;
212        in (fun_name, data_name, rhs_idx) end;
213in
214  fun dest_accessor ctxt thm =
215    case try dest_accessor' thm of
216      SOME x => x
217    | NONE => raise invalid_accessor ctxt thm;
218end
219
220fun add_rule ctxt thm data =
221  let
222    val (fun_name, data_name, idx) = dest_accessor ctxt thm;
223    val entry = (data_name, (idx, fun_name, thm));
224  in Symtab.insert_list eq entry data end;
225
226fun del_rule ctxt thm data =
227  let
228    val (fun_name, data_name, idx) = dest_accessor ctxt thm;
229    val entry = (data_name, (idx, fun_name, thm));
230  in Symtab.remove_list eq entry data end;
231
232val add = gen_att add_rule;
233val del = gen_att del_rule;
234
235fun wrap_tac ctxt = CONVERSION_opt (wrap_conv ctxt)
236
237fun tac1 ctxt = REPEAT_ALL_NEW (get_bound_tac ctxt) THEN' (TRY o wrap_tac ctxt)
238
239fun tac ctxt = tac1 ctxt ORELSE' wrap_tac ctxt
240
241val add_section =
242  Args.add -- Args.colon >> K (Method.modifier add @{here});
243
244val method =
245  Method.sections [add_section] >> (fn _ => fn ctxt => Method.SIMPLE_METHOD' (tac ctxt));
246
247end
248\<close>
249
250setup \<open>
251  Attrib.setup
252    @{binding "datatype_schematic"}
253    (Attrib.add_del Datatype_Schematic.add Datatype_Schematic.del)
254    "Accessor rules to fix datatypes in schematics"
255\<close>
256
257method_setup datatype_schem = \<open>
258  Datatype_Schematic.method
259\<close>
260
261declare prod.sel[datatype_schematic]
262declare option.sel[datatype_schematic]
263declare list.sel(1,3)[datatype_schematic]
264
265locale datatype_schem_demo begin
266
267lemma handles_nested_constructors:
268  "\<exists>f. \<forall>y. f True (Some [x, (y, z)]) = y"
269  apply (rule exI, rule allI)
270  apply datatype_schem
271  apply (rule refl)
272  done
273
274datatype foo =
275    basic nat int
276  | another nat
277
278primrec get_basic_0 where
279  "get_basic_0 (basic x0 x1) = x0"
280
281primrec get_nat where
282    "get_nat (basic x _) = x"
283  | "get_nat (another z) = z"
284
285lemma selectively_exposing_datatype_arugments:
286  notes get_basic_0.simps[datatype_schematic]
287  shows "\<exists>x. \<forall>a b. x (basic a b) = a"
288  apply (rule exI, (rule allI)+)
289  apply datatype_schem \<comment> \<open>Only exposes `a` to the schematic.\<close>
290  by (rule refl)
291
292lemma method_handles_primrecs_with_two_constructors:
293  shows "\<exists>x. \<forall>a b. x (basic a b) = a"
294  apply (rule exI, (rule allI)+)
295  apply (datatype_schem add: get_nat.simps)
296  by (rule refl)
297
298end
299
300end
301