1(* ========================================================================= *)
2(* HIGHER-ORDER UTILITY FUNCTIONS                                            *)
3(* Joe Hurd, 10 June 2001                                                    *)
4(* ========================================================================= *)
5
6structure subtypeUseful :> subtypeUseful =
7struct
8
9open Susp HolKernel Parse Hol_pp boolLib BasicProvers pred_setTheory;
10
11infixr 0 oo ++ << || THEN THENC ORELSEC THENR ORELSER ## thenf orelsef;
12infix 1 >> |->;
13
14val op++ = op THEN;
15val op<< = op THENL;
16val op|| = op ORELSE;
17
18(* ------------------------------------------------------------------------- *)
19(* Basic ML datatypes/functions.                                             *)
20(* ------------------------------------------------------------------------- *)
21
22type 'a thunk = unit -> 'a;
23(* type (''a, 'b) cache = (''a, 'b) Polyhash.hash_table; *)
24type 'a susp = 'a Susp.susp;
25type ppstream = Portable.ppstream;
26type ('a, 'b) maplet = {redex : 'a, residue : 'b};
27type ('a, 'b) subst = ('a, 'b) Lib.subst;
28
29(* Error handling *)
30
31exception BUG_EXN of
32  {origin_structure : string, origin_function : string, message : string};
33
34fun ERR f s = HOL_ERR
35  {origin_structure = "subtypeUseful", origin_function = f, message = s};
36
37fun BUG f s = BUG_EXN
38  {origin_structure = "subtypeUseful", origin_function = f, message = s};
39
40fun BUG_to_string (BUG_EXN {origin_structure, origin_function, message}) =
41  ("\nBUG discovered by " ^ origin_structure ^ " at " ^
42   origin_function ^ ":\n" ^ message ^ "\n")
43  | BUG_to_string _ = raise BUG "print_BUG" "not a BUG_EXN";
44
45fun err_BUG s (h as HOL_ERR _) =
46  (print (exn_to_string h); BUG s "should never fail")
47  | err_BUG s _ =
48  raise BUG "err_BUG" ("not a HOL_ERR (called from " ^ s ^ ")");
49
50(* Success and failure *)
51
52fun assert b e = if b then () else raise e;
53fun try f a = f a
54  handle (h as HOL_ERR _) => (print (exn_to_string h); raise h)
55       | (b as BUG_EXN _) => (print (BUG_to_string b); raise b)
56       | e => (print "\ntry: strange exception raised\n"; raise e);
57fun total f x = SOME (f x) handle HOL_ERR _ => NONE;
58fun can f = Option.isSome o total f;
59fun partial (e as HOL_ERR _) f x = (case f x of SOME y => y | NONE => raise e)
60  | partial _ _ _ = raise BUG "partial" "must take a HOL_ERR";
61
62(* Exception combinators *)
63
64fun nof x = raise ERR "nof" "never succeeds";
65fun allf x = x;
66fun op thenf (f, g) x = g (f x);
67fun op orelsef (f, g) x = f x handle HOL_ERR _ => g x;
68fun tryf f = f orelsef allf;
69fun repeatf f x = ((f thenf repeatf f) orelsef allf) x;
70fun repeatplusf f = f thenf repeatf f;
71fun firstf [] _ = raise ERR "firstf" "out of combinators"
72  | firstf (f :: rest) x = (f orelsef firstf rest) x;
73
74(* Combinators *)
75
76fun A f x = f x;
77fun C f x y = f y x;
78fun I x = x;
79fun K x y = x;
80fun N 0 _ x = x | N 1 f x = f x | N n f x = N (n - 1) f (f x);
81fun S f g x = f x (g x);
82fun W f x = f x x;
83fun f oo g = fn x => f o (g x);
84
85(* Pairs *)
86
87infix 3 ##
88fun (f ## g) (x, y) = (f x, g y);
89fun D x = (x, x);
90fun Df f = f ## f;
91fun fst (x,_) = x;
92fun snd (_,y) = y;
93fun add_fst x y = (x, y);
94fun add_snd x y = (y, x);
95fun curry f x y = f (x, y);
96fun uncurry f (x, y) = f x y;
97fun equal x y = (x = y);
98
99fun pair_to_string fst_to_string snd_to_string (a, b) =
100  "(" ^ fst_to_string a ^ ", " ^ snd_to_string b ^ ")";
101
102(* Ints *)
103
104val plus = curry op+;
105val multiply = curry op*;
106val succ = plus 1;
107
108(* Strings *)
109
110val concat = curry op^;
111val int_to_string = Int.toString;
112val string_to_int =
113  partial (ERR "string_to_int" "couldn't convert string") Int.fromString;
114
115fun mk_string_fn name args = name ^ String.concat (map (concat "_") args);
116fun dest_string_fn name s =
117  (case String.tokens (fn #"_" => true | _ => false) s of []
118     => raise ERR "pure_dest_fn" "empty string"
119   | f::args => (assert (f = name) (ERR "dest_fn" "wrong name"); args));
120fun is_string_fn name = can (dest_string_fn name);
121
122(* --------------------------------------------------------------------- *)
123(* Tools for debugging.                                                  *)
124(* --------------------------------------------------------------------- *)
125
126(* Timing *)
127
128local
129  fun iterate f a 0 = ()
130    | iterate f a n = (f a; iterate f a (n - 1))
131in
132  fun time_n n f a = time (iterate f a) n
133end;
134
135(* Test cases *)
136
137fun tt f = (time o try) f;
138fun tt2 f = tt o f;
139fun tt3 f = tt2 o f;
140fun tt4 f = tt3 o f;
141
142fun ff f =
143  try
144  (fn x =>
145   case (time o total o try) f x of NONE => ()
146   | SOME _ => raise ERR "ff" "f should not have succeeded!");
147fun ff2 f = ff o f;
148fun ff3 f = ff2 o f;
149fun ff4 f = ff3 o f;
150
151(* --------------------------------------------------------------------- *)
152(* Useful imperative features.                                           *)
153(* --------------------------------------------------------------------- *)
154
155(* Fresh integers *)
156
157local
158  val counter = ref 0
159in
160  fun new_int ()
161    = let val c = !counter
162          val _ = counter := c + 1
163      in c end
164end;
165
166(* Random numbers *)
167
168val random_generator = Random.newgen ();
169fun random_integer n = Random.range (0, n) random_generator;
170fun random_real () = Random.random random_generator;
171
172(* Function cacheing *)
173
174(* fun new_cache () : (''a, 'b) cache =
175   Polyhash.mkPolyTable (10000, ERR "cache" "not found"); *)
176
177(* fun cache_lookup c (a, b_thk) =
178  (case Polyhash.peek c a of SOME b => b
179   | NONE =>
180     let
181       val b = b_thk ()
182       val _ = Polyhash.insert c (a, b)
183     in
184       b
185    end); *)
186
187(* fun cachef f =
188  let
189    val c = new_cache ()
190  in
191    fn a => cache_lookup c (a, fn () => f a)
192  end; *)
193
194(* Lazy operations *)
195
196fun pair_susp a b = delay (fn () => (force a, force b));
197
198fun susp_map f s = delay (fn () => f (force s));
199
200(* --------------------------------------------------------------------- *)
201(* Options.                                                              *)
202(* --------------------------------------------------------------------- *)
203
204val is_some = Option.isSome;
205fun grab (SOME x) = x | grab NONE = raise ERR "grab" "NONE";
206fun o_pair (SOME x, y) = SOME (x, y) | o_pair _ = NONE;
207fun pair_o (x, SOME y) = SOME (x, y) | pair_o _ = NONE;
208fun o_pair_o (SOME x, SOME y) = SOME (x, y) | o_pair_o _ = NONE;
209val app_o = Option.map;
210fun o_app f = curry (app_o (uncurry A) o o_pair) f
211fun o_app_o f = curry (app_o (uncurry A) o o_pair_o) f
212fun partial_app_o f = Option.join o app_o f;
213fun partial_o_app f = Option.join o o_app f;
214fun partial_o_app_o f = Option.join o o_app_o f;
215fun option_to_list NONE = [] | option_to_list (SOME s) = [s];
216
217(* --------------------------------------------------------------------- *)
218(* Lists.                                                                *)
219(* --------------------------------------------------------------------- *)
220
221fun cons x = curry op:: x;
222fun append l = curry op@ l;
223fun wrap a = [a];
224fun unwrap [a] = a | unwrap _ = raise ERR "unwrap" "not a singleton list";
225fun fold _ b [] = b | fold f b (h::t) = f h (fold f b t);
226fun trans _ s [] = s | trans f s (h::t) = trans f (f h s) t;
227fun partial_trans _ s [] = SOME s
228  | partial_trans f s (h::t) = partial_app_o (C (partial_trans f) t) (f h s);
229fun first _ [] = raise ERR "first" "no items satisfy"
230  | first f (h::t) = if f h then h else first f t;
231fun partial_first _ [] = NONE
232  | partial_first f (h::t) = (case f h of NONE => partial_first f t | s => s);
233val forall = List.all;
234val exists = List.exists;
235val index = Lib.index;
236fun nth n l = List.nth (l, n);
237val split_after = Lib.split_after;
238fun assoc x = snd o first (equal x o fst);
239fun rev_assoc x = fst o first (equal x o snd);
240
241val map = List.map;
242val partial_map = List.mapPartial;
243
244fun zip_aux _ [] [] = []
245  | zip_aux f (x::xs) (y::ys) = f (x, y) (zip_aux f xs ys)
246  | zip_aux _ _ _ = raise ERR "zip" "lists different lengths";
247fun zip xs ys = zip_aux cons xs ys;
248fun zipwith f xs ys = zip_aux (cons o (uncurry f)) xs ys;
249fun partial_zipwith f xs ys = zip_aux
250  (fn (x, y) => case f x y of NONE => I | SOME s => cons s) xs ys;
251
252fun cart_aux f xs ys =
253  let
254    val xs' = rev xs
255    val ys' = rev ys
256  in
257    trans (fn x => C (trans (fn y => f (x, y))) ys') [] xs'
258  end;
259fun cart xs ys = cart_aux cons xs ys;
260fun cartwith f xs ys = cart_aux (cons o uncurry f) xs ys;
261fun partial_cartwith f xs ys =
262  cart_aux (fn (x, y) => case f x y of NONE => I | SOME s => cons s) xs ys;
263
264fun list_to_string _ [] = "[]"
265  | list_to_string elt_to_string (h :: t) =
266  trans (fn x => fn y => y ^ ", " ^ elt_to_string x)
267  ("[" ^ elt_to_string h) t ^ "]";
268
269(* --------------------------------------------------------------------- *)
270(* Lists as sets.                                                        *)
271(* --------------------------------------------------------------------- *)
272
273fun subset s t = forall (C mem t) s;
274
275fun distinct [] = true
276  | distinct (x :: rest) = not (mem x rest) andalso distinct rest;
277
278fun union2 (a, b) (c, d) = (union a c, union b d);
279
280(* --------------------------------------------------------------------- *)
281(* Rotations, permutations and sorting.                                  *)
282(* --------------------------------------------------------------------- *)
283
284(* Rotations of a list---surprisingly useful *)
285
286local
287  fun rot res _ [] = res
288    | rot res seen (h :: t) = rot ((h, t @ rev seen) :: res) (h :: seen) t
289in
290  fun rotations l = rev (rot [] [] l)
291end;
292
293fun rotate i = nth i o rotations;
294
295fun rotate_random l = rotate (random_integer (length l)) l;
296
297(* Permutations of a list *)
298
299fun permutations [] = [[]]
300  | permutations l =
301  (flatten o map (fn (h, t) => map (cons h) (permutations t)) o rotations) l;
302
303fun permute [] [] = []
304  | permute (i :: is) (xs as _ :: _) = (op:: o (I ## permute is) o rotate i) xs
305  | permute _ _ = raise ERR "permute" "bad arguments (different lengths)";
306
307fun permute_random [] = []
308  | permute_random l = (op:: o (I ## permute_random) o rotate_random) l;
309
310(* Finding the minimal element of a list, wrt some order. *)
311
312local
313  fun min_acc _ best [] = best
314    | min_acc f best (h :: t) = min_acc f (if f best h then best else h) t
315in
316  fun min _ [] = raise ERR "min" "empty list"
317    | min f (h :: t) = min_acc f h t
318end;
319
320(* Merge (for the following merge-sort, but generally useful too). *)
321
322fun merge f [] al' = al'
323  | merge f al [] = al
324  | merge f (a::al) (a'::al') =
325  if f a a' then a::(merge f al (a'::al'))
326  else a'::(merge f (a::al) al');
327
328(* Order function here should be <= for a stable sort...              *)
329(* ...and I think < gives a reverse stable sort (but don't quote me). *)
330fun sort f l =
331  let
332    val n = length l
333  in
334    if n < 2 then l
335    else (uncurry (merge f) o Df (sort f) o split_after (n div 2)) l
336  end;
337
338local
339  fun find_min _ (_, []) = raise ERR "top_min" "no minimal element!"
340    | find_min f (a, x::b) =
341    (assert (f x x <> SOME false) (BUG "top_min" "order function says x > x!");
342     if forall (fn y => f x y <> SOME false) (a @ b) then (x, a @ b)
343     else find_min f (x::a, b))
344in
345  fun top_min f l = find_min f ([], l)
346end;
347
348fun top_sort f [] = []
349  | top_sort f l =
350  let
351    val (x, rest) = top_min f l
352  in
353    x::top_sort f rest
354  end;
355
356(* --------------------------------------------------------------------- *)
357(* Sums.                                                                 *)
358(* --------------------------------------------------------------------- *)
359
360datatype ('a, 'b) sum = LEFT of 'a | RIGHT of 'b;
361
362(* --------------------------------------------------------------------- *)
363(* Streams.                                                              *)
364(* --------------------------------------------------------------------- *)
365
366datatype ('a) stream = STREAM_NIL | STREAM_CONS of ('a * 'a stream thunk);
367
368fun stream_null STREAM_NIL = true
369  | stream_null (STREAM_CONS _) = false;
370
371fun dest_stream_cons STREAM_NIL = raise ERR "dest_stream_cons" "stream is nil"
372  | dest_stream_cons (STREAM_CONS c) = c;
373
374fun stream_hd s = fst (dest_stream_cons s);
375fun stream_tl s = snd (dest_stream_cons s);
376
377local
378  fun to_list res STREAM_NIL = res
379    | to_list res (STREAM_CONS (a, thk)) = to_list (a :: res) (thk ())
380in
381  fun stream_to_list s = rev (to_list [] s)
382end;
383
384fun stream_append s1 s2 () =
385  (case s1 () of STREAM_NIL => s2 ()
386   | STREAM_CONS (a, thk) => STREAM_CONS (a, stream_append thk s2));
387
388fun stream_concat ss = trans (C stream_append) (K STREAM_NIL) ss;
389
390(* --------------------------------------------------------------------- *)
391(* A generic tree type.                                                  *)
392(* --------------------------------------------------------------------- *)
393
394datatype ('a, 'b) tree = BRANCH of 'a * ('a, 'b) tree list | LEAF of 'b;
395
396fun tree_size (LEAF _) = 1
397  | tree_size (BRANCH (_, t)) = trans (plus o tree_size) 0 t;
398
399fun tree_fold f_b f_l (LEAF l) = f_l l
400  | tree_fold f_b f_l (BRANCH (p, s)) = f_b p (map (tree_fold f_b f_l) s);
401
402fun tree_trans f_b f_l state (LEAF l) = [f_l l state]
403  | tree_trans f_b f_l state (BRANCH (p, s)) =
404  flatten (map (tree_trans f_b f_l (f_b p state)) s);
405
406fun tree_partial_trans f_b f_l state (LEAF l) = option_to_list (f_l l state)
407  | tree_partial_trans f_b f_l state (BRANCH (p, s)) =
408  (case f_b p state of NONE => []
409   | SOME state' => flatten (map (tree_partial_trans f_b f_l state') s));
410
411(* --------------------------------------------------------------------- *)
412(* Pretty-printing helper-functions.                                     *)
413(* --------------------------------------------------------------------- *)
414
415fun pp_map f pp_a (ppstrm : ppstream) x : unit = pp_a ppstrm (f x);
416
417fun pp_string ppstrm =
418  let
419    val {add_string,add_break,begin_block,end_block,add_newline,...}
420      = Portable.with_ppstream ppstrm
421
422  in
423    fn s => (begin_block Portable.CONSISTENT 1;
424             add_string s;
425             end_block ())
426  end;
427
428fun pp_unknown ppstrm _ = pp_string ppstrm "_";
429
430fun pp_int ppstrm i = pp_string ppstrm (int_to_string i);
431
432fun pp_pair pp1 pp2 ppstrm =
433  let
434    val {add_string,add_break,begin_block,end_block,add_newline,...}
435      = Portable.with_ppstream ppstrm
436
437  in
438    fn (a, b) => (begin_block Portable.CONSISTENT 1;
439                  add_string "(";
440                  pp1 ppstrm a:unit;
441                  add_string ",";
442                  add_break (1, 0);
443                  pp2 ppstrm b:unit;
444                  add_string ")";
445                  end_block())
446  end;
447
448fun pp_list pp ppstrm =
449  let
450    val {add_string,add_break,begin_block,end_block,add_newline,...}
451      = Portable.with_ppstream ppstrm
452
453    val pp_elt = pp ppstrm
454
455    fun pp_seq [] = ()
456      | pp_seq (h::t) = (add_string ",";
457                         add_break (1, 0);
458                         pp_elt h:unit;
459                         pp_seq t)
460  in
461    fn l => (begin_block Portable.INCONSISTENT 1;
462             add_string "[";
463             (case l of [] => ()
464              | h::t => (pp_elt h; pp_seq t));
465             add_string "]";
466             end_block())
467  end;
468
469(* --------------------------------------------------------------------- *)
470(* Substitution operations.                                              *)
471(* --------------------------------------------------------------------- *)
472
473fun redex {redex, residue = _} = redex;
474fun residue {redex = _, residue} = residue;
475fun find_redex r = first (fn rr as {redex, residue} => r = redex);
476fun clean_subst s = filter (fn {redex, residue} => not (redex = residue)) s;
477fun subst_vars sub = map redex sub;
478fun maplet_map (redf, resf) {redex, residue} = (redf redex |-> resf residue);
479fun subst_map fg = map (maplet_map fg);
480fun redex_map f = subst_map (f, I);
481fun residue_map f = subst_map (I, f);
482
483fun is_renaming_subst vars sub =
484  let
485    val residues = map residue sub
486  in
487    forall (C mem vars) residues andalso distinct residues
488  end;
489
490fun invert_renaming_subst vars sub =
491  let
492    val _ =
493      assert (is_renaming_subst vars sub)
494      (ERR "invert_renaming_subst" "not a renaming subst, so not invertible")
495    fun inv {redex, residue} = residue |-> redex
496  in
497    map inv sub
498  end;
499
500(* --------------------------------------------------------------------- *)
501(* HOL-specific functions.                                               *)
502(* --------------------------------------------------------------------- *)
503
504type hol_type = Type.hol_type
505type term = Term.term
506type thm = Thm.thm
507type goal = term list * term
508type conv = term -> thm
509type rule = thm -> thm
510type validation = thm list -> thm
511type tactic = goal -> goal list * validation
512type thm_tactic = thm -> tactic
513type vars = term list * hol_type list
514type vterm = vars * term
515type vthm = vars * thm
516type type_subst = (hol_type, hol_type) subst
517type term_subst = (term, term) subst
518type substitution = (term, term) subst * (hol_type, hol_type) subst
519type ho_substitution = substitution * thm thunk
520type raw_substitution = (term_subst * term set) * (type_subst * hol_type list)
521type ho_raw_substitution = raw_substitution * thm thunk
522
523(* --------------------------------------------------------------------- *)
524(* General                                                               *)
525(* --------------------------------------------------------------------- *)
526
527(* A profile function counting both time and primitive inferences. *)
528
529fun profile f a =
530  let
531    val m = Count.mk_meter ()
532    val i = #prims(Count.read m)
533    val t = Time.now ()
534    val res = f a
535    val t' = Time.now ()
536    val i' = #prims(Count.read m)
537    val _ = print ("Time taken: " ^ Time.toString (Time.-(t', t)) ^ ".\n"
538                   ^ "Primitive inferences: " ^ Int.toString (i' - i) ^ ".\n")
539  in
540    res
541  end;
542
543(* Parsing in the context of a goal, a la the Q library. *)
544
545fun parse_with_goal t (asms, g) =
546  let
547    val ctxt = free_varsl (g::asms)
548  in
549    Parse.parse_in_context ctxt t
550  end;
551
552(* --------------------------------------------------------------------- *)
553(* Term/type substitutions.                                              *)
554(* --------------------------------------------------------------------- *)
555
556val empty_subst = ([], []) : substitution;
557
558val type_inst = type_subst;
559val inst_ty = inst;
560fun pinst (tm_sub, ty_sub) = subst tm_sub o inst_ty ty_sub;
561
562fun type_subst_vars_in_set (sub : type_subst) vars =
563  subset (subst_vars sub) vars;
564
565fun subst_vars_in_set ((tm_sub, ty_sub) : substitution) (tm_vars, ty_vars) =
566  type_subst_vars_in_set ty_sub ty_vars andalso
567  subset (subst_vars tm_sub) (map (inst_ty ty_sub) tm_vars);
568
569(* Note: cyclic substitutions are right out! *)
570fun type_refine_subst ty1 ty2 : (hol_type, hol_type) subst =
571  ty2 @ (clean_subst o residue_map (type_inst ty2)) ty1;
572
573fun refine_subst (tm1, ty1) (tm2, ty2) =
574  (tm2 @ (clean_subst o subst_map (inst_ty ty2, pinst (tm2, ty2))) tm1,
575   type_refine_subst ty1 ty2);
576
577(*
578refine_subst
579([(``x:'b list`` |-> ``CONS (y:'b list) []``)],
580 [(``:'a`` |-> ``:'b list``)])
581([(``y:real list`` |-> ``[0:real]``)],
582 [(``:'b`` |-> ``:real``)]);
583
584refine_subst
585([(``x:'b list`` |-> ``[y : 'b]``)],
586 [(``:'a`` |-> ``:'b``)])
587([(``y:'a`` |-> ``z:'a``)],
588 [(``:'b`` |-> ``:'a``)]);
589*)
590
591fun type_vars_after_subst vars (sub : (hol_type, hol_type) subst) =
592  subtract vars (subst_vars sub);
593
594fun vars_after_subst (tm_vars, ty_vars) (tm_sub, ty_sub) =
595  (subtract (map (inst_ty ty_sub) tm_vars) (subst_vars tm_sub),
596   type_vars_after_subst ty_vars ty_sub);
597
598fun type_invert_subst vars (sub : (hol_type, hol_type) subst) =
599  invert_renaming_subst vars sub;
600
601fun invert_subst (tm_vars, ty_vars) (tm_sub, ty_sub) =
602  let
603    val _ =
604      assert (is_renaming_subst tm_vars tm_sub)
605      (ERR "invert_subst" "not a renaming term subst")
606    val ty_sub' = type_invert_subst ty_vars ty_sub
607    fun inv {redex, residue} =
608      inst_ty ty_sub' residue |-> inst_ty ty_sub' redex
609  in
610    (map inv tm_sub, ty_sub')
611  end;
612
613(* --------------------------------------------------------------------- *)
614(* Logic variables.                                                      *)
615(* --------------------------------------------------------------------- *)
616
617val empty_vars = ([], []) : vars;
618fun is_tyvar ((_, tyvars) : vars) ty = is_vartype ty andalso mem ty tyvars;
619fun is_tmvar ((tmvars, _) : vars) tm = is_var tm andalso mem tm tmvars;
620
621fun type_new_vars (vars : hol_type list) =
622  let
623    val gvars = map (fn _ => gen_tyvar ()) vars
624    val old_to_new = zipwith (curry op|->) vars gvars
625    val new_to_old = zipwith (curry op|->) gvars vars
626  in
627    (gvars, (old_to_new, new_to_old))
628  end;
629
630fun term_new_vars vars =
631  let
632    val gvars = map (genvar o type_of) vars
633    val old_to_new = zipwith (curry op|->) vars gvars
634    val new_to_old = zipwith (curry op|->) gvars vars
635  in
636    (gvars, (old_to_new, new_to_old))
637  end;
638
639fun new_vars (tm_vars, ty_vars) =
640  let
641    val (ty_gvars, (ty_old_to_new, ty_new_to_old)) = type_new_vars ty_vars
642    val (tm_gvars, (tm_old_to_new, tm_new_to_old)) = term_new_vars tm_vars
643    val old_to_new = refine_subst (tm_old_to_new, []) ([], ty_old_to_new)
644    val new_to_old = (tm_new_to_old, ty_new_to_old)
645  in
646    ((map (inst_ty ty_old_to_new) tm_gvars, ty_gvars), (old_to_new, new_to_old))
647  end;
648
649(* ------------------------------------------------------------------------- *)
650(* Bound variables.                                                          *)
651(* ------------------------------------------------------------------------- *)
652
653fun dest_bv bvs tm =
654  let
655    val _ = assert (is_var tm) (ERR "dest_bv" "not a var")
656  in
657    index (equal tm) bvs
658  end;
659fun is_bv bvs = can (dest_bv bvs);
660fun mk_bv bvs n : term = nth n bvs;
661
662(* --------------------------------------------------------------------- *)
663(* Types.                                                                *)
664(* --------------------------------------------------------------------- *)
665
666(* --------------------------------------------------------------------- *)
667(* Terms.                                                                *)
668(* --------------------------------------------------------------------- *)
669
670val type_vars_in_terms = trans (union o type_vars_in_term) [];
671
672local
673  fun dest (tm, args) =
674    let
675      val (a, b) = dest_comb tm
676    in
677      (a, b::args)
678    end
679in
680  fun list_dest_comb tm = repeat dest (tm, [])
681end;
682
683fun conjuncts tm =
684  if is_conj tm then
685    let
686      val (a, b) = dest_conj tm
687    in
688      a::(conjuncts b)
689    end
690  else [tm];
691
692fun dest_unaryop c tm =
693  let
694    val (a, b) = dest_comb tm
695    val _ = assert (fst (dest_const a) = c)
696      (ERR "dest_unaryop" "different const")
697  in
698    b
699  end;
700fun is_unaryop c = can (dest_unaryop c);
701
702fun dest_binop c tm =
703  let
704    val (a, b) = dest_comb tm
705  in
706    (dest_unaryop c a, b)
707  end;
708fun is_binop c = can (dest_binop c);
709
710val dest_imp = dest_binop "==>";
711val is_imp = can dest_imp;
712
713local
714  fun dest (vs, tm) = (C cons vs ## I) (dest_forall tm)
715in
716  val dest_foralls = repeat dest o add_fst []
717end;
718val mk_foralls = uncurry (C (trans (curry mk_forall)));
719
720fun spec s tm =
721  let
722    val (v, body) = dest_forall tm
723  in
724    subst [v |-> s] body
725  end;
726
727val specl = C (trans spec);
728
729fun var_match vars tm tm' =
730  let
731    val sub = match_term tm tm'
732    val _ = assert (subst_vars_in_set sub vars)
733      (ERR "var_match" "subst vars not contained in set")
734  in
735    sub
736  end;
737
738(* --------------------------------------------------------------------- *)
739(* Thms.                                                                 *)
740(* --------------------------------------------------------------------- *)
741
742val FUN_EQ = prove (``!f g. (f = g) = (!x. f x = g x)``, PROVE_TAC [EQ_EXT]);
743val SET_EQ = prove (``!s t. (s = t) = (!x. x IN s = x IN t)``,
744                    PROVE_TAC [SPECIFICATION, FUN_EQ]);
745
746val hyps = foldl (fn (h,t) => union (hyp h) t) [];
747
748val LHS = lhs o concl;
749val RHS = rhs o concl;
750
751local
752  fun fake_asm_op r th =
753    let
754      val h = rev (hyp th)
755    in
756      (N (length h) UNDISCH o r o C (foldl (uncurry DISCH)) h) th
757    end
758in
759  val INST_TY = fake_asm_op o INST_TYPE;
760  val PINST = fake_asm_op o INST_TY_TERM;
761end;
762
763(* --------------------------------------------------------------------- *)
764(* Conversions.                                                          *)
765(* --------------------------------------------------------------------- *)
766
767(* Conversionals *)
768
769fun CHANGED_CONV c tm =
770    let
771      val th = QCONV c tm
772    in
773      if rhs (concl th) = tm then raise ERR "CHANGED_CONV" "" else th
774    end;
775
776fun FIRSTC [] tm = raise ERR "FIRSTC" "ran out of convs"
777  | FIRSTC (c::cs) tm = (c ORELSEC FIRSTC cs) tm;
778
779fun TRYC c = QCONV (c ORELSEC ALL_CONV);
780
781fun REPEATPLUSC c = c THENC REPEATC c;
782
783fun REPEATC_CUTOFF 0 _ _ = raise ERR "REPEATC_CUTOFF" "cut-off reached"
784  | REPEATC_CUTOFF n c tm =
785  (case (SOME (QCONV c tm) handle HOL_ERR _ => NONE) of NONE
786     => QCONV ALL_CONV tm
787   | SOME eq_th => TRANS eq_th (REPEATC_CUTOFF (n - 1) c (RHS eq_th)));
788
789(* A conversional like DEPTH_CONV, but applies the argument conversion   *)
790(* at most once to each subterm                                          *)
791
792fun DEPTH_ONCE_CONV c tm = QCONV (SUB_CONV (DEPTH_ONCE_CONV c) THENC TRYC c) tm;
793
794fun FORALLS_CONV c tm =
795  QCONV (if is_forall tm then RAND_CONV (ABS_CONV (FORALLS_CONV c)) else c) tm;
796
797fun CONJUNCT_CONV c tm =
798  QCONV
799  (if is_conj tm then RATOR_CONV (RAND_CONV c) THENC RAND_CONV (CONJUNCT_CONV c)
800   else c) tm;
801
802(* Conversions *)
803
804fun EXACT_CONV exact tm = QCONV (if tm = exact then ALL_CONV else NO_CONV) tm;
805
806val NEGNEG_CONV = REWR_CONV (CONJUNCT1 NOT_CLAUSES);
807
808val FUN_EQ_CONV = REWR_CONV FUN_EQ;
809val SET_EQ_CONV = REWR_CONV SET_EQ;
810
811fun N_BETA_CONV 0 = QCONV ALL_CONV
812  | N_BETA_CONV n = RATOR_CONV (N_BETA_CONV (n - 1)) THENC TRYC BETA_CONV;
813
814local
815  val EQ_NEG_T = PROVE [] ``!a. (~a = T) = (a = F)``
816  val EQ_NEG_F = PROVE [] ``!a. (~a = F) = (a = T)``
817  val EQ_NEG_T_CONV = REWR_CONV EQ_NEG_T
818  val EQ_NEG_F_CONV = REWR_CONV EQ_NEG_F
819in
820  val EQ_NEG_BOOL_CONV = QCONV (EQ_NEG_T_CONV ORELSEC EQ_NEG_F_CONV);
821end;
822
823val GENVAR_ALPHA_CONV = W (ALPHA_CONV o genvar o type_of o bvar);
824val GENVAR_BVARS_CONV = DEPTH_ONCE_CONV GENVAR_ALPHA_CONV;
825
826fun ETA_EXPAND_CONV v tm = SYM (ETA_CONV (mk_abs (v, mk_comb (tm, v))));
827val GENVAR_ETA_EXPAND_CONV =
828  W (ETA_EXPAND_CONV o genvar o fst o dom_rng o type_of);
829
830(* --------------------------------------------------------------------- *)
831(* Rules.                                                                *)
832(* --------------------------------------------------------------------- *)
833
834fun op THENR (r1, r2) (th:thm) :thm = r2 (r1 th:thm);
835fun REPEATR r (th:thm) = REPEATR r (r th) handle HOL_ERR _ => th;
836fun op ORELSER (r1, r2) (th:thm):thm = r1 th handle HOL_ERR _ => r2 th;
837fun TRYR r = r ORELSER I;
838val ALL_RULE : rule = I;
839
840fun EVERYR [] = ALL_RULE
841  | EVERYR (r::rest) = r THENR EVERYR rest;
842
843local
844  val fir = prove
845    (``(!(x:'a). P x ==> Q x) ==> ((?x. P x) ==> (?x. Q x))``, PROVE_TAC [])
846in
847  val FORALL_IMP = HO_MATCH_MP fir
848end;
849
850val EQ_BOOL_INTRO = EQT_INTRO THENR CONV_RULE (REPEATC EQ_NEG_BOOL_CONV);
851
852val GENVAR_BVARS = CONV_RULE GENVAR_BVARS_CONV;
853
854val GENVAR_SPEC =
855  CONV_RULE (RAND_CONV GENVAR_ALPHA_CONV) THENR (snd o SPEC_VAR);
856
857val GENVAR_SPEC_ALL = REPEATR GENVAR_SPEC;
858
859local
860  fun mk th [] = th
861    | mk th (c :: rest) = mk (CONJ c th) rest
862    handle HOL_ERR _ => raise BUG "REV_CONJUNCTS" "panic"
863in
864  fun REV_CONJUNCTS [] = raise ERR "REV_CONJUNCTS" "empty list"
865    | REV_CONJUNCTS (th :: rest) = mk th rest
866end;
867
868fun REORDER_ASMS asms th0 =
869  let
870    val th1 = foldr (fn (h,t) => DISCH h t) th0 asms
871    val th2 = funpow (length asms) UNDISCH th1
872  in
873    th2
874  end;
875
876local
877  fun dest_c tm =
878    if is_comb tm then
879      let
880        val (a, b) = dest_comb tm
881      in
882        (I ## cons b) (dest_c a)
883      end
884    else (tm, [])
885
886  fun comb_beta eq_th x =
887    CONV_RULE (RAND_CONV BETA_CONV) (MK_COMB (eq_th, REFL x))
888in
889  fun NEW_CONST_RULE cvar_lvars th =
890    let
891      val (cvar, lvars) = (I ## rev) (dest_c cvar_lvars)
892      val sel_th =
893        CONV_RULE (RATOR_CONV (REWR_CONV EXISTS_DEF) THENC BETA_CONV) th
894      val pred = rator (concl sel_th)
895      val def_tm = list_mk_abs (lvars, rand (concl sel_th))
896      val def_th = ASSUME (mk_eq (cvar, def_tm))
897      val eq_th = MK_COMB (REFL pred, trans (C comb_beta) def_th lvars)
898    in
899      CONV_RULE BETA_CONV (EQ_MP (SYM eq_th) sel_th)
900    end
901end;
902
903val GENVAR_CONST_RULE =
904  W (NEW_CONST_RULE o genvar o type_of o bvar o rand o concl);
905
906local
907  fun zap _ _ [] = raise ERR "zap" "fresh out of asms"
908    | zap th checked (asm::rest) =
909    if is_eq asm then
910      let
911        val (v, def) = dest_eq asm
912      in
913        if is_var v andalso all (not o free_in v) (checked @ rest) then
914          MP (SPEC def (GEN v (DISCH asm th))) (REFL def)
915        else zap th (asm::checked) rest
916      end
917    else zap th (asm::checked) rest
918in
919  val ZAP_CONSTS_RULE = repeat (fn th => zap th [concl th] (hyp th))
920end;
921
922(* ------------------------------------------------------------------------- *)
923(* vthm operations                                                           *)
924(* ------------------------------------------------------------------------- *)
925
926fun thm_to_vthm th =
927  let
928    val tm = concl th
929
930    val c_tyvars = type_vars_in_term tm
931    val h_tyvars = type_vars_in_terms (hyp th)
932    val f_tyvars = subtract c_tyvars h_tyvars
933    val (f_tmvars, _) = dest_foralls tm
934    val f_vars = (f_tmvars, f_tyvars)
935
936    val (vars, (sub, _)) = new_vars f_vars
937  in
938    (vars, PINST sub (REPEATR (snd o SPEC_VAR) th))
939  end;
940
941fun vthm_to_thm (((vars, _), th) : vthm) = GENL vars th;
942
943fun clean_vthm ((tm_vars, ty_vars), th) =
944  let
945    val tms = concl th :: hyp th
946    val ty_vars' = intersect (type_vars_in_terms tms) ty_vars
947    val tm_vars' = intersect (free_varsl tms) tm_vars
948  in
949    ((tm_vars', ty_vars'), ZAP_CONSTS_RULE th)
950  end;
951
952fun var_GENVAR_SPEC ((tm_vars, ty_vars), th) : vthm =
953  let
954    val v = (genvar o type_of o fst o dest_forall o concl) th
955  in
956    ((v :: tm_vars, ty_vars), SPEC v th)
957  end;
958
959fun var_CONJUNCTS (vars, th) : vthm list =
960  map (add_fst vars) (CONJUNCTS th);
961
962fun var_MATCH_MP th : vthm -> vthm = (I ## MATCH_MP th);
963
964(* --------------------------------------------------------------------- *)
965(* Discharging assumptions on to the lhs of an implication:              *)
966(* DISCH_CONJ a : [a] UNION A |- P ==> Q   |->   A |- a /\ P ==> Q       *)
967(* UNDISCH_CONJ : A |- a /\ P ==> Q        |->   [a] UNION A |- P ==> Q  *)
968(* --------------------------------------------------------------------- *)
969
970val DISCH_CONJ_CONV = REWR_CONV AND_IMP_INTRO;
971fun DISCH_CONJ a th = CONV_RULE DISCH_CONJ_CONV (DISCH a th);
972fun DISCH_CONJUNCTS [] _ = raise ERR "DISCH_CONJ" "no assumptions!"
973  | DISCH_CONJUNCTS (a::al) th = foldl (uncurry DISCH_CONJ) (DISCH a th) al;
974fun DISCH_CONJUNCTS_ALL th = DISCH_CONJUNCTS (hyp th) th;
975fun DISCH_CONJUNCTS_FILTER f th = DISCH_CONJUNCTS (filter f (hyp th)) th;
976fun UNDISCH_CONJ_TAC a = UNDISCH_TAC a ++ CONV_TAC DISCH_CONJ_CONV;
977val UNDISCH_CONJUNCTS_TAC =
978  POP_ASSUM MP_TAC ++ REPEAT (POP_ASSUM MP_TAC ++ CONV_TAC DISCH_CONJ_CONV);
979
980val UNDISCH_CONJ_CONV = REWR_CONV (GSYM AND_IMP_INTRO)
981val UNDISCH_CONJ = CONV_RULE UNDISCH_CONJ_CONV THENR UNDISCH
982val UNDISCH_CONJUNCTS = REPEATR UNDISCH_CONJ THENR UNDISCH
983val DISCH_CONJ_TAC = CONV_TAC UNDISCH_CONJ_CONV ++ DISCH_TAC
984val DISCH_CONJUNCTS_TAC = REPEAT DISCH_CONJ_TAC ++ DISCH_TAC
985
986(* --------------------------------------------------------------------- *)
987(* Tacticals.                                                            *)
988(* --------------------------------------------------------------------- *)
989
990fun PURE_CONV_TAC conv :tactic = fn (asms,g) =>
991   let
992     val eq_th = QCONV conv g
993   in
994     ([(asms, RHS eq_th)], EQ_MP (SYM eq_th) o hd)
995   end;
996
997fun ASMLIST_CASES (t1:tactic) _ (g as ([], _)) = t1 g
998  | ASMLIST_CASES _ t2 (g as (x::_, _)) = t2 x g;
999
1000fun POP_ASSUM_TAC tac =
1001  ASMLIST_CASES tac
1002  (K (UNDISCH_CONJUNCTS_TAC
1003      ++ tac
1004      ++ TRY (DISCH_THEN (EVERY o map ASSUME_TAC o CONJUNCTS))));
1005
1006(*---------------------------------------------------------------------------
1007 * tac1 THEN1 tac2: A tactical like THEN that applies tac2 only to the
1008 *                  first subgoal of tac1
1009 *---------------------------------------------------------------------------*)
1010
1011fun op THEN1 (tac1 : tactic, tac2 : tactic) : tactic =
1012  fn g =>
1013  let
1014    val (gl, jf) = tac1 g
1015    val (h_g, t_gl) =
1016      case gl of []
1017        => raise ERR "THEN1" "goal completely solved by first tactic"
1018      | h :: t => (h, t)
1019    val (h_gl, h_jf) = tac2 h_g
1020    val _ =
1021      assert (null h_gl) (ERR "THEN1" "1st subgoal not solved by second tactic")
1022  in
1023    (t_gl, fn thl => jf (h_jf [] :: thl))
1024  end
1025  handle HOL_ERR{origin_structure,origin_function,message}
1026  => raise ERR "THEN1" (origin_structure^"."^origin_function^": "^message);
1027
1028val op>> = op THEN1;
1029
1030(*---------------------------------------------------------------------------
1031 * REVERSE tac: A tactical that reverses the list of subgoals of tac.
1032 *              Intended for use with THEN1 to pick the `easy' subgoal, e.g.:
1033 *              - CONJ_TAC THEN1 SIMP_TAC
1034 *                  if the first conjunct is easily dispatched
1035 *              - REVERSE CONJ_TAC THEN1 SIMP_TAC
1036 *                  if it is the second conjunct that yields.
1037 *---------------------------------------------------------------------------*)
1038
1039fun REVERSE tac g
1040  = let val (gl, jf) = tac g
1041    in (rev gl, jf o rev)
1042    end
1043    handle HOL_ERR{origin_structure,origin_function,message}
1044    => raise ERR "REVERSE" (origin_structure^"."^origin_function^": "^message);
1045
1046(* --------------------------------------------------------------------- *)
1047(* Tactics.                                                              *)
1048(* --------------------------------------------------------------------- *)
1049
1050val TRUTH_TAC = ACCEPT_TAC TRUTH;
1051
1052fun K_TAC _ = ALL_TAC;
1053
1054val KILL_TAC = POP_ASSUM_LIST K_TAC;
1055
1056fun CONJUNCTS_TAC g = TRY (CONJ_TAC << [ALL_TAC, CONJUNCTS_TAC]) g;
1057
1058val FUN_EQ_TAC = CONV_TAC (CHANGED_CONV (ONCE_DEPTH_CONV FUN_EQ_CONV));
1059val SET_EQ_TAC = CONV_TAC (CHANGED_CONV (ONCE_DEPTH_CONV SET_EQ_CONV));
1060
1061fun SUFF_TAC t (al, c)
1062  = let val tm = parse_with_goal t (al, c)
1063    in ([(al, mk_imp (tm, c)), (al, tm)],
1064        fn [th1, th2] => MP th1 th2
1065         | _ => raise ERR "SUFF_TAC" "panic")
1066    end;
1067
1068fun KNOW_TAC t = REVERSE (SUFF_TAC t);
1069
1070local
1071  val th1 = (prove (``!t. T ==> (F ==> t)``, PROVE_TAC []))
1072in
1073  val CHECK_ASMS_TAC :tactic =
1074    REPEAT (PAT_ASSUM T K_TAC)
1075    ++ REPEAT (PAT_ASSUM F (fn th => MP_TAC th ++ MATCH_MP_TAC th1))
1076end;
1077
1078(* --------------------------------------------------------------------- *)
1079(* EXACT_MP_TAC : thm -> tactic                                          *)
1080(*                                                                       *)
1081(* If the goal is (asms, g) then the supplied theorem should be of the   *)
1082(* form [..] |- g' ==> g                                                 *)
1083(*                                                                       *)
1084(* The tactic returns one subgoal of the form (asms, g')                 *)
1085(* --------------------------------------------------------------------- *)
1086
1087fun EXACT_MP_TAC mp_th :tactic =
1088  let
1089    val g' = fst (dest_imp (concl mp_th))
1090  in
1091    fn (asms, g) => ([(asms, g')], MP mp_th o hd)
1092  end;
1093
1094(* --------------------------------------------------------------------- *)
1095(* STRONG_CONJ_TAC : tactic                                              *)
1096(*                                                                       *)
1097(* If the goal is (asms, A /\ B) then the tactic returns two subgoals of *)
1098(* the form (asms, A) and (asms, A ==> B)                                *)
1099(* --------------------------------------------------------------------- *)
1100
1101local
1102  val th = prove (``!a b. a /\ (a ==> b) ==> a /\ b``, PROVE_TAC [])
1103in
1104  val STRONG_CONJ_TAC :tactic = MATCH_MP_TAC th ++ CONJ_TAC
1105end;
1106
1107(* --------------------------------------------------------------------- *)
1108(* FORWARD_TAC : (thm list -> thm list) -> tactic                        *)
1109(*                                                                       *)
1110(* Here is what happens when                                             *)
1111(*   FORWARD_TAC f                                                       *)
1112(* is applied to the goal                                                *)
1113(*   (asms, g).                                                          *)
1114(*                                                                       *)
1115(* 1. It calls the supplied inference function with the assumptions      *)
1116(*    to obtain a list of theorems.                                      *)
1117(*      ths = f (map ASSUME asms)                                        *)
1118(*    IMPORTANT: The assumptions of the theorems in ths must be either   *)
1119(*               in asms, or `definitions' of the form `new_var = body`. *)
1120(*                                                                       *)
1121(* 2. It returns one subgoal with the following form:                    *)
1122(*      (map concl ths, g)                                               *)
1123(*    i.e., the same goal, and a new assumption list that logically      *)
1124(*    follows from asms.                                                 *)
1125(*                                                                       *)
1126(* --------------------------------------------------------------------- *)
1127
1128fun forward_just ths th0 =
1129  let
1130    val th1 = foldr (fn (h,t) => DISCH (concl h) t) th0 ths
1131    val th2 = foldl (fn (h,t) => MP t h) th1 ths
1132  in
1133    th2
1134  end
1135
1136fun FORWARD_TAC f (asms, g:term) =
1137  let
1138    val ths = f (map ASSUME asms)
1139  in
1140    ([(map concl ths, g)],
1141       fn [th] => (REORDER_ASMS asms o ZAP_CONSTS_RULE o forward_just ths) th
1142        | _ => raise BUG "FORWARD_TAC" "justification function panic")
1143  end;
1144
1145(* --------------------------------------------------------------------- *)
1146(* A simple-minded CNF conversion.                                       *)
1147(* --------------------------------------------------------------------- *)
1148
1149local
1150  open simpLib
1151  infix ++
1152in
1153  val EXPAND_COND_CONV =
1154    QCONV (SIMP_CONV (pureSimps.pure_ss ++ boolSimps.COND_elim_ss) [])
1155end
1156
1157local
1158  val EQ_IFF = prove
1159    (``!a b. ((a:bool) = b) = ((a ==> b) /\ (b ==> a))``,
1160     BasicProvers.PROVE_TAC [])
1161in
1162  val EQ_IFF_CONV = QCONV (PURE_REWRITE_CONV [EQ_IFF])
1163end;
1164
1165local
1166  val IMP_DISJ = prove
1167    (``!a b. ((a:bool) ==> b) = ~a \/ b``,
1168     BasicProvers.PROVE_TAC [])
1169in
1170  val IMP_DISJ_CONV = QCONV (PURE_REWRITE_CONV [IMP_DISJ])
1171end;
1172
1173local
1174  val NEG_NEG = CONJUNCT1 NOT_CLAUSES
1175  val DE_MORGAN1
1176    = CONJUNCT1 (CONV_RULE (DEPTH_CONV FORALL_AND_CONV) DE_MORGAN_THM)
1177  val DE_MORGAN2
1178    = CONJUNCT2 (CONV_RULE (DEPTH_CONV FORALL_AND_CONV) DE_MORGAN_THM)
1179in
1180  val NNF_CONV = (QCONV o REPEATC o CHANGED_CONV)
1181    (REWRITE_CONV [NEG_NEG, DE_MORGAN1, DE_MORGAN2]
1182     THENC DEPTH_CONV (NOT_EXISTS_CONV ORELSEC NOT_FORALL_CONV))
1183end;
1184
1185val EXISTS_OUT_CONV = (QCONV o REPEATC o CHANGED_CONV o DEPTH_CONV)
1186  (LEFT_AND_EXISTS_CONV
1187   ORELSEC RIGHT_AND_EXISTS_CONV
1188   ORELSEC LEFT_OR_EXISTS_CONV
1189   ORELSEC RIGHT_OR_EXISTS_CONV
1190   ORELSEC CHANGED_CONV SKOLEM_CONV);
1191
1192val ANDS_OUT_CONV = (QCONV o REPEATC o CHANGED_CONV o DEPTH_CONV)
1193  (FORALL_AND_CONV
1194   ORELSEC REWR_CONV LEFT_OR_OVER_AND
1195   ORELSEC REWR_CONV RIGHT_OR_OVER_AND)
1196
1197val FORALLS_OUT_CONV = (QCONV o REPEATC o CHANGED_CONV o DEPTH_CONV)
1198  (LEFT_OR_FORALL_CONV
1199   ORELSEC RIGHT_OR_FORALL_CONV);
1200
1201val CNF_CONV =
1202 QCONV
1203 (DEPTH_CONV BETA_CONV
1204  THENC EXPAND_COND_CONV
1205  THENC EQ_IFF_CONV
1206  THENC IMP_DISJ_CONV
1207  THENC NNF_CONV
1208  THENC EXISTS_OUT_CONV
1209  THENC ANDS_OUT_CONV
1210  THENC FORALLS_OUT_CONV
1211  THENC REWRITE_CONV [GSYM DISJ_ASSOC, GSYM CONJ_ASSOC]);
1212
1213val CNF_RULE = CONV_RULE CNF_CONV;
1214
1215val CNF_EXPAND = CONJUNCTS o repeat GENVAR_CONST_RULE o CNF_RULE;
1216
1217val CNF_TAC = CCONTR_TAC THEN FORWARD_TAC (flatten o map CNF_EXPAND);
1218
1219(* --------------------------------------------------------------------- *)
1220(* ASM_MATCH_MP_TAC: adding MP-consequences to the assumption list.      *)
1221(* Does less than (EVERY (map ASSUME_TAC ths) ++ RES_TAC).               *)
1222(* --------------------------------------------------------------------- *)
1223
1224local
1225  val is_mp = is_imp o snd o dest_foralls o concl;
1226
1227  fun initialize mp_th =
1228    let
1229      val (vars, (asm, body)) = ((rev ## dest_imp) o dest_foralls o concl) mp_th
1230      val asms = conjuncts asm
1231    in
1232      case asms of [a] => ([], [mp_th])
1233      | _ =>
1234      let
1235        val mp_th' = (SPEC_ALL THENR UNDISCH_CONJUNCTS) mp_th
1236        val rots = rotations asms
1237        fun f (asm, rest) =
1238          (DISCH_CONJUNCTS rest THENR DISCH asm THENR GENL vars) mp_th'
1239      in
1240        (map f rots, [])
1241      end
1242    end
1243
1244  fun initialize_collect (m, s) th =
1245    let
1246      val (mx, sx) = initialize th
1247    in
1248      (mx @ m, sx @ s)
1249    end
1250
1251  val initializel = trans (C initialize_collect)
1252
1253  fun match1 (multi, single) th =
1254    let
1255      val do_match = partial_map (fn x => total (MATCH_MP x) th)
1256    in
1257      (do_match multi, do_match single)
1258    end
1259
1260  fun add_thm th (concls, ths) =
1261    let
1262      val tm = concl th
1263    in
1264      if mem tm concls then (concls, ths) else (tm :: concls, th :: ths)
1265    end
1266
1267  fun clean_add_thms ths = snd o trans add_thm (map concl ths, ths)
1268
1269  fun match 0 _ ths = ths
1270    | match n state ths =
1271    let
1272      val (m_res, s_res) = (Df flatten o unzip o map (match1 state)) ths
1273      val state' = initializel state m_res
1274      val s_res' = clean_add_thms ths s_res
1275    in
1276      match (n - 1) state' s_res'
1277    end;
1278in
1279  fun MATCH_MP_DEPTH n =
1280    match n o initializel ([], []) o filter is_mp
1281end;
1282
1283fun ASM_MATCH_MP_TAC_N depth ths =
1284  POP_ASSUM_LIST
1285  (EVERY o map ASSUME_TAC o rev o MATCH_MP_DEPTH depth ths)
1286
1287val ASM_MATCH_MP_TAC = ASM_MATCH_MP_TAC_N 10;
1288
1289end; (* probTools *)
1290
1291