1(* ========================================================================= *)
2(* FINITE MAPS IMPLEMENTED WITH RANDOMLY BALANCED TREES                      *)
3(* Copyright (c) 2004-2006 Joe Hurd, distributed under the GNU GPL version 2 *)
4(* ========================================================================= *)
5
6structure Map :> Map =
7struct
8
9(* ------------------------------------------------------------------------- *)
10(* Helper functions.                                                         *)
11(* ------------------------------------------------------------------------- *)
12
13exception Bug = mlibUseful.Bug;
14
15exception Error = mlibUseful.Error;
16
17val pointerEqual = Portable.pointer_eq;
18
19val K = mlibUseful.K;
20
21val snd = mlibUseful.snd;
22
23(* ------------------------------------------------------------------------- *)
24(* Random search trees.                                                      *)
25(* ------------------------------------------------------------------------- *)
26
27type ('a,'b,'c) binaryNode =
28     {size : int,
29      priority : real,
30      left : 'c,
31      key : 'a,
32      value : 'b,
33      right : 'c};
34
35datatype ('a,'b) tree = E | T of ('a, 'b, ('a,'b) tree) binaryNode;
36
37type ('a,'b) node = ('a, 'b, ('a,'b) tree) binaryNode;
38
39datatype ('a,'b) map = Map of ('a * 'a -> order) * ('a,'b) tree;
40
41(* ------------------------------------------------------------------------- *)
42(* Random priorities.                                                        *)
43(* ------------------------------------------------------------------------- *)
44
45local
46  val randomPriority =
47      let
48        val gen = Random.newgenseed 2.0
49      in
50        fn () => Random.random gen
51      end;
52
53  val priorityOrder = Real.compare;
54in
55  fun treeSingleton (key,value) =
56      T {size = 1, priority = randomPriority (),
57         left = E, key = key, value = value, right = E};
58
59  fun nodePriorityOrder cmp (x1 : ('a,'b) node, x2 : ('a,'b) node) =
60      let
61        val {priority = p1, key = k1, ...} = x1
62        and {priority = p2, key = k2, ...} = x2
63      in
64        case priorityOrder (p1,p2) of
65          LESS => LESS
66        | EQUAL => cmp (k1,k2)
67        | GREATER => GREATER
68      end;
69end;
70
71(* ------------------------------------------------------------------------- *)
72(* Basic operations.                                                         *)
73(* ------------------------------------------------------------------------- *)
74
75local
76  fun checkSizes E = 0
77    | checkSizes (T {size,left,right,...}) =
78      let
79        val l = checkSizes left
80        and r = checkSizes right
81        val () = if l + 1 + r = size then () else raise Error "wrong size"
82      in
83        size
84      end
85
86  fun checkSorted _ x E = x
87    | checkSorted cmp x (T {left,key,right,...}) =
88      let
89        val x = checkSorted cmp x left
90        val () =
91            case x of
92              NONE => ()
93            | SOME k =>
94              case cmp (k,key) of
95                LESS => ()
96              | EQUAL => raise Error "duplicate keys"
97              | GREATER => raise Error "unsorted"
98      in
99        checkSorted cmp (SOME key) right
100      end;
101
102  fun checkPriorities _ E = NONE
103    | checkPriorities cmp (T (x as {left,right,...})) =
104      let
105        val () =
106            case checkPriorities cmp left of
107              NONE => ()
108            | SOME l =>
109              case nodePriorityOrder cmp (l,x) of
110                LESS => ()
111              | EQUAL => raise Error "left child has equal key"
112              | GREATER => raise Error "left child has greater priority"
113        val () =
114            case checkPriorities cmp right of
115              NONE => ()
116            | SOME r =>
117              case nodePriorityOrder cmp (r,x) of
118                LESS => ()
119              | EQUAL => raise Error "right child has equal key"
120              | GREATER => raise Error "right child has greater priority"
121      in
122        SOME x
123      end;
124in
125  fun checkWellformed s (m as Map (cmp,tree)) =
126      (let
127         val _ = checkSizes tree
128         val _ = checkSorted cmp NONE tree
129         val _ = checkPriorities cmp tree
130         val () = print "."
131       in
132         m
133       end
134       handle Error err => raise Bug err)
135      handle Bug bug =>
136        raise Bug ("RandomMap.checkWellformed: " ^ bug ^ " (" ^ s ^ ")");
137end;
138
139fun comparison (Map (cmp,_)) = cmp;
140
141fun new cmp = Map (cmp,E);
142
143fun treeSize E = 0
144  | treeSize (T {size = s, ...}) = s;
145
146fun size (Map (_,tree)) = treeSize tree;
147
148fun mkT p l k v r =
149    T {size = treeSize l + 1 + treeSize r, priority = p,
150       left = l, key = k, value = v, right = r};
151
152fun singleton cmp key_value = Map (cmp, treeSingleton key_value);
153
154local
155  fun treePeek cmp E pkey = NONE
156    | treePeek cmp (T {left,key,value,right,...}) pkey =
157      case cmp (pkey,key) of
158        LESS => treePeek cmp left pkey
159      | EQUAL => SOME value
160      | GREATER => treePeek cmp right pkey
161in
162  fun peek (Map (cmp,tree)) key = treePeek cmp tree key;
163end;
164
165(* treeAppend assumes that every element of the first tree is less than *)
166(* every element of the second tree. *)
167
168fun treeAppend _ t1 E = t1
169  | treeAppend _ E t2 = t2
170  | treeAppend cmp (t1 as T x1) (t2 as T x2) =
171    case nodePriorityOrder cmp (x1,x2) of
172      LESS =>
173      let
174        val {priority = p2,
175             left = l2, key = k2, value = v2, right = r2, ...} = x2
176      in
177        mkT p2 (treeAppend cmp t1 l2) k2 v2 r2
178      end
179    | EQUAL => raise Bug "RandomSet.treeAppend: equal keys"
180    | GREATER =>
181      let
182        val {priority = p1,
183             left = l1, key = k1, value = v1, right = r1, ...} = x1
184      in
185        mkT p1 l1 k1 v1 (treeAppend cmp r1 t2)
186      end;
187
188(* nodePartition splits the node into three parts: the keys comparing less *)
189(* than the supplied key, an optional equal key, and the keys comparing *)
190(* greater. *)
191
192local
193  fun mkLeft [] t = t
194    | mkLeft (({priority,left,key,value,...} : ('a,'b) node) :: xs) t =
195      mkLeft xs (mkT priority left key value t);
196
197  fun mkRight [] t = t
198    | mkRight (({priority,key,value,right,...} : ('a,'b) node) :: xs) t =
199      mkRight xs (mkT priority t key value right);
200
201  fun treePart _ _ lefts rights E = (mkLeft lefts E, NONE, mkRight rights E)
202    | treePart cmp pkey lefts rights (T x) = nodePart cmp pkey lefts rights x
203  and nodePart cmp pkey lefts rights (x as {left,key,value,right,...}) =
204      case cmp (pkey,key) of
205        LESS => treePart cmp pkey lefts (x :: rights) left
206      | EQUAL => (mkLeft lefts left, SOME (key,value), mkRight rights right)
207      | GREATER => treePart cmp pkey (x :: lefts) rights right;
208in
209  fun nodePartition cmp x pkey = nodePart cmp pkey [] [] x;
210end;
211
212(* union first calls treeCombineRemove, to combine the values *)
213(* for equal keys into the first map and remove them from the second map. *)
214(* Note that the combined key is always the one from the second map. *)
215
216local
217  fun treeCombineRemove _ _ t1 E = (t1,E)
218    | treeCombineRemove _ _ E t2 = (E,t2)
219    | treeCombineRemove cmp f (t1 as T x1) (t2 as T x2) =
220      let
221        val {priority = p1,
222             left = l1, key = k1, value = v1, right = r1, ...} = x1
223        val (l2,k2_v2,r2) = nodePartition cmp x2 k1
224        val (l1,l2) = treeCombineRemove cmp f l1 l2
225        and (r1,r2) = treeCombineRemove cmp f r1 r2
226      in
227        case k2_v2 of
228          NONE =>
229          if treeSize l2 + treeSize r2 = #size x2 then (t1,t2)
230          else (mkT p1 l1 k1 v1 r1, treeAppend cmp l2 r2)
231        | SOME (k2,v2) =>
232          case f (v1,v2) of
233            NONE => (treeAppend cmp l1 r1, treeAppend cmp l2 r2)
234          | SOME v => (mkT p1 l1 k2 v r1, treeAppend cmp l2 r2)
235      end;
236
237  fun treeUnionDisjoint _ t1 E = t1
238    | treeUnionDisjoint _ E t2 = t2
239    | treeUnionDisjoint cmp (T x1) (T x2) =
240      case nodePriorityOrder cmp (x1,x2) of
241        LESS => nodeUnionDisjoint cmp x2 x1
242      | EQUAL => raise Bug "RandomSet.unionDisjoint: equal keys"
243      | GREATER => nodeUnionDisjoint cmp x1 x2
244  and nodeUnionDisjoint cmp x1 x2 =
245      let
246        val {priority = p1,
247             left = l1, key = k1, value = v1, right = r1, ...} = x1
248        val (l2,_,r2) = nodePartition cmp x2 k1
249        val l = treeUnionDisjoint cmp l1 l2
250        and r = treeUnionDisjoint cmp r1 r2
251      in
252        mkT p1 l k1 v1 r
253      end;
254in
255  fun union f (m1 as Map (cmp,t1)) (Map (_,t2)) =
256      if pointerEqual (t1,t2) then m1
257      else
258        let
259          val (t1,t2) = treeCombineRemove cmp f t1 t2
260        in
261          Map (cmp, treeUnionDisjoint cmp t1 t2)
262        end;
263end;
264
265(*
266val union =
267    fn f => fn t1 => fn t2 =>
268    checkWellformed
269      "after union"
270      (union f (checkWellformed "before union 1" t1)
271               (checkWellformed "before union 2" t2));
272*)
273
274(* intersect is a simple case of the union algorithm. *)
275
276local
277  fun treeIntersect _ _ _ E = E
278    | treeIntersect _ _ E _ = E
279    | treeIntersect cmp f (t1 as T x1) (t2 as T x2) =
280      let
281        val {priority = p1,
282             left = l1, key = k1, value = v1, right = r1, ...} = x1
283        val (l2,k2_v2,r2) = nodePartition cmp x2 k1
284        val l = treeIntersect cmp f l1 l2
285        and r = treeIntersect cmp f r1 r2
286      in
287        case k2_v2 of
288          NONE => treeAppend cmp l r
289        | SOME (k2,v2) =>
290          case f (v1,v2) of
291            NONE => treeAppend cmp l r
292          | SOME v => mkT p1 l k2 v r
293      end;
294in
295  fun intersect f (m1 as Map (cmp,t1)) (Map (_,t2)) =
296      if pointerEqual (t1,t2) then m1
297      else Map (cmp, treeIntersect cmp f t1 t2);
298end;
299
300(*
301val intersect =
302    fn f => fn t1 => fn t2 =>
303    checkWellformed
304      "after intersect"
305      (intersect f (checkWellformed "before intersect 1" t1)
306                   (checkWellformed "before intersect 2" t2));
307*)
308
309(* delete raises an exception if the supplied key is not found, which *)
310(* makes it simpler to maximize sharing. *)
311
312local
313  fun treeDelete _ E _ = raise Error "RandomMap.delete: element not found"
314    | treeDelete cmp (T {priority,left,key,value,right,...}) dkey =
315      case cmp (dkey,key) of
316        LESS => mkT priority (treeDelete cmp left dkey) key value right
317      | EQUAL => treeAppend cmp left right
318      | GREATER => mkT priority left key value (treeDelete cmp right dkey);
319in
320  fun delete (Map (cmp,tree)) key = Map (cmp, treeDelete cmp tree key);
321end;
322
323(*
324val delete =
325    fn t => fn x =>
326    checkWellformed
327      "after delete" (delete (checkWellformed "before delete" t) x);
328*)
329
330(* Set difference is mainly used when using maps as sets *)
331
332local
333  fun treeDifference _ t1 E = t1
334    | treeDifference _ E _ = E
335    | treeDifference cmp (t1 as T x1) (T x2) =
336      let
337        val {size = s1, priority = p1,
338             left = l1, key = k1, value = v1, right = r1} = x1
339        val (l2,k2_v2,r2) = nodePartition cmp x2 k1
340        val l = treeDifference cmp l1 l2
341        and r = treeDifference cmp r1 r2
342      in
343        if Option.isSome k2_v2 then treeAppend cmp l r
344        else if treeSize l + treeSize r + 1 = s1 then t1
345        else mkT p1 l k1 v1 r
346      end;
347in
348  fun difference (Map (cmp,tree1)) (Map (_,tree2)) =
349      if pointerEqual (tree1,tree2) then Map (cmp,E)
350      else Map (cmp, treeDifference cmp tree1 tree2);
351end;
352
353(*
354val difference =
355    fn t1 => fn t2 =>
356    checkWellformed
357      "after difference"
358      (difference (checkWellformed "before difference 1" t1)
359                  (checkWellformed "before difference 2" t2));
360*)
361
362(* subsetDomain is mainly used when using maps as sets. *)
363
364local
365  fun treeSubsetDomain _ E _ = true
366    | treeSubsetDomain _ _ E = false
367    | treeSubsetDomain cmp (t1 as T x1) (T x2) =
368      let
369        val {size = s1, left = l1, key = k1, right = r1, ...} = x1
370        and {size = s2, ...} = x2
371      in
372        s1 <= s2 andalso
373        let
374          val (l2,k2_v2,r2) = nodePartition cmp x2 k1
375        in
376          Option.isSome k2_v2 andalso
377          treeSubsetDomain cmp l1 l2 andalso
378          treeSubsetDomain cmp r1 r2
379        end
380      end;
381in
382  fun subsetDomain (Map (cmp,tree1)) (Map (_,tree2)) =
383      pointerEqual (tree1,tree2) orelse
384      treeSubsetDomain cmp tree1 tree2
385end;
386
387(* equalDomain is mainly used when using maps as sets. *)
388
389local
390  fun treeEqualDomain _ E _ = true
391    | treeEqualDomain _ _ E = false
392    | treeEqualDomain cmp (t1 as T x1) (T x2) =
393      let
394        val {size = s1, left = l1, key = k1, right = r1, ...} = x1
395        and {size = s2, ...} = x2
396      in
397        s1 = s2 andalso
398        let
399          val (l2,k2_v2,r2) = nodePartition cmp x2 k1
400        in
401          Option.isSome k2_v2 andalso
402          treeEqualDomain cmp l1 l2 andalso
403          treeEqualDomain cmp r1 r2
404        end
405      end;
406in
407  fun equalDomain (Map (cmp,tree1)) (Map (_,tree2)) =
408      pointerEqual (tree1,tree2) orelse
409      treeEqualDomain cmp tree1 tree2
410end;
411
412(* mapPartial is the basic function for preserving the tree structure. *)
413(* It applies the argument function to the elements *in order*. *)
414
415local
416  fun treeMapPartial cmp _ E = E
417    | treeMapPartial cmp f (T {priority,left,key,value,right,...}) =
418      let
419        val left = treeMapPartial cmp f left
420        and value' = f (key,value)
421        and right = treeMapPartial cmp f right
422      in
423        case value' of
424          NONE => treeAppend cmp left right
425        | SOME value => mkT priority left key value right
426      end;
427in
428  fun mapPartial f (Map (cmp,tree)) = Map (cmp, treeMapPartial cmp f tree);
429end;
430
431(* map is a primitive function for efficiency reasons. *)
432(* It also applies the argument function to the elements *in order*. *)
433
434local
435  fun treeMap _ E = E
436    | treeMap f (T {size,priority,left,key,value,right}) =
437      let
438        val left = treeMap f left
439        and value = f (key,value)
440        and right = treeMap f right
441      in
442        T {size = size, priority = priority, left = left,
443           key = key, value = value, right = right}
444      end;
445in
446  fun map f (Map (cmp,tree)) = Map (cmp, treeMap f tree);
447end;
448
449(* nth picks the nth smallest key/value (counting from 0). *)
450
451
452local
453  fun treeNth E _ = raise Error "RandomMap.nth"
454    | treeNth (T {left,key,value,right,...}) n =
455      let
456        val k = treeSize left
457      in
458        if n = k then (key,value)
459        else if n < k then treeNth left n
460        else treeNth right (n - (k + 1))
461      end;
462in
463  fun nth (Map (_,tree)) n = treeNth tree n;
464end;
465
466(* ------------------------------------------------------------------------- *)
467(* Iterators.                                                                *)
468(* ------------------------------------------------------------------------- *)
469
470fun leftSpine E acc = acc
471  | leftSpine (t as T {left,...}) acc = leftSpine left (t :: acc);
472
473fun rightSpine E acc = acc
474  | rightSpine (t as T {right,...}) acc = rightSpine right (t :: acc);
475
476datatype ('key,'a) iterator =
477    LR of ('key * 'a) * ('key,'a) tree * ('key,'a) tree list
478  | RL of ('key * 'a) * ('key,'a) tree * ('key,'a) tree list;
479
480fun mkLR [] = NONE
481  | mkLR (T {key,value,right,...} :: l) = SOME (LR ((key,value),right,l))
482  | mkLR (E :: _) = raise Bug "RandomMap.mkLR";
483
484fun mkRL [] = NONE
485  | mkRL (T {key,value,left,...} :: l) = SOME (RL ((key,value),left,l))
486  | mkRL (E :: _) = raise Bug "RandomMap.mkRL";
487
488fun mkIterator (Map (_,tree)) = mkLR (leftSpine tree []);
489
490fun mkRevIterator (Map (_,tree)) = mkRL (rightSpine tree []);
491
492fun readIterator (LR (key_value,_,_)) = key_value
493  | readIterator (RL (key_value,_,_)) = key_value;
494
495fun advanceIterator (LR (_,next,l)) = mkLR (leftSpine next l)
496  | advanceIterator (RL (_,next,l)) = mkRL (rightSpine next l);
497
498(* ------------------------------------------------------------------------- *)
499(* Derived operations.                                                       *)
500(* ------------------------------------------------------------------------- *)
501
502fun null m = size m = 0;
503
504fun get m key =
505    case peek m key of
506      NONE => raise Error "RandomMap.get: element not found"
507    | SOME value => value;
508
509fun inDomain key m = Option.isSome (peek m key);
510
511fun insert m key_value =
512    union (SOME o snd) m (singleton (comparison m) key_value);
513
514(*
515val insert =
516    fn m => fn x =>
517    checkWellformed
518      "after insert" (insert (checkWellformed "before insert" m) x);
519*)
520
521local
522  fun fold _ NONE acc = acc
523    | fold f (SOME iter) acc =
524      let
525        val (key,value) = readIterator iter
526      in
527        fold f (advanceIterator iter) (f (key,value,acc))
528      end;
529in
530  fun foldl f b m = fold f (mkIterator m) b;
531
532  fun foldr f b m = fold f (mkRevIterator m) b;
533end;
534
535local
536  fun find _ NONE = NONE
537    | find pred (SOME iter) =
538      let
539        val key_value = readIterator iter
540      in
541        if pred key_value then SOME key_value
542        else find pred (advanceIterator iter)
543      end;
544in
545  fun findl p m = find p (mkIterator m);
546
547  fun findr p m = find p (mkRevIterator m);
548end;
549
550fun fromList cmp l = List.foldl (fn (k_v,m) => insert m k_v) (new cmp) l;
551
552fun insertList m l = union (SOME o snd) m (fromList (comparison m) l);
553
554fun filter p =
555    let
556      fun f (key_value as (_,value)) =
557          if p key_value then SOME value else NONE
558    in
559      mapPartial f
560    end;
561
562fun app f m = foldl (fn (key,value,()) => f (key,value)) () m;
563
564fun transform f = map (fn (_,value) => f value);
565
566fun toList m = foldr (fn (key,value,l) => (key,value) :: l) [] m;
567
568fun domain m = foldr (fn (key,_,l) => key :: l) [] m;
569
570fun exists p m = Option.isSome (findl p m);
571
572fun all p m = not (exists (not o p) m);
573
574local
575  fun iterCompare _ _ NONE NONE = EQUAL
576    | iterCompare _ _ NONE (SOME _) = LESS
577    | iterCompare _ _ (SOME _) NONE = GREATER
578    | iterCompare kcmp vcmp (SOME i1) (SOME i2) =
579      keyIterCompare kcmp vcmp (readIterator i1) (readIterator i2) i1 i2
580  and keyIterCompare kcmp vcmp (k1,v1) (k2,v2) i1 i2 =
581      case kcmp (k1,k2) of
582        LESS => LESS
583      | EQUAL =>
584        (case vcmp (v1,v2) of
585           LESS => LESS
586         | EQUAL =>
587           iterCompare kcmp vcmp (advanceIterator i1) (advanceIterator i2)
588         | GREATER => GREATER)
589      | GREATER => GREATER;
590in
591  fun compare cmp (m1,m2) =
592      iterCompare (comparison m1) cmp (mkIterator m1) (mkIterator m2);
593end;
594
595end
596