1(* ========================================================================= *)
2(* ML UTILITY FUNCTIONS                                                      *)
3(* Copyright (c) 2001-2004 Joe Hurd.                                         *)
4(* ========================================================================= *)
5
6structure mlibUseful :> mlibUseful =
7struct
8
9(* ------------------------------------------------------------------------- *)
10(* Exceptions, profiling and tracing.                                        *)
11(* ------------------------------------------------------------------------- *)
12
13exception Error of string;
14exception Bug of string;
15
16fun Error_to_string (Error message) =
17  "\nError: " ^ message ^ "\n"
18  | Error_to_string _ = raise Bug "Error_to_string: not an Error exception";
19
20fun Bug_to_string (Bug message) =
21  "\nBug: " ^ message ^ "\n"
22  | Bug_to_string _ = raise Bug "Bug_to_string: not a Bug exception";
23
24fun report (e as Error _) = Error_to_string e
25  | report (e as Bug _) = Bug_to_string e
26  | report _ = raise Bug "report: not an Error or Bug exception";
27
28fun assert b e = if b then () else raise e;
29
30fun try f a = f a
31  handle h as Error _ => (print (Error_to_string h); raise h)
32       | b as Bug _ => (print (Bug_to_string b); raise b)
33       | e => (print "\ntry: strange exception raised\n"; raise e);
34
35fun total f x = SOME (f x) handle Error _ => NONE;
36
37fun can f = Option.isSome o total f;
38
39fun partial (e as Error _) f x = (case f x of SOME y => y | NONE => raise e)
40  | partial _ _ _ = raise Bug "partial: must take an Error exception";
41
42fun timed f a =
43  let
44    val tmr = Timer.startCPUTimer ()
45    val res = f a
46    val {usr,sys,...} = Timer.checkCPUTimer tmr
47  in
48    (Time.toReal usr + Time.toReal sys, res)
49  end;
50
51local
52  val MIN = 1.0;
53
54  fun several n t f a =
55    let
56      val (t',res) = timed f a
57      val t = t + t'
58      val n = n + 1
59    in
60      if t > MIN then (t / Real.fromInt n, res) else several n t f a
61    end;
62in
63  fun timed_many f a = several 0 0.0 f a
64end;
65
66val trace_level = ref 1;
67
68val traces : {module : string, alignment : int -> int} list ref = ref [];
69
70fun add_trace t = traces := t :: !traces
71fun set_traces ts = traces := ts
72
73local
74  val MAX = 10;
75  fun query m l =
76    let val t = List.find (fn {module, ...} => module = m) (!traces)
77    in case t of NONE => MAX | SOME {alignment, ...} => alignment l
78    end;
79in
80  fun tracing {module = m, level = l} =
81    let val t = !trace_level
82    in 0 < t andalso (MAX <= t orelse MAX <= l orelse query m l <= t)
83    end;
84end;
85
86val trace = Lib.say;
87
88(* ------------------------------------------------------------------------- *)
89(* Combinators                                                               *)
90(* ------------------------------------------------------------------------- *)
91
92fun C f x y = f y x;
93
94fun I x = x;
95
96fun K x y = x;
97
98fun S f g x = f x (g x);
99
100fun W f x = f x x;
101
102fun funpow 0 _ x = x | funpow n f x = funpow (n - 1) f (f x);
103
104(* ------------------------------------------------------------------------- *)
105(* Booleans                                                                  *)
106(* ------------------------------------------------------------------------- *)
107
108fun bool_to_string true = "true"
109  | bool_to_string false = "false";
110
111fun non f = not o f;
112
113fun bool_compare (true,false) = LESS
114  | bool_compare (false,true) = GREATER
115  | bool_compare _ = EQUAL;
116
117(* ------------------------------------------------------------------------- *)
118(* Pairs                                                                     *)
119(* ------------------------------------------------------------------------- *)
120
121fun op## (f,g) (x,y) = (f x, g y);
122
123fun D x = (x,x);
124
125fun Df f = f ## f;
126
127fun fst (x,_) = x;
128
129fun snd (_,y) = y;
130
131fun pair x y = (x,y);
132
133fun swap (x,y) = (y,x);
134
135fun curry f x y = f (x,y);
136
137fun uncurry f (x,y) = f x y;
138
139fun equal x y = (x = y);
140
141(* ------------------------------------------------------------------------- *)
142(* State transformers                                                        *)
143(* ------------------------------------------------------------------------- *)
144
145val unit : 'a -> 's -> 'a * 's = pair;
146
147fun bind f (g : 'a -> 's -> 'b * 's) = uncurry g o f;
148
149fun mmap f (m : 's -> 'a * 's) = bind m (unit o f);
150
151fun mjoin (f : 's -> ('s -> 'a * 's) * 's) = bind f I;
152
153fun mwhile c b = let fun f a = if c a then bind (b a) f else unit a in f end;
154
155(* ------------------------------------------------------------------------- *)
156(* Lists                                                                     *)
157(* ------------------------------------------------------------------------- *)
158
159fun cons x y = x :: y;
160
161fun hd_tl l = (hd l, tl l);
162
163fun append xs ys = xs @ ys;
164
165fun sing a = [a];
166
167fun first f [] = NONE
168  | first f (x :: xs) = (case f x of NONE => first f xs | s => s);
169
170fun index p =
171  let
172    fun idx _ [] = NONE
173      | idx n (x :: xs) = if p x then SOME n else idx (n + 1) xs
174  in
175    idx 0
176  end;
177
178fun maps (_ : 'a -> 's -> 'b * 's) [] = unit []
179  | maps f (x :: xs) =
180  bind (f x) (fn y => bind (maps f xs) (fn ys => unit (y :: ys)));
181
182fun partial_maps (_ : 'a -> 's -> 'b option * 's) [] = unit []
183  | partial_maps f (x :: xs) =
184  bind (f x)
185  (fn yo => bind (partial_maps f xs)
186   (fn ys => unit (case yo of NONE => ys | SOME y => y :: ys)));
187
188fun enumerate n = fst o C (maps (fn x => fn m => ((m, x), m + 1))) n;
189
190fun zipwith f =
191  let
192    fun z l [] [] = l
193      | z l (x :: xs) (y :: ys) = z (f x y :: l) xs ys
194      | z _ _ _ = raise Error "zipwith: lists different lengths";
195  in
196    fn xs => fn ys => rev (z [] xs ys)
197  end;
198
199fun zip xs ys = zipwith pair xs ys;
200
201fun unzip ab =
202  foldl (fn ((x, y), (xs, ys)) => (x :: xs, y :: ys)) ([], []) (rev ab);
203
204fun cartwith f =
205  let
206    fun aux _ res _ [] = res
207      | aux xs_copy res [] (y :: yt) = aux xs_copy res xs_copy yt
208      | aux xs_copy res (x :: xt) (ys as y :: _) =
209      aux xs_copy (f x y :: res) xt ys
210  in
211    fn xs => fn ys =>
212    let val xs' = rev xs in aux xs' [] xs' (rev ys) end
213  end;
214
215fun cart xs ys = cartwith pair xs ys;
216
217local
218  fun aux res l 0 = (rev res, l)
219    | aux _ [] _ = raise Subscript
220    | aux res (h :: t) n = aux (h :: res) t (n - 1);
221in
222  fun divide l n = aux [] l n;
223end;
224
225fun update_nth f n l =
226  let
227    val (a, b) = divide l n
228  in
229    case b of [] => raise Subscript
230    | h :: t => a @ (f h :: t)
231  end;
232
233fun shared_map f =
234    let
235      fun map _ (a,b) [] = List.revAppend (a,b)
236        | map c (a,b) (x :: xs) =
237          let
238            val y = f x
239            val c = y :: c
240          in
241            map c (if mlibPortable.pointer_eq x y then (a,b) else (c,xs)) xs
242          end
243    in
244      fn l => map [] ([],l) l
245    end;
246
247(* ------------------------------------------------------------------------- *)
248(* Lists-as-sets                                                             *)
249(* ------------------------------------------------------------------------- *)
250
251fun mem x = List.exists (equal x);
252
253fun insert x s = if mem x s then s else x :: s;
254fun delete x s = List.filter (not o equal x) s;
255
256(* Removes duplicates *)
257fun setify s = foldl (fn (v,x) => if mem v x then x else v :: x) [] s;
258
259fun union s t = foldl (fn (v,x) => if mem v t then x else v::x) t (rev s);
260fun intersect s t = foldl (fn (v,x) => if mem v t then v::x else x) [] (rev s);
261fun subtract s t = foldl (fn (v,x) => if mem v t then x else v::x) [] (rev s);
262
263fun subset s t = List.all (fn x => mem x t) s;
264
265fun distinct [] = true
266  | distinct (x :: rest) = not (mem x rest) andalso distinct rest;
267
268(* ------------------------------------------------------------------------- *)
269(* Comparisons.                                                              *)
270(* ------------------------------------------------------------------------- *)
271
272type 'a ordering = 'a * 'a -> order;
273
274fun order_to_string LESS = "LESS"
275  | order_to_string EQUAL = "EQUAL"
276  | order_to_string GREATER = "GREATER";
277
278fun map_order mf f (a,b) = f (mf a, mf b);
279
280fun rev_order f xy =
281  case f xy of LESS => GREATER | EQUAL => EQUAL | GREATER => LESS;
282
283fun lex_order f g ((a,c),(b,d)) = case f (a,b) of EQUAL => g (c,d) | x => x;
284
285fun lex_order2 f = lex_order f f;
286
287fun lex_order3 f =
288    map_order (fn (a,b,c) => (a,(b,c))) (lex_order f (lex_order2 f));
289
290fun lex_seq_order f g (a,b) = lex_order f g ((a,a),(b,b));
291
292fun lex_list_order f =
293  let
294    fun lex [] [] = EQUAL
295      | lex [] (_ :: _) = LESS
296      | lex (_ :: _) [] = GREATER
297      | lex (x :: xs) (y :: ys) = case f (x,y) of EQUAL => lex xs ys | r => r
298  in
299    uncurry lex
300  end;
301
302(* ------------------------------------------------------------------------- *)
303(* Finding the minimum and maximum element of a list, wrt some order.        *)
304(* ------------------------------------------------------------------------- *)
305
306fun min cmp =
307  let
308    fun min_acc (l,m,r) _ [] = (m, List.revAppend (l,r))
309      | min_acc (best as (_,m,_)) l (x :: r) =
310      min_acc (case cmp (x,m) of LESS => (l,x,r) | _ => best) (x :: l) r
311  in
312    fn [] => raise Error "min: empty list"
313     | h :: t => min_acc ([],h,t) [h] t
314  end;
315
316fun max cmp = min (rev_order cmp);
317
318(* ------------------------------------------------------------------------- *)
319(* Merge (for the following merge-sort, but generally useful too).           *)
320(* ------------------------------------------------------------------------- *)
321
322fun merge cmp =
323  let
324    fun mrg acc [] ys = List.revAppend (acc, ys)
325      | mrg acc xs [] = List.revAppend (acc, xs)
326      | mrg acc (xs as x :: xt) (ys as y :: yt) =
327      (case cmp (x,y) of GREATER => mrg (y :: acc) xs yt
328       | _ => mrg (x :: acc) xt ys)
329  in
330    mrg []
331  end;
332
333(* ------------------------------------------------------------------------- *)
334(* Merge sort.(stable)                                                       *)
335(* ------------------------------------------------------------------------- *)
336
337fun sort cmp =
338  let
339    val m = merge cmp
340    fun f [] = []
341      | f (xs as [_]) = xs
342      | f xs = let val (l,r) = divide xs (length xs div 2) in m (f l) (f r) end
343  in
344    f
345  end;
346
347fun sort_map _ _ [] = []
348  | sort_map _ _ (xs as [_]) = xs
349  | sort_map f cmp xs =
350  let
351    fun ncmp ((m,_),(n,_)) = cmp (m,n)
352    val nxs = map (fn x => (f x, x)) xs
353    val nys = sort ncmp nxs
354  in
355    map snd nys
356  end;
357
358(* ------------------------------------------------------------------------- *)
359(* Topological sort                                                          *)
360(* ------------------------------------------------------------------------- *)
361
362fun top_sort cmp parents =
363    let
364      fun f stack (x,(acc,seen)) =
365          if Binaryset.member (stack,x) then raise Error "top_sort: cycle"
366          else if Binaryset.member (seen,x) then (acc,seen)
367          else
368            let
369              val stack = Binaryset.add (stack,x)
370              val (acc,seen) = foldl (f stack) (acc,seen) (parents x)
371              val acc = x :: acc
372              val seen = Binaryset.add (seen,x)
373            in
374              (acc,seen)
375            end
376    in
377      rev o fst o foldl (f (Binaryset.empty cmp)) ([], Binaryset.empty cmp)
378    end
379
380(* ------------------------------------------------------------------------- *)
381(* Integers.                                                                 *)
382(* ------------------------------------------------------------------------- *)
383
384val int_to_string = Int.toString;
385fun string_to_int s =
386  case Int.fromString s of SOME n => n | NONE => raise Error "string_to_int";
387
388fun int_to_bits 0 = []
389  | int_to_bits n = (n mod 2 <> 0) :: (int_to_bits (n div 2));
390
391fun bits_to_int [] = 0
392  | bits_to_int (h :: t) = (if h then curry op+ 1 else I) (2 * bits_to_int t);
393
394local
395  val enc = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
396
397  val (max, rev_enc) =
398    foldl (fn (c,(i,m)) => (i + 1, Binarymap.insert (m,c,i)))
399    (0, Binarymap.mkDict Char.compare) (String.explode enc);
400in
401  fun int_to_base64 n =
402    if 0 <= n andalso n < max then String.sub (enc,n)
403    else raise Error "int_to_base64: out of range";
404
405  fun base64_to_int c =
406    case Binarymap.peek (rev_enc, c) of SOME n => n
407    | NONE => raise Error "base64_to_int: out of range";
408end;
409
410fun interval m 0 = []
411  | interval m len = m :: interval (m + 1) (len - 1);
412
413fun divides a b = if a = 0 then b = 0 else b mod (Int.abs a) = 0;
414
415fun even n = divides 2 n;
416
417fun odd n = not (even n);
418
419local
420  fun both f g n = f n andalso g n;
421  fun next f = let fun nx x = if f x then x else nx (x + 1) in nx end;
422
423  fun looking res 0 _ _ = rev res
424    | looking res n f x =
425    let
426      val p = next f x
427      val res' = p :: res
428      val f' = both f (not o divides p)
429    in
430      looking res' (n - 1) f' (p + 1)
431    end
432in
433  fun primes n = looking [] n (K true) 2
434end;
435
436local
437  fun hcf 0 n = n | hcf 1 _ = 1 | hcf m n = hcf (n mod m) m;
438in
439  fun gcd m n =
440    let
441      val m = Int.abs m
442      val n = Int.abs n
443    in
444      if m < n then hcf m n else hcf n m
445    end;
446end;
447
448(* ------------------------------------------------------------------------- *)
449(* Strings                                                                   *)
450(* ------------------------------------------------------------------------- *)
451
452local
453  fun len l = (length l, l)
454  val upper = len (explode "ABCDEFGHIJKLMNOPQRSTUVWXYZ");
455  val lower = len (explode "abcdefghijklmnopqrstuvwxyz");
456  fun rotate (n,l) c k = List.nth (l, (k+Option.valOf(index(equal c)l)) mod n);
457in
458  fun rot k c =
459    if Char.isLower c then rotate lower c k
460    else if Char.isUpper c then rotate upper c k
461    else c;
462end;
463
464fun nchars x =
465  let fun dup _ 0 l = l | dup x n l = dup x (n - 1) (x :: l)
466  in fn n => implode (dup x n [])
467  end;
468
469fun chomp s =
470    let
471      val n = size s
472    in
473      if n = 0 orelse String.sub (s, n - 1) <> #"\n" then s
474      else String.substring (s, 0, n - 1)
475    end;
476
477local
478  fun chop [] = []
479    | chop (l as (h :: t)) = if Char.isSpace h then chop t else l;
480in
481  val unpad = implode o chop o rev o chop o rev o explode;
482end;
483
484fun join _ [] = "" | join s (h :: t) = foldl (fn (x,y) => y ^ s ^ x) h t;
485
486local
487  fun match [] l = SOME l
488    | match _ [] = NONE
489    | match (x :: xs) (y :: ys) = if x = y then match xs ys else NONE;
490
491  fun stringify acc [] = acc
492    | stringify acc (h :: t) = stringify (implode h :: acc) t;
493in
494  fun split sep =
495    let
496      val pat = String.explode sep
497      fun div1 prev recent [] = stringify [] (rev recent :: prev)
498        | div1 prev recent (l as h :: t) =
499        case match pat l of NONE => div1 prev (h :: recent) t
500        | SOME rest => div1 (rev recent :: prev) [] rest
501    in
502      fn s => div1 [] [] (explode s)
503    end;
504end;
505
506fun variant x vars = if mem x vars then variant (x ^ "'") vars else x;
507
508fun variant_num x vars =
509  let
510    fun xn n = x ^ int_to_string n
511    fun v n = let val x' = xn n in if mem x' vars then v (n + 1) else x' end
512  in
513    if mem x vars then v 1 else x
514  end;
515
516fun dest_prefix p =
517  let
518    fun check s = assert (String.isPrefix p s) (Error "dest_prefix")
519    val size_p = size p
520  in
521    fn s => (check s; String.extract (s, size_p, NONE))
522  end;
523
524fun is_prefix p = can (dest_prefix p);
525
526fun mk_prefix p s = p ^ s;
527
528fun align_table {left,pad} =
529  let
530    fun pad_col n s =
531      let val p = nchars pad (n - size s)
532      in if left then s ^ p else p ^ s
533      end
534    fun pad_cols (l as [] :: _) = map (K "") l
535      | pad_cols l =
536      let
537        val hs = map hd l
538        val (n,_) = min (Int.compare o swap) (map size hs)
539        val last_left = left andalso length (hd l) = 1
540        val hs = if last_left then hs else map (pad_col n) hs
541      in
542        zipwith (fn x => fn y => x ^ y) hs (pad_cols (map tl l))
543      end
544  in
545    pad_cols
546  end;
547
548(* ------------------------------------------------------------------------- *)
549(* Reals.                                                                    *)
550(* ------------------------------------------------------------------------- *)
551
552val real_to_string = Real.toString;
553
554fun percent_to_string x = int_to_string (Real.round (100.0 * x)) ^ "%";
555
556fun pos r = Real.max (r,0.0);
557
558local val ln2 = Math.ln 2.0 in fun log2 x = Math.ln x / ln2 end;
559
560(* ------------------------------------------------------------------------- *)
561(* Pretty-printing.                                                          *)
562(* ------------------------------------------------------------------------- *)
563
564(* Generic pretty-printers *)
565
566type 'a pp = 'a Parse.pprinter
567
568val LINE_LENGTH = ref 75;
569
570fun pp_map f pp_a x = pp_a (f x);
571
572fun pp_bracket l r pp_a a =
573  PP.block PP.INCONSISTENT (size l) [PP.add_string l, pp_a a, PP.add_string r]
574
575fun pp_sequence sep pp_a els =
576  let
577    fun recurse els =
578      case els of
579          [] => []
580        | [e] => [pp_a e]
581        | e::es => [pp_a e, PP.add_string sep, PP.add_break(1,0)] @
582                   recurse es
583  in
584    PP.block PP.INCONSISTENT 0 (recurse els)
585  end
586
587fun pp_binop s pp_a pp_b (a,b) =
588  PP.block PP.INCONSISTENT 0
589           [pp_a a, PP.add_string s, PP.add_break (1,0), pp_b b]
590
591(* Pretty-printers for common types *)
592
593fun pp_string s = PP.add_string s
594
595val pp_unit = pp_map (fn () => "()") pp_string;
596
597val pp_char = pp_map str pp_string;
598
599val pp_bool = pp_map bool_to_string pp_string;
600
601val pp_int = pp_map int_to_string pp_string;
602
603val pp_real = pp_map real_to_string pp_string;
604
605val pp_order = pp_map order_to_string pp_string;
606
607val pp_porder =
608  pp_map (fn NONE => "INCOMPARABLE" | SOME x => order_to_string x) pp_string;
609
610fun pp_list pp_a = pp_bracket "[" "]" (pp_sequence "," pp_a);
611
612fun pp_pair pp_a pp_b = pp_bracket "(" ")" (pp_binop "," pp_a pp_b);
613
614fun pp_triple pp_a pp_b pp_c =
615  pp_bracket "(" ")"
616  (pp_map (fn (a, b, c) => (a, (b, c)))
617   (pp_binop "," pp_a (pp_binop "," pp_b pp_c)));
618
619fun to_string pp_a a = PP.pp_to_string (!LINE_LENGTH) pp_a a;
620
621(* ------------------------------------------------------------------------- *)
622(* Sums                                                                      *)
623(* ------------------------------------------------------------------------- *)
624
625datatype ('a, 'b) sum = INL of 'a | INR of 'b
626
627fun is_inl (INL _) = true | is_inl (INR _) = false;
628
629fun is_inr (INR _) = true | is_inr (INL _) = false;
630
631fun pp_sum pp_a _ (INL a) = pp_a a
632  | pp_sum _ pp_b (INR b) = pp_b b;
633
634(* ------------------------------------------------------------------------- *)
635(* Maplets.                                                                  *)
636(* ------------------------------------------------------------------------- *)
637
638datatype ('a, 'b) maplet = op|-> of 'a * 'b;
639
640fun pp_maplet pp_a pp_b =
641  pp_map (fn a |-> b => (a, b)) (pp_binop " |->" pp_a pp_b);
642
643(* ------------------------------------------------------------------------- *)
644(* Trees.                                                                    *)
645(* ------------------------------------------------------------------------- *)
646
647datatype ('a, 'b) tree = BRANCH of 'a * ('a, 'b) tree list | LEAF of 'b;
648
649local
650  fun f (LEAF _) = {leaves = 1, branches = 0}
651    | f (BRANCH (_, ts)) = foldl g {leaves = 0, branches = 1} ts
652  and g (t, {leaves = l, branches = b}) =
653    let val {leaves=l', branches=b'} = f t in {leaves=l+l', branches=b+b'} end;
654in
655  fun tree_size t = f t;
656end;
657
658fun tree_foldr f_b f_l (LEAF l) = f_l l
659  | tree_foldr f_b f_l (BRANCH (p, s)) = f_b p (map (tree_foldr f_b f_l) s);
660
661fun tree_foldl f_b f_l =
662  let
663    fun fold state (LEAF l, res) = f_l l state :: res
664      | fold state (BRANCH (p, ts), res) = foldl (fold (f_b p state)) res ts
665  in
666    fn state => fn t => fold state (t, [])
667  end;
668
669fun tree_partial_foldl f_b f_l =
670  let
671    fun fold state (LEAF l, res) =
672      (case f_l l state of NONE => res | SOME x => x :: res)
673      | fold state (BRANCH (p, ts), res) =
674      (case f_b p state of NONE => res | SOME s => foldl (fold s) res ts)
675  in
676    fn state => fn t => fold state (t, [])
677  end;
678
679(* ------------------------------------------------------------------------- *)
680(* mlibUseful impure features                                                    *)
681(* ------------------------------------------------------------------------- *)
682
683fun memoize f = let val s = Susp.delay f in fn () => Susp.force s end;
684
685local
686  val generator = Portable.make_counter{inc=1,init=0}
687in
688  fun new_int () = generator()
689
690  fun new_ints 0 = []
691    | new_ints k = generator() :: new_ints (k - 1)
692end;
693
694local
695  val gen = Random.newgenseed 1.0;
696in
697  fun uniform () = Random.random gen;
698  fun coin_flip () = Random.range (0,2) gen = 0;
699end;
700
701fun with_flag (r,update) f x =
702  let
703    val old = !r
704    val () = r := update old
705    val y = f x handle e => (r := old; raise e)
706    val () = r := old
707  in
708    y
709  end;
710
711(* ------------------------------------------------------------------------- *)
712(* Environment.                                                              *)
713(* ------------------------------------------------------------------------- *)
714
715local
716  fun err x s = TextIO.output (TextIO.stdErr, x ^ ": " ^ s ^ "\n");
717in
718  val warn = err "WARNING";
719  fun die s = (err "\nFATAL ERROR" s; OS.Process.exit OS.Process.failure);
720end
721
722end
723