1structure match_goal :> match_goal =
2struct
3
4open HolKernel boolLib Streams
5
6fun Redblackmap_contains (m,v) =
7  let
8    exception found
9    val _ = Redblackmap.app (fn (_,v') => if v = v' then raise found else ()) m
10  in false end
11  handle found => true
12
13val ERR = Feedback.mk_HOL_ERR"match_goal";
14
15datatype name =
16    Assumption of string option
17  | Conclusion
18  | Anything
19
20type pattern = term quotation
21type matcher = name * pattern * bool
22
23               (* (name, assumption number) *)
24type named_thms = (string, int) Redblackmap.dict
25val (empty_named_thms:named_thms) = Redblackmap.mkDict String.compare;
26
27(* If you want to be able to treat certain type variables as constants (e.g.,
28   if they are in the goal) then you need to keep track of the avoid_tys/tyIds
29   thing that raw_match uses. But we are not doing that. *)
30type named_tms =
31  ((term,term) subst) * ((hol_type,hol_type) subst)
32
33val empty_named_tms : named_tms = ([],[])
34
35type data = named_thms * named_tms
36
37fun is_underscore v =
38  case total dest_var v of NONE
39    => raise(ERR"umatch""unexpected non-variable binding") (* should not happen *)
40  | SOME (s,_) => String.isPrefix "_" s
41
42val is_uvar = String.isSuffix "_" o #1 o dest_var
43
44fun umatch avoid_tms ((tmS,tyS):named_tms) pat ob : named_tms =
45  let
46    val ((tmS',tmIds'),(tyS',_)) = raw_match [] avoid_tms pat ob (tmS,tyS)
47                                   handle HOL_ERR _ => raise end_of_stream
48    val tmS'' = List.filter (not o is_underscore o #redex) tmS'
49    val _ = assert_exn (List.all (is_uvar o #redex)) tmS'' end_of_stream
50    val _ = assert_exn (curry HOLset.isSubset tmIds') avoid_tms end_of_stream
51  in
52    (tmS'', tyS')
53  end
54
55fun umatch_subterms avoid_tms (ntms:named_tms) pat ob : unit -> named_tms stream =
56  stream_append
57    (fn () => Stream(umatch avoid_tms ntms pat ob,empty_stream))
58    (fn () =>
59      (case dest_term ob of
60         COMB(t1,t2) =>
61           stream_append
62             (umatch_subterms avoid_tms ntms pat t1)
63             (umatch_subterms avoid_tms ntms pat t2)
64       | LAMB(v,b) =>
65           umatch_subterms (HOLset.add(avoid_tms,v)) ntms pat b
66       | _ => empty_stream)
67      ())
68
69fun preprocess_matcher fvs =
70  fn (nm,q,b):matcher => (nm, Parse.parse_in_context fvs q, b)
71
72type mg_tactic = (string -> thm) * (string -> term) -> tactic
73
74fun match_single fvs ((asl,w):goal)
75  ((nm,pat,whole):name * term * bool) ((nths,ntms):data) : unit -> data stream =
76  let
77    fun add_nth NONE _ = SOME nths
78      | add_nth (SOME l) i =
79        (case Redblackmap.peek(nths,l) of
80           NONE =>
81             if Redblackmap_contains(nths,i)
82             then NONE
83             else SOME (Redblackmap.insert(nths,l,i))
84         | SOME j => if i = j then SOME nths else NONE)
85
86    fun protect f NONE = K empty_stream
87      | protect f (SOME x) = f x
88
89    fun umatch_sing nths w =
90      (fn() => Stream ((nths, umatch fvs ntms pat w),empty_stream))
91    fun umatch_many nths w =
92      stream_map (fn ntms => (nths,ntms)) (umatch_subterms fvs ntms pat w)
93  in
94    case nm of
95      Conclusion =>
96        if whole then
97          umatch_sing nths w
98        else
99          umatch_many nths w
100    | Assumption l =>
101        if whole then
102          stream_append_list (Lib.mapi (protect umatch_sing o add_nth l) asl)
103        else
104          stream_append_list (Lib.mapi (protect umatch_many o add_nth l) asl)
105    | Anything =>
106        if whole then
107          stream_append_list (List.map (umatch_sing nths) (w::asl))
108        else
109          stream_append_list (List.map (umatch_many nths) (w::asl))
110(*
111          (el 2 (List.map (umatch_many nths) (w::asl))) ()
112
113          val avoid_tms = fvs
114          val ob = el 2 (w::asl)
115          umatch_subterms fvs ntms pat (el 2 (w::asl))
116          val
117          el 2 (w::asl)
118          pat
119*)
120  end
121
122val tr1 = Substring.string o Substring.trimr 1 o Substring.full
123
124fun match_tac (ms:matcher list,mtac:mg_tactic) (g as (asl,w):goal) : goal list * validation =
125  let
126    fun try_tactic ((thms,tms):data) : (unit -> (goal list * validation) stream) =
127      let
128        fun lookup_assum s =
129          let
130            val i = Redblackmap.find(thms,s)
131            val tm = List.nth(#1 g,i)
132          in ASSUME tm end
133        val s =
134          case tms of (tmS,tyS) =>
135            #1 (norm_subst ((tmS,empty_tmset),(tyS,[])))
136            |> List.map (fn {redex,residue} => (tr1(#1(dest_var redex)),residue))
137        val tac = mtac (lookup_assum,Lib.C assoc s)
138        val r = tac g
139      in
140        (fn()=>Stream (r,empty_stream))
141      end
142      handle HOL_ERR _ => empty_stream
143
144    fun search [] d = try_tactic d
145      | search (m::ms) d = stream_flat (stream_map (search ms) (m d))
146
147    val fvs = FVL (w::asl) empty_tmset
148
149    val matches = map (match_single fvs g o preprocess_matcher (HOLset.listItems fvs)) ms
150  in
151    (case search matches (empty_named_thms,empty_named_tms) () of
152       Stream (x,_) => x)
153    handle end_of_stream => raise ERR "match_tac" "no match"
154  end
155
156val first_match_tac = FIRST o List.map match_tac
157
158fun match1_tac (x,t) = match_tac ([x],t)
159
160fun kill_asm th = first_x_assum((K ALL_TAC) o assert (aconv (concl th) o concl))
161
162fun drule_thm th = mp_tac o Lib.C MATCH_MP th
163
164structure mg = struct
165  type pattern = pattern
166  type matcher = matcher
167
168  fun a nm p = (Assumption (SOME nm),p,true)
169
170  fun ua p = (Assumption NONE,p,true)
171  val au = ua
172
173  fun ab nm p = (Assumption (SOME nm),p,false)
174  val ba = ab
175
176  fun uab p = (Assumption NONE,p,false)
177  val uba = uab
178  val aub = uab
179  val abu = uab
180  val bau = uab
181  val bua = uab
182
183  fun c p = (Conclusion,p,true)
184
185  fun cb p = (Conclusion,p,false)
186  val bc = cb
187
188  fun ac p = (Anything,p,true)
189  val ca = ac
190
191  fun acb p = (Anything,p,false)
192  val abc = acb
193  val bca = acb
194  val cba = acb
195  val cab = acb
196  val bac = acb
197end
198
199end
200