1structure resolve_then :> resolve_then =
2struct
3
4open HolKernel boolSyntax Drule thmpos_dtype
5
6fun match_subterm pat = find_term (can (match_term pat))
7
8fun UDISCH' avoidnames th =
9    let
10      val (l,r) = dest_imp_only (concl th)
11      fun buildcth nms0 t =
12          case Lib.total dest_conj t of
13              SOME (l,r) =>
14              let
15                val (lth, lcs, nms1) = buildcth nms0 l
16                val (rth, rcs, nms2) = buildcth nms1 r
17              in
18                (CONJ lth rth, lcs @ rcs, nms2)
19              end
20            | NONE => (case Lib.total dest_exists t of
21                           NONE => (ASSUME t, [t], nms0)
22                         | SOME (bv, bod) =>
23                           let
24                             val (bvnm,bvty) = dest_var bv
25                             val (c, cnm,bod') =
26                                 let val cnm =
27                                         Lexis.gen_variant Lexis.tmvar_vary
28                                                           nms0
29                                                           bvnm
30                                     val c = mk_var(cnm, bvty)
31                                 in
32                                   (c, cnm, subst[bv |-> c] bod)
33                                 end
34                             val (bodth, bodcs, nms1) =
35                                 buildcth (cnm::nms0) bod'
36                           in
37                             (EXISTS(t, c) bodth, bodcs, nms1)
38                           end)
39      val (th', conjuncts, nms') = buildcth avoidnames l
40    in
41      (PROVE_HYP th' (UNDISCH th), conjuncts, nms')
42    end handle HOL_ERR _ => let
43      val (bv,_) = dest_forall (concl th)
44      val (bvnm, bvty) = dest_var bv
45      val newnm = Lexis.gen_variant Lexis.tmvar_vary avoidnames bvnm
46      val newv = mk_var(newnm, bvty)
47    in
48      (SPEC newv th, [], newnm::avoidnames)
49    end
50
51fun UDALL nms0 th0 =
52    case Lib.total (UDISCH' nms0) th0 of
53        NONE => (th0, [], nms0)
54      | SOME (th', cs1, nms') =>
55        let
56          val (th, cs2, nms) = UDALL nms' th'
57        in
58          (th, cs1 @ cs2, nms)
59        end
60
61(* moves a bunch of hypotheses from a theorem into an implication, conjoining
62   them all rather than creating iterated implications *)
63fun DISCHl tms th =
64    if null tms then th
65    else
66      let
67        val cjt = list_mk_conj tms
68      in
69        th |> rev_itlist PROVE_HYP (CONJUNCTS $ ASSUME $ list_mk_conj tms)
70           |> DISCH cjt
71      end
72
73(* turns G |- p   into G, ~p |- F, where p not negated; and
74         G |- ~p  into G,  p |- F
75   also returns the new hypothesis
76*)
77fun liftconcl th =
78    let
79      val c = concl th
80    in
81      let
82        val c0 = dest_neg c
83      in
84        (UNDISCH th, c0)
85      end handle HOL_ERR _ =>
86                 let val h = mk_neg c
87                 in
88                   (EQ_MP (EQF_INTRO (ASSUME h)) th, h)
89                 end
90    end
91
92(* val th2 = prim_recTheory.LESS_REFL
93   val th1 = arithmeticTheory.LESS_TRANS
94*)
95fun resolve_then mpos ttac th1 th2 (g as (asl,w)) =
96    (* conclusion of th1 unifies with some part of th2 *)
97    let
98      val th1 = GEN_ALL (GEN_TYVARIFY th1)
99      val th2 = GEN_ALL (GEN_TYVARIFY th2)
100      val fixed_tms1 = hyp_frees th1
101      val fixed_tys1 = hyp_tyvars th1
102      val fixed_tms2 = hyp_frees th2
103      val fixed_tys2 = hyp_tyvars th2
104      val fixed_tyl = HOLset.listItems (HOLset.union(fixed_tys1,fixed_tys2))
105      val fixed_tms = HOLset.union(fixed_tms1,fixed_tms2)
106      val fixed_tml = HOLset.listItems fixed_tms
107      val hyps = HOLset.union(hypset th1, hypset th2)
108      val badnames = HOLset.foldl(fn (v,A) =>HOLset.add(A,#1 (dest_var v)))
109                                 (HOLset.empty String.compare)
110                                 fixed_tms
111      val (th1_ud, cs1, nms1) =
112          UDALL (HOLset.listItems badnames) th1
113      val (th2_ud, cs2, _) = UDALL nms1 th2
114      val (th2_ud, con) =
115          case mpos of
116              Concl => liftconcl th2_ud
117            | _ => (th2_ud, T)
118      fun INSTT (tyi,tmi) th = th |> INST_TYPE tyi |> INST tmi
119      fun instt (tyi,tmi) t = t |> Term.inst tyi |> Term.subst tmi
120      open optmonad
121      fun postprocess sigma th =
122          let
123            val thhyps = hypset th
124            val dhyps0 = map (instt sigma) (cs1 @ cs2) |> op_mk_set aconv
125            val dhyps =
126                List.filter (fn t => not (HOLset.member(hyps,t)) andalso
127                                     HOLset.member(thhyps, t))
128                            dhyps0
129          in
130            DISCHl dhyps th |> GEN_ALL
131          end
132      fun try t k =
133          case FullUnify.Env.fromEmpty
134                 (FullUnify.unify fixed_tyl fixed_tml(t, concl th1_ud) >>
135                                  FullUnify.collapse)
136           of
137              NONE => k()
138            | SOME sigma =>
139              let val kth =
140                      PROVE_HYP (INSTT sigma th1_ud) (INSTT sigma th2_ud) |>
141                      postprocess sigma
142              in
143                ttac kth g handle HOL_ERR _ => k()
144              end
145      val max = length cs2
146      val fail = mk_HOL_ERR "resolve_then" "resolve_then" "No unifier"
147    in
148      case mpos of
149          Any =>
150          let
151            fun doit n =
152                if n > max then raise fail
153                else try (el n cs2) (fn _ => doit (n + 1))
154          in
155            doit 1
156          end
157        | Pos f => try (f cs2) (fn _ => raise fail)
158        | Pat q =>
159          let
160            open TermParse
161            val pats =
162                prim_ctxt_termS Parse.Absyn (Parse.term_grammar())
163                                (HOLset.listItems (FVL (w::asl) empty_tmset))
164                                q
165            fun doit ps n =
166                if n > max then raise fail
167                else
168                  case seq.cases ps of
169                      NONE => doit pats (n + 1)
170                    | SOME (pat, rest) =>
171                      if can (match_subterm pat) (el n cs2) then
172                        try (el n cs2) (fn _ => doit rest n)
173                      else doit rest n
174          in
175            doit pats 1
176          end
177        | Concl => try con (fn _ => raise fail)
178    end
179end (* struct *)
180