1structure Overload :> Overload = 2struct 3 4open HolKernel Lexis 5infix ## 6 7(* invariant on the type overloaded_op_info; 8 base_type is the anti-unification of all the types in the actual_ops 9 list 10 invariant on the overload_info list: 11 all members of the list have non-empty actual_ops lists 12*) 13 14type nthy_rec = {Name : string, Thy : string} 15 16fun lose_constrec_ty {Name,Ty,Thy} = {Name = Name, Thy = Thy} 17 18type overloaded_op_info = {base_type : Type.hol_type, actual_ops : term list, 19 tyavoids : Type.hol_type list} 20 21(* the overload info is thus a pair: 22 * first component is for the "parsing direction"; it's a map from 23 identifier name to an overloaded_op_info record. 24 * second component is for the "printing direction"; it takes constant 25 specifications {Name,Thy} records, and returns the preferred 26 identifier. If no entry exists, the constant should be printed in 27 thy$constant name form. 28*) 29 30 31type printmap_data = term * string * int 32 (* the term is the lambda abstraction provided by the user, the 33 string is the name that it is to be used in the printing process, and 34 the int is the 'timestamp' *) 35fun pmdata_compare ((t1,s1,_), (t2,s2,_)) = 36 case Term.compare(t1,t2) of 37 EQUAL => String.compare(s1,s2) 38 | r => r 39val pos_tstamp : bool -> int = let 40 val neg = ref 0 41 val cnt = ref 1 42in 43 fn true => (!cnt before (cnt := !cnt + 1)) 44 | false => (!neg before (neg := !neg - 1)) 45end 46fun tstamp () = pos_tstamp true 47 48structure PMDataSet = struct 49 type value = printmap_data 50 type t = value HOLset.set 51 val empty = HOLset.empty pmdata_compare 52 val insert = HOLset.add 53 val fold = HOLset.foldl 54 val listItems = HOLset.listItems 55 fun filter P s = 56 fold (fn (v,a) => if P v then insert(a,v) else a) 57 empty 58 s 59 val numItems = HOLset.numItems 60end 61 62structure PrintMap = LVTermNetFunctor(PMDataSet) 63 64type overload_info = 65 ((string,overloaded_op_info) Binarymap.dict * PrintMap.lvtermnet) 66 67fun raw_print_map ((x,y):overload_info) = y 68 69fun nthy_rec_cmp ({Name = n1, Thy = thy1}, {Name = n2, Thy = thy2}) = 70 pair_compare (String.compare, String.compare) ((thy1, n1), (thy2, n2)) 71 72val null_oinfo : overload_info = 73 (Binarymap.mkDict String.compare, PrintMap.empty) 74 75fun oinfo_ops (oi,_) = Binarymap.listItems oi 76fun print_map (_, pm) = let 77 fun foldthis (_,(t,nm,_),acc) = 78 if Theory.uptodate_term t then 79 (lose_constrec_ty (dest_thy_const t),nm) :: acc 80 handle HOL_ERR _ => acc 81 else acc 82in 83 PrintMap.fold foldthis [] pm 84end 85 86fun update_assoc k v [] = [(k,v)] 87 | update_assoc k v ((k',v')::kvs) = if k = k' then (k,v)::kvs 88 else (k',v')::update_assoc k v kvs 89 90exception OVERLOAD_ERR of string 91 92fun tmlist_tyvs tlist = 93 List.foldl (fn (t,acc) => Lib.union (type_vars_in_term t) acc) [] tlist 94 95local 96 open stmonad Lib Type 97 infix >- >> 98 fun lookup n (env,avds) = 99 case assoc1 n env of 100 NONE => ((env,avds), NONE) 101 | SOME (_,v) => ((env,avds), SOME v) 102 fun extend x (env,avds) = ((x::env,avds), ()) 103 (* invariant on type generation part of state: 104 not (next_var MEM sofar) 105 *) 106 fun newtyvar (env, (next_var, sofar)) = let 107 val new_sofar = next_var::sofar 108 val new_next = gen_variant tyvar_vary sofar (tyvar_vary next_var) 109 (* new_next can't be in new_sofar because gen_variant ensures that 110 it won't be in sofar, and tyvar_vary ensures it won't be equal to 111 next_var *) 112 in 113 ((env, (new_next, new_sofar)), mk_vartype next_var) 114 end 115 116 fun au (ty1, ty2) = 117 if ty1 = ty2 then return ty1 118 else 119 lookup (ty1, ty2) >- 120 (fn result => 121 case result of 122 NONE => 123 if not (is_vartype ty1) andalso not (is_vartype ty2) then let 124 val {Thy = thy1, Tyop = tyop1, Args = args1} = dest_thy_type ty1 125 val {Thy = thy2, Tyop = tyop2, Args = args2} = dest_thy_type ty2 126 in 127 if tyop1 = tyop2 andalso thy1 = thy2 then 128 mmap au (ListPair.zip (args1, args2)) >- 129 (fn tylist => 130 return (mk_thy_type{Thy = thy1, Tyop = tyop1, 131 Args = tylist})) 132 else 133 newtyvar >- (fn new_ty => extend ((ty1, ty2), new_ty) >> 134 return new_ty) 135 end 136 else 137 newtyvar >- (fn new_ty => 138 extend ((ty1, ty2), new_ty) >> 139 return new_ty) 140 | SOME v => return v) 141 142 fun initial_state (ty1, ty2) = let 143 val avoids = map dest_vartype (type_varsl [ty1, ty2]) 144 val first_var = gen_variant tyvar_vary avoids "'a" 145 in 146 ([], (first_var, avoids)) 147 end 148 fun generate_iterates n f x = 149 if n <= 0 then [] 150 else x::generate_iterates (n - 1) f (f x) 151 152 fun canonicalise ty = let 153 val tyvars = type_vars ty 154 val replacements = 155 map mk_vartype (generate_iterates (length tyvars) tyvar_vary "'a") 156 val subst = 157 ListPair.map (fn (ty1, ty2) => Lib.|->(ty1, ty2)) (tyvars, replacements) 158 in 159 type_subst subst ty 160 end 161in 162 fun anti_unify ty1 ty2 = let 163 val (_, result) = au (ty1, ty2) (initial_state (ty1, ty2)) 164 in 165 canonicalise result 166 end 167end 168 169(* find anti-unification for list of types *) 170fun aul tyl = 171 case tyl of 172 [] => raise Fail "Overload.aul applied to empty list - shouldn't happen" 173 | (h::t) => foldl (uncurry anti_unify) h t 174 175fun au_tml tml = 176 case tml of 177 [] => raise Fail "Overload.au_tml applied to empty list: shouldn't happen" 178 | tm :: tms => foldl (fn (t,acc) => anti_unify (type_of t) acc) 179 (type_of tm) 180 tms 181 182fun fupd_actual_ops f {base_type, actual_ops, tyavoids} = 183 {base_type = base_type, actual_ops = f actual_ops, tyavoids = tyavoids} 184 185fun fupd_base_type f {base_type, actual_ops, tyavoids} = 186 {base_type = f base_type, actual_ops = actual_ops, tyavoids = tyavoids} 187 188fun fupd_tyavoids f {base_type, actual_ops, tyavoids} = 189 {base_type = base_type, actual_ops = actual_ops, tyavoids = f tyavoids} 190 191fun fupd_dict_at_key k f dict = let 192 val (newdict, kitem) = Binarymap.remove(dict,k) 193in 194 Binarymap.insert(newdict,k,f kitem) 195end 196 197fun info_for_name (overloads:overload_info) s = 198 Binarymap.peek (#1 overloads, s) 199fun is_overloaded (overloads:overload_info) s = 200 isSome (info_for_name overloads s) 201 202fun type_compare (ty1, ty2) = let 203 val ty1_gte_ty2 = Lib.can (Type.match_type ty1) ty2 204 val ty2_gte_ty1 = Lib.can (Type.match_type ty2) ty1 205in 206 case (ty1_gte_ty2, ty2_gte_ty1) of 207 (true, true) => SOME EQUAL 208 | (true, false) => SOME GREATER 209 | (false, true) => SOME LESS 210 | (false, false) => NONE 211end 212 213fun remove_overloaded_form s (oinfo:overload_info) = let 214 val (op2cnst, cnst2op) = oinfo 215 val (okopc, badopc0) = (I ## #actual_ops) (Binarymap.remove(op2cnst, s)) 216 handle Binarymap.NotFound => (op2cnst, []) 217 val badopc = List.filter Theory.uptodate_term badopc0 218 (* will keep okopc, but should now remove from cnst2op all pairs of the form 219 (c, s) 220 where s is the s above *) 221 fun foldthis (k,fullv as (t,v,_),acc as (map,removed)) = 222 if not (Theory.uptodate_term t) then (map, removed) 223 else if v = s then (map, t ::removed) 224 else (PrintMap.insert(map, k, fullv), removed) 225 226 val (okcop, badcop) = PrintMap.fold foldthis (PrintMap.empty,[]) cnst2op 227in 228 ((okopc, okcop), (badopc, badcop)) 229end 230 231fun raw_map_insert s (new_op2cs, new_c2ops) (op2c_map, c2op_map) = let 232 fun install_ty (r as {Name,Thy}) = 233 Term.prim_mk_const r 234 handle HOL_ERR _ => 235 raise OVERLOAD_ERR ("No such constant: "^Thy^"$"^Name) 236 val withtypes = map install_ty new_op2cs 237 238 val new_c2op_map = let 239 val withtypes = map install_ty new_c2ops 240 in 241 List.foldl (fn (t,acc) => PrintMap.insert(acc, ([],t), (t,s,tstamp()))) 242 c2op_map 243 withtypes 244 end 245in 246 case withtypes of 247 [] => (op2c_map, new_c2op_map) 248 | (r::rs) => let 249 val au = foldl (fn (r1, t) => anti_unify (type_of r1) t) (type_of r) rs 250 in 251 (Binarymap.insert 252 (op2c_map, s, 253 {base_type = au, actual_ops = withtypes, 254 tyavoids = tmlist_tyvs (HOLset.listItems 255 (FVL withtypes empty_tmset))}), 256 new_c2op_map) 257 end 258end 259 260(* a predicate on pairs of operations and types that returns true if 261 they're equal, given that two types are equal if they can match 262 each other *) 263fun ntys_equal {Ty = ty1,Name = n1, Thy = thy1} 264 {Ty = ty2, Name = n2, Thy = thy2} = 265 type_compare (ty1, ty2) = SOME EQUAL andalso n1 = n2 andalso thy1 = thy2 266 267 268(* put a new overloading resolution into the database. If it's already 269 there for a given operator, don't mind. In either case, make sure that 270 it's at the head of the list, meaning that it will be the first choice 271 in ambigous resolutions. 272 update: abstracted the inserter to allow adding at the 273 end of the list for inferior resolutions. *) 274fun add_overloading_with_inserter inserter tstamp (opname, term) oinfo = let 275 val _ = Theory.uptodate_term term orelse 276 raise OVERLOAD_ERR ("Term is out-of-date; opname = "^opname) 277 val (opc0, cop0) = oinfo 278 val opc = 279 case info_for_name oinfo opname of 280 SOME {base_type, actual_ops = a0, tyavoids} => let 281 (* this name is already overloaded *) 282 val actual_ops = List.filter Theory.uptodate_term a0 283 val changed = length actual_ops <> length a0 284 in 285 case Lib.total (Lib.pluck (aconv term)) actual_ops of 286 SOME (_, rest) => let 287 (* this term was already in the map *) 288 (* must replace it *) 289 val (avoids, base_type) = 290 if changed then 291 (tmlist_tyvs (free_varsl actual_ops), au_tml actual_ops) 292 else (tyavoids, base_type) 293 in 294 Binarymap.insert(opc0, opname, 295 {actual_ops = inserter(term,rest), 296 base_type = base_type, 297 tyavoids = avoids}) 298 end 299 | NONE => let 300 (* Wasn't in the map, so can just cons its record in *) 301 val (newbase, new_avoids) = 302 if changed then 303 (au_tml (term::actual_ops), 304 tmlist_tyvs (free_varsl (term::actual_ops))) 305 else 306 (anti_unify base_type (type_of term), 307 Lib.union (tmlist_tyvs (free_vars term)) tyavoids) 308 in 309 Binarymap.insert(opc0, opname, 310 {actual_ops = inserter(term,actual_ops), 311 base_type = newbase, 312 tyavoids = new_avoids}) 313 end 314 end 315 | NONE => 316 (* this name not overloaded at all *) 317 Binarymap.insert(opc0, opname, 318 {actual_ops = [term], base_type = type_of term, 319 tyavoids = tmlist_tyvs (free_vars term)}) 320 val cop = let 321 val fvs = free_vars term 322 val (_, pat) = strip_abs term 323 in 324 PrintMap.insert(cop0,(fvs,pat),(term,opname,tstamp())) 325 end 326in 327 (opc, cop) 328end 329 330val add_overloading = add_overloading_with_inserter (op ::) (fn () => pos_tstamp true) 331val add_inferior_overloading = add_overloading_with_inserter (fn (a,l) => l @ [a]) (fn() => pos_tstamp false) 332 333local 334 fun foverloading f {opname, realname, realthy} oinfo = let 335 val nthy_rec = {Name = realname, Thy = realthy} 336 val cnst = prim_mk_const nthy_rec 337 handle HOL_ERR _ => 338 raise OVERLOAD_ERR ("No such constant: "^realthy^"$"^realname) 339 val (opc0, cop0) = oinfo 340 val opc = 341 case info_for_name oinfo opname of 342 SOME {base_type, actual_ops, tyavoids} => let 343 (* this name is overloaded *) 344 in 345 case List.find (aconv cnst) actual_ops of 346 SOME x => (* the constant is in the map *) 347 Binarymap.insert(opc0, opname, 348 {actual_ops = f (aconv cnst) actual_ops, 349 base_type = base_type, 350 tyavoids = tyavoids}) 351 | NONE => raise OVERLOAD_ERR 352 ("Constant not overloaded: "^realthy^"$"^realname) 353 end 354 | NONE => raise OVERLOAD_ERR 355 ("No overloading for Operator: "^opname) 356 in 357 (opc, cop0) 358 end 359 360 fun send_to_back P l = let val (m,r) = Lib.pluck P l in r @ [m] end 361 fun bring_to_front P l = let val (m,r) = Lib.pluck P l in m::r end 362in 363 fun send_to_back_overloading x oinfo = foverloading send_to_back x oinfo 364 fun bring_to_front_overloading x oinfo = foverloading bring_to_front x oinfo 365end; 366 367 368fun myfind f [] = NONE 369 | myfind f (x::xs) = case f x of (v as SOME _) => v | NONE => myfind f xs 370 371fun isize0 acc f [] = acc 372 | isize0 acc f ({redex,residue} :: rest) = isize0 (acc + f residue + 1) f rest 373fun isize f x = isize0 0 f x 374 375fun strip_comb ((_, prmap): overload_info) namePred t = let 376 val matches = PrintMap.match(prmap, t) 377 val cmp0 = pair_compare (measure_cmp (isize term_size), 378 pair_compare (measure_cmp (isize type_size), 379 flip_order o Int.compare)) 380 val cmp = inv_img_cmp (fn (a,b,c,d) => (a,(b,c))) cmp0 381 382 fun test ((fvs, pat), (orig, nm, tstamp)) = let 383 val _ = assert namePred nm 384 val tyvs = tmlist_tyvs fvs 385 val tmset = HOLset.addList(empty_tmset, fvs) 386 val ((tmi0,tmeq),(tyi0,tyeq)) = raw_match tyvs tmset pat t ([],[]) 387 val tmi = HOLset.foldl (fn (t,acc) => if HOLset.member(tmset,t) then acc 388 else (t |-> t) :: acc) 389 tmi0 390 tmeq 391 val tyi = List.foldl (fn (ty,acc) => if mem ty tyvs then acc 392 else (ty |-> ty) :: acc) 393 tyi0 394 tyeq 395 in 396 SOME (tmi, tyi, tstamp, (orig, nm)) 397 end handle HOL_ERR _ => NONE 398 399 val inst_data = List.mapPartial test matches 400 val sorted = Listsort.sort cmp inst_data 401 fun rearrange (tmi, _, _, (orig, nm)) = let 402 val (bvs,basepat) = strip_abs orig 403 fun findarg v = 404 case List.find (fn {redex,residue} => aconv redex v) tmi of 405 NONE => mk_const("ARB", type_of v) 406 | SOME i => #residue i 407 val args = map findarg bvs 408 val fconst_ty = List.foldr (fn (arg,acc) => type_of arg --> acc) 409 (type_of t) 410 args 411 val origopt = let 412 val (hd, args) = HolKernel.strip_comb basepat 413 in 414 if ListPair.all (uncurry aconv) (bvs, args) then 415 let 416 val {Name,Thy,...} = dest_thy_const hd 417 in 418 SOME {Thy=Thy,Name=Name} 419 end handle HOL_ERR _ => NONE 420 else NONE 421 end 422 in 423 (mk_var(GrammarSpecials.mk_fakeconst_name {fake = nm, original = origopt}, 424 fconst_ty), 425 args) 426 end 427in 428 case sorted of 429 [] => NONE 430 | (m as (_, _, _, (_, nm))) :: _ => if nm = "" then NONE 431 else SOME (rearrange m) 432end 433fun oi_strip_combP oinfo P t = let 434 fun recurse acc t = 435 case strip_comb oinfo P t of 436 NONE => let 437 in 438 case Lib.total dest_comb t of 439 NONE => NONE 440 | SOME (f,x) => recurse (x::acc) f 441 end 442 | SOME (f,args) => SOME(f, args @ acc) 443 val (realf, args) = HolKernel.strip_comb t 444in 445 if is_var realf andalso 446 String.isPrefix GrammarSpecials.fakeconst_special (#1 (dest_var realf)) 447 then 448 SOME(realf, args) 449 else recurse [] t 450end 451 452fun oi_strip_comb oinfo t = oi_strip_combP oinfo (fn _ => true) t 453 454 455fun overloading_of_termP (oinfo as (_, prmap) : overload_info) P t = 456 case strip_comb oinfo P t of 457 SOME (f, []) => f |> dest_var |> #1 |> GrammarSpecials.dest_fakeconst_name 458 |> Option.map #fake 459 | _ => NONE 460 461fun overloading_of_term oinfo t = overloading_of_termP oinfo (fn _ => true) t 462 463fun overloading_of_nametype (oinfo:overload_info) r = 464 case Lib.total prim_mk_const r of 465 NONE => NONE 466 | SOME c => overloading_of_term oinfo c 467 468fun rev_append [] rest = rest 469 | rev_append (x::xs) rest = rev_append xs (x::rest) 470 471val show_alias_resolution = ref true 472val _ = Feedback.register_btrace ("show_alias_printing_choices", 473 show_alias_resolution) 474 475fun merge_oinfos (O1:overload_info) (O2:overload_info) : overload_info = let 476 val O1ops_sorted = Binarymap.listItems (#1 O1) 477 val O2ops_sorted = Binarymap.listItems (#1 O2) 478 fun merge acc op1s op2s = 479 case (op1s, op2s) of 480 ([], x) => rev_append acc x 481 | (x, []) => rev_append acc x 482 | ((k1,op1)::op1s', (k2,op2)::op2s') => let 483 in 484 case String.compare (k1, k2) of 485 LESS => merge ((k1,op1)::acc) op1s' op2s 486 | EQUAL => let 487 val name = k1 488 val ty1 = #base_type op1 489 val ty2 = #base_type op2 490 val newty = anti_unify ty1 ty2 491 val newopinfo = 492 (name, 493 {base_type = newty, 494 actual_ops = 495 Lib.op_union aconv (#actual_ops op1) (#actual_ops op2), 496 tyavoids = Lib.union (#tyavoids op1) (#tyavoids op2)}) 497 in 498 merge (newopinfo::acc) op1s' op2s' 499 end 500 | GREATER => merge ((k2, op2)::acc) op1s op2s' 501 end 502 infix ## 503 fun foldthis (k,v as (t,_,_),acc) = 504 if Theory.uptodate_term t then PrintMap.insert(acc,k,v) 505 else acc 506 val new_prmap = PrintMap.fold foldthis (#2 O2) (#2 O1) 507in 508 (List.foldr (fn ((k,v),dict) => Binarymap.insert(dict,k,v)) 509 (Binarymap.mkDict String.compare) 510 (merge [] O1ops_sorted O2ops_sorted), 511 new_prmap) 512end 513 514fun keys dict = Binarymap.foldr (fn (k,v,l) => k::l) [] dict 515 516fun known_constants (oi:overload_info) = keys (#1 oi) 517 518fun remove_omapping t str opdict = let 519 val (dictlessk, kitem) = Binarymap.remove(opdict, str) 520 fun ok_actual t' = not (aconv t' t) 521 val new_rec = fupd_actual_ops (List.filter ok_actual) kitem 522in 523 if (null (#actual_ops new_rec)) then dictlessk 524 else Binarymap.insert(dictlessk, str, new_rec) 525end handle Binarymap.NotFound => opdict 526 527fun gen_remove_mapping str t ((opc, cop) : overload_info) = let 528 val cop' = let 529 val ds = PrintMap.peek (cop, ([], t)) 530 val ds' = PMDataSet.filter (fn (_, s, _) => s <> str) ds 531 in 532 if PMDataSet.numItems ds' = PMDataSet.numItems ds then cop 533 else let 534 val (pm',_) = PrintMap.delete(cop, ([], t)) 535 in 536 PMDataSet.fold (fn (d,acc) => PrintMap.insert(acc,([],t),d)) 537 pm' 538 ds' 539 end 540 end 541in 542 (remove_omapping t str opc, cop') 543end 544fun remove_mapping str crec = gen_remove_mapping str (prim_mk_const crec) 545 546end (* Overload *) 547