1structure FullUnify :> FullUnify = 2struct 3 4open HolKernel 5 6type tyenv = (string, hol_type) Binarymap.dict 7val empty_tyenv : tyenv = Binarymap.mkDict String.compare 8 9exception UNIF_ERROR of string 10 11datatype tyrep = UVar of string 12 | Tyvar of string 13 | Tyop of {Thy:string, Name:string, Args : hol_type list} 14 15datatype tmrep = tmUVar of string * tyrep 16 | tmVar of string * tyrep 17 | uConst of {Thy : string, Name : string, Ty : tyrep} 18 | uCOMB of term * term 19 | uLAMB of term * term 20 21fun dest_type P ty = 22 let 23 val nm = dest_vartype ty 24 in 25 if P nm then UVar nm else Tyvar nm 26 end handle HOL_ERR _ => 27 let val {Thy,Tyop=nm,Args} = dest_thy_type ty 28 in 29 Tyop{Thy = Thy, Name = nm, Args = Args} 30 end 31 32fun mkty (UVar s) = mk_vartype s 33 | mkty (Tyvar s) = mk_vartype s 34 | mkty (Tyop{Thy,Name,Args}) = 35 mk_thy_type {Thy = Thy, Tyop = Name, Args = Args} 36 37fun dest_term {tmP, tyP} tm = 38 case HolKernel.dest_term tm of 39 VAR (nm, ty) => if tmP nm then tmUVar(nm, dest_type tyP ty) 40 else tmVar(nm, dest_type tyP ty) 41 | CONST{Thy,Name,Ty} => uConst {Thy = Thy, Name = Name, 42 Ty = dest_type tyP Ty} 43 | LAMB p => uLAMB p 44 | COMB p => uCOMB p 45 46fun mktm (tmUVar p) = mk_var ((I ## mkty) p) 47 | mktm (tmVar p) = mk_var ((I ## mkty) p) 48 | mktm (uCOMB p) = mk_comb p 49 | mktm (uLAMB p) = mk_abs p 50 | mktm (uConst {Name,Thy,Ty}) = mk_thy_const {Name=Name,Thy=Thy,Ty=mkty Ty} 51 52fun tyvwalk P e s = 53 case Binarymap.peek(e, s) of 54 SOME ty => (case dest_type P ty of 55 UVar vnm => tyvwalk P e vnm 56 | tyr => tyr) 57 | NONE => UVar s 58 59fun tywalk P E tyr = 60 case tyr of 61 UVar vnm => tyvwalk P E vnm 62 | _ => tyr 63 64fun tyoc P E tyrep v = 65 case tywalk P E tyrep of 66 UVar v' => v' = v 67 | Tyvar _ => false 68 | Tyop{Args,...} => List.exists (fn ty => tyoc P E (dest_type P ty) v) Args 69 70fun tyunify0 P E (tyr1, tyr2) = 71 case (tywalk P E tyr1, tywalk P E tyr2) of 72 (UVar v1, UVar v2) => SOME (if v1 = v2 then E 73 else Binarymap.insert(E,v1,mk_vartype v2)) 74 | (UVar v1, t2) => if tyoc P E t2 v1 then NONE 75 else SOME (Binarymap.insert(E,v1,mkty t2)) 76 | (t1, UVar v2) => if tyoc P E t1 v2 then NONE 77 else SOME (Binarymap.insert(E,v2,mkty t1)) 78 | (Tyvar s1, Tyvar s2) => if s1 = s2 then SOME E else NONE 79 | (Tyop{Thy=thy1,Name=name1,Args = a1}, 80 Tyop{Thy=thy2,Name=name2,Args = a2}) => 81 if thy1 = thy2 andalso name1 = name2 then 82 tyunifyl P E (a1,a2) 83 else NONE 84 | _ => NONE 85and tyunifyl P E ([],[]) = SOME E 86 | tyunifyl P E (ty1::ty1s, ty2::ty2s) = 87 (case tyunify0 P E (dest_type P ty1, dest_type P ty2) of 88 NONE => NONE 89 | SOME E' => tyunifyl P E' (ty1s, ty2s)) 90 | tyunifyl P E _ = NONE 91 92fun qtyrep ty = dest_type (K true) ty 93fun unified tyE (tyr1, tyr2) = 94 case (tywalk (K true) tyE tyr1, tywalk (K true) tyE tyr2) of 95 (Tyop{Thy=thy1,Name=n1,Args=a1}, Tyop{Thy=thy2,Name=n2,Args=a2}) => 96 thy1 = thy2 andalso n1 = n2 andalso 97 ListPair.all (fn (ty1,ty2) => unified tyE (qtyrep ty1, qtyrep ty2)) 98 (a1,a2) 99 | (tyr1', tyr2') => tyr1' = tyr2' 100 101fun tmvwalk (P as {tyP, tmP}) (E as {tyE, tmE}) (p as (s,tyr)) = 102 let 103 val tyr' = tywalk tyP tyE tyr 104 in 105 case Binarymap.peek (tmE, s) of 106 SOME (ty,tm) => 107 if unified tyE (dest_type tyP ty, tyr') then 108 case dest_term P tm of 109 tmUVar p => tmvwalk P E p 110 | tmr => tmr 111 else raise UNIF_ERROR ("Variable "^s^" has two distinct types") 112 | NONE => tmUVar p 113 end 114 115fun tmwalk (P as {tyP, tmP}) (E as {tyE, tmE}) tmr = 116 case tmr of 117 tmUVar p => tmvwalk P E p 118 | _ => tmr 119 120fun tmoc P E v tmrep = 121 case tmwalk P E tmrep of 122 tmUVar (v',_) => v = v' 123 | tmVar _ => false 124 | uConst _ => false 125 | uCOMB(t1,t2) => tmoc P E v (dest_term P t1) orelse 126 tmoc P E v (dest_term P t2) 127 | uLAMB(bv,body) => tmoc P E v (dest_term P body) 128 129fun utype_of {tyP,tmP} E tmr = 130 dest_type tyP (type_of (mktm tmr)) 131 132val insert = Binarymap.insert 133 134fun tmunify P (E0 as {tyE,tmE}) (tmr1, tmr2) = 135 case tyunify0 (#tyP P) (#tyE E0) (utype_of P E0 tmr1, utype_of P E0 tmr2) of 136 NONE => NONE 137 | SOME tyE => 138 let 139 val E as {tyE,tmE} = {tyE = tyE, tmE = #tmE E0} 140 in 141 case (tmwalk P E tmr1, tmwalk P E tmr2) of 142 (tmUVar (s1, tyr1), tmUVar (s2, tyr2)) => 143 if s1 = s2 then SOME E 144 else 145 SOME{tyE = tyE, 146 tmE = insert(tmE, s1, (mkty tyr1, mk_var(s2, mkty tyr1)))} 147 | (tmUVar (s1, tyr1), tmr2) => 148 if tmoc P E s1 tmr2 then NONE 149 else SOME{tyE=tyE, tmE = insert(tmE, s1, (mkty tyr1, mktm tmr2))} 150 | (tmr1, tmUVar (s2, tyr2)) => 151 if tmoc P E s2 tmr1 then NONE 152 else SOME{tyE=tyE, tmE = insert(tmE, s2, (mkty tyr2, mktm tmr1))} 153 | (tmVar (s1,_), tmVar (s2,_)) => if s1 = s2 then SOME E else NONE 154 | (uConst{Name=n1,Thy=thy1,...}, uConst{Name=n2,Thy=thy2,...}) => 155 if n1 = n2 andalso thy1 = thy2 then SOME E else NONE 156 | (uCOMB(t1,t2), uCOMB(u1,u2)) => 157 (case tmunify P E (dest_term P t1, dest_term P u1) of 158 NONE => NONE 159 | SOME E' => tmunify P E' (dest_term P t2, dest_term P u2)) 160 | (uLAMB(p1 as (bv1,bod1)), uLAMB(p2 as (bv2,bod2))) => 161 let 162 fun fo() = 163 let val gv = genvar (type_of bv1) 164 val bod1' = subst [bv1 |-> gv] bod1 165 val bod2' = subst [bv2 |-> gv] bod2 166 in 167 tmunify P E (dest_term P bod1', dest_term P bod2') 168 end 169 fun test ((bv1,bod1), (bv2,bod2)) k = 170 case dest_term P bod1 of 171 uCOMB(f,x) => 172 if aconv x bv1 then 173 case tmwalk P E (dest_term P f) of 174 tmUVar(s1,tyr1) => 175 if tmoc P E s1 (dest_term P bod2) then NONE 176 else 177 SOME{tyE = tyE, 178 tmE = insert(tmE, s1, 179 (type_of bv1 --> type_of bod1, 180 mk_abs(bv2,bod2)))} 181 | _ => k() 182 else k() 183 | _ => k() 184 in 185 test (p1, p2) (fn () => test (p2, p1) fo) 186 end 187 | _ => NONE 188 end 189 190fun tyunify P E0 (ty1, ty2) = 191 tyunify0 P E0 (dest_type P ty1, dest_type P ty2) 192 193fun tywalkstar P E tyr = 194 case tywalk P E tyr of 195 Tyop{Thy,Name,Args} => 196 Tyop{Thy = Thy, Name = Name, 197 Args = map (mkty o tywalkstar P E o dest_type P) Args} 198 | x => x 199 200fun tmwalkstar P E tmr = 201 let 202 val tyws = tywalkstar (#tyP P) (#tyE E) 203 val tmws = mktm o tmwalkstar P E o dest_term P 204 in 205 case tmwalk P E tmr of 206 tmUVar(s, tyr) => tmUVar(s, tyws tyr) 207 | tmVar(s, tyr) => tmVar(s, tyws tyr) 208 | uCOMB(t1,t2) => uCOMB(tmws t1, tmws t2) 209 | uLAMB (bv, bod) => uLAMB (tmws bv, tmws bod) 210 | uConst {Name,Thy,Ty} => uConst{Name = Name, Thy = Thy, Ty = tyws Ty} 211 end 212 213fun tycollapse P E = 214 Binarymap.map (fn (vnm,_) => mkty (tywalkstar P E (UVar vnm))) E 215 216fun tmcollapse (P as {tyP,...}) (E as {tyE,tmE}) = 217 let 218 val tyE' = tycollapse tyP tyE 219 in 220 {tyE = tyE', 221 tmE = Binarymap.map 222 (fn (vnm, (ty,_)) => 223 mktm (tmwalkstar P E (tmUVar(vnm,dest_type tyP ty)))) 224 tmE} 225 end 226 227(* test 228 229(* types *) 230 231val uvar_v = String.isPrefix "'v" 232fun tf ty1 ty2 = 233 Option.map (Binarymap.listItems o tycollapse uvar_v) 234 (tyunify uvar_v empty_tyenv (ty1, ty2)) 235 236val r1 = tf ``:'a list`` ``:'v`` 237val r2 = tf ``:'v1 list`` ``:'v2`` 238val r3 = tf ``:'v1 # ('v2 list)`` ``:'v3 # 'v3`` 239val r4 = tf ``:'v1 # ('v2 list)`` ``:'v3 # 'v1`` 240val r5 = tf ``:'v1 list`` ``:'v3 # 'v1`` 241val r6 = tf ``:'v1 list`` ``:'v1`` 242val r7 = tf ``:'v4 # ('v5 list)`` ``:'v3 # 'v4`` 243val r8 = tf ``:'v4 # ('a -> 'v4)`` ``:'v3 # 'v3`` 244val r9 = tf ``:'v4 list`` ``:('a -> 'v5) list`` 245val r10 = tf ``:('a -> 'v5) list`` ``:'v4`` 246val r11 = tf ``:'a`` ``:'v1`` 247val r12 = tf ``:'a`` ``:'v1 list`` 248val r13 = tf ``:'a -> bool`` ``:'a -> 'v`` 249 250val tmvar = String.isPrefix "uv" 251val P = {tmP = tmvar, tyP = uvar_v} 252type 'a dict = (string,'a) Binarymap.dict 253val E0 = {tyE = Binarymap.mkDict String.compare : hol_type dict, 254 tmE = Binarymap.mkDict String.compare : (hol_type * term) dict} 255 256fun tmf f (tm1,tm2) = 257 Option.map 258 ((fn {tmE, tyE} => (Binarymap.listItems tyE, Binarymap.listItems tmE)) o f) 259 (tmunify P E0 (dest_term P tm1, dest_term P tm2)) 260 261val t1 = tmf (tmcollapse P) (``f:'a->'v1``, ``uv:'a->bool``) 262val t2 = tmf I (``(f:'v1->'v2) uv2``, ``(uv:'a -> bool) x``) 263val t3 = tmf (``(f:'v1->bool) uv2``, ``(uv:'a -> 'v2) x``) 264val t4 = tmf (``(uvf:'v1->'v2) uv2``, ``CONS (h:'a) t``) 265val t5 = tmf (``?x. x /\ T``, ``?y:'v1. uv y``) 266val t6 = tmf (``?y:'v1. uv y``, ``?x:bool. x /\ T``) 267val t7 = tmf (``uv1 (x:'v1) (uv:'v2) : 'v3``, ``x /\ x``) 268val t8 = tmf (``uv1 (uv:'v1) (uv:'v1) : 'v3``, ``x /\ y``) (* NONE *) 269 270*) 271 272 273end 274