1structure Overload :> Overload =
2struct
3
4open HolKernel Lexis
5infix ##
6
7(* invariant on the type overloaded_op_info;
8     base_type is the anti-unification of all the types in the actual_ops
9     list
10   invariant on the overload_info list:
11     all members of the list have non-empty actual_ops lists
12*)
13
14type nthy_rec = {Name : string, Thy : string}
15
16fun lose_constrec_ty {Name,Ty,Thy} = {Name = Name, Thy = Thy}
17
18type overloaded_op_info = {base_type : Type.hol_type, actual_ops : term list,
19                           tyavoids : Type.hol_type list}
20
21(* the overload info is thus a pair:
22   * first component is for the "parsing direction"; it's a map from
23     identifier name to an overloaded_op_info record.
24   * second component is for the "printing direction"; it takes constant
25     specifications {Name,Thy} records, and returns the preferred
26     identifier. If no entry exists, the constant should be printed in
27     thy$constant name form.
28*)
29
30
31type printmap_data = term * string * int
32  (* the term is the lambda abstraction provided by the user, the
33     string is the name that it is to be used in the printing process, and
34     the int is the 'timestamp' *)
35fun pmdata_compare ((t1,s1,_), (t2,s2,_)) =
36    case Term.compare(t1,t2) of
37        EQUAL => String.compare(s1,s2)
38      | r => r
39val pos_tstamp : bool -> int = let
40  val neg = ref 0
41  val cnt = ref 1
42in
43  fn true => (!cnt before (cnt := !cnt + 1))
44   | false => (!neg before (neg := !neg - 1))
45end
46fun tstamp () = pos_tstamp true
47
48structure PMDataSet = struct
49  type value = printmap_data
50  type t = value HOLset.set
51  val empty = HOLset.empty pmdata_compare
52  val insert = HOLset.add
53  val fold = HOLset.foldl
54  val listItems = HOLset.listItems
55  fun filter P s =
56      fold (fn (v,a) => if P v then insert(a,v) else a)
57           empty
58           s
59  val numItems = HOLset.numItems
60end
61
62structure PrintMap = LVTermNetFunctor(PMDataSet)
63
64type overload_info =
65     ((string,overloaded_op_info) Binarymap.dict * PrintMap.lvtermnet)
66
67fun raw_print_map ((x,y):overload_info) = y
68
69fun nthy_rec_cmp ({Name = n1, Thy = thy1}, {Name = n2, Thy = thy2}) =
70    pair_compare (String.compare, String.compare) ((thy1, n1), (thy2, n2))
71
72val null_oinfo : overload_info =
73  (Binarymap.mkDict String.compare, PrintMap.empty)
74
75fun oinfo_ops (oi,_) = Binarymap.listItems oi
76fun print_map (_, pm) = let
77  fun foldthis (_,(t,nm,_),acc) =
78      if Theory.uptodate_term t then
79        (lose_constrec_ty (dest_thy_const t),nm) :: acc
80        handle HOL_ERR _ => acc
81      else acc
82in
83  PrintMap.fold foldthis [] pm
84end
85
86fun update_assoc k v [] = [(k,v)]
87  | update_assoc k v ((k',v')::kvs) = if k = k' then (k,v)::kvs
88                                      else (k',v')::update_assoc k v kvs
89
90exception OVERLOAD_ERR of string
91
92fun tmlist_tyvs tlist =
93  List.foldl (fn (t,acc) => Lib.union (type_vars_in_term t) acc) [] tlist
94
95local
96  open stmonad Lib Type
97  infix >- >>
98  fun lookup n (env,avds) =
99    case assoc1 n env of
100      NONE => ((env,avds), NONE)
101    | SOME (_,v) => ((env,avds), SOME v)
102  fun extend x (env,avds) = ((x::env,avds), ())
103  (* invariant on type generation part of state:
104       not (next_var MEM sofar)
105  *)
106  fun newtyvar (env, (next_var, sofar)) = let
107    val new_sofar = next_var::sofar
108    val new_next = gen_variant tyvar_vary sofar (tyvar_vary next_var)
109    (* new_next can't be in new_sofar because gen_variant ensures that
110       it won't be in sofar, and tyvar_vary ensures it won't be equal to
111       next_var *)
112  in
113    ((env, (new_next, new_sofar)), mk_vartype next_var)
114  end
115
116  fun au (ty1, ty2) =
117    if ty1 = ty2 then return ty1
118    else
119      lookup (ty1, ty2) >-
120      (fn result =>
121       case result of
122         NONE =>
123           if not (is_vartype ty1) andalso not (is_vartype ty2) then let
124               val {Thy = thy1, Tyop = tyop1, Args = args1} = dest_thy_type ty1
125               val {Thy = thy2, Tyop = tyop2, Args = args2} = dest_thy_type ty2
126             in
127               if tyop1 = tyop2 andalso thy1 = thy2 then
128                 mmap au (ListPair.zip (args1, args2)) >-
129                 (fn tylist =>
130                   return (mk_thy_type{Thy = thy1, Tyop = tyop1,
131                                       Args = tylist}))
132               else
133                 newtyvar >- (fn new_ty => extend ((ty1, ty2), new_ty) >>
134                              return new_ty)
135             end
136           else
137             newtyvar >- (fn new_ty =>
138                          extend ((ty1, ty2), new_ty) >>
139                          return new_ty)
140        | SOME v => return v)
141
142  fun initial_state (ty1, ty2) = let
143    val avoids = map dest_vartype (type_varsl [ty1, ty2])
144    val first_var = gen_variant tyvar_vary avoids "'a"
145  in
146    ([], (first_var, avoids))
147  end
148  fun generate_iterates n f x =
149    if n <= 0 then []
150    else x::generate_iterates (n - 1) f (f x)
151
152  fun canonicalise ty = let
153    val tyvars = type_vars ty
154    val replacements =
155      map mk_vartype (generate_iterates (length tyvars) tyvar_vary "'a")
156    val subst =
157      ListPair.map (fn (ty1, ty2) => Lib.|->(ty1, ty2)) (tyvars, replacements)
158  in
159    type_subst subst ty
160  end
161in
162  fun anti_unify ty1 ty2 = let
163    val (_, result) = au (ty1, ty2) (initial_state (ty1, ty2))
164  in
165    canonicalise result
166  end
167end
168
169(* find anti-unification for list of types *)
170fun aul tyl =
171    case tyl of
172      [] => raise Fail "Overload.aul applied to empty list - shouldn't happen"
173    | (h::t) => foldl (uncurry anti_unify) h t
174
175fun au_tml tml =
176    case tml of
177      [] => raise Fail "Overload.au_tml applied to empty list: shouldn't happen"
178    | tm :: tms => foldl (fn (t,acc) => anti_unify (type_of t) acc)
179                         (type_of tm)
180                         tms
181
182fun fupd_actual_ops f {base_type, actual_ops, tyavoids} =
183  {base_type = base_type, actual_ops = f actual_ops, tyavoids = tyavoids}
184
185fun fupd_base_type f {base_type, actual_ops, tyavoids} =
186  {base_type = f base_type, actual_ops = actual_ops, tyavoids = tyavoids}
187
188fun fupd_tyavoids f {base_type, actual_ops, tyavoids} =
189    {base_type = base_type, actual_ops = actual_ops, tyavoids = f tyavoids}
190
191fun fupd_dict_at_key k f dict = let
192  val (newdict, kitem) = Binarymap.remove(dict,k)
193in
194  Binarymap.insert(newdict,k,f kitem)
195end
196
197fun info_for_name (overloads:overload_info) s =
198  Binarymap.peek (#1 overloads, s)
199fun is_overloaded (overloads:overload_info) s =
200  isSome (info_for_name overloads s)
201
202fun type_compare (ty1, ty2) = let
203  val ty1_gte_ty2 = Lib.can (Type.match_type ty1) ty2
204  val ty2_gte_ty1 = Lib.can (Type.match_type ty2) ty1
205in
206  case (ty1_gte_ty2, ty2_gte_ty1) of
207    (true, true) => SOME EQUAL
208  | (true, false) => SOME GREATER
209  | (false, true) => SOME LESS
210  | (false, false) => NONE
211end
212
213fun remove_overloaded_form s (oinfo:overload_info) = let
214  val (op2cnst, cnst2op) = oinfo
215  val (okopc, badopc0) = (I ## #actual_ops) (Binarymap.remove(op2cnst, s))
216    handle Binarymap.NotFound => (op2cnst, [])
217  val badopc = List.filter Theory.uptodate_term badopc0
218  (* will keep okopc, but should now remove from cnst2op all pairs of the form
219       (c, s)
220     where s is the s above *)
221  fun foldthis (k,fullv as (t,v,_),acc as (map,removed)) =
222      if not (Theory.uptodate_term t) then (map, removed)
223      else if v = s then (map, t ::removed)
224      else (PrintMap.insert(map, k, fullv), removed)
225
226  val (okcop, badcop) = PrintMap.fold foldthis (PrintMap.empty,[]) cnst2op
227in
228  ((okopc, okcop), (badopc, badcop))
229end
230
231fun raw_map_insert s (new_op2cs, new_c2ops) (op2c_map, c2op_map) = let
232  fun install_ty (r as {Name,Thy}) =
233    Term.prim_mk_const r
234    handle HOL_ERR _ =>
235      raise OVERLOAD_ERR ("No such constant: "^Thy^"$"^Name)
236  val withtypes = map install_ty new_op2cs
237
238  val new_c2op_map = let
239    val withtypes = map install_ty new_c2ops
240  in
241    List.foldl (fn (t,acc) => PrintMap.insert(acc, ([],t), (t,s,tstamp())))
242               c2op_map
243               withtypes
244  end
245in
246  case withtypes of
247    [] => (op2c_map, new_c2op_map)
248  | (r::rs) => let
249      val au = foldl (fn (r1, t) => anti_unify (type_of r1) t) (type_of r) rs
250    in
251      (Binarymap.insert
252         (op2c_map, s,
253          {base_type = au, actual_ops = withtypes,
254           tyavoids = tmlist_tyvs (HOLset.listItems
255                                     (FVL withtypes empty_tmset))}),
256       new_c2op_map)
257    end
258end
259
260(* a predicate on pairs of operations and types that returns true if
261   they're equal, given that two types are equal if they can match
262   each other *)
263fun ntys_equal {Ty = ty1,Name = n1, Thy = thy1}
264               {Ty = ty2, Name = n2, Thy = thy2} =
265  type_compare (ty1, ty2) = SOME EQUAL andalso n1 = n2 andalso thy1 = thy2
266
267
268(* put a new overloading resolution into the database.  If it's already
269   there for a given operator, don't mind.  In either case, make sure that
270   it's at the head of the list, meaning that it will be the first choice
271   in ambigous resolutions.
272   update: abstracted the inserter to allow adding at the
273           end of the list for inferior resolutions.  *)
274fun add_overloading_with_inserter inserter tstamp (opname, term) oinfo = let
275  val _ = Theory.uptodate_term term orelse
276          raise OVERLOAD_ERR ("Term is out-of-date; opname = "^opname)
277  val (opc0, cop0) = oinfo
278  val opc =
279      case info_for_name oinfo opname of
280        SOME {base_type, actual_ops = a0, tyavoids} => let
281          (* this name is already overloaded *)
282          val actual_ops = List.filter Theory.uptodate_term a0
283          val changed = length actual_ops <> length a0
284        in
285          case Lib.total (Lib.pluck (aconv term)) actual_ops of
286            SOME (_, rest) => let
287              (* this term was already in the map *)
288              (* must replace it *)
289              val (avoids, base_type) =
290                  if changed then
291                    (tmlist_tyvs (free_varsl actual_ops), au_tml actual_ops)
292                  else (tyavoids, base_type)
293            in
294              Binarymap.insert(opc0, opname,
295                               {actual_ops = inserter(term,rest),
296                                base_type = base_type,
297                                tyavoids = avoids})
298            end
299          | NONE => let
300              (* Wasn't in the map, so can just cons its record in *)
301              val (newbase, new_avoids) =
302                  if changed then
303                    (au_tml (term::actual_ops),
304                     tmlist_tyvs (free_varsl (term::actual_ops)))
305                  else
306                    (anti_unify base_type (type_of term),
307                     Lib.union (tmlist_tyvs (free_vars term)) tyavoids)
308            in
309              Binarymap.insert(opc0, opname,
310                               {actual_ops = inserter(term,actual_ops),
311                                base_type = newbase,
312                                tyavoids = new_avoids})
313            end
314        end
315      | NONE =>
316        (* this name not overloaded at all *)
317        Binarymap.insert(opc0, opname,
318                         {actual_ops = [term], base_type = type_of term,
319                          tyavoids = tmlist_tyvs (free_vars term)})
320  val cop = let
321    val fvs = free_vars term
322    val (_, pat) = strip_abs term
323  in
324    PrintMap.insert(cop0,(fvs,pat),(term,opname,tstamp()))
325  end
326in
327  (opc, cop)
328end
329
330val add_overloading = add_overloading_with_inserter (op ::) (fn () => pos_tstamp true)
331val add_inferior_overloading = add_overloading_with_inserter (fn (a,l) => l @ [a]) (fn() => pos_tstamp false)
332
333local
334  fun foverloading f {opname, realname, realthy} oinfo = let
335    val nthy_rec = {Name = realname, Thy = realthy}
336    val cnst = prim_mk_const nthy_rec
337      handle HOL_ERR _ =>
338        raise OVERLOAD_ERR ("No such constant: "^realthy^"$"^realname)
339    val (opc0, cop0) = oinfo
340    val opc =
341        case info_for_name oinfo opname of
342          SOME {base_type, actual_ops, tyavoids} => let
343            (* this name is overloaded *)
344          in
345            case List.find (aconv cnst) actual_ops of
346              SOME x => (* the constant is in the map *)
347                Binarymap.insert(opc0, opname,
348                  {actual_ops = f (aconv cnst) actual_ops,
349                   base_type = base_type,
350                   tyavoids = tyavoids})
351            | NONE => raise OVERLOAD_ERR
352                        ("Constant not overloaded: "^realthy^"$"^realname)
353          end
354        | NONE => raise OVERLOAD_ERR
355                    ("No overloading for Operator: "^opname)
356  in
357    (opc, cop0)
358  end
359
360  fun send_to_back P l = let val (m,r) = Lib.pluck P l in r @ [m] end
361  fun bring_to_front P l = let val (m,r) = Lib.pluck P l in m::r end
362in
363  fun send_to_back_overloading x oinfo = foverloading send_to_back x oinfo
364  fun bring_to_front_overloading x oinfo = foverloading bring_to_front x oinfo
365end;
366
367
368fun myfind f [] = NONE
369  | myfind f (x::xs) = case f x of (v as SOME _) => v | NONE => myfind f xs
370
371fun isize0 acc f [] = acc
372  | isize0 acc f ({redex,residue} :: rest) = isize0 (acc + f residue + 1) f rest
373fun isize f x = isize0 0 f x
374
375fun strip_comb ((_, prmap): overload_info) namePred t = let
376  val matches = PrintMap.match(prmap, t)
377  val cmp0 = pair_compare (measure_cmp (isize term_size),
378                           pair_compare (measure_cmp (isize type_size),
379                                         flip_order o Int.compare))
380  val cmp = inv_img_cmp (fn (a,b,c,d) => (a,(b,c))) cmp0
381
382  fun test ((fvs, pat), (orig, nm, tstamp)) = let
383    val _ = assert namePred nm
384    val tyvs = tmlist_tyvs fvs
385    val tmset = HOLset.addList(empty_tmset, fvs)
386    val ((tmi0,tmeq),(tyi0,tyeq)) = raw_match tyvs tmset pat t ([],[])
387    val tmi = HOLset.foldl (fn (t,acc) => if HOLset.member(tmset,t) then acc
388                                          else (t |-> t) :: acc)
389                           tmi0
390                           tmeq
391    val tyi = List.foldl (fn (ty,acc) => if mem ty tyvs then acc
392                                         else (ty |-> ty) :: acc)
393                         tyi0
394                         tyeq
395  in
396    SOME (tmi, tyi, tstamp, (orig, nm))
397  end handle HOL_ERR _ => NONE
398
399  val inst_data = List.mapPartial test matches
400  val sorted = Listsort.sort cmp inst_data
401  fun rearrange (tmi, _, _, (orig, nm)) = let
402    val (bvs,basepat) = strip_abs orig
403    fun findarg v =
404        case List.find (fn {redex,residue} => aconv redex v) tmi of
405          NONE => mk_const("ARB", type_of v)
406        | SOME i => #residue i
407    val args = map findarg bvs
408    val fconst_ty = List.foldr (fn (arg,acc) => type_of arg --> acc)
409                               (type_of t)
410                               args
411    val origopt = let
412      val (hd, args) = HolKernel.strip_comb basepat
413    in
414      if ListPair.all (uncurry aconv) (bvs, args) then
415        let
416          val {Name,Thy,...} = dest_thy_const hd
417        in
418          SOME {Thy=Thy,Name=Name}
419        end handle HOL_ERR _ => NONE
420      else NONE
421    end
422  in
423    (mk_var(GrammarSpecials.mk_fakeconst_name {fake = nm, original = origopt},
424            fconst_ty),
425     args)
426  end
427in
428  case sorted of
429    [] => NONE
430  | (m as (_, _, _, (_, nm))) :: _ => if nm = "" then NONE
431                                      else SOME (rearrange m)
432end
433fun oi_strip_combP oinfo P t = let
434  fun recurse acc t =
435      case strip_comb oinfo P t of
436        NONE => let
437        in
438          case Lib.total dest_comb t of
439            NONE => NONE
440          | SOME (f,x) => recurse (x::acc) f
441        end
442      | SOME (f,args) => SOME(f, args @ acc)
443  val (realf, args) = HolKernel.strip_comb t
444in
445  if is_var realf andalso
446     String.isPrefix GrammarSpecials.fakeconst_special (#1 (dest_var realf))
447  then
448    SOME(realf, args)
449  else recurse [] t
450end
451
452fun oi_strip_comb oinfo t = oi_strip_combP oinfo (fn _ => true) t
453
454
455fun overloading_of_termP (oinfo as (_, prmap) : overload_info) P t =
456    case strip_comb oinfo P t of
457      SOME (f, []) => f |> dest_var |> #1 |> GrammarSpecials.dest_fakeconst_name
458                        |> Option.map #fake
459    | _ => NONE
460
461fun overloading_of_term oinfo t = overloading_of_termP oinfo (fn _ => true) t
462
463fun overloading_of_nametype (oinfo:overload_info) r =
464    case Lib.total prim_mk_const r of
465      NONE => NONE
466    | SOME c => overloading_of_term oinfo c
467
468fun rev_append [] rest = rest
469  | rev_append (x::xs) rest = rev_append xs (x::rest)
470
471val show_alias_resolution = ref true
472val _ = Feedback.register_btrace ("show_alias_printing_choices",
473                                  show_alias_resolution)
474
475fun merge_oinfos (O1:overload_info) (O2:overload_info) : overload_info = let
476  val O1ops_sorted = Binarymap.listItems (#1 O1)
477  val O2ops_sorted = Binarymap.listItems (#1 O2)
478  fun merge acc op1s op2s =
479    case (op1s, op2s) of
480      ([], x) => rev_append acc x
481    | (x, []) => rev_append acc x
482    | ((k1,op1)::op1s', (k2,op2)::op2s') => let
483      in
484        case String.compare (k1, k2) of
485          LESS => merge ((k1,op1)::acc) op1s' op2s
486        | EQUAL => let
487            val name = k1
488            val ty1 = #base_type op1
489            val ty2 = #base_type op2
490            val newty = anti_unify ty1 ty2
491            val newopinfo =
492              (name,
493               {base_type = newty,
494                actual_ops =
495                Lib.op_union aconv (#actual_ops op1) (#actual_ops op2),
496                tyavoids = Lib.union (#tyavoids op1) (#tyavoids op2)})
497          in
498            merge (newopinfo::acc) op1s' op2s'
499          end
500        | GREATER => merge ((k2, op2)::acc) op1s op2s'
501      end
502    infix ##
503    fun foldthis (k,v as (t,_,_),acc) =
504        if Theory.uptodate_term t then PrintMap.insert(acc,k,v)
505        else acc
506    val new_prmap = PrintMap.fold foldthis (#2 O2) (#2 O1)
507in
508  (List.foldr (fn ((k,v),dict) => Binarymap.insert(dict,k,v))
509              (Binarymap.mkDict String.compare)
510              (merge [] O1ops_sorted O2ops_sorted),
511   new_prmap)
512end
513
514fun keys dict = Binarymap.foldr (fn (k,v,l) => k::l) [] dict
515
516fun known_constants (oi:overload_info) = keys (#1 oi)
517
518fun remove_omapping t str opdict = let
519  val (dictlessk, kitem) = Binarymap.remove(opdict, str)
520  fun ok_actual t' = not (aconv t' t)
521  val new_rec = fupd_actual_ops (List.filter ok_actual) kitem
522in
523  if (null (#actual_ops new_rec)) then dictlessk
524  else Binarymap.insert(dictlessk, str, new_rec)
525end handle Binarymap.NotFound => opdict
526
527fun gen_remove_mapping str t ((opc, cop) : overload_info) = let
528  val cop' = let
529    val ds = PrintMap.peek (cop, ([], t))
530    val ds' = PMDataSet.filter (fn (_, s, _) => s <> str) ds
531  in
532    if PMDataSet.numItems ds' = PMDataSet.numItems ds then cop
533    else let
534        val (pm',_) = PrintMap.delete(cop, ([], t))
535      in
536        PMDataSet.fold (fn (d,acc) => PrintMap.insert(acc,([],t),d))
537                       pm'
538                       ds'
539      end
540  end
541in
542  (remove_omapping t str opc, cop')
543end
544fun remove_mapping str crec = gen_remove_mapping str (prim_mk_const crec)
545
546end (* Overload *)
547