1(* ========================================================================= *) 2(* ML UTILITY FUNCTIONS *) 3(* Copyright (c) 2001-2004 Joe Hurd. *) 4(* ========================================================================= *) 5 6structure mlibUseful :> mlibUseful = 7struct 8 9(* ------------------------------------------------------------------------- *) 10(* Exceptions, profiling and tracing. *) 11(* ------------------------------------------------------------------------- *) 12 13exception Error of string; 14exception Bug of string; 15 16fun Error_to_string (Error message) = 17 "\nError: " ^ message ^ "\n" 18 | Error_to_string _ = raise Bug "Error_to_string: not an Error exception"; 19 20fun Bug_to_string (Bug message) = 21 "\nBug: " ^ message ^ "\n" 22 | Bug_to_string _ = raise Bug "Bug_to_string: not a Bug exception"; 23 24fun report (e as Error _) = Error_to_string e 25 | report (e as Bug _) = Bug_to_string e 26 | report _ = raise Bug "report: not an Error or Bug exception"; 27 28fun assert b e = if b then () else raise e; 29 30fun try f a = f a 31 handle h as Error _ => (print (Error_to_string h); raise h) 32 | b as Bug _ => (print (Bug_to_string b); raise b) 33 | e => (print "\ntry: strange exception raised\n"; raise e); 34 35fun total f x = SOME (f x) handle Error _ => NONE; 36 37fun can f = Option.isSome o total f; 38 39fun partial (e as Error _) f x = (case f x of SOME y => y | NONE => raise e) 40 | partial _ _ _ = raise Bug "partial: must take an Error exception"; 41 42fun timed f a = 43 let 44 val tmr = Timer.startCPUTimer () 45 val res = f a 46 val {usr,sys,...} = Timer.checkCPUTimer tmr 47 in 48 (Time.toReal usr + Time.toReal sys, res) 49 end; 50 51local 52 val MIN = 1.0; 53 54 fun several n t f a = 55 let 56 val (t',res) = timed f a 57 val t = t + t' 58 val n = n + 1 59 in 60 if t > MIN then (t / Real.fromInt n, res) else several n t f a 61 end; 62in 63 fun timed_many f a = several 0 0.0 f a 64end; 65 66val trace_level = ref 1; 67 68val traces : {module : string, alignment : int -> int} list ref = ref []; 69 70fun add_trace t = traces := t :: !traces 71fun set_traces ts = traces := ts 72 73local 74 val MAX = 10; 75 fun query m l = 76 let val t = List.find (fn {module, ...} => module = m) (!traces) 77 in case t of NONE => MAX | SOME {alignment, ...} => alignment l 78 end; 79in 80 fun tracing {module = m, level = l} = 81 let val t = !trace_level 82 in 0 < t andalso (MAX <= t orelse MAX <= l orelse query m l <= t) 83 end; 84end; 85 86val trace = Lib.say; 87 88(* ------------------------------------------------------------------------- *) 89(* Combinators *) 90(* ------------------------------------------------------------------------- *) 91 92fun C f x y = f y x; 93 94fun I x = x; 95 96fun K x y = x; 97 98fun S f g x = f x (g x); 99 100fun W f x = f x x; 101 102fun funpow 0 _ x = x | funpow n f x = funpow (n - 1) f (f x); 103 104(* ------------------------------------------------------------------------- *) 105(* Booleans *) 106(* ------------------------------------------------------------------------- *) 107 108fun bool_to_string true = "true" 109 | bool_to_string false = "false"; 110 111fun non f = not o f; 112 113fun bool_compare (true,false) = LESS 114 | bool_compare (false,true) = GREATER 115 | bool_compare _ = EQUAL; 116 117(* ------------------------------------------------------------------------- *) 118(* Pairs *) 119(* ------------------------------------------------------------------------- *) 120 121fun op## (f,g) (x,y) = (f x, g y); 122 123fun D x = (x,x); 124 125fun Df f = f ## f; 126 127fun fst (x,_) = x; 128 129fun snd (_,y) = y; 130 131fun pair x y = (x,y); 132 133fun swap (x,y) = (y,x); 134 135fun curry f x y = f (x,y); 136 137fun uncurry f (x,y) = f x y; 138 139fun equal x y = (x = y); 140 141(* ------------------------------------------------------------------------- *) 142(* State transformers *) 143(* ------------------------------------------------------------------------- *) 144 145val unit : 'a -> 's -> 'a * 's = pair; 146 147fun bind f (g : 'a -> 's -> 'b * 's) = uncurry g o f; 148 149fun mmap f (m : 's -> 'a * 's) = bind m (unit o f); 150 151fun mjoin (f : 's -> ('s -> 'a * 's) * 's) = bind f I; 152 153fun mwhile c b = let fun f a = if c a then bind (b a) f else unit a in f end; 154 155(* ------------------------------------------------------------------------- *) 156(* Lists *) 157(* ------------------------------------------------------------------------- *) 158 159fun cons x y = x :: y; 160 161fun hd_tl l = (hd l, tl l); 162 163fun append xs ys = xs @ ys; 164 165fun sing a = [a]; 166 167fun first f [] = NONE 168 | first f (x :: xs) = (case f x of NONE => first f xs | s => s); 169 170fun index p = 171 let 172 fun idx _ [] = NONE 173 | idx n (x :: xs) = if p x then SOME n else idx (n + 1) xs 174 in 175 idx 0 176 end; 177 178fun maps (_ : 'a -> 's -> 'b * 's) [] = unit [] 179 | maps f (x :: xs) = 180 bind (f x) (fn y => bind (maps f xs) (fn ys => unit (y :: ys))); 181 182fun partial_maps (_ : 'a -> 's -> 'b option * 's) [] = unit [] 183 | partial_maps f (x :: xs) = 184 bind (f x) 185 (fn yo => bind (partial_maps f xs) 186 (fn ys => unit (case yo of NONE => ys | SOME y => y :: ys))); 187 188fun enumerate n = fst o C (maps (fn x => fn m => ((m, x), m + 1))) n; 189 190fun zipwith f = 191 let 192 fun z l [] [] = l 193 | z l (x :: xs) (y :: ys) = z (f x y :: l) xs ys 194 | z _ _ _ = raise Error "zipwith: lists different lengths"; 195 in 196 fn xs => fn ys => rev (z [] xs ys) 197 end; 198 199fun zip xs ys = zipwith pair xs ys; 200 201fun unzip ab = 202 foldl (fn ((x, y), (xs, ys)) => (x :: xs, y :: ys)) ([], []) (rev ab); 203 204fun cartwith f = 205 let 206 fun aux _ res _ [] = res 207 | aux xs_copy res [] (y :: yt) = aux xs_copy res xs_copy yt 208 | aux xs_copy res (x :: xt) (ys as y :: _) = 209 aux xs_copy (f x y :: res) xt ys 210 in 211 fn xs => fn ys => 212 let val xs' = rev xs in aux xs' [] xs' (rev ys) end 213 end; 214 215fun cart xs ys = cartwith pair xs ys; 216 217local 218 fun aux res l 0 = (rev res, l) 219 | aux _ [] _ = raise Subscript 220 | aux res (h :: t) n = aux (h :: res) t (n - 1); 221in 222 fun divide l n = aux [] l n; 223end; 224 225fun update_nth f n l = 226 let 227 val (a, b) = divide l n 228 in 229 case b of [] => raise Subscript 230 | h :: t => a @ (f h :: t) 231 end; 232 233fun shared_map f = 234 let 235 fun map _ (a,b) [] = List.revAppend (a,b) 236 | map c (a,b) (x :: xs) = 237 let 238 val y = f x 239 val c = y :: c 240 in 241 map c (if mlibPortable.pointer_eq x y then (a,b) else (c,xs)) xs 242 end 243 in 244 fn l => map [] ([],l) l 245 end; 246 247(* ------------------------------------------------------------------------- *) 248(* Lists-as-sets *) 249(* ------------------------------------------------------------------------- *) 250 251fun mem x = List.exists (equal x); 252 253fun insert x s = if mem x s then s else x :: s; 254fun delete x s = List.filter (not o equal x) s; 255 256(* Removes duplicates *) 257fun setify s = foldl (fn (v,x) => if mem v x then x else v :: x) [] s; 258 259fun union s t = foldl (fn (v,x) => if mem v t then x else v::x) t (rev s); 260fun intersect s t = foldl (fn (v,x) => if mem v t then v::x else x) [] (rev s); 261fun subtract s t = foldl (fn (v,x) => if mem v t then x else v::x) [] (rev s); 262 263fun subset s t = List.all (fn x => mem x t) s; 264 265fun distinct [] = true 266 | distinct (x :: rest) = not (mem x rest) andalso distinct rest; 267 268(* ------------------------------------------------------------------------- *) 269(* Comparisons. *) 270(* ------------------------------------------------------------------------- *) 271 272type 'a ordering = 'a * 'a -> order; 273 274fun order_to_string LESS = "LESS" 275 | order_to_string EQUAL = "EQUAL" 276 | order_to_string GREATER = "GREATER"; 277 278fun map_order mf f (a,b) = f (mf a, mf b); 279 280fun rev_order f xy = 281 case f xy of LESS => GREATER | EQUAL => EQUAL | GREATER => LESS; 282 283fun lex_order f g ((a,c),(b,d)) = case f (a,b) of EQUAL => g (c,d) | x => x; 284 285fun lex_order2 f = lex_order f f; 286 287fun lex_order3 f = 288 map_order (fn (a,b,c) => (a,(b,c))) (lex_order f (lex_order2 f)); 289 290fun lex_seq_order f g (a,b) = lex_order f g ((a,a),(b,b)); 291 292fun lex_list_order f = 293 let 294 fun lex [] [] = EQUAL 295 | lex [] (_ :: _) = LESS 296 | lex (_ :: _) [] = GREATER 297 | lex (x :: xs) (y :: ys) = case f (x,y) of EQUAL => lex xs ys | r => r 298 in 299 uncurry lex 300 end; 301 302(* ------------------------------------------------------------------------- *) 303(* Finding the minimum and maximum element of a list, wrt some order. *) 304(* ------------------------------------------------------------------------- *) 305 306fun min cmp = 307 let 308 fun min_acc (l,m,r) _ [] = (m, List.revAppend (l,r)) 309 | min_acc (best as (_,m,_)) l (x :: r) = 310 min_acc (case cmp (x,m) of LESS => (l,x,r) | _ => best) (x :: l) r 311 in 312 fn [] => raise Error "min: empty list" 313 | h :: t => min_acc ([],h,t) [h] t 314 end; 315 316fun max cmp = min (rev_order cmp); 317 318(* ------------------------------------------------------------------------- *) 319(* Merge (for the following merge-sort, but generally useful too). *) 320(* ------------------------------------------------------------------------- *) 321 322fun merge cmp = 323 let 324 fun mrg acc [] ys = List.revAppend (acc, ys) 325 | mrg acc xs [] = List.revAppend (acc, xs) 326 | mrg acc (xs as x :: xt) (ys as y :: yt) = 327 (case cmp (x,y) of GREATER => mrg (y :: acc) xs yt 328 | _ => mrg (x :: acc) xt ys) 329 in 330 mrg [] 331 end; 332 333(* ------------------------------------------------------------------------- *) 334(* Merge sort.(stable) *) 335(* ------------------------------------------------------------------------- *) 336 337fun sort cmp = 338 let 339 val m = merge cmp 340 fun f [] = [] 341 | f (xs as [_]) = xs 342 | f xs = let val (l,r) = divide xs (length xs div 2) in m (f l) (f r) end 343 in 344 f 345 end; 346 347fun sort_map _ _ [] = [] 348 | sort_map _ _ (xs as [_]) = xs 349 | sort_map f cmp xs = 350 let 351 fun ncmp ((m,_),(n,_)) = cmp (m,n) 352 val nxs = map (fn x => (f x, x)) xs 353 val nys = sort ncmp nxs 354 in 355 map snd nys 356 end; 357 358(* ------------------------------------------------------------------------- *) 359(* Topological sort *) 360(* ------------------------------------------------------------------------- *) 361 362fun top_sort cmp parents = 363 let 364 fun f stack (x,(acc,seen)) = 365 if Binaryset.member (stack,x) then raise Error "top_sort: cycle" 366 else if Binaryset.member (seen,x) then (acc,seen) 367 else 368 let 369 val stack = Binaryset.add (stack,x) 370 val (acc,seen) = foldl (f stack) (acc,seen) (parents x) 371 val acc = x :: acc 372 val seen = Binaryset.add (seen,x) 373 in 374 (acc,seen) 375 end 376 in 377 rev o fst o foldl (f (Binaryset.empty cmp)) ([], Binaryset.empty cmp) 378 end 379 380(* ------------------------------------------------------------------------- *) 381(* Integers. *) 382(* ------------------------------------------------------------------------- *) 383 384val int_to_string = Int.toString; 385fun string_to_int s = 386 case Int.fromString s of SOME n => n | NONE => raise Error "string_to_int"; 387 388fun int_to_bits 0 = [] 389 | int_to_bits n = (n mod 2 <> 0) :: (int_to_bits (n div 2)); 390 391fun bits_to_int [] = 0 392 | bits_to_int (h :: t) = (if h then curry op+ 1 else I) (2 * bits_to_int t); 393 394local 395 val enc = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; 396 397 val (max, rev_enc) = 398 foldl (fn (c,(i,m)) => (i + 1, Binarymap.insert (m,c,i))) 399 (0, Binarymap.mkDict Char.compare) (String.explode enc); 400in 401 fun int_to_base64 n = 402 if 0 <= n andalso n < max then String.sub (enc,n) 403 else raise Error "int_to_base64: out of range"; 404 405 fun base64_to_int c = 406 case Binarymap.peek (rev_enc, c) of SOME n => n 407 | NONE => raise Error "base64_to_int: out of range"; 408end; 409 410fun interval m 0 = [] 411 | interval m len = m :: interval (m + 1) (len - 1); 412 413fun divides a b = if a = 0 then b = 0 else b mod (Int.abs a) = 0; 414 415fun even n = divides 2 n; 416 417fun odd n = not (even n); 418 419local 420 fun both f g n = f n andalso g n; 421 fun next f = let fun nx x = if f x then x else nx (x + 1) in nx end; 422 423 fun looking res 0 _ _ = rev res 424 | looking res n f x = 425 let 426 val p = next f x 427 val res' = p :: res 428 val f' = both f (not o divides p) 429 in 430 looking res' (n - 1) f' (p + 1) 431 end 432in 433 fun primes n = looking [] n (K true) 2 434end; 435 436local 437 fun hcf 0 n = n | hcf 1 _ = 1 | hcf m n = hcf (n mod m) m; 438in 439 fun gcd m n = 440 let 441 val m = Int.abs m 442 val n = Int.abs n 443 in 444 if m < n then hcf m n else hcf n m 445 end; 446end; 447 448(* ------------------------------------------------------------------------- *) 449(* Strings *) 450(* ------------------------------------------------------------------------- *) 451 452local 453 fun len l = (length l, l) 454 val upper = len (explode "ABCDEFGHIJKLMNOPQRSTUVWXYZ"); 455 val lower = len (explode "abcdefghijklmnopqrstuvwxyz"); 456 fun rotate (n,l) c k = List.nth (l, (k+Option.valOf(index(equal c)l)) mod n); 457in 458 fun rot k c = 459 if Char.isLower c then rotate lower c k 460 else if Char.isUpper c then rotate upper c k 461 else c; 462end; 463 464fun nchars x = 465 let fun dup _ 0 l = l | dup x n l = dup x (n - 1) (x :: l) 466 in fn n => implode (dup x n []) 467 end; 468 469fun chomp s = 470 let 471 val n = size s 472 in 473 if n = 0 orelse String.sub (s, n - 1) <> #"\n" then s 474 else String.substring (s, 0, n - 1) 475 end; 476 477local 478 fun chop [] = [] 479 | chop (l as (h :: t)) = if Char.isSpace h then chop t else l; 480in 481 val unpad = implode o chop o rev o chop o rev o explode; 482end; 483 484fun join _ [] = "" | join s (h :: t) = foldl (fn (x,y) => y ^ s ^ x) h t; 485 486local 487 fun match [] l = SOME l 488 | match _ [] = NONE 489 | match (x :: xs) (y :: ys) = if x = y then match xs ys else NONE; 490 491 fun stringify acc [] = acc 492 | stringify acc (h :: t) = stringify (implode h :: acc) t; 493in 494 fun split sep = 495 let 496 val pat = String.explode sep 497 fun div1 prev recent [] = stringify [] (rev recent :: prev) 498 | div1 prev recent (l as h :: t) = 499 case match pat l of NONE => div1 prev (h :: recent) t 500 | SOME rest => div1 (rev recent :: prev) [] rest 501 in 502 fn s => div1 [] [] (explode s) 503 end; 504end; 505 506fun variant x vars = if mem x vars then variant (x ^ "'") vars else x; 507 508fun variant_num x vars = 509 let 510 fun xn n = x ^ int_to_string n 511 fun v n = let val x' = xn n in if mem x' vars then v (n + 1) else x' end 512 in 513 if mem x vars then v 1 else x 514 end; 515 516fun dest_prefix p = 517 let 518 fun check s = assert (String.isPrefix p s) (Error "dest_prefix") 519 val size_p = size p 520 in 521 fn s => (check s; String.extract (s, size_p, NONE)) 522 end; 523 524fun is_prefix p = can (dest_prefix p); 525 526fun mk_prefix p s = p ^ s; 527 528fun align_table {left,pad} = 529 let 530 fun pad_col n s = 531 let val p = nchars pad (n - size s) 532 in if left then s ^ p else p ^ s 533 end 534 fun pad_cols (l as [] :: _) = map (K "") l 535 | pad_cols l = 536 let 537 val hs = map hd l 538 val (n,_) = min (Int.compare o swap) (map size hs) 539 val last_left = left andalso length (hd l) = 1 540 val hs = if last_left then hs else map (pad_col n) hs 541 in 542 zipwith (fn x => fn y => x ^ y) hs (pad_cols (map tl l)) 543 end 544 in 545 pad_cols 546 end; 547 548(* ------------------------------------------------------------------------- *) 549(* Reals. *) 550(* ------------------------------------------------------------------------- *) 551 552val real_to_string = Real.toString; 553 554fun percent_to_string x = int_to_string (Real.round (100.0 * x)) ^ "%"; 555 556fun pos r = Real.max (r,0.0); 557 558local val ln2 = Math.ln 2.0 in fun log2 x = Math.ln x / ln2 end; 559 560(* ------------------------------------------------------------------------- *) 561(* Pretty-printing. *) 562(* ------------------------------------------------------------------------- *) 563 564(* Generic pretty-printers *) 565 566type 'a pp = 'a Parse.pprinter 567 568val LINE_LENGTH = ref 75; 569 570fun pp_map f pp_a x = pp_a (f x); 571 572fun pp_bracket l r pp_a a = 573 PP.block PP.INCONSISTENT (size l) [PP.add_string l, pp_a a, PP.add_string r] 574 575fun pp_sequence sep pp_a els = 576 let 577 fun recurse els = 578 case els of 579 [] => [] 580 | [e] => [pp_a e] 581 | e::es => [pp_a e, PP.add_string sep, PP.add_break(1,0)] @ 582 recurse es 583 in 584 PP.block PP.INCONSISTENT 0 (recurse els) 585 end 586 587fun pp_binop s pp_a pp_b (a,b) = 588 PP.block PP.INCONSISTENT 0 589 [pp_a a, PP.add_string s, PP.add_break (1,0), pp_b b] 590 591(* Pretty-printers for common types *) 592 593fun pp_string s = PP.add_string s 594 595val pp_unit = pp_map (fn () => "()") pp_string; 596 597val pp_char = pp_map str pp_string; 598 599val pp_bool = pp_map bool_to_string pp_string; 600 601val pp_int = pp_map int_to_string pp_string; 602 603val pp_real = pp_map real_to_string pp_string; 604 605val pp_order = pp_map order_to_string pp_string; 606 607val pp_porder = 608 pp_map (fn NONE => "INCOMPARABLE" | SOME x => order_to_string x) pp_string; 609 610fun pp_list pp_a = pp_bracket "[" "]" (pp_sequence "," pp_a); 611 612fun pp_pair pp_a pp_b = pp_bracket "(" ")" (pp_binop "," pp_a pp_b); 613 614fun pp_triple pp_a pp_b pp_c = 615 pp_bracket "(" ")" 616 (pp_map (fn (a, b, c) => (a, (b, c))) 617 (pp_binop "," pp_a (pp_binop "," pp_b pp_c))); 618 619fun to_string pp_a a = PP.pp_to_string (!LINE_LENGTH) pp_a a; 620 621(* ------------------------------------------------------------------------- *) 622(* Sums *) 623(* ------------------------------------------------------------------------- *) 624 625datatype ('a, 'b) sum = INL of 'a | INR of 'b 626 627fun is_inl (INL _) = true | is_inl (INR _) = false; 628 629fun is_inr (INR _) = true | is_inr (INL _) = false; 630 631fun pp_sum pp_a _ (INL a) = pp_a a 632 | pp_sum _ pp_b (INR b) = pp_b b; 633 634(* ------------------------------------------------------------------------- *) 635(* Maplets. *) 636(* ------------------------------------------------------------------------- *) 637 638datatype ('a, 'b) maplet = op|-> of 'a * 'b; 639 640fun pp_maplet pp_a pp_b = 641 pp_map (fn a |-> b => (a, b)) (pp_binop " |->" pp_a pp_b); 642 643(* ------------------------------------------------------------------------- *) 644(* Trees. *) 645(* ------------------------------------------------------------------------- *) 646 647datatype ('a, 'b) tree = BRANCH of 'a * ('a, 'b) tree list | LEAF of 'b; 648 649local 650 fun f (LEAF _) = {leaves = 1, branches = 0} 651 | f (BRANCH (_, ts)) = foldl g {leaves = 0, branches = 1} ts 652 and g (t, {leaves = l, branches = b}) = 653 let val {leaves=l', branches=b'} = f t in {leaves=l+l', branches=b+b'} end; 654in 655 fun tree_size t = f t; 656end; 657 658fun tree_foldr f_b f_l (LEAF l) = f_l l 659 | tree_foldr f_b f_l (BRANCH (p, s)) = f_b p (map (tree_foldr f_b f_l) s); 660 661fun tree_foldl f_b f_l = 662 let 663 fun fold state (LEAF l, res) = f_l l state :: res 664 | fold state (BRANCH (p, ts), res) = foldl (fold (f_b p state)) res ts 665 in 666 fn state => fn t => fold state (t, []) 667 end; 668 669fun tree_partial_foldl f_b f_l = 670 let 671 fun fold state (LEAF l, res) = 672 (case f_l l state of NONE => res | SOME x => x :: res) 673 | fold state (BRANCH (p, ts), res) = 674 (case f_b p state of NONE => res | SOME s => foldl (fold s) res ts) 675 in 676 fn state => fn t => fold state (t, []) 677 end; 678 679(* ------------------------------------------------------------------------- *) 680(* mlibUseful impure features *) 681(* ------------------------------------------------------------------------- *) 682 683fun memoize f = let val s = Susp.delay f in fn () => Susp.force s end; 684 685local 686 val generator = Portable.make_counter{inc=1,init=0} 687in 688 fun new_int () = generator() 689 690 fun new_ints 0 = [] 691 | new_ints k = generator() :: new_ints (k - 1) 692end; 693 694local 695 val gen = Random.newgenseed 1.0; 696in 697 fun uniform () = Random.random gen; 698 fun coin_flip () = Random.range (0,2) gen = 0; 699end; 700 701fun with_flag (r,update) f x = 702 let 703 val old = !r 704 val () = r := update old 705 val y = f x handle e => (r := old; raise e) 706 val () = r := old 707 in 708 y 709 end; 710 711(* ------------------------------------------------------------------------- *) 712(* Environment. *) 713(* ------------------------------------------------------------------------- *) 714 715local 716 fun err x s = TextIO.output (TextIO.stdErr, x ^ ": " ^ s ^ "\n"); 717in 718 val warn = err "WARNING"; 719 fun die s = (err "\nFATAL ERROR" s; OS.Process.exit OS.Process.failure); 720end 721 722end 723