1(*  Title:      HOL/Statespace/distinct_tree_prover.ML
2    Author:     Norbert Schirmer, TU Muenchen
3*)
4
5signature DISTINCT_TREE_PROVER =
6sig
7  datatype direction = Left | Right
8  val mk_tree : ('a -> term) -> typ -> 'a list -> term
9  val dest_tree : term -> term list
10  val find_tree : term -> term -> direction list option
11
12  val neq_to_eq_False : thm
13  val distinctTreeProver : Proof.context -> thm -> direction list -> direction list -> thm
14  val neq_x_y : Proof.context -> term -> term -> string -> thm option
15  val distinctFieldSolver : string list -> solver
16  val distinctTree_tac : string list -> Proof.context -> int -> tactic
17  val distinct_implProver : Proof.context -> thm -> cterm -> thm
18  val subtractProver : Proof.context -> term -> cterm -> thm -> thm
19  val distinct_simproc : string list -> simproc
20
21  val discharge : Proof.context -> thm list -> thm -> thm
22end;
23
24structure DistinctTreeProver : DISTINCT_TREE_PROVER =
25struct
26
27val neq_to_eq_False = @{thm neq_to_eq_False};
28
29datatype direction = Left | Right;
30
31fun treeT T = Type (@{type_name tree}, [T]);
32
33fun mk_tree' e T n [] = Const (@{const_name Tip}, treeT T)
34  | mk_tree' e T n xs =
35     let
36       val m = (n - 1) div 2;
37       val (xsl,x::xsr) = chop m xs;
38       val l = mk_tree' e T m xsl;
39       val r = mk_tree' e T (n-(m+1)) xsr;
40     in
41       Const (@{const_name Node}, treeT T --> T --> HOLogic.boolT--> treeT T --> treeT T) $
42         l $ e x $ @{term False} $ r
43     end
44
45fun mk_tree e T xs = mk_tree' e T (length xs) xs;
46
47fun dest_tree (Const (@{const_name Tip}, _)) = []
48  | dest_tree (Const (@{const_name Node}, _) $ l $ e $ _ $ r) = dest_tree l @ e :: dest_tree r
49  | dest_tree t = raise TERM ("dest_tree", [t]);
50
51
52
53fun lin_find_tree e (Const (@{const_name Tip}, _)) = NONE
54  | lin_find_tree e (Const (@{const_name Node}, _) $ l $ x $ _ $ r) =
55      if e aconv x
56      then SOME []
57      else
58        (case lin_find_tree e l of
59          SOME path => SOME (Left :: path)
60        | NONE =>
61            (case lin_find_tree e r of
62              SOME path => SOME (Right :: path)
63            | NONE => NONE))
64  | lin_find_tree e t = raise TERM ("find_tree: input not a tree", [t])
65
66fun bin_find_tree order e (Const (@{const_name Tip}, _)) = NONE
67  | bin_find_tree order e (Const (@{const_name Node}, _) $ l $ x $ _ $ r) =
68      (case order (e, x) of
69        EQUAL => SOME []
70      | LESS => Option.map (cons Left) (bin_find_tree order e l)
71      | GREATER => Option.map (cons Right) (bin_find_tree order e r))
72  | bin_find_tree order e t = raise TERM ("find_tree: input not a tree", [t])
73
74fun find_tree e t =
75  (case bin_find_tree Term_Ord.fast_term_ord e t of
76    NONE => lin_find_tree e t
77  | x => x);
78
79
80fun split_common_prefix xs [] = ([], xs, [])
81  | split_common_prefix [] ys = ([], [], ys)
82  | split_common_prefix (xs as (x :: xs')) (ys as (y :: ys')) =
83      if x = y
84      then let val (ps, xs'', ys'') = split_common_prefix xs' ys' in (x :: ps, xs'', ys'') end
85      else ([], xs, ys)
86
87
88(* Wrapper around Thm.instantiate. The type instiations of instTs are applied to
89 * the right hand sides of insts
90 *)
91fun instantiate ctxt instTs insts =
92  let
93    val instTs' = map (fn (T, U) => (dest_TVar (Thm.typ_of T), Thm.typ_of U)) instTs;
94    fun substT x = (case AList.lookup (op =) instTs' x of NONE => TVar x | SOME T' => T');
95    fun mapT_and_recertify ct =
96      (Thm.cterm_of ctxt (Term.map_types (Term.map_type_tvar substT) (Thm.term_of ct)));
97    val insts' = map (apfst mapT_and_recertify) insts;
98  in
99    Thm.instantiate
100     (map (apfst (dest_TVar o Thm.typ_of)) instTs,
101      map (apfst (dest_Var o Thm.term_of)) insts')
102  end;
103
104fun tvar_clash ixn S S' =
105  raise TYPE ("Type variable has two distinct sorts", [TVar (ixn, S), TVar (ixn, S')], []);
106
107fun lookup (tye, (ixn, S)) =
108  (case AList.lookup (op =) tye ixn of
109    NONE => NONE
110  | SOME (S', T) => if S = S' then SOME T else tvar_clash ixn S S');
111
112val naive_typ_match =
113  let
114    fun match (TVar (v, S), T) subs =
115          (case lookup (subs, (v, S)) of
116            NONE => ((v, (S, T))::subs)
117          | SOME _ => subs)
118      | match (Type (a, Ts), Type (b, Us)) subs =
119          if a <> b then raise Type.TYPE_MATCH
120          else matches (Ts, Us) subs
121      | match (TFree x, TFree y) subs =
122          if x = y then subs else raise Type.TYPE_MATCH
123      | match _ _ = raise Type.TYPE_MATCH
124    and matches (T :: Ts, U :: Us) subs = matches (Ts, Us) (match (T, U) subs)
125      | matches _ subs = subs;
126  in match end;
127
128
129(* expects that relevant type variables are already contained in
130 * term variables. First instantiation of variables is returned without further
131 * checking.
132 *)
133fun naive_cterm_first_order_match (t, ct) env =
134  let
135    fun mtch (env as (tyinsts, insts)) =
136      fn (Var (ixn, T), ct) =>
137          (case AList.lookup (op =) insts ixn of
138            NONE => (naive_typ_match (T, Thm.typ_of_cterm ct) tyinsts, (ixn, ct) :: insts)
139          | SOME _ => env)
140       | (f $ t, ct) =>
141          let val (cf, ct') = Thm.dest_comb ct;
142          in mtch (mtch env (f, cf)) (t, ct') end
143       | _ => env;
144  in mtch env (t, ct) end;
145
146
147fun discharge ctxt prems rule =
148  let
149    val (tyinsts,insts) =
150      fold naive_cterm_first_order_match (Thm.prems_of rule ~~ map Thm.cprop_of prems) ([], []);
151    val tyinsts' =
152      map (fn (v, (S, U)) => ((v, S), Thm.ctyp_of ctxt U)) tyinsts;
153    val insts' =
154      map (fn (idxn, ct) => ((idxn, Thm.typ_of_cterm ct), ct)) insts;
155    val rule' = Thm.instantiate (tyinsts', insts') rule;
156  in fold Thm.elim_implies prems rule' end;
157
158local
159
160val (l_in_set_root, x_in_set_root, r_in_set_root) =
161  let
162    val (Node_l_x_d, r) =
163      Thm.cprop_of @{thm in_set_root}
164      |> Thm.dest_comb |> #2
165      |> Thm.dest_comb |> #2 |> Thm.dest_comb |> #2 |> Thm.dest_comb;
166    val (Node_l, x) = Node_l_x_d |> Thm.dest_comb |> #1 |> Thm.dest_comb;
167    val l = Node_l |> Thm.dest_comb |> #2;
168  in (l,x,r) end;
169
170val (x_in_set_left, r_in_set_left) =
171  let
172    val (Node_l_x_d, r) =
173      Thm.cprop_of @{thm in_set_left}
174      |> Thm.dest_comb |> #2 |> Thm.dest_comb |> #2
175      |> Thm.dest_comb |> #2 |> Thm.dest_comb |> #2 |> Thm.dest_comb;
176    val x = Node_l_x_d |> Thm.dest_comb |> #1 |> Thm.dest_comb |> #2;
177  in (x, r) end;
178
179val (x_in_set_right, l_in_set_right) =
180  let
181    val (Node_l, x) =
182      Thm.cprop_of @{thm in_set_right}
183      |> Thm.dest_comb |> #2 |> Thm.dest_comb |> #2
184      |> Thm.dest_comb |> #2 |> Thm.dest_comb |> #2
185      |> Thm.dest_comb |> #1 |> Thm.dest_comb |> #1
186      |> Thm.dest_comb;
187    val l = Node_l |> Thm.dest_comb |> #2;
188  in (x, l) end;
189
190in
191(*
1921. First get paths x_path y_path of x and y in the tree.
1932. For the common prefix descend into the tree according to the path
194   and lemmas all_distinct_left/right
1953. If one restpath is empty use distinct_left/right,
196   otherwise all_distinct_left_right
197*)
198
199fun distinctTreeProver ctxt dist_thm x_path y_path =
200  let
201    fun dist_subtree [] thm = thm
202      | dist_subtree (p :: ps) thm =
203         let
204           val rule =
205            (case p of Left => @{thm all_distinct_left} | Right => @{thm all_distinct_right})
206         in dist_subtree ps (discharge ctxt [thm] rule) end;
207
208    val (ps, x_rest, y_rest) = split_common_prefix x_path y_path;
209    val dist_subtree_thm = dist_subtree ps dist_thm;
210    val subtree = Thm.cprop_of dist_subtree_thm |> Thm.dest_comb |> #2 |> Thm.dest_comb |> #2;
211    val (_, [l, _, _, r]) = Drule.strip_comb subtree;
212
213    fun in_set ps tree =
214      let
215        val (_, [l, x, _, r]) = Drule.strip_comb tree;
216        val xT = Thm.ctyp_of_cterm x;
217      in
218        (case ps of
219          [] =>
220            instantiate ctxt
221              [(Thm.ctyp_of_cterm x_in_set_root, xT)]
222              [(l_in_set_root, l), (x_in_set_root, x), (r_in_set_root, r)] @{thm in_set_root}
223        | Left :: ps' =>
224            let
225              val in_set_l = in_set ps' l;
226              val in_set_left' =
227                instantiate ctxt
228                  [(Thm.ctyp_of_cterm x_in_set_left, xT)]
229                  [(x_in_set_left, x), (r_in_set_left, r)] @{thm in_set_left};
230            in discharge ctxt [in_set_l] in_set_left' end
231        | Right :: ps' =>
232            let
233              val in_set_r = in_set ps' r;
234              val in_set_right' =
235                instantiate ctxt
236                  [(Thm.ctyp_of_cterm x_in_set_right, xT)]
237                  [(x_in_set_right, x), (l_in_set_right, l)] @{thm in_set_right};
238            in discharge ctxt [in_set_r] in_set_right' end)
239      end;
240
241  fun in_set' [] = raise TERM ("distinctTreeProver", [])
242    | in_set' (Left :: ps) = in_set ps l
243    | in_set' (Right :: ps) = in_set ps r;
244
245  fun distinct_lr node_in_set Left =
246        discharge ctxt [dist_subtree_thm, node_in_set] @{thm distinct_left}
247    | distinct_lr node_in_set Right =
248        discharge ctxt [dist_subtree_thm, node_in_set] @{thm distinct_right}
249
250  val (swap, neq) =
251    (case x_rest of
252      [] =>
253        let val y_in_set = in_set' y_rest;
254        in (false, distinct_lr y_in_set (hd y_rest)) end
255    | xr :: xrs =>
256        (case y_rest of
257          [] =>
258            let val x_in_set = in_set' x_rest;
259            in (true, distinct_lr x_in_set (hd x_rest)) end
260        | yr :: yrs =>
261            let
262              val x_in_set = in_set' x_rest;
263              val y_in_set = in_set' y_rest;
264            in
265              (case xr of
266                Left =>
267                  (false,
268                    discharge ctxt [dist_subtree_thm, x_in_set, y_in_set] @{thm distinct_left_right})
269              | Right =>
270                  (true,
271                    discharge ctxt [dist_subtree_thm, y_in_set, x_in_set] @{thm distinct_left_right}))
272           end));
273  in if swap then discharge ctxt [neq] @{thm swap_neq} else neq end;
274
275
276fun deleteProver _ dist_thm [] = @{thm delete_root} OF [dist_thm]
277  | deleteProver ctxt dist_thm (p::ps) =
278      let
279        val dist_rule =
280          (case p of Left => @{thm all_distinct_left} | Right => @{thm all_distinct_right});
281        val dist_thm' = discharge ctxt [dist_thm] dist_rule;
282        val del_rule = (case p of Left => @{thm delete_left} | Right => @{thm delete_right});
283        val del = deleteProver ctxt dist_thm' ps;
284      in discharge ctxt [dist_thm, del] del_rule end;
285
286local
287  val (alpha, v) =
288    let
289      val ct =
290        @{thm subtract_Tip} |> Thm.cprop_of |> Thm.dest_comb |> #2 |> Thm.dest_comb |> #2
291        |> Thm.dest_comb |> #2;
292      val [alpha] = ct |> Thm.ctyp_of_cterm |> Thm.dest_ctyp;
293    in (dest_TVar (Thm.typ_of alpha), #1 (dest_Var (Thm.term_of ct))) end;
294in
295
296fun subtractProver ctxt (Const (@{const_name Tip}, T)) ct dist_thm =
297      let
298        val ct' = dist_thm |> Thm.cprop_of |> Thm.dest_comb |> #2 |> Thm.dest_comb |> #2;
299        val [alphaI] = #2 (dest_Type T);
300      in
301        Thm.instantiate
302          ([(alpha, Thm.ctyp_of ctxt alphaI)],
303           [((v, treeT alphaI), ct')]) @{thm subtract_Tip}
304      end
305  | subtractProver ctxt (Const (@{const_name Node}, nT) $ l $ x $ d $ r) ct dist_thm =
306      let
307        val ct' = dist_thm |> Thm.cprop_of |> Thm.dest_comb |> #2 |> Thm.dest_comb |> #2;
308        val (_, [cl, _, _, cr]) = Drule.strip_comb ct;
309        val ps = the (find_tree x (Thm.term_of ct'));
310        val del_tree = deleteProver ctxt dist_thm ps;
311        val dist_thm' = discharge ctxt [del_tree, dist_thm] @{thm delete_Some_all_distinct};
312        val sub_l = subtractProver ctxt (Thm.term_of cl) cl (dist_thm');
313        val sub_r =
314          subtractProver ctxt (Thm.term_of cr) cr
315            (discharge ctxt [sub_l, dist_thm'] @{thm subtract_Some_all_distinct_res});
316      in discharge ctxt [del_tree, sub_l, sub_r] @{thm subtract_Node} end;
317
318end;
319
320fun distinct_implProver ctxt dist_thm ct =
321  let
322    val ctree = ct |> Thm.dest_comb |> #2 |> Thm.dest_comb |> #2;
323    val sub = subtractProver ctxt (Thm.term_of ctree) ctree dist_thm;
324  in @{thm subtract_Some_all_distinct} OF [sub, dist_thm] end;
325
326fun get_fst_success f [] = NONE
327  | get_fst_success f (x :: xs) =
328      (case f x of
329        NONE => get_fst_success f xs
330      | SOME v => SOME v);
331
332fun neq_x_y ctxt x y name =
333  (let
334    val dist_thm = the (try (Proof_Context.get_thm ctxt) name);
335    val ctree = Thm.cprop_of dist_thm |> Thm.dest_comb |> #2 |> Thm.dest_comb |> #2;
336    val tree = Thm.term_of ctree;
337    val x_path = the (find_tree x tree);
338    val y_path = the (find_tree y tree);
339    val thm = distinctTreeProver ctxt dist_thm x_path y_path;
340  in SOME thm
341  end handle Option.Option => NONE);
342
343fun distinctTree_tac names ctxt = SUBGOAL (fn (goal, i) =>
344    (case goal of
345      Const (@{const_name Trueprop}, _) $
346          (Const (@{const_name Not}, _) $ (Const (@{const_name HOL.eq}, _) $ x $ y)) =>
347        (case get_fst_success (neq_x_y ctxt x y) names of
348          SOME neq => resolve_tac ctxt [neq] i
349        | NONE => no_tac)
350    | _ => no_tac))
351
352fun distinctFieldSolver names =
353  mk_solver "distinctFieldSolver" (distinctTree_tac names);
354
355fun distinct_simproc names =
356  Simplifier.make_simproc @{context} "DistinctTreeProver.distinct_simproc"
357   {lhss = [@{term "x = y"}],
358    proc = fn _ => fn ctxt => fn ct =>
359      (case Thm.term_of ct of
360        Const (@{const_name HOL.eq}, _) $ x $ y =>
361          Option.map (fn neq => @{thm neq_to_eq_False} OF [neq])
362            (get_fst_success (neq_x_y ctxt x y) names)
363      | _ => NONE)};
364
365end;
366
367end;
368