1structure FullUnify :> FullUnify =
2struct
3
4open HolKernel boolSyntax
5structure Env =
6struct
7
8  (* It would be nice to have a simple invariant such as
9       term maps always preserve type
10     But this can't work because you might want to unify
11
12       x:'a    with    SUC
13
14     this requires :'a to map to :num -> num.  So the invariant has to be
15     something like
16
17       term-variable v maps to term t  ==>
18       type_of (sigma v) = type_of (sigma t)
19
20     where sigma is the type instantiation given by the map
21  *)
22  type t = (string, hol_type) Binarymap.dict * (term, term) Binarymap.dict
23  fun triTY ((d,_):t) = d
24  fun triTM ((_, d):t) = d
25  type 'a EM = (t, 'a) optmonad.optmonad
26  val empty : t =
27        (Binarymap.mkDict String.compare, Binarymap.mkDict Term.compare)
28
29  fun lookup_ty0 tym ty =
30      if is_vartype ty then
31        case Binarymap.peek(tym, dest_vartype ty) of
32            NONE => ty
33          | SOME ty' => lookup_ty0 tym ty'
34      else
35        let val {Thy,Tyop,Args} = dest_thy_type ty
36            val Args' = map (lookup_ty0 tym) Args
37        in
38          mk_thy_type {Thy=Thy, Tyop=Tyop, Args=Args'}
39        end
40  fun lookup_ty (E:t) ty = lookup_ty0 (#1 E) ty
41  fun instE (E:t) tm =
42      let val tyvs = type_vars_in_term tm
43          val sigma =
44              map (fn ty => {redex = ty, residue = lookup_ty0 (#1 E) ty}) tyvs
45      in
46        Term.inst sigma tm
47      end
48  fun lookup_tm E tm0 =
49      let
50        val tm = instE E tm0
51      in
52        case dest_term tm of
53          VAR _ => (case Binarymap.peek(#2 E, tm) of
54                        NONE => tm
55                      | SOME tm' => lookup_tm E tm')
56        | CONST _ => tm
57        | COMB(f,x) => mk_comb(lookup_tm E f, lookup_tm E x)
58        | LAMB(v,bod) =>
59          let
60            val tm' =
61                #1 (Binarymap.remove(#2 E, v)) handle Binarymap.NotFound => #2 E
62          in
63            mk_abs(v, lookup_tm (#1 E, tm') bod)
64          end
65      end
66
67
68
69  fun add_tybind (s,ty) : unit EM = fn (tym,tmm)  =>
70      case Binarymap.peek(tym, s) of
71          NONE => SOME ((Binarymap.insert(tym,s,ty), tmm), ())
72        | SOME _ => NONE
73
74  fun add_tmbind (v, tm) : unit EM = fn (tym,tmm) =>
75      case Binarymap.peek(tmm, v) of
76          NONE => SOME((tym, Binarymap.insert(tmm,v,tm)), ())
77        | SOME _ => NONE
78
79  fun fromEmpty (m : 'a EM) = Option.map #2 (m empty)
80end (* Env struct *)
81
82fun getty ty E = SOME(E, Env.lookup_ty E ty)
83fun gettm tm : term Env.EM = fn E => SOME(E, Env.lookup_tm E tm)
84
85infix >*
86fun (m1 >* m2) = optmonad.lift2 (fn x => fn y => (x,y)) m1 m2
87
88fun unify_types ctys (ty1, ty2) : unit Env.EM =
89  let
90    open optmonad
91    val op>>- = op>-
92    fun recurse (ty1_0, ty2_0) =
93        let
94          fun k (ty1, ty2) =
95              if is_vartype ty1 then
96                if ty1 = ty2 then return ()
97                else if Lib.mem ty1 (type_vars ty2) then fail
98                else if Lib.mem ty1 ctys then fail
99                else Env.add_tybind (dest_vartype ty1, ty2)
100              else if is_vartype ty2 then fail
101              else
102                let
103                  val {Args=a1,Tyop=op1,Thy=thy1} = dest_thy_type ty1
104                  val {Args=a2,Tyop=op2,Thy=thy2} = dest_thy_type ty2
105                in
106                  if thy1 <> thy2 orelse op1 <> op2 then fail
107                  else
108                    mmap recurse (ListPair.zip(a1,a2)) >> return ()
109                end
110          fun flip (p as (ty1, ty2)) =
111              if is_vartype ty1 andalso not (Lib.mem ty1 ctys) then p
112              else (ty2, ty1)
113        in
114          lift flip (getty ty1_0 >* getty ty2_0) >>- k
115        end
116  in
117    recurse(ty1,ty2)
118  end
119
120fun unify ctys ctms (t1, t2) : unit Env.EM =
121  let
122    open optmonad
123    val op>>- = op>-
124    fun recurse bvs (tm10, tm20) : unit Env.EM =
125        let
126          fun k (tm1, tm2) =
127              case (dest_term tm1, dest_term tm2) of
128                  (VAR _, VAR _) => if tm1 ~~ tm2 then return ()
129                                    else if tmem tm1 bvs orelse tmem tm2 bvs
130                                    then
131                                      fail
132                                    else if tmem tm1 ctms then
133                                      if tmem tm2 ctms then fail
134                                      else Env.add_tmbind (tm2, tm1)
135                                    else Env.add_tmbind (tm1, tm2)
136                | (VAR _, _) => if free_in tm1 tm2 orelse tmem tm1 ctms orelse
137                                   tmem tm1 bvs
138                                then
139                                  fail
140                                else Env.add_tmbind (tm1, tm2)
141                | (CONST _, CONST _) => if tm1 ~~ tm2 then return () else fail
142                | (COMB(f1,x1), COMB(f2,x2)) =>
143                  recurse bvs (f1,f2) >> recurse bvs (x1,x2)
144                | (LAMB(bv1,bod1), LAMB(bv2,bod2)) =>
145                  let
146                    val gv = genvar (type_of bv1)
147                  in
148                    recurse (gv::bvs)
149                            (subst [bv1 |-> gv] bod1, subst [bv2 |-> gv] bod2)
150                  end
151                | _ => fail
152          fun flip (p as (t1, t2)) =
153              if is_var t1 then p else if is_var t2 then (t2,t1) else p
154        in
155          unify_types ctys (type_of tm10, type_of tm20) >>
156          lift flip (gettm tm10 >* gettm tm20) >>- k
157        end
158  in
159    recurse [] (t1, t2)
160  end
161
162fun collapse0 E =
163    let
164      val mk_vartype = trace ("Vartype Format Complaint", 0) mk_vartype
165    in
166      (Binarymap.foldl
167         (fn (s,ty,A) => {redex=mk_vartype s, residue = Env.lookup_ty E ty}::A)
168         []
169         (Env.triTY E),
170       Binarymap.foldl
171         (fn (v,tm,A) => {redex = v, residue = Env.lookup_tm E tm} :: A)
172         []
173         (Env.triTM E))
174    end
175
176fun collapse E = SOME(E, collapse0 E)
177
178
179end
180