1(* ========================================================================= *) 2(* ADDITIONS TO HOL TERM AND TYPE MATCHING. *) 3(* Created by Joe Hurd, June 2002. *) 4(* ========================================================================= *) 5 6(* 7*) 8structure matchTools :> matchTools = 9struct 10 11open HolKernel Parse boolLib; 12 13type tySubst = (hol_type, hol_type) subst; 14type Subst = (term, term) subst * tySubst; 15 16(* ------------------------------------------------------------------------- *) 17(* Chatting. *) 18(* ------------------------------------------------------------------------- *) 19 20local 21 open mlibUseful; 22 val module = "matchTools"; 23in 24 val () = add_trace {module = module, alignment = I} 25 fun chatting l = tracing {module = module, level = l}; 26 fun chat s = (trace s; true) 27 val ERR = mk_HOL_ERR module; 28 fun BUG f m = Bug (f ^ ": " ^ m); 29end; 30 31(* ------------------------------------------------------------------------- *) 32(* Helper functions. *) 33(* ------------------------------------------------------------------------- *) 34 35fun assert b e = if b then () else raise e; 36 37fun zipwith f = 38 let 39 fun z l [] [] = l 40 | z l (x :: xs) (y :: ys) = z (f x y :: l) xs ys 41 | z _ _ _ = raise ERR "zipwith" "lists different lengths"; 42 in 43 fn xs => fn ys => rev (z [] xs ys) 44 end; 45 46fun chain [] = [] 47 | chain [_] = [] 48 | chain (x :: (xs as y :: _)) = (x, y) :: chain xs; 49 50fun check_redexes p = List.all (fn {redex, residue} => p redex); 51 52fun residue_map f = map (fn {redex, residue} => redex |-> f residue); 53 54(* ------------------------------------------------------------------------- *) 55(* Basic operations on substitutions. *) 56(* Public health warning: don't use cyclic substitutions! *) 57(* ------------------------------------------------------------------------- *) 58 59fun inst_ty (_, tyS) tm = inst tyS tm; 60 61fun pinst (tmS, tyS) tm = subst tmS (inst tyS tm); 62 63local 64 fun fake_asm_op r th = 65 let val h = rev (hyp th) 66 in (funpow (length h) UNDISCH o r o C (foldl (uncurry DISCH)) h) th 67 end 68in 69 val INST_TY = fake_asm_op o INST_TYPE; 70 val PINST = fake_asm_op o INST_TY_TERM; 71end; 72 73fun type_refine_subst [] tyS' : (hol_type, hol_type) subst = tyS' 74 | type_refine_subst tyS [] = tyS 75 | type_refine_subst tyS tyS' = 76 tyS' @ (map (fn {redex, residue} => redex |-> type_subst tyS' residue) tyS); 77 78fun refine_subst ([], []) sub' = sub' 79 | refine_subst sub ([], []) = sub 80 | refine_subst (tmS, tyS) (tmS', tyS') = 81 let fun f {redex, residue} = inst tyS' redex |-> pinst (tmS', tyS') residue 82 in (tmS' @ map f tmS, type_refine_subst tyS tyS') 83 end; 84 85(* ------------------------------------------------------------------------- *) 86(* Raw matching. *) 87(* ------------------------------------------------------------------------- *) 88 89type raw_subst = 90 ((term,term)subst * term set) * ((hol_type,hol_type)subst * hol_type list); 91 92val empty_raw_subst : raw_subst = (([], empty_tmset), ([], [])); 93 94fun raw_match_term ((tmS, tmIds), (tyS, tyIds)) tm1 tm2 = 95 raw_match tyIds tmIds tm1 tm2 (tmS, tyS); 96 97local 98 fun mk_tm ty = 99 mk_thy_const {Thy = "combin", Name = "K", Ty = bool --> ty --> bool}; 100in 101 fun raw_match_ty sub ty1 ty2 = raw_match_term sub (mk_tm ty1) (mk_tm ty2); 102end; 103 104val finalize_subst = norm_subst; 105 106(* ------------------------------------------------------------------------- *) 107(* Operations on types containing "locally constant" variables. *) 108(* ------------------------------------------------------------------------- *) 109 110type tyVars = hol_type -> bool; 111 112fun vmatch_type tyvarP ty ty' = 113 let 114 val tyS = match_type ty ty' 115 val () = assert (check_redexes tyvarP tyS) (ERR "vmatch_type" "") 116 in 117 tyS 118 end; 119 120fun vunifyl_type tyvarP = 121 let 122 fun unify sub [] = sub 123 | unify sub ((ty, ty') :: work) = 124 unify' sub work (type_subst sub ty) (type_subst sub ty') 125 and unify' sub work ty ty' = 126 if ty = ty' then unify sub work 127 else if tyvarP ty then 128 if type_var_in ty ty' then raise ERR "unify_type" "occurs" 129 else unify (type_refine_subst sub [ty |-> ty']) work 130 else if tyvarP ty' then unify' sub work ty' ty 131 else if is_vartype ty orelse is_vartype ty' then 132 raise ERR "unify_type" "locally constant variable" 133 else 134 let 135 val (f , args ) = dest_type ty 136 val (f', args') = dest_type ty' 137 in 138 if f = f' andalso length args = length args' then 139 unify sub (zip args args' @ work) 140 else raise ERR "unify_type" "different type constructors" 141 end 142 in 143 unify 144 end; 145 146fun vunify_type tyvarP = vunifyl_type tyvarP [] o chain; 147 148(* ------------------------------------------------------------------------- *) 149(* Operations on terms containing "locally constant" variables. *) 150(* ------------------------------------------------------------------------- *) 151 152type tmVars = term -> bool; 153type Vars = tmVars * tyVars; 154 155fun vfree_names tmvarP tm = 156 let 157 val names = map (fst o dest_var) (filter tmvarP (free_vars tm)) 158 open Binaryset 159 in 160 listItems (addList (empty String.compare, names)) 161 end; 162 163fun vfree_vars (tmvarP, tyvarP) tm = 164 (filter tmvarP (free_vars tm), filter tyvarP (type_vars_in_term tm)); 165 166fun vmatch (tmvarP, tyvarP) tm tm' = 167 let 168 val sub = match_term tm tm' 169 val (tmS, tyS) = sub 170 val () = assert (check_redexes tyvarP tyS) (ERR "vmatch_term" "lconst ty") 171 val () = assert (check_redexes tmvarP tmS) (ERR "vmatch_term" "lconst tm") 172 in 173 sub 174 end; 175 176fun vunifyl (tmvarP, tyvarP) = 177 let 178 val varname = fst o dest_var 179 fun occurs v = List.exists (equal (varname v) o varname) o free_vars 180 val pure_unify_type = vunify_type tyvarP 181 fun unify_type sub tyL = refine_subst sub ([], pure_unify_type tyL) 182 fun unify_var_type sub v tm = 183 let val s = pure_unify_type [type_of v, type_of tm] 184 in refine_subst sub ([inst s v |-> inst s tm], s) 185 end 186 fun unify sub [] = sub 187 | unify sub ((tm, tm') :: work) = 188 unify' sub work (pinst sub tm) (pinst sub tm') 189 and unify' sub work tm tm' = 190 if aconv tm tm' then unify sub work 191 else if tmvarP tm then 192 if tmvarP tm' andalso varname tm = varname tm' then 193 unify (unify_type sub [type_of tm, type_of tm']) work 194 else if occurs tm tm' then raise ERR "unify_term" "occurs" 195 else unify (unify_var_type sub tm tm') work 196 else if tmvarP tm' then unify' sub work tm' tm 197 else 198 case (dest_term tm, dest_term tm') of (VAR (n, ty), VAR (n', ty')) 199 => if n <> n' then raise ERR "unify_term" "different variables" 200 else unify (unify_type sub [ty, ty']) work 201 | (CONST {Thy, Name, Ty}, CONST {Thy = Thy', Name = Name', Ty = Ty'}) 202 => if Thy <> Thy' orelse Name <> Name' then 203 raise ERR "unify_term" "different constants" 204 else unify (unify_type sub [Ty, Ty']) work 205 | (COMB (a, b), COMB (a', b')) => unify' sub ((b, b') :: work) a a' 206 | (LAMB _, LAMB _) => raise ERR "unify_term" "can't deal with lambda" 207 | _ => raise ERR "unify_term" "different structure" 208 in 209 unify 210 end; 211 212fun vunify varP = vunifyl varP ([], []) o chain; 213 214local 215 fun new_name () = "XXfrozenXX" ^ int_to_string (mlibUseful.new_int ()) 216 fun correspond tmvarP = map (fn n => (n, new_name ())) o vfree_names tmvarP; 217 fun revc c = map (fn (a, b) => (b, a)) c; 218 fun csub c tm = 219 let 220 fun g v (_, y) = v |-> mk_var (y, type_of v) 221 fun f v = Option.map (g v) (assoc1 (fst (dest_var v)) c) 222 val tmS = List.mapPartial f (free_vars tm) 223 in 224 subst tmS tm 225 end; 226in 227 fun vmatch_uty (varP as (tmvarP, _)) tm tm' = 228 let 229 val c = correspond tmvarP tm' 230 val gtm' = csub c tm' 231 val (tmS, tyS) = vunify varP [tm, gtm'] 232 val sub = (residue_map (csub (revc c)) tmS, tyS) 233 in 234 sub 235 end; 236end; 237 238end 239