1(*  Title:      HOL/Tools/Function/function_context_tree.ML
2    Author:     Alexander Krauss, TU Muenchen
3
4Construction and traversal of trees of nested contexts along a term.
5*)
6
7signature FUNCTION_CONTEXT_TREE =
8sig
9  (* poor man's contexts: fixes + assumes *)
10  type ctxt = (string * typ) list * thm list
11  type ctx_tree
12
13  val get_function_congs : Proof.context -> thm list
14  val add_function_cong : thm -> Context.generic -> Context.generic
15
16  val cong_add: attribute
17  val cong_del: attribute
18
19  val mk_tree: term -> term -> Proof.context -> term -> ctx_tree
20
21  val inst_tree: Proof.context -> term -> term -> ctx_tree -> ctx_tree
22
23  val export_term : ctxt -> term -> term
24  val export_thm : Proof.context -> ctxt -> thm -> thm
25  val import_thm : Proof.context -> ctxt -> thm -> thm
26
27  val traverse_tree :
28   (ctxt -> term ->
29   (ctxt * thm) list ->
30   (ctxt * thm) list * 'b ->
31   (ctxt * thm) list * 'b)
32   -> ctx_tree -> 'b -> 'b
33
34  val rewrite_by_tree : Proof.context -> term -> thm -> (thm * thm) list ->
35    ctx_tree -> thm * (thm * thm) list
36end
37
38structure Function_Context_Tree : FUNCTION_CONTEXT_TREE =
39struct
40
41type ctxt = (string * typ) list * thm list
42
43open Function_Common
44open Function_Lib
45
46structure FunctionCongs = Generic_Data
47(
48  type T = thm list
49  val empty = []
50  val extend = I
51  val merge = Thm.merge_thms
52);
53
54fun get_function_congs ctxt =
55  FunctionCongs.get (Context.Proof ctxt)
56  |> map (Thm.transfer' ctxt);
57
58val add_function_cong = FunctionCongs.map o Thm.add_thm o Thm.trim_context;
59
60
61(* congruence rules *)
62
63val cong_add = Thm.declaration_attribute (add_function_cong o safe_mk_meta_eq);
64val cong_del = Thm.declaration_attribute (FunctionCongs.map o Thm.del_thm o safe_mk_meta_eq);
65
66
67type depgraph = int Int_Graph.T
68
69datatype ctx_tree =
70  Leaf of term
71  | Cong of (thm * depgraph * (ctxt * ctx_tree) list)
72  | RCall of (term * ctx_tree)
73
74
75(* Maps "Trueprop A = B" to "A" *)
76val rhs_of = snd o HOLogic.dest_eq o HOLogic.dest_Trueprop
77
78
79(*** Dependency analysis for congruence rules ***)
80
81fun branch_vars t =
82  let
83    val t' = snd (dest_all_all t)
84    val (assumes, concl) = Logic.strip_horn t'
85  in
86    (fold Term.add_vars assumes [], Term.add_vars concl [])
87  end
88
89fun cong_deps crule =
90  let
91    val num_branches = map_index (apsnd branch_vars) (Thm.prems_of crule)
92  in
93    Int_Graph.empty
94    |> fold (fn (i,_)=> Int_Graph.new_node (i,i)) num_branches
95    |> fold_product (fn (i, (c1, _)) => fn (j, (_, t2)) =>
96         if i = j orelse null (inter (op =) c1 t2)
97         then I else Int_Graph.add_edge_acyclic (i,j))
98       num_branches num_branches
99    end
100
101val default_congs =
102  map (fn c => c RS eq_reflection) [@{thm "cong"}, @{thm "ext"}]
103
104(* Called on the INSTANTIATED branches of the congruence rule *)
105fun mk_branch ctxt t =
106  let
107    val ((params, impl), ctxt') = Variable.focus NONE t ctxt
108    val (assms, concl) = Logic.strip_horn impl
109  in
110    (ctxt', map #2 params, assms, rhs_of concl)
111  end
112
113fun find_cong_rule ctxt fvar h ((r,dep)::rs) t =
114     (let
115        val thy = Proof_Context.theory_of ctxt
116
117        val tt' = Logic.mk_equals (Pattern.rewrite_term thy [(fvar, h)] [] t, t)
118        val (c, subs) = (Thm.concl_of r, Thm.prems_of r)
119
120        val subst = Pattern.match thy (c, tt') (Vartab.empty, Vartab.empty)
121        val branches = map (mk_branch ctxt o Envir.beta_norm o Envir.subst_term subst) subs
122        val inst =
123          map (fn v => (#1 v, Thm.cterm_of ctxt (Envir.subst_term subst (Var v))))
124            (Term.add_vars c [])
125      in
126        (infer_instantiate ctxt inst r, dep, branches)
127      end handle Pattern.MATCH => find_cong_rule ctxt fvar h rs t)
128  | find_cong_rule _ _ _ [] _ = raise General.Fail "No cong rule found!"
129
130
131fun mk_tree fvar h ctxt t =
132  let
133    val congs = get_function_congs ctxt
134
135    (* FIXME: Save in theory: *)
136    val congs_deps = map (fn c => (c, cong_deps c)) (congs @ default_congs)
137
138    fun matchcall (a $ b) = if a = fvar then SOME b else NONE
139      | matchcall _ = NONE
140
141    fun mk_tree' ctxt t =
142      case matchcall t of
143        SOME arg => RCall (t, mk_tree' ctxt arg)
144      | NONE =>
145        if not (exists_subterm (fn v => v = fvar) t) then Leaf t
146        else
147          let
148            val (r, dep, branches) = find_cong_rule ctxt fvar h congs_deps t
149            fun subtree (ctxt', fixes, assumes, st) =
150              ((fixes,
151                map (Thm.assume o Thm.cterm_of ctxt) assumes),
152               mk_tree' ctxt' st)
153          in
154            Cong (r, dep, map subtree branches)
155          end
156  in
157    mk_tree' ctxt t
158  end
159
160fun inst_tree ctxt fvar f tr =
161  let
162    val cfvar = Thm.cterm_of ctxt fvar
163    val cf = Thm.cterm_of ctxt f
164
165    fun inst_term t =
166      subst_bound(f, abstract_over (fvar, t))
167
168    val inst_thm = Thm.forall_elim cf o Thm.forall_intr cfvar
169
170    fun inst_tree_aux (Leaf t) = Leaf t
171      | inst_tree_aux (Cong (crule, deps, branches)) =
172        Cong (inst_thm crule, deps, map inst_branch branches)
173      | inst_tree_aux (RCall (t, str)) =
174        RCall (inst_term t, inst_tree_aux str)
175    and inst_branch ((fxs, assms), str) =
176      ((fxs, map (Thm.assume o Thm.cterm_of ctxt o inst_term o Thm.prop_of) assms),
177       inst_tree_aux str)
178  in
179    inst_tree_aux tr
180  end
181
182
183(* Poor man's contexts: Only fixes and assumes *)
184fun compose (fs1, as1) (fs2, as2) = (fs1 @ fs2, as1 @ as2)
185
186fun export_term (fixes, assumes) =
187 fold_rev (curry Logic.mk_implies o Thm.prop_of) assumes
188 #> fold_rev (Logic.all o Free) fixes
189
190fun export_thm ctxt (fixes, assumes) =
191 fold_rev (Thm.implies_intr o Thm.cprop_of) assumes
192 #> fold_rev (Thm.forall_intr o Thm.cterm_of ctxt o Free) fixes
193
194fun import_thm ctxt (fixes, athms) =
195 fold (Thm.forall_elim o Thm.cterm_of ctxt o Free) fixes
196 #> fold Thm.elim_implies athms
197
198
199(* folds in the order of the dependencies of a graph. *)
200fun fold_deps G f x =
201  let
202    fun fill_table i (T, x) =
203      case Inttab.lookup T i of
204        SOME _ => (T, x)
205      | NONE =>
206        let
207          val (T', x') = Int_Graph.Keys.fold fill_table (Int_Graph.imm_succs G i) (T, x)
208          val (v, x'') = f (the o Inttab.lookup T') i x'
209        in
210          (Inttab.update (i, v) T', x'')
211        end
212
213    val (T, x) = fold fill_table (Int_Graph.keys G) (Inttab.empty, x)
214  in
215    (Inttab.fold (cons o snd) T [], x)
216  end
217
218fun traverse_tree rcOp tr =
219  let
220    fun traverse_help ctxt (Leaf _) _ x = ([], x)
221      | traverse_help ctxt (RCall (t, st)) u x =
222          rcOp ctxt t u (traverse_help ctxt st u x)
223      | traverse_help ctxt (Cong (_, deps, branches)) u x =
224          let
225            fun sub_step lu i x =
226              let
227                val (ctxt', subtree) = nth branches i
228                val used = Int_Graph.Keys.fold_rev (append o lu) (Int_Graph.imm_succs deps i) u
229                val (subs, x') = traverse_help (compose ctxt ctxt') subtree used x
230                val exported_subs = map (apfst (compose ctxt')) subs (* FIXME: Right order of composition? *)
231              in
232                (exported_subs, x')
233              end
234          in
235            fold_deps deps sub_step x
236            |> apfst flat
237          end
238  in
239    snd o traverse_help ([], []) tr []
240  end
241
242fun rewrite_by_tree ctxt h ih x tr =
243  let
244    fun rewrite_help _ _ x (Leaf t) = (Thm.reflexive (Thm.cterm_of ctxt t), x)
245      | rewrite_help fix h_as x (RCall (_ $ arg, st)) =
246        let
247          val (inner, (lRi,ha)::x') = rewrite_help fix h_as x st (* "a' = a" *)
248
249          val iha = import_thm ctxt (fix, h_as) ha (* (a', h a') : G *)
250            |> Conv.fconv_rule (Conv.arg_conv (Conv.comb_conv (Conv.arg_conv (K inner))))
251                                                    (* (a, h a) : G   *)
252          val inst_ih = Thm.instantiate' [] [SOME (Thm.cterm_of ctxt arg)] ih
253          val eq = Thm.implies_elim (Thm.implies_elim inst_ih lRi) iha (* h a = f a *)
254
255          val h_a'_eq_h_a = Thm.combination (Thm.reflexive (Thm.cterm_of ctxt h)) inner
256          val h_a_eq_f_a = eq RS eq_reflection
257          val result = Thm.transitive h_a'_eq_h_a h_a_eq_f_a
258        in
259          (result, x')
260        end
261      | rewrite_help fix h_as x (Cong (crule, deps, branches)) =
262        let
263          fun sub_step lu i x =
264            let
265              val ((fixes, assumes), st) = nth branches i
266              val used = map lu (Int_Graph.immediate_succs deps i)
267                |> map (fn u_eq => (u_eq RS sym) RS eq_reflection)
268                |> filter_out Thm.is_reflexive
269
270              val assumes' = map (simplify (put_simpset HOL_basic_ss ctxt addsimps used)) assumes
271
272              val (subeq, x') =
273                rewrite_help (fix @ fixes) (h_as @ assumes') x st
274              val subeq_exp =
275                export_thm ctxt (fixes, assumes) (HOLogic.mk_obj_eq subeq)
276            in
277              (subeq_exp, x')
278            end
279          val (subthms, x') = fold_deps deps sub_step x
280        in
281          (fold_rev (curry op COMP) subthms crule, x')
282        end
283  in
284    rewrite_help [] [] x tr
285  end
286
287end
288