1(*  Title:      HOL/Statespace/state_fun.ML
2    Author:     Norbert Schirmer, TU Muenchen
3*)
4
5signature STATE_FUN =
6sig
7  val lookupN : string
8  val updateN : string
9
10  val mk_constr : theory -> typ -> term
11  val mk_destr : theory -> typ -> term
12
13  val lookup_simproc : simproc
14  val update_simproc : simproc
15  val ex_lookup_eq_simproc : simproc
16  val ex_lookup_ss : simpset
17  val lazy_conj_simproc : simproc
18  val string_eq_simp_tac : Proof.context -> int -> tactic
19end;
20
21structure StateFun: STATE_FUN =
22struct
23
24val lookupN = @{const_name StateFun.lookup};
25val updateN = @{const_name StateFun.update};
26
27val sel_name = HOLogic.dest_string;
28
29fun mk_name i t =
30  (case try sel_name t of
31    SOME name => name
32  | NONE =>
33      (case t of
34        Free (x, _) => x
35      | Const (x, _) => x
36      | _ => "x" ^ string_of_int i));
37
38local
39
40val conj1_False = @{thm conj1_False};
41val conj2_False = @{thm conj2_False};
42val conj_True = @{thm conj_True};
43val conj_cong = @{thm conj_cong};
44
45fun isFalse (Const (@{const_name False}, _)) = true
46  | isFalse _ = false;
47
48fun isTrue (Const (@{const_name True}, _)) = true
49  | isTrue _ = false;
50
51in
52
53val lazy_conj_simproc =
54  Simplifier.make_simproc @{context} "lazy_conj_simp"
55   {lhss = [@{term "P & Q"}],
56    proc = fn _ => fn ctxt => fn ct =>
57      (case Thm.term_of ct of
58        Const (@{const_name HOL.conj},_) $ P $ Q =>
59          let
60            val P_P' = Simplifier.rewrite ctxt (Thm.cterm_of ctxt P);
61            val P' = P_P' |> Thm.prop_of |> Logic.dest_equals |> #2;
62          in
63            if isFalse P' then SOME (conj1_False OF [P_P'])
64            else
65              let
66                val Q_Q' = Simplifier.rewrite ctxt (Thm.cterm_of ctxt Q);
67                val Q' = Q_Q' |> Thm.prop_of |> Logic.dest_equals |> #2;
68              in
69                if isFalse Q' then SOME (conj2_False OF [Q_Q'])
70                else if isTrue P' andalso isTrue Q' then SOME (conj_True OF [P_P', Q_Q'])
71                else if P aconv P' andalso Q aconv Q' then NONE
72                else SOME (conj_cong OF [P_P', Q_Q'])
73              end
74           end
75      | _ => NONE)};
76
77fun string_eq_simp_tac ctxt =
78  simp_tac (put_simpset HOL_basic_ss ctxt
79    addsimps @{thms list.inject list.distinct char.inject
80      cong_exp_iff_simps simp_thms}
81    addsimprocs [lazy_conj_simproc]
82    |> Simplifier.add_cong @{thm block_conj_cong});
83
84end;
85
86val lookup_ss =
87  simpset_of (put_simpset HOL_basic_ss @{context}
88    addsimps (@{thms list.inject} @ @{thms char.inject}
89      @ @{thms list.distinct} @ @{thms simp_thms}
90      @ [@{thm StateFun.lookup_update_id_same}, @{thm StateFun.id_id_cancel},
91        @{thm StateFun.lookup_update_same}, @{thm StateFun.lookup_update_other}])
92    addsimprocs [lazy_conj_simproc]
93    addSolver StateSpace.distinctNameSolver
94    |> fold Simplifier.add_cong @{thms block_conj_cong});
95
96val ex_lookup_ss =
97  simpset_of (put_simpset HOL_ss @{context} addsimps @{thms StateFun.ex_id});
98
99
100structure Data = Generic_Data
101(
102  type T = simpset * simpset * bool;  (*lookup simpset, ex_lookup simpset, are simprocs installed*)
103  val empty = (empty_ss, empty_ss, false);
104  val extend = I;
105  fun merge ((ss1, ex_ss1, b1), (ss2, ex_ss2, b2)) =
106    (merge_ss (ss1, ss2), merge_ss (ex_ss1, ex_ss2), b1 orelse b2);
107);
108
109val _ = Theory.setup (Context.theory_map (Data.put (lookup_ss, ex_lookup_ss, false)));
110
111val lookup_simproc =
112  Simplifier.make_simproc @{context} "lookup_simp"
113   {lhss = [@{term "lookup d n (update d' c m v s)"}],
114    proc = fn _ => fn ctxt => fn ct =>
115      (case Thm.term_of ct of (Const (@{const_name StateFun.lookup}, lT) $ destr $ n $
116                   (s as Const (@{const_name StateFun.update}, uT) $ _ $ _ $ _ $ _ $ _)) =>
117        (let
118          val (_::_::_::_::sT::_) = binder_types uT;
119          val mi = Term.maxidx_of_term (Thm.term_of ct);
120          fun mk_upds (Const (@{const_name StateFun.update}, uT) $ d' $ c $ m $ v $ s) =
121                let
122                  val (_ :: _ :: _ :: fT :: _ :: _) = binder_types uT;
123                  val vT = domain_type fT;
124                  val (s', cnt) = mk_upds s;
125                  val (v', cnt') =
126                    (case v of
127                      Const (@{const_name K_statefun}, KT) $ v'' =>
128                        (case v'' of
129                          (Const (@{const_name StateFun.lookup}, _) $
130                            (d as (Const (@{const_name Fun.id}, _))) $ n' $ _) =>
131                              if d aconv c andalso n aconv m andalso m aconv n'
132                              then (v,cnt) (* Keep value so that
133                                              lookup_update_id_same can fire *)
134                              else
135                                (Const (@{const_name StateFun.K_statefun}, KT) $
136                                  Var (("v", cnt), vT), cnt + 1)
137                        | _ =>
138                          (Const (@{const_name StateFun.K_statefun}, KT) $
139                            Var (("v", cnt), vT), cnt + 1))
140                     | _ => (v, cnt));
141                in (Const (@{const_name StateFun.update}, uT) $ d' $ c $ m $ v' $ s', cnt') end
142            | mk_upds s = (Var (("s", mi + 1), sT), mi + 2);
143
144          val ct =
145            Thm.cterm_of ctxt
146              (Const (@{const_name StateFun.lookup}, lT) $ destr $ n $ fst (mk_upds s));
147          val basic_ss = #1 (Data.get (Context.Proof ctxt));
148          val ctxt' = ctxt |> Config.put simp_depth_limit 100 |> put_simpset basic_ss;
149          val thm = Simplifier.rewrite ctxt' ct;
150        in
151          if (op aconv) (Logic.dest_equals (Thm.prop_of thm))
152          then NONE
153          else SOME thm
154        end
155        handle Option.Option => NONE)
156      | _ => NONE)};
157
158
159local
160
161val meta_ext = @{thm StateFun.meta_ext};
162val ss' =
163  simpset_of (put_simpset HOL_ss @{context} addsimps
164    (@{thm StateFun.update_apply} :: @{thm Fun.o_apply} :: @{thms list.inject} @ @{thms char.inject}
165      @ @{thms list.distinct})
166    addsimprocs [lazy_conj_simproc, StateSpace.distinct_simproc]
167    |> fold Simplifier.add_cong @{thms block_conj_cong});
168
169in
170
171val update_simproc =
172  Simplifier.make_simproc @{context} "update_simp"
173   {lhss = [@{term "update d c n v s"}],
174    proc = fn _ => fn ctxt => fn ct =>
175      (case Thm.term_of ct of
176        Const (@{const_name StateFun.update}, uT) $ _ $ _ $ _ $ _ $ _ =>
177          let
178            val (_ :: _ :: _ :: _ :: sT :: _) = binder_types uT;
179              (*"('v => 'a1) => ('a2 => 'v) => 'n => ('a1 => 'a2) => ('n => 'v) => ('n => 'v)"*)
180            fun init_seed s = (Bound 0, Bound 0, [("s", sT)], [], false);
181
182            fun mk_comp f fT g gT =
183              let val T = domain_type fT --> range_type gT
184              in (Const (@{const_name Fun.comp}, gT --> fT --> T) $ g $ f, T) end;
185
186            fun mk_comps fs = foldl1 (fn ((f, fT), (g, gT)) => mk_comp g gT f fT) fs;
187
188            fun append n c cT f fT d dT comps =
189              (case AList.lookup (op aconv) comps n of
190                SOME gTs => AList.update (op aconv) (n, [(c, cT), (f, fT), (d, dT)] @ gTs) comps
191              | NONE => AList.update (op aconv) (n, [(c, cT), (f, fT), (d, dT)]) comps);
192
193            fun split_list (x :: xs) = let val (xs', y) = split_last xs in (x, xs', y) end
194              | split_list _ = error "StateFun.split_list";
195
196            fun merge_upds n comps =
197              let val ((c, cT), fs, (d, dT)) = split_list (the (AList.lookup (op aconv) comps n))
198              in ((c, cT), fst (mk_comps fs), (d, dT)) end;
199
200               (* mk_updterm returns
201                *  - (orig-term-skeleton,simplified-term-skeleton, vars, b)
202                *     where boolean b tells if a simplification has occurred.
203                      "orig-term-skeleton = simplified-term-skeleton" is
204                *     the desired simplification rule.
205                * The algorithm first walks down the updates to the seed-state while
206                * memorising the updates in the already-table. While walking up the
207                * updates again, the optimised term is constructed.
208                *)
209            fun mk_updterm already
210                ((upd as Const (@{const_name StateFun.update}, uT)) $ d $ c $ n $ v $ s) =
211                  let
212                    fun rest already = mk_updterm already;
213                    val (dT :: cT :: nT :: vT :: sT :: _) = binder_types uT;
214                      (*"('v => 'a1) => ('a2 => 'v) => 'n => ('a1 => 'a2) =>
215                            ('n => 'v) => ('n => 'v)"*)
216                  in
217                    if member (op aconv) already n then
218                      (case rest already s of
219                        (trm, trm', vars, comps, _) =>
220                          let
221                            val i = length vars;
222                            val kv = (mk_name i n, vT);
223                            val kb = Bound i;
224                            val comps' = append n c cT kb vT d dT comps;
225                          in (upd $ d $ c $ n $ kb $ trm, trm', kv :: vars, comps',true) end)
226                    else
227                      (case rest (n :: already) s of
228                        (trm, trm', vars, comps, b) =>
229                          let
230                            val i = length vars;
231                            val kv = (mk_name i n, vT);
232                            val kb = Bound i;
233                            val comps' = append n c cT kb vT d dT comps;
234                            val ((c', c'T), f', (d', d'T)) = merge_upds n comps';
235                            val vT' = range_type d'T --> domain_type c'T;
236                            val upd' =
237                              Const (@{const_name StateFun.update},
238                                d'T --> c'T --> nT --> vT' --> sT --> sT);
239                          in
240                            (upd $ d $ c $ n $ kb $ trm, upd' $ d' $ c' $ n $ f' $ trm', kv :: vars,
241                              comps', b)
242                          end)
243                  end
244              | mk_updterm _ t = init_seed t;
245
246            val ctxt0 = Config.put simp_depth_limit 100 ctxt;
247            val ctxt1 = put_simpset ss' ctxt0;
248            val ctxt2 = put_simpset (#1 (Data.get (Context.Proof ctxt0))) ctxt0;
249          in
250            (case mk_updterm [] (Thm.term_of ct) of
251              (trm, trm', vars, _, true) =>
252                let
253                  val eq1 =
254                    Goal.prove ctxt0 [] []
255                      (Logic.list_all (vars, Logic.mk_equals (trm, trm')))
256                      (fn _ => resolve_tac ctxt0 [meta_ext] 1 THEN simp_tac ctxt1 1);
257                  val eq2 = Simplifier.asm_full_rewrite ctxt2 (Thm.dest_equals_rhs (Thm.cprop_of eq1));
258                in SOME (Thm.transitive eq1 eq2) end
259            | _ => NONE)
260          end
261      | _ => NONE)};
262
263end;
264
265
266local
267
268val swap_ex_eq = @{thm StateFun.swap_ex_eq};
269
270fun is_selector thy T sel =
271  let val (flds, more) = Record.get_recT_fields thy T
272  in member (fn (s, (n, _)) => n = s) (more :: flds) sel end;
273
274in
275
276val ex_lookup_eq_simproc =
277  Simplifier.make_simproc @{context} "ex_lookup_eq_simproc"
278   {lhss = [@{term "Ex t"}],
279    proc = fn _ => fn ctxt => fn ct =>
280      let
281        val thy = Proof_Context.theory_of ctxt;
282        val t = Thm.term_of ct;
283
284        val ex_lookup_ss = #2 (Data.get (Context.Proof ctxt));
285        val ctxt' = ctxt |> Config.put simp_depth_limit 100 |> put_simpset ex_lookup_ss;
286        fun prove prop =
287          Goal.prove_global thy [] [] prop
288            (fn _ => Record.split_simp_tac ctxt [] (K ~1) 1 THEN simp_tac ctxt' 1);
289
290        fun mkeq (swap, Teq, lT, lo, d, n, x, s) i =
291          let
292            val (_ :: nT :: _) = binder_types lT;
293            (*  ('v => 'a) => 'n => ('n => 'v) => 'a *)
294            val x' = if not (Term.is_dependent x) then Bound 1 else raise TERM ("", [x]);
295            val n' = if not (Term.is_dependent n) then Bound 2 else raise TERM ("", [n]);
296            val sel' = lo $ d $ n' $ s;
297          in (Const (@{const_name HOL.eq}, Teq) $ sel' $ x', hd (binder_types Teq), nT, swap) end;
298
299        fun dest_state (s as Bound 0) = s
300          | dest_state (s as (Const (sel, sT) $ Bound 0)) =
301              if is_selector thy (domain_type sT) sel then s
302              else raise TERM ("StateFun.ex_lookup_eq_simproc: not a record slector", [s])
303          | dest_state s = raise TERM ("StateFun.ex_lookup_eq_simproc: not a record slector", [s]);
304
305        fun dest_sel_eq
306              (Const (@{const_name HOL.eq}, Teq) $
307                ((lo as (Const (@{const_name StateFun.lookup}, lT))) $ d $ n $ s) $ X) =
308              (false, Teq, lT, lo, d, n, X, dest_state s)
309          | dest_sel_eq
310              (Const (@{const_name HOL.eq}, Teq) $ X $
311                ((lo as (Const (@{const_name StateFun.lookup}, lT))) $ d $ n $ s)) =
312              (true, Teq, lT, lo, d, n, X, dest_state s)
313          | dest_sel_eq _ = raise TERM ("", []);
314      in
315        (case t of
316          Const (@{const_name Ex}, Tex) $ Abs (s, T, t) =>
317            (let
318              val (eq, eT, nT, swap) = mkeq (dest_sel_eq t) 0;
319              val prop =
320                Logic.list_all ([("n", nT), ("x", eT)],
321                  Logic.mk_equals (Const (@{const_name Ex}, Tex) $ Abs (s, T, eq), @{term True}));
322              val thm = Drule.export_without_context (prove prop);
323              val thm' = if swap then swap_ex_eq OF [thm] else thm
324            in SOME thm' end handle TERM _ => NONE)
325        | _ => NONE)
326      end handle Option.Option => NONE};
327
328end;
329
330val val_sfx = "V";
331val val_prfx = "StateFun."
332fun deco base_prfx s = val_prfx ^ (base_prfx ^ suffix val_sfx s);
333
334fun mkUpper str =
335  (case String.explode str of
336    [] => ""
337  | c::cs => String.implode (Char.toUpper c :: cs));
338
339fun mkName (Type (T,args)) = implode (map mkName args) ^ mkUpper (Long_Name.base_name T)
340  | mkName (TFree (x,_)) = mkUpper (Long_Name.base_name x)
341  | mkName (TVar ((x,_),_)) = mkUpper (Long_Name.base_name x);
342
343fun is_datatype thy = is_some o BNF_LFP_Compat.get_info thy [BNF_LFP_Compat.Keep_Nesting];
344
345fun mk_map @{type_name List.list} = Syntax.const @{const_name List.map}
346  | mk_map n = Syntax.const ("StateFun.map_" ^ Long_Name.base_name n);
347
348fun gen_constr_destr comp prfx thy (Type (T, [])) =
349      Syntax.const (deco prfx (mkUpper (Long_Name.base_name T)))
350  | gen_constr_destr comp prfx thy (T as Type ("fun",_)) =
351      let val (argTs, rangeT) = strip_type T;
352      in
353        comp
354          (Syntax.const (deco prfx (implode (map mkName argTs) ^ "Fun")))
355          (fold (fn x => fn y => x $ y)
356            (replicate (length argTs) (Syntax.const "StateFun.map_fun"))
357            (gen_constr_destr comp prfx thy rangeT))
358      end
359  | gen_constr_destr comp prfx thy (T' as Type (T, argTs)) =
360      if is_datatype thy T
361      then (* datatype args are recursively embedded into val *)
362        (case argTs of
363          [argT] =>
364            comp
365              ((Syntax.const (deco prfx (mkUpper (Long_Name.base_name T)))))
366              ((mk_map T $ gen_constr_destr comp prfx thy argT))
367        | _ => raise (TYPE ("StateFun.gen_constr_destr", [T'], [])))
368      else (* type args are not recursively embedded into val *)
369        Syntax.const (deco prfx (implode (map mkName argTs) ^ mkUpper (Long_Name.base_name T)))
370  | gen_constr_destr thy _ _ T = raise (TYPE ("StateFun.gen_constr_destr", [T], []));
371
372val mk_constr = gen_constr_destr (fn a => fn b => Syntax.const @{const_name Fun.comp} $ a $ b) "";
373val mk_destr = gen_constr_destr (fn a => fn b => Syntax.const @{const_name Fun.comp} $ b $ a) "the_";
374
375val _ =
376  Theory.setup
377    (Attrib.setup @{binding statefun_simp}
378      (Scan.succeed (Thm.declaration_attribute (fn thm => fn context =>
379        let
380          val ctxt = Context.proof_of context;
381          val (lookup_ss, ex_lookup_ss, simprocs_active) = Data.get context;
382          val (lookup_ss', ex_lookup_ss') =
383            (case Thm.concl_of thm of
384              (_ $ ((Const (@{const_name Ex}, _) $ _))) =>
385                (lookup_ss, simpset_map ctxt (Simplifier.add_simp thm) ex_lookup_ss)
386            | _ =>
387                (simpset_map ctxt (Simplifier.add_simp thm) lookup_ss, ex_lookup_ss));
388          val activate_simprocs =
389            if simprocs_active then I
390            else Simplifier.map_ss (fn ctxt => ctxt addsimprocs [lookup_simproc, update_simproc]);
391        in
392          context
393          |> activate_simprocs
394          |> Data.put (lookup_ss', ex_lookup_ss', true)
395        end)))
396      "simplification in statespaces");
397
398end;
399