1(* ========================================================================= *)
2(* ADDITIONS TO HOL TERM AND TYPE MATCHING.                                  *)
3(* Created by Joe Hurd, June 2002.                                           *)
4(* ========================================================================= *)
5
6(*
7*)
8structure matchTools :> matchTools =
9struct
10
11open HolKernel Parse boolLib;
12
13type tySubst = (hol_type, hol_type) subst;
14type Subst   = (term, term) subst * tySubst;
15
16(* ------------------------------------------------------------------------- *)
17(* Chatting.                                                                 *)
18(* ------------------------------------------------------------------------- *)
19
20local
21  open mlibUseful;
22  val module = "matchTools";
23in
24  val () = add_trace {module = module, alignment = I}
25  fun chatting l = tracing {module = module, level = l};
26  fun chat s = (trace s; true)
27  val ERR = mk_HOL_ERR module;
28  fun BUG f m = Bug (f ^ ": " ^ m);
29end;
30
31(* ------------------------------------------------------------------------- *)
32(* Helper functions.                                                         *)
33(* ------------------------------------------------------------------------- *)
34
35fun assert b e = if b then () else raise e;
36
37fun zipwith f =
38  let
39    fun z l [] [] = l
40      | z l (x :: xs) (y :: ys) = z (f x y :: l) xs ys
41      | z _ _ _ = raise ERR "zipwith" "lists different lengths";
42  in
43    fn xs => fn ys => rev (z [] xs ys)
44  end;
45
46fun chain [] = []
47  | chain [_] = []
48  | chain (x :: (xs as y :: _)) = (x, y) :: chain xs;
49
50fun check_redexes p = List.all (fn {redex, residue} => p redex);
51
52fun residue_map f = map (fn {redex, residue} => redex |-> f residue);
53
54(* ------------------------------------------------------------------------- *)
55(* Basic operations on substitutions.                                        *)
56(* Public health warning: don't use cyclic substitutions!                    *)
57(* ------------------------------------------------------------------------- *)
58
59fun inst_ty (_, tyS) tm = inst tyS tm;
60
61fun pinst (tmS, tyS) tm = subst tmS (inst tyS tm);
62
63local
64  fun fake_asm_op r th =
65    let val h = rev (hyp th)
66    in (funpow (length h) UNDISCH o r o C (foldl (uncurry DISCH)) h) th
67    end
68in
69  val INST_TY = fake_asm_op o INST_TYPE;
70  val PINST   = fake_asm_op o INST_TY_TERM;
71end;
72
73fun type_refine_subst [] tyS' : (hol_type, hol_type) subst = tyS'
74  | type_refine_subst tyS [] = tyS
75  | type_refine_subst tyS tyS' =
76  tyS' @ (map (fn {redex, residue} => redex |-> type_subst tyS' residue) tyS);
77
78fun refine_subst ([], []) sub' = sub'
79  | refine_subst sub ([], []) = sub
80  | refine_subst (tmS, tyS) (tmS', tyS') =
81  let fun f {redex, residue} = inst tyS' redex |-> pinst (tmS', tyS') residue
82  in (tmS' @ map f tmS, type_refine_subst tyS tyS')
83  end;
84
85(* ------------------------------------------------------------------------- *)
86(* Raw matching.                                                             *)
87(* ------------------------------------------------------------------------- *)
88
89type raw_subst =
90  ((term,term)subst * term set) * ((hol_type,hol_type)subst * hol_type list);
91
92val empty_raw_subst : raw_subst = (([], empty_tmset), ([], []));
93
94fun raw_match_term ((tmS, tmIds), (tyS, tyIds)) tm1 tm2 =
95  raw_match tyIds tmIds tm1 tm2 (tmS, tyS);
96
97local
98  fun mk_tm ty =
99    mk_thy_const {Thy = "combin", Name = "K", Ty = bool --> ty --> bool};
100in
101  fun raw_match_ty sub ty1 ty2 = raw_match_term sub (mk_tm ty1) (mk_tm ty2);
102end;
103
104val finalize_subst = norm_subst;
105
106(* ------------------------------------------------------------------------- *)
107(* Operations on types containing "locally constant" variables.              *)
108(* ------------------------------------------------------------------------- *)
109
110type tyVars = hol_type -> bool;
111
112fun vmatch_type tyvarP ty ty' =
113  let
114    val tyS = match_type ty ty'
115    val () = assert (check_redexes tyvarP tyS) (ERR "vmatch_type" "")
116  in
117    tyS
118  end;
119
120fun vunifyl_type tyvarP =
121  let
122    fun unify sub [] = sub
123      | unify sub ((ty, ty') :: work) =
124      unify' sub work (type_subst sub ty) (type_subst sub ty')
125    and unify' sub work ty ty' =
126      if ty = ty' then unify sub work
127      else if tyvarP ty then
128        if type_var_in ty ty' then raise ERR "unify_type" "occurs"
129        else unify (type_refine_subst sub [ty |-> ty']) work
130      else if tyvarP ty' then unify' sub work ty' ty
131      else if is_vartype ty orelse is_vartype ty' then
132        raise ERR "unify_type" "locally constant variable"
133      else
134        let
135          val (f , args ) = dest_type ty
136          val (f', args') = dest_type ty'
137        in
138          if f = f' andalso length args = length args' then
139            unify sub (zip args args' @ work)
140          else raise ERR "unify_type" "different type constructors"
141        end
142  in
143    unify
144  end;
145
146fun vunify_type tyvarP = vunifyl_type tyvarP [] o chain;
147
148(* ------------------------------------------------------------------------- *)
149(* Operations on terms containing "locally constant" variables.              *)
150(* ------------------------------------------------------------------------- *)
151
152type tmVars = term -> bool;
153type Vars   = tmVars * tyVars;
154
155fun vfree_names tmvarP tm =
156  let
157    val names = map (fst o dest_var) (filter tmvarP (free_vars tm))
158    open Binaryset
159  in
160    listItems (addList (empty String.compare, names))
161  end;
162
163fun vfree_vars (tmvarP, tyvarP) tm =
164  (filter tmvarP (free_vars tm), filter tyvarP (type_vars_in_term tm));
165
166fun vmatch (tmvarP, tyvarP) tm tm' =
167  let
168    val sub = match_term tm tm'
169    val (tmS, tyS) = sub
170    val () = assert (check_redexes tyvarP tyS) (ERR "vmatch_term" "lconst ty")
171    val () = assert (check_redexes tmvarP tmS) (ERR "vmatch_term" "lconst tm")
172  in
173    sub
174  end;
175
176fun vunifyl (tmvarP, tyvarP) =
177  let
178    val varname = fst o dest_var
179    fun occurs v = List.exists (equal (varname v) o varname) o free_vars
180    val pure_unify_type = vunify_type tyvarP
181    fun unify_type sub tyL = refine_subst sub ([], pure_unify_type tyL)
182    fun unify_var_type sub v tm =
183      let val s = pure_unify_type [type_of v, type_of tm]
184      in refine_subst sub ([inst s v |-> inst s tm], s)
185      end
186    fun unify sub [] = sub
187      | unify sub ((tm, tm') :: work) =
188      unify' sub work (pinst sub tm) (pinst sub tm')
189    and unify' sub work tm tm' =
190      if aconv tm tm' then unify sub work
191      else if tmvarP tm then
192        if tmvarP tm' andalso varname tm = varname tm' then
193          unify (unify_type sub [type_of tm, type_of tm']) work
194        else if occurs tm tm' then raise ERR "unify_term" "occurs"
195        else unify (unify_var_type sub tm tm') work
196      else if tmvarP tm' then unify' sub work tm' tm
197      else
198        case (dest_term tm, dest_term tm') of (VAR (n, ty), VAR (n', ty'))
199          => if n <> n' then raise ERR "unify_term" "different variables"
200             else unify (unify_type sub [ty, ty']) work
201        | (CONST {Thy, Name, Ty}, CONST {Thy = Thy', Name = Name', Ty = Ty'})
202          => if Thy <> Thy' orelse Name <> Name' then
203               raise ERR "unify_term" "different constants"
204             else unify (unify_type sub [Ty, Ty']) work
205        | (COMB (a, b), COMB (a', b')) => unify' sub ((b, b') :: work) a a'
206        | (LAMB _, LAMB _) => raise ERR "unify_term" "can't deal with lambda"
207        | _ => raise ERR "unify_term" "different structure"
208  in
209    unify
210  end;
211
212fun vunify varP = vunifyl varP ([], []) o chain;
213
214local
215  fun new_name () = "XXfrozenXX" ^ int_to_string (mlibUseful.new_int ())
216  fun correspond tmvarP = map (fn n => (n, new_name ())) o vfree_names tmvarP;
217  fun revc c = map (fn (a, b) => (b, a)) c;
218  fun csub c tm =
219    let
220      fun g v (_, y) = v |-> mk_var (y, type_of v)
221      fun f v = Option.map (g v) (assoc1 (fst (dest_var v)) c)
222      val tmS = List.mapPartial f (free_vars tm)
223    in
224      subst tmS tm
225    end;
226in
227  fun vmatch_uty (varP as (tmvarP, _)) tm tm' =
228    let
229      val c = correspond tmvarP tm'
230      val gtm' = csub c tm'
231      val (tmS, tyS) = vunify varP [tm, gtm']
232      val sub = (residue_map (csub (revc c)) tmS, tyS)
233    in
234      sub
235    end;
236end;
237
238end
239