1structure utilsLib :> utilsLib =
2struct
3
4open HolKernel boolLib bossLib
5open state_transformerTheory
6open wordsLib integer_wordLib bitstringLib
7
8val ERR = Feedback.mk_HOL_ERR "utilsLib"
9val WARN = Feedback.HOL_WARNING "utilsLib"
10
11structure Parse =
12struct
13   open Parse
14   val (Type,Term) = parse_from_grammars wordsTheory.words_grammars
15end
16open Parse
17
18(* ------------------------------------------------------------------------- *)
19
20fun cache size cmp f =
21   let
22      val d = ref (Redblackmap.mkDict cmp)
23      val k = ref []
24      val finite = 0 < size
25   in
26      fn v =>
27         case Redblackmap.peek (!d, v) of
28            SOME r => r
29          | NONE =>
30               let
31                  val r = f v
32               in
33                  if finite
34                     then (k := !k @ [v]
35                           ; if size < Redblackmap.numItems (!d)
36                                then case List.getItem (!k) of
37                                        SOME (h, t) =>
38                                          (d := fst (Redblackmap.remove (!d, h))
39                                           ; k := t)
40                                      | NONE => raise ERR "cache" "empty"
41                              else ())
42                  else ()
43                  ; d := Redblackmap.insert (!d, v, r)
44                  ; r
45               end
46   end
47
48(* ------------------------------------------------------------------------- *)
49
50fun partitions [] = []
51  | partitions [x] = [[[x]]]
52  | partitions (h::t) =
53      let
54         val ps = partitions t
55      in
56         List.concat
57           (List.map
58              (fn p =>
59                  List.tabulate
60                     (List.length p,
61                      fn i =>
62                         Lib.mapi (fn j => fn l =>
63                                      if i = j then h :: l else l) p)) ps) @
64          List.map (fn l => [h] :: l) ps
65      end
66
67fun classes eq =
68   let
69      fun add x =
70         let
71            fun iter a =
72               fn [] => [x] :: a
73                | h :: t => if eq (x, hd h)
74                               then a @ ((x :: h) :: t)
75                            else iter (h :: a) t
76         in
77            iter []
78         end
79      fun iter a =
80         fn [] => a
81          | h :: t => iter (add h a) t
82   in
83      iter []
84   end
85
86(* ------------------------------------------------------------------------- *)
87
88local
89   fun loop a =
90      fn [] => a
91       | r => (case Lib.total (Lib.split_after 8) r of
92                  SOME (x, y) => loop (x :: a) y
93                | NONE => r :: a)
94in
95   fun rev_endian l = List.concat (loop [] l)
96end
97
98(* ------------------------------------------------------------------------- *)
99
100local
101   fun find_pos P =
102      let
103         fun iter n [] = n
104           | iter n (h::t) = if P h then n else iter (n + 1) t
105      in
106         iter 0
107      end
108in
109   fun process_option P g s d l f =
110      let
111         val (l, r) = List.partition P l
112         val positions = Lib.mk_set (List.map g l)
113         val result =
114            if List.null positions
115               then d
116            else if List.length positions = 1
117               then f (hd positions)
118            else raise ERR "process_option" ("More than one " ^ s ^ " option.")
119      in
120         (result, r)
121      end
122   fun process_opt opt = process_option (Lib.C Lib.mem (List.concat opt))
123                           (fn option => find_pos (Lib.mem option) opt)
124end
125
126fun print_options bk name l =
127   let
128      val s = " * " ^ name ^ " options:"
129      val s = case bk of SOME w => StringCvt.padRight #" " w s | _ => s ^ "\n\t"
130   in
131      TextIO.print (s ^ String.concat (Lib.commafy (List.map hd l)) ^ "\n")
132   end
133
134(* ------------------------------------------------------------------------- *)
135
136fun maximal (cmp: 'a cmp) f =
137   let
138      fun max_acc (best as (left, vm, m, right)) l =
139         fn [] => (m, List.revAppend (left, right))
140          | h :: t =>
141              let
142                 val vh = f h
143                 val best' = case cmp (vh, vm) of
144                                General.GREATER => (l, vh, h, t)
145                              | _ => best
146              in
147                 max_acc best' (h :: l) t
148              end
149   in
150      fn [] => raise ERR "maximal" "empty"
151       | h :: t => max_acc ([], f h, h, t) [h] t
152   end
153
154fun minimal cmp = maximal (Lib.flip_cmp cmp)
155
156(* ------------------------------------------------------------------------- *)
157
158fun padLeft c n l = List.tabulate (n - List.length l, fn _ => c) @ l
159(* fun padRight c n l = l @ List.tabulate (n - List.length l, fn _ => c) *)
160
161fun pick [] l2 = (WARN "pick" "not picking"; l2)
162  | pick l1 l2 =
163      let
164         val l = Lib.zip l1 l2
165      in
166         List.mapPartial (fn (a, b) => if a then SOME b else NONE) l
167      end
168
169type cover = {redex: term frag list, residue: term} list list
170
171fun augment (v, l1) l2 =
172   List.concat (List.map (fn x => List.map (fn c => ((v |-> x) :: c)) l2) l1)
173
174fun zipLists f =
175   let
176      fun loop a l =
177         if List.null (hd l) then List.map f (List.rev a)
178         else loop (List.map List.hd l::a) (List.map List.tl l)
179   in
180      loop []
181   end
182
183fun list_mk_wordii w = List.map (fn i => wordsSyntax.mk_wordii (i, w))
184
185fun tab_fixedwidth m w =
186   List.tabulate
187     (m, fn n => bitstringSyntax.padded_fixedwidth_of_num (Arbnum.fromInt n, w))
188
189local
190   fun liftSplit f = (Substring.string ## Substring.string) o f o Substring.full
191in
192   fun splitAtChar P = liftSplit (Substring.splitl (not o P))
193   fun splitAtPos n = liftSplit (fn s => Substring.splitAt (s, n))
194end
195
196val lowercase = String.map Char.toLower
197val uppercase = String.map Char.toUpper
198
199val removeSpaces =
200   String.translate (fn c => if Char.isSpace c then "" else String.str c)
201
202val long_term_to_string =
203   Lib.with_flag (Globals.linewidth, 1000) Hol_pp.term_to_string
204
205val strings_to_quote =
206   (Lib.list_of_singleton o QUOTE o String.concat o Lib.separate "\n")
207   : string list -> string frag list
208
209val lhsc = boolSyntax.lhs o Thm.concl
210val rhsc = boolSyntax.rhs o Thm.concl
211val eval = rhsc o bossLib.EVAL
212val dom = fst o Type.dom_rng
213val rng = snd o Type.dom_rng
214
215local
216   val cnv = Conv.QCONV (REWRITE_CONV [boolTheory.DE_MORGAN_THM])
217in
218   fun mk_negation tm = rhsc (cnv (boolSyntax.mk_neg tm))
219end
220
221local
222   fun mk_x (s, ty) = Term.mk_var ("x" ^ String.extract (s, 1, NONE), ty)
223   fun rename v =
224      case Lib.total Term.dest_var v of
225         SOME (s_ty as (s, _)) =>
226           if String.sub (s, 0) = #"_" then SOME (v |-> mk_x s_ty) else NONE
227       | NONE => NONE
228   val mk_l = String.implode o Lib.separate #";" o String.explode o uppercase
229in
230   fun pattern s =
231      let
232         val tm = Parse.Term [HOLPP.QUOTE ("[" ^ mk_l s ^ "]")]
233      in
234         Term.subst (List.mapPartial rename (Term.free_vars tm)) tm
235      end
236end
237
238val strip_add_or_sub =
239   let
240      fun iter a t =
241         case Lib.total wordsSyntax.dest_word_add t of
242            SOME (l, r) => iter ((true, r) :: a) l
243          | NONE => (case Lib.total wordsSyntax.dest_word_sub t of
244                        SOME (l, r) => iter ((false, r) :: a) l
245                      | NONE => (t, a))
246   in
247      iter []
248   end
249
250val get_function =
251   fst o boolSyntax.strip_comb o boolSyntax.lhs o
252   snd o boolSyntax.strip_forall o List.hd o boolSyntax.strip_conj o Thm.concl
253
254fun vacuous thm =
255   let
256      val (h, c) = Thm.dest_thm thm
257   in
258      Teq c orelse List.exists Feq h
259   end
260
261fun add_to_rw_net f (thm: thm, n) = LVTermNet.insert (n, ([], f thm), thm)
262
263fun mk_rw_net f = List.foldl (add_to_rw_net f) LVTermNet.empty
264
265fun find_rw net tm =
266   case LVTermNet.match (net, tm) of
267      [] => raise ERR "find_rw" "not found"
268    | l => List.map snd l: thm list
269
270(* ---------------------------- *)
271
272local
273   val cmp = reduceLib.num_compset ()
274   val () = computeLib.add_thms
275              [pairTheory.UNCURRY, combinTheory.o_THM,
276               state_transformerTheory.FOR_def,
277               state_transformerTheory.BIND_DEF,
278               state_transformerTheory.UNIT_DEF] cmp
279   val FOR_CONV = computeLib.CBV_CONV cmp
280   fun term_frag_of_int i = [QUOTE (Int.toString i)]: term frag list
281in
282   fun for_thm (h, l) =
283      state_transformerTheory.FOR_def
284      |> Conv.CONV_RULE (Conv.DEPTH_CONV Conv.FUN_EQ_CONV)
285      |> Q.SPECL [term_frag_of_int h, term_frag_of_int l, `a`, `s`]
286      |> Conv.RIGHT_CONV_RULE FOR_CONV
287      |> Drule.GEN_ALL
288end
289
290(* ---------------------------- *)
291
292(* Variant of UNDISCH
293   [..] |- a1 /\ ... /\ aN ==> t    |->
294   [.., a1, .., aN] |- t
295*)
296
297local
298   fun AND_INTRO_CONV n tm =
299      if n = 0 then ALL_CONV tm
300      else (Conv.REWR_CONV satTheory.AND_IMP
301            THENC Conv.RAND_CONV (AND_INTRO_CONV (n - 1))) tm
302in
303   fun STRIP_UNDISCH th =
304      let
305         val ps =
306            boolSyntax.strip_conj (fst (boolSyntax.dest_imp (Thm.concl th)))
307         val th' = Conv.CONV_RULE (AND_INTRO_CONV (List.length ps - 1)) th
308      in
309         Drule.LIST_MP (List.map Thm.ASSUME ps) th'
310      end
311end
312
313val save_as = Lib.curry Theory.save_thm
314fun usave_as s = save_as s o STRIP_UNDISCH
315fun ustore_thm (s, t, tac) = usave_as s (Q.prove (t, tac))
316
317local
318  val names = ref ([] : string list)
319  fun add (n, th) = (names := n :: !names; Theory.save_thm (n, th))
320  val add_list = List.map add
321in
322  fun reset_thms () = names := []
323  fun save_thms name l =
324    add_list
325     (case l of
326         [] => raise ERR "save_thms" "empty"
327       | [th] => [(name, th)]
328       | _ => ListPair.zip
329                 (List.tabulate
330                    (List.length l, fn i => name ^ "_" ^ Int.toString i), l))
331  fun adjoin_thms () =
332    Theory.adjoin_to_theory
333      { sig_ps = SOME (fn _ => PP.add_string ("val rwts : string list")),
334        struct_ps =
335          SOME (fn _ =>
336                   PP.block PP.INCONSISTENT 12 (
337                     [PP.add_string "val rwts = ["] @
338                     PP.pr_list (PP.add_string o Lib.quote)
339                                [PP.add_string ",", PP.add_break (1, 0)]
340                                (!names) @
341                     [PP.add_string "]", PP.add_newline]
342                   )
343               )
344      }
345end
346
347
348(* Variant of UNDISCH
349   [..] |- T ==> t    |->   [..] |- t
350   [..] |- F ==> t    |->   [..] |- T
351   [..] |- p ==> t    |->   [.., p] |- t
352*)
353
354local
355   val thms = Drule.CONJUNCTS (Q.SPEC `t` boolTheory.IMP_CLAUSES)
356   val T_imp = Drule.GEN_ALL (hd thms)
357   val F_imp = Drule.GEN_ALL (List.nth (thms, 2))
358   val NT_imp = DECIDE ``(~F ==> t) = t``
359   val T_imp_rule = Conv.CONV_RULE (Conv.REWR_CONV T_imp)
360   val F_imp_rule = Conv.CONV_RULE (Conv.REWR_CONV F_imp)
361   val NT_imp_rule = Conv.CONV_RULE (Conv.REWR_CONV NT_imp)
362   fun dest_neg_occ_var tm1 tm2 =
363      case Lib.total boolSyntax.dest_neg tm1 of
364         SOME v => if Term.is_var v andalso not (Term.var_occurs v tm2)
365                      then SOME v
366                   else NONE
367       | NONE => NONE
368in
369   fun ELIM_UNDISCH thm =
370      case Lib.total boolSyntax.dest_imp (Thm.concl thm) of
371         SOME (l, r) =>
372            if Teq l then T_imp_rule thm
373            else if Feq l then F_imp_rule thm
374            else if Term.is_var l andalso not (Term.var_occurs l r)
375               then T_imp_rule (Thm.INST [l |-> boolSyntax.T] thm)
376            else (case dest_neg_occ_var l r of
377                     SOME v => F_imp_rule (Thm.INST [v |-> boolSyntax.F] thm)
378                   | NONE => Drule.UNDISCH thm)
379       | NONE => raise ERR "ELIM_UNDISCH" ""
380end
381
382fun LIST_DISCH tms thm = List.foldl (Lib.uncurry Thm.DISCH) thm tms
383
384(* ---------------------------- *)
385
386local
387   val rl =
388      REWRITE_RULE [boolTheory.NOT_CLAUSES, GSYM boolTheory.AND_IMP_INTRO,
389                    boolTheory.DE_MORGAN_THM]
390   val pats = [``~ ~a: bool``, ``a /\ b``, ``~(a \/ b)``]
391   fun mtch tm = List.exists (fn p => Lib.can (Term.match_term p) tm) pats
392in
393   fun HYP_CANON_RULE thm =
394      let
395         val hs = List.filter mtch (Thm.hyp thm)
396      in
397         List.foldl
398           (fn (h, t) => repeat ELIM_UNDISCH (rl (Thm.DISCH h t))) thm hs
399      end
400end
401
402(* Apply rule to hyphothesis tm *)
403
404fun HYP_RULE r tm = ELIM_UNDISCH o r o Thm.DISCH tm
405
406(* Apply rule to hyphotheses satisfying P *)
407
408fun PRED_HYP_RULE r P thm =
409   List.foldl (Lib.uncurry (HYP_RULE r)) thm (List.filter P (Thm.hyp thm))
410
411(* Apply rule to hyphotheses matching pat *)
412
413fun MATCH_HYP_RULE r pat = PRED_HYP_RULE r (Lib.can (Term.match_term pat))
414
415(* Apply conversion c to all hyphotheses *)
416
417fun ALL_HYP_RULE r = PRED_HYP_RULE r (K true)
418
419local
420   fun LAND_RULE c = Conv.CONV_RULE (Conv.LAND_CONV c)
421in
422   fun HYP_CONV_RULE c = HYP_RULE (LAND_RULE c)
423   fun PRED_HYP_CONV_RULE c = PRED_HYP_RULE (LAND_RULE c)
424   fun MATCH_HYP_CONV_RULE c = MATCH_HYP_RULE (LAND_RULE c)
425   fun ALL_HYP_CONV_RULE c = ALL_HYP_RULE (LAND_RULE c)
426   fun FULL_CONV_RULE c = ALL_HYP_CONV_RULE c o Conv.CONV_RULE c
427end
428
429(* ---------------------------- *)
430
431(* CBV_CONV but fail if term unchanged *)
432fun CHANGE_CBV_CONV cmp = Conv.CHANGED_CONV (computeLib.CBV_CONV cmp)
433
434local
435   val rule = PURE_REWRITE_RULE [SYM wordsTheory.WORD_NEG_1]
436   val and_thms = rule wordsTheory.WORD_AND_CLAUSES
437   val or_thms  = rule wordsTheory.WORD_OR_CLAUSES
438   val xor_thms = rule wordsTheory.WORD_XOR_CLAUSES
439   val alpha_rwts =
440      [boolTheory.COND_ID, wordsTheory.WORD_SUB_RZERO,
441       wordsTheory.WORD_ADD_0, wordsTheory.WORD_MULT_CLAUSES,
442       and_thms, or_thms, xor_thms, wordsTheory.WORD_EXTRACT_ZERO2,
443       wordsTheory.w2w_0, wordsTheory.WORD_SUB_REFL, wordsTheory.SHIFT_ZERO]
444   val UINT_MAX_LOGIC_CONV =
445     let
446       fun get th = List.take (Drule.CONJUNCTS (Drule.SPEC_ALL th), 2)
447     in
448       (Conv.LAND_CONV wordsLib.UINT_MAX_CONV
449        ORELSEC Conv.RAND_CONV wordsLib.UINT_MAX_CONV)
450       THENC Conv.CHANGED_CONV
451               (PURE_REWRITE_CONV
452                  (List.concat (List.map get [and_thms, or_thms, xor_thms])))
453     end
454   val WALPHA_CONV = REWRITE_CONV alpha_rwts
455in
456   val WGROUND_CONV =
457      WALPHA_CONV
458      THENC Conv.DEPTH_CONV (wordsLib.WORD_GROUND_CONV ORELSEC
459                             integer_wordLib.INT_WORD_GROUND_CONV)
460      THENC Conv.DEPTH_CONV UINT_MAX_LOGIC_CONV
461      THENC WALPHA_CONV
462end
463
464fun NCONV n conv = Lib.funpow n (Lib.curry (op THENC) conv) Conv.ALL_CONV
465fun SRW_CONV thms = SIMP_CONV (srw_ss()) thms
466val EXTRACT_CONV = SIMP_CONV (srw_ss()++wordsLib.WORD_EXTRACT_ss) []
467val SET_CONV = SIMP_CONV (bool_ss++pred_setLib.PRED_SET_ss) []
468fun SRW_RULE thms = Conv.CONV_RULE (SRW_CONV thms)
469val SET_RULE = Conv.CONV_RULE SET_CONV
470val o_RULE = REWRITE_RULE [combinTheory.o_THM]
471
472fun qm l = Feedback.trace ("metis", 0) (metisLib.METIS_PROVE l)
473fun qm_tac l = Feedback.trace ("metis", 0) (metisLib.METIS_TAC l)
474
475(* ---------------------------- *)
476
477(* mk_cond_exhaustive_thm i
478   generates a theorem of the form:
479
480 |-  !x : i word v0 v1 ... v(2^i).
481        (if x = 0w then v0
482         else if x = 1w then v1
483           ...
484         else v(2^i)) =
485        (if x = 0w then v0
486         else if x = 1w then v1
487           ...
488         else v(2^i - 1))
489
490*)
491
492fun mk_cond_exhaustive_thm i =
493  let
494    val _ = i < 7 orelse
495            raise ERR "mk_cond_exhaustive_thm" "word size must be < 7"
496    val ty = wordsSyntax.mk_int_word_type i
497    val n = Word.toInt (Word.<< (0w1, Word.fromInt i))
498    val vars = List.tabulate
499                (n + 1, fn j => Term.mk_var ("v" ^ Int.toString j, Type.alpha))
500    val x = Term.mk_var ("x", ty)
501    val fold =
502      List.foldr
503        (fn (v, (j, t)) =>
504          (j - 1,
505           boolSyntax.mk_cond
506             (boolSyntax.mk_eq (x, wordsSyntax.mk_wordii (j, i)), v, t)))
507    val l = fold (n - 1, List.last vars) (Lib.butlast vars)
508    val vars = Lib.butlast vars
509    val r = fold (n - 2, List.last vars) (Lib.butlast vars)
510    val th = Tactical.prove
511               (boolSyntax.mk_eq (snd l, snd r),
512                wordsLib.Cases_on_word_value `^x` THEN bossLib.simp [])
513  in
514    Drule.GEN_ALL th
515  end
516
517(* ---------------------------- *)
518
519
520fun accessor_update_fns ty =
521  let
522    val {Thy, Tyop, ...} = Type.dest_thy_type ty
523  in
524    List.map
525      (fn (s, {ty = fld_ty, ...}) =>
526         let
527           val v = Term.mk_var ("v", fld_ty)
528           val kv = Term.inst [Type.beta |-> fld_ty]
529                      (boolSyntax.mk_icomb (combinSyntax.K_tm, v))
530         in
531           (Term.prim_mk_const {Name = Tyop ^ "_" ^ s, Thy = Thy},
532            Term.mk_comb
533              (Term.prim_mk_const
534                 {Name = Tyop ^ "_" ^ s ^ "_fupd", Thy = Thy}, kv))
535         end)
536      (TypeBase.fields_of ty)
537  end
538val accessor_fns = List.map fst o accessor_update_fns
539val update_fns = List.map snd o accessor_update_fns
540
541fun map_conv (cnv: conv) = Drule.LIST_CONJ o List.map cnv
542
543local
544   val thm2l =
545      qm [] ``!f:'a -> 'b -> 'c.
546                f (if b then x else y) z = (if b then f x z else f y z)``
547   val thm2r =
548      qm [] ``!f:'a -> 'b -> 'c.
549                f z (if b then x else y) = (if b then f z x else f z y)``
550   fun is_binop tm =
551      case boolSyntax.strip_fun (Term.type_of tm) of
552         ([ty1, ty2], ty3) =>
553            ty1 = ty2 andalso (ty3 = Type.bool orelse ty3 = ty1)
554       | _ => false
555   fun spec_thm tm =
556      let
557         val rule = Drule.GEN_ALL o o_RULE o Drule.ISPEC tm
558      in
559         if is_binop tm
560            then Thm.CONJ (rule thm2l) (rule thm2r)
561         else rule boolTheory.COND_RAND
562      end
563in
564   val mk_cond_rand_thms = map_conv spec_thm
565end
566
567local
568   val COND_UPDATE0 = Q.prove(
569      `!b s1 : 'a s2.
570        (if b then ((), s1) else ((), s2)) = ((), if b then s1 else s2)`,
571      RW_TAC std_ss [])
572   val COND_UPDATE1 = Q.prove(
573      `!f : ('a -> 'b) -> 'c -> 'd b v1 v2 s1 s2.
574         (if b then f (K v1) s1 else f (K v2) s2) =
575         f (K (if b then v1 else v2)) (if b then s1 else s2)`,
576      Cases_on `b` THEN REWRITE_TAC [])
577   val COND_UPDATE2 = Q.prove(
578      `(!b a x y f : 'a -> 'b.
579         (if b then (a =+ x) f else (a =+ y) f) =
580         (a =+ if b then x else y) f) /\
581       (!b a y f : 'a -> 'b.
582         (if b then f else (a =+ y) f) = (a =+ if b then f a else y) f) /\
583       (!b a x f : 'a -> 'b.
584         (if b then (a =+ x) f else f) = (a =+ if b then x else f a) f)`,
585      REPEAT CONJ_TAC
586      THEN Cases
587      THEN REWRITE_TAC [combinTheory.APPLY_UPDATE_ID])
588   val COND_UPDATE3 = qm [] ``!b. (if b then T else F) = b``
589   fun mk_cond_update_thm component_equality (t1, t2) =
590      let
591         val thm = Drule.ISPEC (boolSyntax.rator t2) COND_UPDATE1
592         val thm0 = Drule.SPEC_ALL thm
593         val v = hd (Term.free_vars t2)
594         val (v1, v2, s1, s2) =
595            case boolSyntax.strip_forall (Thm.concl thm) of
596               ([_, v1, v2, s1, s2], _) => (v1, v2, s1, s2)
597             | _ => raise ERR "mk_cond_update_thms" ""
598         val s1p = Term.mk_comb (t1, s1)
599         val s2p = Term.mk_comb (t1, s2)
600         val id_thm =
601            Tactical.prove(
602               boolSyntax.mk_eq
603                  (Term.subst [v |-> s1p] (Term.mk_comb (t2, s1)), s1),
604               SRW_TAC [] [component_equality])
605         val rule = Drule.GEN_ALL o REWRITE_RULE [id_thm]
606         val thm1 = rule (Thm.INST [v1 |-> s1p] thm0)
607         val thm2 = rule (Thm.INST [v2 |-> s2p] thm0)
608      in
609         [thm, thm1, thm2]
610      end
611   fun cond_update_thms ty =
612      let
613         val {Thy, Tyop, ...} = Type.dest_thy_type ty
614         val component_equality = DB.fetch Thy (Tyop ^ "_component_equality")
615      in
616        List.concat
617          (List.map (mk_cond_update_thm component_equality)
618             (accessor_update_fns ty))
619      end
620in
621   fun mk_cond_update_thms l =
622      [boolTheory.COND_ID, COND_UPDATE0, COND_UPDATE2, COND_UPDATE3] @
623      List.concat (List.map cond_update_thms l)
624end
625
626(*
627  Conversion for rewriting instances of:
628
629    f (case x of .. => y1 | .. => y2 | .. => yn)
630
631  to
632
633    case x of .. => f y1 | .. => f y2 | .. => f yn
634*)
635
636local
637  val case_rng = snd o HolKernel.strip_fun o Term.type_of
638  val term_rng = snd o Type.dom_rng o Term.type_of
639  val tac =
640    CONV_TAC (Conv.FORK_CONV
641                (Conv.RAND_CONV Drule.LIST_BETA_CONV, Drule.LIST_BETA_CONV))
642    THEN REFL_TAC
643  fun CASE_RAND_CONV1 rand_f tm =
644    let
645      val (f, x) = Term.dest_comb tm
646      val _ = Term.same_const rand_f f orelse Term.term_eq rand_f f orelse
647              raise ERR "CASE_RAND_CONV" ""
648      val (c, x, l) =
649        case boolSyntax.strip_comb x of
650           (c, x :: l) => (c, x, l)
651         | _ => raise ERR "CASE_RAND_CONV" ""
652      val ty = Term.type_of x
653      val case_c = TypeBase.case_const_of ty
654      val l' =
655        List.map
656          (fn t => let
657                     val (vs, b) = boolSyntax.strip_abs t
658                   in
659                     boolSyntax.list_mk_abs (vs, Term.mk_comb (f, b))
660                   end) l
661      val fvs = List.concat (List.map Term.free_vars l')
662      val x' = Term.variant fvs (Term.mk_var ("x", ty))
663      val th =
664        Tactical.prove
665          (boolSyntax.mk_eq
666            (Term.mk_comb (f, Term.list_mk_comb (c, x' :: l)),
667             boolSyntax.list_mk_icomb (case_c, x' :: l')),
668           Cases_on `^x'`
669           THEN ONCE_REWRITE_TAC [TypeBase.case_def_of ty]
670           THEN tac
671          )
672    in
673      Conv.REWR_CONV th tm
674    end
675  val literal_case_rand = Q.prove(
676    `!f : 'a -> 'b x : 'c y a b.
677       f (literal_case (\v. if v = x then a else b) y) =
678       literal_case (\v. if v = x then f a else f b) y`,
679    SIMP_TAC std_ss [boolTheory.literal_case_DEF, boolTheory.COND_RAND])
680in
681  fun CASE_RAND_CONV f =
682    let
683      val cnv = Conv.REWR_CONV (Drule.ISPEC f literal_case_rand)
684    in
685      Conv.TOP_DEPTH_CONV (cnv ORELSEC CASE_RAND_CONV1 f)
686    end
687end
688
689(* Substitution allowing for type match *)
690
691local
692   fun match_residue {redex = a, residue = b} =
693      let
694         val m = Type.match_type (Term.type_of b) (Term.type_of a)
695      in
696         a |-> Term.inst m b
697      end
698in
699   fun match_subst s = Term.subst (List.map match_residue s)
700end
701
702(*
703fun match_mk_eq (a, b) =
704   let
705      val m = Type.match_type (Term.type_of b) (Term.type_of a)
706   in
707      boolSyntax.mk_eq (a, Term.inst m b)
708   end
709
710fun mk_eq_contexts (a, l) = List.map (fn b => [match_mk_eq (a, b)]) l
711*)
712
713fun avoid_name_clashes tm2 tm1 =
714   let
715      val v1 = Term.free_vars tm1
716      val v2 = Term.free_vars tm2
717      val ns = List.map (fst o Term.dest_var) v2
718      val (l, r) =
719         List.partition (fn v => Lib.mem (fst (Term.dest_var v)) ns) v1
720      val v2 = v2 @ r
721      val sb = List.foldl
722                  (fn (v, (sb, avoids)) =>
723                     let
724                        val v' = Term.numvariant avoids v
725                     in
726                        ((v |-> v') :: sb, v' :: avoids)
727                     end) ([], v2) l
728   in
729      Term.subst (fst sb) tm1
730   end
731
732local
733   fun mk_fupd s f = s ^ "_" ^ f ^ "_fupd"
734   val name = fst o Term.dest_const o fst o Term.dest_comb
735in
736   fun mk_state_id_thm eqthm =
737      let
738         val ty = Term.type_of (fst (boolSyntax.dest_forall (Thm.concl eqthm)))
739         fun mk_thm l =
740            let
741               val {Tyop, Thy, ...} = Type.dest_thy_type ty
742               val mk_f = mk_fupd Tyop
743               val fns = update_fns ty
744               fun get s = List.find (fn f => name f = mk_f s) fns
745               val l1 = List.mapPartial get l
746               val s = Term.mk_var ("s", ty)
747               val h = hd l1
748               val id = Term.prim_mk_const {Thy = Thy, Name = Tyop ^ "_" ^ hd l}
749               val id =
750                  Term.subst [hd (Term.free_vars h) |-> Term.mk_comb (id, s)] h
751               val after = List.foldr
752                              (fn (f, tm) =>
753                                 let
754                                    val f1 = avoid_name_clashes tm f
755                                 in
756                                    Term.mk_comb (f1, tm)
757                                 end) s (tl l1)
758               val goal = boolSyntax.mk_eq (Term.mk_comb (id, after), after)
759            in
760               Drule.GEN_ALL (Tactical.prove (goal, bossLib.SRW_TAC [] [eqthm]))
761            end
762      in
763         Drule.LIST_CONJ o List.map mk_thm
764      end
765end
766
767(* ---------------------------- *)
768
769(* Rewrite tm using theorem thm, instantiating free variables from hypotheses
770   as required *)
771
772local
773   fun TRY_EQ_FT thm =
774      if boolSyntax.is_eq (Thm.concl thm)
775         then thm
776      else (Drule.EQF_INTRO thm handle HOL_ERR _ => Drule.EQT_INTRO thm)
777in
778   fun INST_REWRITE_CONV1 thm =
779      let
780         val mtch = Term.match_term (boolSyntax.lhs (Thm.concl thm))
781      in
782         fn tm => PURE_ONCE_REWRITE_CONV [Drule.INST_TY_TERM (mtch tm) thm] tm
783                  handle HOL_ERR _ => raise ERR "INST_REWRITE_CONV1" ""
784      end
785   fun INST_REWRITE_CONV l =
786      let
787         val thms =
788            l |> List.map (Drule.CONJUNCTS o Drule.SPEC_ALL)
789              |> List.concat
790              |> List.map (TRY_EQ_FT o Drule.SPEC_ALL)
791         val net = List.partition (List.null o Thm.hyp) o
792                   find_rw (mk_rw_net lhsc thms)
793      in
794         Conv.REDEPTH_CONV
795           (fn tm =>
796               case net tm of
797                  ([], []) => raise Conv.UNCHANGED
798                | (thm :: _, _) => Conv.REWR_CONV thm tm
799                | ([], thm :: _) => INST_REWRITE_CONV1 thm tm)
800      end
801   fun INST_REWRITE_RULE thm = Conv.CONV_RULE (INST_REWRITE_CONV thm)
802end
803
804(* ---------------------------- *)
805
806(*
807  Given two theorems of the form:
808
809    [..., tm, ...] |- a
810    [..., ~tm, ...] |- a
811
812  produce theorem of the form
813
814    [...] |- a
815*)
816
817local
818   val rule =
819      Conv.CONV_RULE
820         (Conv.CHANGED_CONV
821             (REWRITE_CONV [DECIDE ``((b ==> a) /\ (~b ==> a)) <=> a``,
822                            DECIDE ``((~b ==> a) /\ (b ==> a)) <=> a``]))
823   fun SMART_DISCH tm thm =
824      let
825         val l = Thm.hyp thm
826         val thm' = Thm.DISCH tm thm
827         val l' = Thm.hyp thm'
828      in
829         if List.length l' < List.length l
830            then thm'
831         else let
832                 val thm' = Thm.DISCH (boolSyntax.mk_neg tm) thm
833                 val l' = Thm.hyp thm'
834              in
835                 if List.length l' < List.length l
836                    then thm'
837                 else raise ERR "SMART_DISCH" "Term not in hypotheses"
838              end
839      end
840in
841   fun MERGE_CASES tm thm1 thm2 =
842      let
843         val thm3 = SMART_DISCH tm thm1
844         val thm4 = SMART_DISCH tm thm2
845      in
846         rule (Thm.CONJ thm3 thm4)
847      end
848end
849
850(* ---------------------------- *)
851
852local
853   fun base t =
854      case Lib.total boolSyntax.dest_neg t of
855         SOME s => base s
856       | NONE =>
857          (case Lib.total boolSyntax.lhs t of
858              SOME s => s
859            | NONE => t)
860   fun find_occurance r t =
861      Lib.can (HolKernel.find_term (aconv (base t))) r
862   val modified = ref 0
863   fun specialize (conv, tms) thm =
864      let
865         val hs = Thm.hyp thm
866         val hs = List.filter (fn h => List.exists (find_occurance h) tms) hs
867         val sthm = thm |> LIST_DISCH hs
868                        |> REWRITE_RULE (List.map ASSUME tms)
869                        |> Conv.CONV_RULE conv
870                        |> Drule.UNDISCH_ALL
871      in
872         if vacuous sthm then NONE else (Portable.inc modified; SOME sthm)
873      end handle Conv.UNCHANGED => SOME thm
874in
875   fun specialized msg ctms thms =
876      let
877         val sz = Int.toString o List.length
878         val () = print ("Specializing " ^ msg ^ ": " ^ sz thms ^ " -> ")
879         val () = modified := 0
880         val r = List.mapPartial (specialize ctms) thms
881      in
882         print (sz r ^ "(" ^ Int.toString (!modified) ^ ")\n"); r
883      end
884end
885
886(* ---------------------------- *)
887
888(* case split theorem. For example: split_conditions applied to
889
890     |- q = ((if b then x else y), c)
891
892   gives theorems
893
894     [[~b] |- q = (y, c), [b] |- q = (x, c)]
895*)
896
897local
898   fun p q = Drule.UNDISCH (Q.prove(q, RW_TAC bool_ss []))
899   val split_xt = p `b ==> ((if b then x else y) = x: 'a)`
900   val split_yt = p `~b ==> ((if b then x else y) = y: 'a)`
901   val split_zt = p `b ==> ((if ~b then x else y) = y: 'a)`
902   val split_xl = p `b ==> (((if b then x else y), c) = (x, c): 'a # 'b)`
903   val split_yl = p `~b ==> (((if b then x else y), c) = (y, c): 'a # 'b)`
904   val split_zl = p `b ==> (((if ~b then x else y), c) = (y, c): 'a # 'b)`
905   val split_xr = p `b ==> ((c, (if b then x else y)) = (c, x): 'b # 'a)`
906   val split_yr = p `~b ==> ((c, (if b then x else y)) = (c, y): 'b # 'a)`
907   val split_zr = p `b ==> ((c, (if ~b then x else y)) = (c, y): 'b # 'a)`
908   val vb = Term.mk_var ("b", Type.bool)
909   fun REWR_RULE thm = Conv.RIGHT_CONV_RULE (Conv.REWR_CONV thm)
910   fun cond_true b = Thm.INST [vb |-> b] split_xt
911   fun cond_false b = Thm.INST [vb |-> b] split_yt
912   fun split_cond tm =
913      case Lib.total pairSyntax.dest_pair tm of
914         SOME (a, b) =>
915          (case Lib.total boolSyntax.dest_cond a of
916              SOME bxy => SOME (split_xl, split_yl, split_zl, bxy)
917            | NONE => (case Lib.total boolSyntax.dest_cond b of
918                          SOME bxy => SOME (split_xr, split_yr, split_zr, bxy)
919                        | NONE => NONE))
920       | NONE => Lib.total
921                     (fn t => (split_xt, split_yt, split_zt,
922                               boolSyntax.dest_cond t)) tm
923in
924   val split_conditions =
925      let
926         fun loop a t =
927            case split_cond (rhsc t) of
928               SOME (splitx, splity, splitz, (b, x, y)) =>
929                  let
930                     val ty = Term.type_of x
931                     val vx = Term.mk_var ("x", ty)
932                     val vy = Term.mk_var ("y", ty)
933                     fun s cb = Drule.INST_TY_TERM
934                                 ([vb |-> cb, vx |-> x, vy |-> y],
935                                  [Type.alpha |-> ty])
936                     val (split_yz, nb) =
937                        case Lib.total boolSyntax.dest_neg b of
938                           SOME nb => (splitz, nb)
939                         | NONE => (splity, b)
940                  in
941                     loop (loop a (REWR_RULE (s b splitx) t))
942                                  (REWR_RULE (s nb split_yz) t)
943                  end
944             | NONE => t :: a
945      in
946         loop []
947      end
948   fun paths [] = []
949     | paths (h :: t) =
950         [[cond_false h]] @ (List.map (fn p => cond_true h :: p) (paths t))
951end
952
953(* ---------------------------- *)
954
955(* Support for rewriting/evaluation *)
956
957val basic_rewrites =
958   [state_transformerTheory.FOR_def,
959    state_transformerTheory.BIND_DEF,
960    combinTheory.APPLY_UPDATE_THM,
961    combinTheory.K_o_THM,
962    combinTheory.K_THM,
963    combinTheory.o_THM,
964    pairTheory.FST,
965    pairTheory.SND,
966    pairTheory.pair_case_thm,
967    pairTheory.CURRY_DEF,
968    optionTheory.option_case_compute,
969    optionTheory.IS_SOME_DEF,
970    optionTheory.THE_DEF]
971
972local
973   fun in_conv conv tm =
974      case Lib.total pred_setSyntax.dest_in tm of
975         SOME (a1, a2) =>
976            if pred_setSyntax.is_set_spec a2
977               then pred_setLib.SET_SPEC_CONV tm
978            else pred_setLib.IN_CONV conv tm
979       | NONE => raise ERR "in_conv" "not an IN term";
980in
981   fun add_base_datatypes cmp =
982      let
983         val cnv = computeLib.CBV_CONV cmp
984      in
985         computeLib.add_thms basic_rewrites cmp
986         ; List.app (fn x => computeLib.add_conv x cmp)
987             [(pred_setSyntax.in_tm, 2, in_conv cnv),
988              (pred_setSyntax.insert_tm, 2, pred_setLib.INSERT_CONV cnv)]
989      end
990end
991
992local
993   (* Taken from src/datatype/EnumType.sml *)
994   fun gen_triangle l =
995      let
996         fun gen_row i [] acc = acc
997           | gen_row i (h::t) acc = gen_row i t ((i,h)::acc)
998         fun doitall [] acc = acc
999           | doitall (h::t) acc = doitall t (gen_row h t acc)
1000      in
1001         List.rev (doitall l [])
1002      end
1003   fun datatype_rewrites1 ty =
1004      case TypeBase.simpls_of ty of
1005        {convs = [], rewrs = r} => r
1006      | {convs = {conv = c, name = n, ...} :: _, rewrs = r} =>
1007            if String.isSuffix "const_eq_CONV" n
1008               then let
1009                       val neq = Drule.EQF_ELIM o
1010                                 c (K Conv.ALL_CONV) [] o
1011                                 boolSyntax.mk_eq
1012                       val l = ty |> TypeBase.constructors_of
1013                                  |> gen_triangle
1014                                  |> List.map neq
1015                                  |> Drule.LIST_CONJ
1016                    in
1017                       [l, GSYM l] @ r
1018                    end
1019            else r
1020in
1021   fun datatype_rewrites extra thy l =
1022      let
1023         fun typ name = Type.mk_thy_type {Thy = thy, Args = [], Tyop = name}
1024      in
1025         (if extra then List.drop (basic_rewrites, 2) else []) @
1026         List.concat (List.map (datatype_rewrites1 o typ) l)
1027      end
1028end
1029
1030local
1031   fun add_datatype cmp =
1032     computeLib.add_datatype_info cmp o Option.valOf o TypeBase.fetch
1033in
1034   fun add_datatypes l cmp = List.app (add_datatype cmp) l
1035end
1036
1037type inventory = {C: string list, N: int list, T: string list, Thy: string}
1038
1039fun theory_types (i: inventory)  =
1040   let
1041      val {Thy = thy, T = l, ...} = i
1042   in
1043      List.map (fn t => Type.mk_thy_type {Thy = thy, Args = [], Tyop = t}) l
1044   end
1045
1046fun filter_inventory names ({Thy = thy, C = l, N = n, T = t}: inventory) =
1047   let
1048      val es = List.map (fn s => s ^ "_def") names
1049   in
1050      {Thy = thy, C = List.filter (fn t => not (Lib.mem t es)) l, N = n, T = t}
1051   end
1052
1053local
1054   fun bool_bit_thms i =
1055      let
1056         val s = Int.toString i
1057         val b = "boolify" ^ s
1058      in
1059         ["bitify" ^ s ^ "_def", b ^ "_n2w", b ^ "_v2w"]
1060      end
1061   val get_name = fst o Term.dest_const o fst o HolKernel.strip_comb o
1062                  boolSyntax.lhs o snd o boolSyntax.strip_forall o
1063                  List.hd o boolSyntax.strip_conj o Thm.concl
1064in
1065   fun theory_rewrites (thms, i: inventory) =
1066      let
1067         val thm_names = List.map get_name thms
1068         val {Thy = thy, C = l, N = n, ...} = filter_inventory thm_names i
1069         val m = List.concat (List.map bool_bit_thms n)
1070      in
1071         List.map (fn t => DB.fetch thy t) (l @ m) @ thms
1072      end
1073end
1074
1075fun add_theory (x as (_, i)) cmp =
1076   ( add_datatypes (theory_types i) cmp
1077   ; computeLib.add_thms (theory_rewrites x) cmp)
1078
1079fun add_to_the_compset x = computeLib.add_funs (theory_rewrites x)
1080
1081fun theory_compset x =
1082   let
1083      val cmp = wordsLib.words_compset ()
1084   in
1085      add_base_datatypes cmp; add_theory x cmp; cmp
1086   end
1087
1088(* ---------------------------- *)
1089
1090(* Help prove theorems of the form:
1091
1092|- rec'r (bit_field_insert h l w (reg'r q)) = q with <| ? := ?; ... |>
1093
1094Where "r" is some register (record) component in the theory "thy".
1095
1096*)
1097
1098local
1099   fun EXTRACT_BIT_CONV tm =
1100      if fcpSyntax.is_fcp_index tm
1101         then blastLib.BBLAST_CONV tm
1102      else Conv.NO_CONV tm
1103   val bit_field_insert_tm =
1104      ``bit_field_insert a b (w: 'a word) : 'b word -> 'b word``
1105in
1106   fun BIT_FIELD_INSERT_CONV thy r =
1107      let
1108         val s = thy ^ "_state"
1109         val ty1 = Type.mk_thy_type {Thy = thy, Tyop = r, Args = []}
1110         val ty2 = Type.mk_thy_type {Thy = thy, Tyop = s, Args = []}
1111         val au = accessor_update_fns ty1 @ accessor_update_fns ty2
1112         val au = op @ (ListPair.unzip au)
1113      in
1114         REWRITE_CONV
1115           ([boolTheory.COND_ID,
1116             mk_cond_rand_thms (bit_field_insert_tm :: au)] @
1117             datatype_rewrites true thy [r, s])
1118         THENC Conv.DEPTH_CONV EXTRACT_BIT_CONV
1119         THENC Conv.DEPTH_CONV (wordsLib.WORD_BIT_INDEX_CONV true)
1120      end
1121   fun REC_REG_BIT_FIELD_INSERT_TAC thy r =
1122      let
1123         val cnv = BIT_FIELD_INSERT_CONV thy r
1124         val f = DB.fetch thy
1125         val reg' = f ("reg'" ^ r ^ "_def")
1126         val rec' = f ("rec'" ^ r ^ "_def")
1127         val eq = f (r ^ "_component_equality")
1128      in
1129         fn q =>
1130            Cases_on q
1131            THEN TRY STRIP_TAC
1132            THEN REWRITE_TAC [reg']
1133            THEN CONV_TAC cnv
1134            THEN BETA_TAC
1135            THEN REWRITE_TAC [rec', eq, wordsTheory.bit_field_insert_def]
1136            THEN CONV_TAC cnv
1137            THEN REPEAT CONJ_TAC
1138            THEN blastLib.BBLAST_TAC
1139      end
1140end
1141
1142(* Make a theorem of the form
1143
1144|- !x. reg'r x = x.? @@ x.?
1145
1146*)
1147
1148local
1149   fun mk_component_subst v =
1150      fn h =>
1151         let
1152            val (x, y) = boolSyntax.dest_eq h
1153         in
1154            x |-> Term.mk_comb (Term.rator y, v)
1155         end
1156in
1157   fun mk_reg_thm thy r =
1158      let
1159         val ftch = DB.fetch thy
1160         val reg' = ftch ("reg'" ^ r ^ "_def")
1161         val a = ftch (r ^ "_accessors")
1162         val ((_, v), (vs, m)) =
1163            reg'
1164            |> Drule.SPEC_ALL
1165            |> rhsc
1166            |> Term.dest_comb
1167            |> (Term.dest_comb ## boolSyntax.strip_abs)
1168         val mk_s = mk_component_subst v o Thm.concl o SYM o Drule.SPECL vs
1169         val tm = Term.subst (List.map mk_s (Drule.CONJUNCTS a)) m
1170      in
1171         Tactical.prove
1172            (boolSyntax.mk_eq (Term.mk_comb (get_function reg', v), tm),
1173             REC_REG_BIT_FIELD_INSERT_TAC thy r `^v`)
1174         |> Drule.GEN_ALL
1175      end
1176end
1177
1178(* ---------------------------- *)
1179
1180local
1181   val dr = Type.dom_rng o Term.type_of
1182   val dom = fst o dr
1183   val rng = snd o dr
1184   fun mk_def thy tm =
1185      let
1186         val name = fst (Term.dest_const tm)
1187         val (l, r) = splitAtChar (Lib.equal #"@") name
1188      in
1189         if r = "" orelse
1190            Option.isSome (Int.fromString (String.extract (r, 1, NONE)))
1191            then Term.prim_mk_const {Thy = thy, Name = "dfn'" ^ l}
1192         else raise ERR "mk_def" ""
1193      end
1194   fun buildAst thy ty =
1195      let
1196         val cs = TypeBase.constructors_of ty
1197         val (t0, n) = List.partition (Lib.equal ty o Term.type_of) cs
1198         val (t1, n) = List.partition (Lib.can (mk_def thy)) n
1199         val t1 =
1200            List.map (fn t => Term.mk_comb (t, Term.mk_var ("x", dom t))) t1
1201         val n =
1202            List.map (fn t =>
1203                        let
1204                           val l = buildAst thy (dom t)
1205                        in
1206                           List.map (fn x => Term.mk_comb (t, x)
1207                           handle HOL_ERR {origin_function = "mk_comb", ...} =>
1208                             (Parse.print_term t; print "\n";
1209                              Parse.print_term x; raise ERR "buildAst" "")) l
1210                        end) n
1211      in
1212         t0 @ t1 @ List.concat n
1213      end
1214   fun is_call x tm =
1215      case Lib.total Term.rand tm of
1216        SOME y => x ~~ y
1217      | NONE => false
1218   fun leaf tm =
1219      case Lib.total Term.rand tm of
1220        SOME y => leaf y
1221      | NONE => tm
1222   fun run_thm0 pv thy ast =
1223      let
1224         val tac = SIMP_TAC (srw_ss()) [DB.fetch thy "Run_def"]
1225         val f = mk_def thy (leaf ast)
1226      in
1227         pv (if Term.type_of f = oneSyntax.one_ty orelse
1228                rng f = oneSyntax.one_ty
1229                then `!s. Run ^ast s = s`
1230             else `!s. Run ^ast s = ^f s`) : thm
1231      end
1232   fun run_thm pv thy ast =
1233      let
1234         val tac = SIMP_TAC (srw_ss()) [DB.fetch thy "Run_def"]
1235         val x = hd (Term.free_vars ast)
1236         val tm = Term.rator (HolKernel.find_term (is_call x) ast)
1237         val f = boolSyntax.mk_icomb (mk_def thy tm, x)
1238      in
1239         pv (if Term.type_of f = oneSyntax.one_ty
1240                then `!s. Run ^ast s = s`
1241             else `!s. Run ^ast s = ^f s`) : thm
1242      end
1243   fun run_rwts thy =
1244      let
1245         val ty = Type.mk_thy_type {Thy = thy, Args = [], Tyop = "instruction"}
1246         val (arg0, args) =
1247            List.partition (List.null o Term.free_vars) (buildAst thy ty)
1248         val tac = SIMP_TAC (srw_ss()) [DB.fetch thy "Run_def"]
1249         fun pv q = Q.prove (q, tac)
1250      in
1251         List.map (run_thm0 pv thy) arg0 @ List.map (run_thm pv thy) args
1252      end
1253   fun run_tm thy = Term.prim_mk_const {Thy = thy, Name = "Run"}
1254in
1255   fun mk_run (thy, st) = fn ast => Term.list_mk_comb (run_tm thy, [ast, st])
1256   fun Run_CONV (thy, st) =
1257      Thm.GEN st o PURE_REWRITE_CONV (run_rwts thy) o mk_run (thy, st)
1258end
1259
1260(* ---------------------------- *)
1261
1262local
1263   val rwts = [pairTheory.UNCURRY, combinTheory.o_THM, combinTheory.K_THM]
1264   val no_hyp = List.partition (List.null o Thm.hyp)
1265   val add_word_eq =
1266      computeLib.add_conv (``$= :'a word -> 'a word -> bool``, 2,
1267                           bitstringLib.word_eq_CONV)
1268   fun context_subst tm =
1269      let
1270         val f = Parse.parse_in_context (Term.free_vars tm)
1271      in
1272         List.map (List.map (fn {redex, residue} => f redex |-> residue))
1273      end
1274   val step_conv = ref Conv.ALL_CONV
1275in
1276   fun resetStepConv () = step_conv := Conv.ALL_CONV
1277   fun setStepConv c = step_conv := c
1278   fun STEP (datatype_thms, st) =
1279      let
1280         val DATATYPE_CONV = REWRITE_CONV (datatype_thms [])
1281         fun fix_datatype tm = rhsc (Conv.QCONV DATATYPE_CONV tm)
1282         val SAFE_ASSUME = Thm.ASSUME o fix_datatype
1283      in
1284         fn l => fn ctms => fn s => fn tm =>
1285            let
1286               val (nh, h) = no_hyp l
1287               val c = INST_REWRITE_CONV h
1288               val cmp = reduceLib.num_compset ()
1289               val () = ( computeLib.add_thms (rwts @ nh) cmp
1290                        ; add_word_eq cmp )
1291               fun cnv rwt =
1292                  Conv.REPEATC
1293                    (Conv.TRY_CONV (CHANGE_CBV_CONV cmp)
1294                     THENC REWRITE_CONV (datatype_thms (rwt @ h))
1295                     THENC (!step_conv)
1296                     THENC c)
1297               val stm = Term.mk_comb (tm, st) handle HOL_ERR _ => tm
1298               val sbst = context_subst stm s
1299               fun cnvs rwt =
1300                  if List.null sbst
1301                     then [cnv rwt stm]
1302                  else List.map (fn sub => cnv rwt (match_subst sub stm)) sbst
1303               val ctxts = List.map (List.map SAFE_ASSUME) ctms
1304            in
1305               if List.null ctxts
1306                  then cnvs []
1307               else List.concat (List.map cnvs ctxts)
1308            end
1309      end
1310end
1311
1312end
1313