1structure FullUnify :> FullUnify =
2struct
3
4open HolKernel
5
6type tyenv = (string, hol_type) Binarymap.dict
7val empty_tyenv : tyenv = Binarymap.mkDict String.compare
8
9exception UNIF_ERROR of string
10
11datatype tyrep = UVar of string
12               | Tyvar of string
13               | Tyop of {Thy:string, Name:string, Args : hol_type list}
14
15datatype tmrep = tmUVar of string * tyrep
16               | tmVar of string * tyrep
17               | uConst of {Thy : string, Name : string, Ty : tyrep}
18               | uCOMB of term * term
19               | uLAMB of term * term
20
21fun dest_type P ty =
22  let
23    val nm = dest_vartype ty
24  in
25    if P nm then UVar nm else Tyvar nm
26  end handle HOL_ERR _ =>
27             let val {Thy,Tyop=nm,Args} = dest_thy_type ty
28             in
29               Tyop{Thy = Thy, Name = nm, Args = Args}
30             end
31
32fun mkty (UVar s) = mk_vartype s
33  | mkty (Tyvar s) = mk_vartype s
34  | mkty (Tyop{Thy,Name,Args}) =
35      mk_thy_type {Thy = Thy, Tyop = Name, Args = Args}
36
37fun dest_term {tmP, tyP} tm =
38  case HolKernel.dest_term tm of
39      VAR (nm, ty) => if tmP nm then tmUVar(nm, dest_type tyP ty)
40                      else tmVar(nm, dest_type tyP ty)
41    | CONST{Thy,Name,Ty} => uConst {Thy = Thy, Name = Name,
42                                    Ty = dest_type tyP Ty}
43    | LAMB p => uLAMB p
44    | COMB p => uCOMB p
45
46fun mktm (tmUVar p) = mk_var ((I ## mkty) p)
47  | mktm (tmVar p) = mk_var  ((I ## mkty) p)
48  | mktm (uCOMB p) = mk_comb p
49  | mktm (uLAMB p) = mk_abs p
50  | mktm (uConst {Name,Thy,Ty}) = mk_thy_const {Name=Name,Thy=Thy,Ty=mkty Ty}
51
52fun tyvwalk P e s =
53  case Binarymap.peek(e, s) of
54      SOME ty => (case dest_type P ty of
55                      UVar vnm => tyvwalk P e vnm
56                    | tyr => tyr)
57    | NONE => UVar s
58
59fun tywalk P E tyr =
60  case tyr of
61      UVar vnm => tyvwalk P E vnm
62    | _ => tyr
63
64fun tyoc P E tyrep v =
65  case tywalk P E tyrep of
66      UVar v' => v' = v
67    | Tyvar _ => false
68    | Tyop{Args,...} => List.exists (fn ty => tyoc P E (dest_type P ty) v) Args
69
70fun tyunify0 P E (tyr1, tyr2) =
71  case (tywalk P E tyr1, tywalk P E tyr2) of
72      (UVar v1, UVar v2) => SOME (if v1 = v2 then E
73                                  else Binarymap.insert(E,v1,mk_vartype v2))
74    | (UVar v1, t2) => if tyoc P E t2 v1 then NONE
75                       else SOME (Binarymap.insert(E,v1,mkty t2))
76    | (t1, UVar v2) => if tyoc P E t1 v2 then NONE
77                       else SOME (Binarymap.insert(E,v2,mkty t1))
78    | (Tyvar s1, Tyvar s2) => if s1 = s2 then SOME E else NONE
79    | (Tyop{Thy=thy1,Name=name1,Args = a1},
80       Tyop{Thy=thy2,Name=name2,Args = a2}) =>
81      if thy1 = thy2 andalso name1 = name2 then
82        tyunifyl P E (a1,a2)
83      else NONE
84    | _ => NONE
85and tyunifyl P E ([],[]) = SOME E
86  | tyunifyl P E (ty1::ty1s, ty2::ty2s) =
87    (case tyunify0 P E (dest_type P ty1, dest_type P ty2) of
88         NONE => NONE
89       | SOME E' => tyunifyl P E' (ty1s, ty2s))
90  | tyunifyl P E _ = NONE
91
92fun qtyrep ty = dest_type (K true) ty
93fun unified tyE (tyr1, tyr2) =
94  case (tywalk (K true) tyE tyr1, tywalk (K true) tyE tyr2) of
95      (Tyop{Thy=thy1,Name=n1,Args=a1}, Tyop{Thy=thy2,Name=n2,Args=a2}) =>
96        thy1 = thy2 andalso n1 = n2 andalso
97        ListPair.all (fn (ty1,ty2) => unified tyE (qtyrep ty1, qtyrep ty2))
98                     (a1,a2)
99    | (tyr1', tyr2') => tyr1' = tyr2'
100
101fun tmvwalk (P as {tyP, tmP}) (E as {tyE, tmE}) (p as (s,tyr)) =
102  let
103    val tyr' = tywalk tyP tyE tyr
104  in
105    case Binarymap.peek (tmE, s) of
106        SOME (ty,tm) =>
107        if unified tyE (dest_type tyP ty, tyr') then
108          case dest_term P tm of
109              tmUVar p => tmvwalk P E p
110            | tmr => tmr
111        else raise UNIF_ERROR ("Variable "^s^" has two distinct types")
112      | NONE => tmUVar p
113  end
114
115fun tmwalk (P as {tyP, tmP}) (E as {tyE, tmE}) tmr =
116  case tmr of
117      tmUVar p => tmvwalk P E p
118    | _ => tmr
119
120fun tmoc P E v tmrep =
121  case tmwalk P E tmrep of
122      tmUVar (v',_) => v = v'
123    | tmVar _ => false
124    | uConst _ => false
125    | uCOMB(t1,t2) => tmoc P E v (dest_term P t1) orelse
126                      tmoc P E v (dest_term P t2)
127    | uLAMB(bv,body) => tmoc P E v (dest_term P body)
128
129fun utype_of {tyP,tmP} E tmr =
130  dest_type tyP (type_of (mktm tmr))
131
132val insert = Binarymap.insert
133
134fun tmunify P (E0 as {tyE,tmE}) (tmr1, tmr2) =
135  case tyunify0 (#tyP P) (#tyE E0) (utype_of P E0 tmr1, utype_of P E0 tmr2) of
136      NONE => NONE
137    | SOME tyE =>
138      let
139        val E as {tyE,tmE} = {tyE = tyE, tmE = #tmE E0}
140      in
141        case (tmwalk P E tmr1, tmwalk P E tmr2) of
142            (tmUVar (s1, tyr1), tmUVar (s2, tyr2)) =>
143              if s1 = s2 then SOME E
144              else
145                SOME{tyE = tyE,
146                     tmE = insert(tmE, s1, (mkty tyr1, mk_var(s2, mkty tyr1)))}
147          | (tmUVar (s1, tyr1), tmr2) =>
148              if tmoc P E s1 tmr2 then NONE
149              else SOME{tyE=tyE, tmE = insert(tmE, s1, (mkty tyr1, mktm tmr2))}
150          | (tmr1, tmUVar (s2, tyr2)) =>
151              if tmoc P E s2 tmr1 then NONE
152              else SOME{tyE=tyE, tmE = insert(tmE, s2, (mkty tyr2, mktm tmr1))}
153          | (tmVar (s1,_), tmVar (s2,_)) => if s1 = s2 then SOME E else NONE
154          | (uConst{Name=n1,Thy=thy1,...}, uConst{Name=n2,Thy=thy2,...}) =>
155              if n1 = n2 andalso thy1 = thy2 then SOME E else NONE
156          | (uCOMB(t1,t2), uCOMB(u1,u2)) =>
157            (case tmunify P E (dest_term P t1, dest_term P u1) of
158                 NONE => NONE
159               | SOME E' => tmunify P E' (dest_term P t2, dest_term P u2))
160          | (uLAMB(p1 as (bv1,bod1)), uLAMB(p2 as (bv2,bod2))) =>
161            let
162              fun fo() =
163                let val gv = genvar (type_of bv1)
164                    val bod1' = subst [bv1 |-> gv] bod1
165                    val bod2' = subst [bv2 |-> gv] bod2
166                in
167                  tmunify P E (dest_term P bod1', dest_term P bod2')
168                end
169              fun test ((bv1,bod1), (bv2,bod2)) k =
170                case dest_term P bod1 of
171                    uCOMB(f,x) =>
172                    if aconv x bv1 then
173                      case tmwalk P E (dest_term P f) of
174                          tmUVar(s1,tyr1) =>
175                            if tmoc P E s1 (dest_term P bod2) then NONE
176                            else
177                              SOME{tyE = tyE,
178                                   tmE = insert(tmE, s1,
179                                                (type_of bv1 --> type_of bod1,
180                                                 mk_abs(bv2,bod2)))}
181                        | _ => k()
182                    else k()
183                  | _ => k()
184            in
185              test (p1, p2) (fn () => test (p2, p1) fo)
186            end
187          | _ => NONE
188      end
189
190fun tyunify P E0 (ty1, ty2) =
191  tyunify0 P E0 (dest_type P ty1, dest_type P ty2)
192
193fun tywalkstar P E tyr =
194  case tywalk P E tyr of
195      Tyop{Thy,Name,Args} =>
196      Tyop{Thy = Thy, Name = Name,
197           Args = map (mkty o tywalkstar P E o dest_type P) Args}
198    | x => x
199
200fun tmwalkstar P E tmr =
201  let
202    val tyws = tywalkstar (#tyP P) (#tyE E)
203    val tmws = mktm o tmwalkstar P E o dest_term P
204  in
205    case tmwalk P E tmr of
206        tmUVar(s, tyr) => tmUVar(s, tyws tyr)
207      | tmVar(s, tyr) => tmVar(s, tyws tyr)
208      | uCOMB(t1,t2) => uCOMB(tmws t1, tmws t2)
209      | uLAMB (bv, bod) => uLAMB (tmws bv, tmws bod)
210      | uConst {Name,Thy,Ty} => uConst{Name = Name, Thy = Thy, Ty = tyws Ty}
211  end
212
213fun tycollapse P E =
214  Binarymap.map (fn (vnm,_) => mkty (tywalkstar P E (UVar vnm))) E
215
216fun tmcollapse (P as {tyP,...}) (E as {tyE,tmE}) =
217  let
218    val tyE' = tycollapse tyP tyE
219  in
220    {tyE = tyE',
221     tmE = Binarymap.map
222             (fn (vnm, (ty,_)) =>
223                 mktm (tmwalkstar P E (tmUVar(vnm,dest_type tyP ty))))
224             tmE}
225  end
226
227(* test
228
229(* types *)
230
231val uvar_v = String.isPrefix "'v"
232fun tf ty1 ty2 =
233  Option.map (Binarymap.listItems o tycollapse uvar_v)
234             (tyunify uvar_v empty_tyenv (ty1, ty2))
235
236val  r1 = tf ``:'a list`` ``:'v``
237val  r2 = tf ``:'v1 list`` ``:'v2``
238val  r3 = tf ``:'v1 # ('v2 list)`` ``:'v3 # 'v3``
239val  r4 = tf ``:'v1 # ('v2 list)`` ``:'v3 # 'v1``
240val  r5 = tf ``:'v1 list`` ``:'v3 # 'v1``
241val  r6 = tf ``:'v1 list`` ``:'v1``
242val  r7 = tf ``:'v4 # ('v5 list)`` ``:'v3 # 'v4``
243val  r8 = tf ``:'v4 # ('a -> 'v4)`` ``:'v3 # 'v3``
244val  r9 = tf ``:'v4 list`` ``:('a -> 'v5) list``
245val r10 = tf ``:('a -> 'v5) list`` ``:'v4``
246val r11 = tf ``:'a`` ``:'v1``
247val r12 = tf ``:'a`` ``:'v1 list``
248val r13 = tf ``:'a -> bool`` ``:'a -> 'v``
249
250val tmvar = String.isPrefix "uv"
251val P = {tmP = tmvar, tyP = uvar_v}
252type 'a dict = (string,'a) Binarymap.dict
253val E0 = {tyE = Binarymap.mkDict String.compare : hol_type dict,
254          tmE = Binarymap.mkDict String.compare : (hol_type * term) dict}
255
256fun tmf f (tm1,tm2) =
257  Option.map
258    ((fn {tmE, tyE} => (Binarymap.listItems tyE, Binarymap.listItems tmE)) o f)
259    (tmunify P E0 (dest_term P tm1, dest_term P tm2))
260
261val t1 = tmf (tmcollapse P) (``f:'a->'v1``, ``uv:'a->bool``)
262val t2 = tmf I (``(f:'v1->'v2) uv2``, ``(uv:'a -> bool) x``)
263val t3 = tmf (``(f:'v1->bool) uv2``, ``(uv:'a -> 'v2) x``)
264val t4 = tmf (``(uvf:'v1->'v2) uv2``, ``CONS (h:'a) t``)
265val t5 = tmf (``?x. x /\ T``, ``?y:'v1. uv y``)
266val t6 = tmf (``?y:'v1. uv y``, ``?x:bool. x /\ T``)
267val t7 = tmf (``uv1 (x:'v1) (uv:'v2) : 'v3``, ``x /\ x``)
268val t8 = tmf (``uv1 (uv:'v1) (uv:'v1) : 'v3``, ``x /\ y``) (* NONE *)
269
270*)
271
272
273end
274