1(*
2 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3 *
4 * SPDX-License-Identifier: BSD-2-Clause
5 *)
6theory Match_Abbreviation
7
8imports Main
9
10keywords "match_abbreviation" :: thy_decl
11  and "reassoc_thm" :: thy_decl
12
13begin
14
15text \<open>Splicing components of terms and saving as abbreviations.
16See the example at the bottom for explanation/documentation.
17\<close>
18
19ML \<open>
20structure Match_Abbreviation = struct
21
22fun app_cons_dummy cons x y
23  = Const (cons, dummyT) $ x $ y
24
25fun lazy_lam x t = if Term.exists_subterm (fn t' => t' aconv x) t
26    then lambda x t else t
27
28fun abs_dig_f ctxt lazy f (Abs (nm, T, t))
29  = let
30    val (nms, ctxt) = Variable.variant_fixes [nm] ctxt
31    val x = Free (hd nms, T)
32    val t = betapply (Abs (nm, T, t), x)
33    val t' = f ctxt t
34  in if lazy then lazy_lam x t' else lambda x t' end
35  | abs_dig_f _ _ _ t = raise TERM ("abs_dig_f: not abs", [t])
36
37fun find_term1 ctxt get (f $ x)
38  = (get ctxt (f $ x) handle Option => (find_term1 ctxt get f
39        handle Option => find_term1 ctxt get x))
40  | find_term1 ctxt get (a as Abs _)
41  = abs_dig_f ctxt true (fn ctxt => find_term1 ctxt get) a
42  | find_term1 ctxt get t = get ctxt t
43
44fun not_found pat t = raise TERM ("pattern not found", [pat, t])
45
46fun find_term ctxt get pat t = find_term1 ctxt get t
47  handle Option => not_found pat t
48
49fun lambda_frees_vars ctxt ord_t t = let
50    fun is_free t = is_Free t andalso not (Variable.is_fixed ctxt (Term.term_name t))
51    fun is_it t = is_free t orelse is_Var t
52    val get = fold_aterms (fn t => if is_it t then insert (=) t else I)
53    val all_vars = get ord_t []
54    val vars = get t []
55    val ord_vars = filter (member (=) vars) all_vars
56  in fold lambda ord_vars t end
57
58fun parse_pat_fixes ctxt fixes pats = let
59    val (_, ctxt') = Variable.add_fixes
60            (map (fn (b, _, _) => Binding.name_of b) fixes) ctxt
61    val read_pats = Syntax.read_terms ctxt' pats
62  in Variable.export_terms ctxt' ctxt read_pats end
63
64fun add_reassoc name rhs fixes thms_info ctxt = let
65    val thms = Attrib.eval_thms ctxt thms_info
66    val rhs_pat = singleton (parse_pat_fixes ctxt fixes) rhs
67      |> Thm.cterm_of ctxt
68    val rew = Simplifier.rewrite (clear_simpset ctxt addsimps thms) rhs_pat
69      |> Thm.symmetric
70    val (_, ctxt) = Local_Theory.note ((name, []), [rew]) ctxt
71    val pretty_decl = Pretty.block [Pretty.str (Binding.name_of name ^ ":\n"),
72        Thm.pretty_thm ctxt rew]
73  in Pretty.writeln pretty_decl; ctxt end
74
75fun dig_f ctxt repeat adj (f $ x) = (adj ctxt (f $ x)
76    handle Option => (dig_f ctxt repeat adj f
77            $ (if repeat then (dig_f ctxt repeat adj x
78                handle Option => x) else x)
79        handle Option => f $ dig_f ctxt repeat adj x))
80  | dig_f ctxt repeat adj (a as Abs _)
81    = abs_dig_f ctxt false (fn ctxt => dig_f ctxt repeat adj) a
82  | dig_f ctxt _ adj t = adj ctxt t
83
84fun do_rewrite ctxt repeat rew_pair t = let
85    val thy = Proof_Context.theory_of ctxt
86    fun adj _ t = case Pattern.match_rew thy t rew_pair
87      of NONE => raise Option | SOME (t', _) => t'
88  in dig_f ctxt repeat adj t
89    handle Option => not_found (fst rew_pair) t end
90
91fun select_dig ctxt [] f t = f ctxt t
92  | select_dig ctxt (p :: ps) f t = let
93    val thy = Proof_Context.theory_of ctxt
94    fun do_rec ctxt t = if Pattern.matches thy (p, t)
95      then select_dig ctxt ps f t else raise Option
96  in dig_f ctxt false do_rec t handle Option => not_found p t end
97
98fun ext_dig_lazy ctxt f (a as Abs _)
99  = abs_dig_f ctxt true (fn ctxt => ext_dig_lazy ctxt f) a
100  | ext_dig_lazy ctxt f t = f ctxt t
101
102fun report_adjust ctxt nm t = let
103    val pretty_decl = Pretty.block [Pretty.str (nm ^ ", have:\n"),
104        Syntax.pretty_term ctxt t]
105  in Pretty.writeln pretty_decl; t end
106
107fun do_adjust ctxt ((("select", []), [p]), fixes) t = let
108    val p = singleton (parse_pat_fixes ctxt fixes) p
109    val thy = Proof_Context.theory_of ctxt
110    fun get _ t = if Pattern.matches thy (p, t) then t else raise Option
111    val t = find_term ctxt get p t
112  in report_adjust ctxt "Selected" t end
113  | do_adjust ctxt ((("retype_consts", []), consts), []) t = let
114    fun get_constname (Const (s, _)) = s
115      | get_constname (Abs (_, _, t)) = get_constname t
116      | get_constname (f $ _) = get_constname f
117      | get_constname _ = raise Option
118    fun get_constname2 t = get_constname t
119      handle Option => raise TERM ("do_adjust: no constant", [t])
120    val cnames = map (get_constname2 o Syntax.read_term ctxt) consts
121      |> Symtab.make_set
122    fun adj (Const (cn, T)) = if Symtab.defined cnames cn
123        then Const (cn, dummyT) else Const (cn, T)
124      | adj t = t
125    val t = Syntax.check_term ctxt (Term.map_aterms adj t)
126  in report_adjust ctxt "Adjusted types" t end
127  | do_adjust ctxt (((r, in_selects), [from, to]), fixes) t = if
128        r = "rewrite1" orelse r = "rewrite" then let
129    val repeat = r <> "rewrite1"
130    val sel_pats = map (fn (p, fixes) => singleton (parse_pat_fixes ctxt fixes) p)
131        in_selects
132    val rewrite_pair = case parse_pat_fixes ctxt fixes [from, to]
133      of [f, t] => (f, t) | _ => error ("do_adjust: unexpected length")
134    val t = ext_dig_lazy ctxt (fn ctxt => select_dig ctxt sel_pats
135        (fn ctxt => do_rewrite ctxt repeat rewrite_pair)) t
136  in report_adjust ctxt (if repeat then "Rewrote" else "Rewrote (repeated)") t end
137  else error ("do_adjust: unexpected: " ^ r)
138  | do_adjust _ args _ = error ("do_adjust: unexpected: " ^ @{make_string} args)
139
140fun unvarify_types_same ty = ty
141  |> Term_Subst.map_atypsT_same
142    (fn TVar ((a, i), S) => TFree (a ^ "_var_" ^ string_of_int i, S)
143      | _ => raise Same.SAME)
144
145fun unvarify_types tm = tm
146  |> Same.commit (Term_Subst.map_types_same unvarify_types_same)
147
148fun match_abbreviation mode name init adjusts int ctxt = let
149    val init_term = init ctxt
150    val init_lambda = lambda_frees_vars ctxt init_term init_term
151      |> unvarify_types
152      |> Syntax.check_term ctxt
153    val decl = (name, NONE, Mixfix.NoSyn)
154    val result = fold (do_adjust ctxt) adjusts init_lambda
155    val lhs = Free (Binding.name_of name, fastype_of result)
156    val eq = Logic.mk_equals (lhs, result)
157    val ctxt = Specification.abbreviation mode (SOME decl) [] eq int ctxt
158    val pretty_eq = Syntax.pretty_term ctxt eq
159  in Pretty.writeln pretty_eq; ctxt end
160
161fun from_thm f thm_info ctxt = let
162    val thm = singleton (Attrib.eval_thms ctxt) thm_info
163  in f thm end
164
165fun from_term term_str ctxt = Syntax.parse_term ctxt term_str
166
167val init_term_parse = Parse.$$$ "in" |--
168    ((Parse.reserved "concl" |-- Parse.thm >> from_thm Thm.concl_of)
169        || (Parse.reserved "thm_prop" |-- Parse.thm >> from_thm Thm.prop_of)
170        || (Parse.term >> from_term)
171    )
172
173val term_to_term = (Parse.term -- (Parse.reserved "to" |-- Parse.term))
174    >> (fn (a, b) => [a, b])
175
176val p_for_fixes = Scan.optional
177    (Parse.$$$ "(" |-- Parse.for_fixes --| Parse.$$$ ")") []
178
179val adjust_parser = Parse.and_list1
180    ((Parse.reserved "select" -- Scan.succeed [] -- (Parse.term >> single) -- p_for_fixes)
181        || (Parse.reserved "retype_consts" -- Scan.succeed []
182            -- Scan.repeat Parse.term -- Scan.succeed [])
183        || ((Parse.reserved "rewrite1" || Parse.reserved "rewrite")
184            -- Scan.repeat (Parse.$$$ "in" |-- Parse.term -- p_for_fixes)
185            -- term_to_term -- p_for_fixes)
186    )
187
188(* install match_abbreviation. see below for examples/docs *)
189val _ =
190  Outer_Syntax.local_theory' @{command_keyword match_abbreviation}
191    "setup abbreviation for subterm of theorem"
192    (Parse.syntax_mode -- Parse.binding
193        -- init_term_parse -- adjust_parser
194      >> (fn (((mode, name), init), adjusts)
195            => match_abbreviation mode name init adjusts));
196
197val _ =
198  Outer_Syntax.local_theory @{command_keyword reassoc_thm}
199    "store a reassociate-theorem"
200    (Parse.binding -- Parse.term -- p_for_fixes -- Scan.repeat Parse.thm
201      >> (fn (((name, rhs), fixes), thms)
202            => add_reassoc name rhs fixes thms));
203end
204\<close>
205
206text \<open>
207The match/abbreviate command. There are examples of all elements below,
208and an example involving monadic syntax in the theory Match-Abbreviation-Test.
209
210Each invocation is match abbreviation, a syntax mode (e.g. (input)), an
211abbreviation name, a term specifier, and a list of adjustment specifiers.
212
213A term specifier can be term syntax or the conclusion or proposition of
214some theorem. Examples below.
215
216Each adjustment is a select, a rewrite, or a constant retype.
217
218The select adjustment picks out the part of the term matching the
219pattern (examples below). It picks the first match point, ordered in
220term order with compound terms before their subterms and functions
221before their arguments.
222
223The rewrite adjustment uses a pattern pair, and rewrites instances
224of the first pattern into the second. The match points are found in
225the same order as select. The "in" specifiers (examples below)
226limit the rewriting to within some matching subterm, specified with
227pattern in the same way as select. The rewrite1 variant only
228rewrites once, at the first matching site.
229
230The rewrite mechanism can be used to replace terms with terms
231of different types. The retype adjustment can then be used
232to repair the term by resetting the types of all instances of
233the named constants. This is used below with list constructors,
234to assemble a new list with a different element type.
235\<close>
236
237experiment begin
238
239text \<open>Fetching part of the statement of a theorem.\<close>
240match_abbreviation (input) fixp_thm_bit
241  in thm_prop fixp_induct_tailrec
242  select "X \<equiv> Y" (for X Y)
243
244text \<open>Ditto conclusion.\<close>
245match_abbreviation (input) rev_simps_bit
246  in concl rev.simps(2)
247  select "X" (for X)
248
249text \<open>Selecting some conjuncts and reorienting an equality.\<close>
250match_abbreviation (input) conjunct_test
251  in "(P \<and> Q \<and> P \<and> P \<and> P \<and> ((1 :: nat) = 2) \<and> Q \<and> Q, [Suc 0, 0])"
252  select "Q \<and> Z" (for Z)
253  and rewrite "x = y" to "y = x" (for x y)
254  and rewrite in "x = y & Z" (for x y Z)
255    "A \<and> B" to "A" (for A B)
256
257text \<open>The relevant reassociate theorem, that rearranges a
258conjunction like the above to group the elements selected.\<close>
259reassoc_thm conjunct_test_reassoc
260  "conjunct_test P Q \<and> Z" (for P Q Z)
261  conj_assoc
262
263text \<open>Selecting some elements of a list, and then replacing
264tuples with equalities, and adjusting the type of the list constructors
265so the new term is type correct.\<close>
266match_abbreviation (input) list_test
267  in "[(Suc 1, Suc 2), (4, 5), (6, 7), (8, 9), (10, 11), (x, y), (6, 7),
268    (18, 19), a, a, a, a, a, a, a]"
269  select "(4, V) # xs" (for V xs)
270  and rewrite "(x, y)" to "(y, x)" (for x y)
271  and rewrite1 in "(9, V) # xs" (for V xs) in "(7, V) # xs" (for V xs)
272    "x # xs" to "[x]" (for x xs)
273  and rewrite "(x, y)" to "x = y" (for x y)
274  and retype_consts Cons Nil
275
276end
277
278end
279