1(* Redblackset -- sets implemented by Okasaki-style Red-Black trees *)
2(* Ken Friis Larsen <kfl@it.edu>                                    *)
3structure Redblackset :>  Redblackset =
4struct
5
6  datatype 'item tree = LEAF
7                      | RED   of 'item * 'item tree * 'item tree
8                      | BLACK of 'item * 'item tree * 'item tree
9
10  type 'item set  = ('item * 'item -> order) * 'item tree * int
11
12  exception NotFound
13
14  fun empty compare = (compare, LEAF, 0)
15
16  fun numItems (_, _, n) = n
17
18  fun singleton compare x = (compare, BLACK(x, LEAF, LEAF), 1)
19
20  fun isEmpty (_, LEAF, _) = true
21    | isEmpty _            = false
22
23  fun member ((compare, tree, n), elm) =
24      let fun memShared x left right =
25              case compare(elm,x) of
26                  EQUAL   => true
27                | LESS    => mem left
28                | GREATER => mem right
29          and mem LEAF                    = false
30            | mem (RED(x, left, right))   = memShared x left right
31            | mem (BLACK(x, left, right)) = memShared x left right
32      in  mem tree end
33
34  fun retrieve (set, x) = if member(set, x) then x else raise NotFound
35
36  fun peek (set, x) = if member(set, x) then SOME x else NONE
37
38  fun lbalance z (RED(y,RED(x,a,b),c)) d =
39      RED(y,BLACK(x,a,b),BLACK(z,c,d))
40    | lbalance z (RED(x,a,RED(y,b,c))) d =
41      RED(y,BLACK(x,a,b),BLACK(z,c,d))
42    | lbalance x left right = BLACK(x, left, right)
43
44  fun rbalance x a (RED(y,b,RED(z,c,d))) =
45      RED(y,BLACK(x,a,b),BLACK(z,c,d))
46    | rbalance x a (RED(z,RED(y,b,c),d)) =
47      RED(y,BLACK(x,a,b),BLACK(z,c,d))
48    | rbalance x left right = BLACK(x, left, right)
49
50
51  local
52    fun lbal x (l,inc) r = (lbalance x l r, inc)
53    fun rbal x l (r,inc) = (rbalance x l r, inc)
54    fun REDl(x,(l,inc),r) = (RED(x,l,r), inc)
55    fun REDr(x,l,(r,inc)) = (RED(x,l,r), inc)
56    fun insert compare elm = let
57      fun ins LEAF = (RED(elm,LEAF,LEAF), 1)
58        | ins (BLACK(x,left,right)) = let
59          in
60            case compare(elm, x) of
61              LESS    => lbal x (ins left) right
62            | GREATER => rbal x left (ins right)
63            | EQUAL   => (BLACK(elm,left,right), 0)
64          end
65        | ins (RED(x,left,right)) = let
66          in
67            case compare(elm, x) of
68              LESS    => REDl(x, (ins left), right)
69            | GREATER => REDr(x, left, (ins right))
70            | EQUAL   => (RED(elm,left,right), 0)
71          end
72    in
73      ins
74    end
75  in
76
77  fun add (set as (compare, tree, n), elm) = let
78    val (nt, inc) = insert compare elm tree
79  in
80    ( compare
81    , case nt of
82        RED(e, l, r) => BLACK(e, l, r)
83      | tree         => tree
84    , n+inc)
85  end
86
87  fun addList (set, xs) = List.foldl (fn(x,set) => add(set, x)) set xs
88  end
89
90  fun fromList compare xs =
91    addList (empty compare, xs)
92
93  fun push LEAF stack = stack
94    | push tree stack = tree :: stack
95
96  fun pushNode left x right stack =
97      left :: (BLACK(x, LEAF, LEAF) :: (push right stack))
98
99  fun getMin []             some none = none
100    | getMin (tree :: rest) some none =
101      case tree of
102          LEAF                  => getMin rest some none
103        | RED  (x, LEAF, right) => some x (push right rest)
104        | BLACK(x, LEAF, right) => some x (push right rest)
105        | RED  (x, left, right) => getMin(pushNode left x right rest) some none
106        | BLACK(x, left, right) => getMin(pushNode left x right rest) some none
107
108  fun getMax []             some none = none
109    | getMax (tree :: rest) some none =
110      case tree of
111          LEAF                  => getMax rest some none
112        | RED  (x, left, LEAF)  => some x (push left rest)
113        | BLACK(x, left, LEAF)  => some x (push left rest)
114        | RED  (x, left, right) => getMax(pushNode right x left rest) some none
115        | BLACK(x, left, right) => getMax(pushNode right x left rest) some none
116
117  fun fold get f e (compare, tree, n) =
118      let fun loop stack acc =
119              get stack (fn x => fn stack => loop stack (f(x, acc))) acc
120      in  loop [tree] e end
121
122  fun foldl f = fold getMin f
123
124  fun foldr f = fold getMax f
125
126  fun listItems set = foldr op:: [] set
127
128  fun appAll get f (compare, tree, n) =
129      let fun loop stack = get stack (fn x => (f x; loop)) ()
130      in  loop [tree] end
131
132  fun app f = appAll getMin f
133
134  fun revapp f = appAll getMax f
135
136  fun find p set =
137      let exception EXIT of 'a
138          fun newp x = if p x then raise EXIT x else ()
139      in  (app newp set; NONE)
140          handle EXIT x => SOME x end
141
142
143  (*  Ralf Hinze's convert a sorted list to RB tree *)
144  local
145      datatype 'item digits =
146               ZERO
147             | ONE of 'item * 'item tree * 'item digits
148             | TWO of 'item * 'item tree * 'item * 'item tree * 'item digits
149
150      fun incr x a ZERO                  = ONE(x, a, ZERO)
151        | incr x a (ONE(y, b, ds))       = TWO(x, a, y, b, ds)
152        | incr z c (TWO(y, b, x, a, ds)) =
153          ONE(z, c, incr y (BLACK(x, a, b)) ds)
154
155      fun insertMax(a, digits) = incr a LEAF digits
156
157      fun build ZERO                  a = a
158        | build (ONE(x, a, ds))       b = build ds (BLACK(x, a, b))
159        | build (TWO(y, b, x, a, ds)) c = build ds (BLACK(x, a, RED(y, b, c)))
160
161      fun buildAll digits = build digits LEAF
162
163      fun toInt digits =
164          let fun loop ZERO power acc            = acc
165                | loop (ONE(_,_,rest)) power acc =
166                  loop rest (2*power) (power + acc)
167                | loop (TWO(_,_,_,_,rest)) power acc =
168                  loop rest (2*power) (2*power + acc)
169          in  loop digits 1 0 end
170
171      fun get stack = getMin stack (fn x => fn stack => SOME(x,stack)) NONE
172
173      fun insRest stack acc =
174          getMin stack (fn x => fn stack => insRest stack (insertMax(x,acc)))
175                 acc
176
177  in
178  fun fromSortedList (compare, ls) =
179      let val digits = List.foldl insertMax ZERO ls
180      in  (compare, buildAll digits, toInt digits) end
181
182
183  (* FIXME: it *must* be possible to write union, equal, isSubset,
184            intersection, and difference more elegant.
185  *)
186  fun actual_union (s1 as (compare, t1, n1), s2 as (_, t2, n2)) =
187      let fun loop x y stack1 stack2 res =
188              case compare(x, y) of
189                  EQUAL =>
190                  let val res = insertMax(x, res)
191                  in  case (get stack1, get stack2) of
192                          (SOME(x, s1), SOME(y, s2)) => loop x y s1 s2 res
193                        | (NONE, NONE)               => res
194                        | (SOME _, _)                => insRest stack1 res
195                        | (_, SOME _)                => insRest stack2 res
196                  end
197                | LESS =>
198                  let val res = insertMax(x, res)
199                  in  case get stack1 of
200                          NONE => insRest stack2 (insertMax(y, res))
201                        | SOME(x, stack1) => loop x y stack1 stack2 res
202                  end
203                | GREATER =>
204                  let val res = insertMax(y, res)
205                  in  case get stack2 of
206                          NONE => insRest stack1 (insertMax(x, res))
207                        | SOME(y, stack2) => loop x y stack1 stack2 res
208                  end
209      in  (* FIXME: here is lots of room for optimizations *)
210          case (get [t1], get [t2]) of
211              (SOME(x, stack1), SOME(y, stack2)) =>
212              let val digits = loop x y stack1 stack2 ZERO
213              in  (compare, buildAll digits, toInt digits) end
214            | (_, SOME _) => s2
215            | _           => s1 end
216
217  local
218      val ln2 = Math.ln 2.0
219  in
220  fun union ((_, _, 0), s2) = s2
221    | union (s1,(_, _, 0)) = s1
222    | union (s1 as (_, _, n1), s2 as (_, _, n2)) =
223      let val ((smin,nmin),(smax,nmax)) =
224              if (n1<n2) then ((s1,Real.fromInt n1),(s2,Real.fromInt n2))
225              else ((s2,Real.fromInt n2),(s1,Real.fromInt n1))
226      in
227          if Math.ln nmax / ln2 * nmin < nmin + nmax
228          then foldl (fn(x, res) => add(res, x)) smax smin
229          else actual_union(s1,s2)
230      end
231  end
232
233  fun equal ((compare, t1, _), (_, t2, _)) =
234      let fun loop x y stack1 stack2 =
235              case compare(x, y) of
236                  EQUAL =>
237                  (case (get stack1, get stack2) of
238                       (SOME(x, s1), SOME(y, s2)) => loop x y s1 s2
239                     | (NONE, NONE)               => true
240                     | _                          => false)
241                | _ => false
242      in  (* FIXME: here is lots of room for optimizations *)
243          case (get [t1], get [t2]) of
244              (SOME(x, stack1), SOME(y, stack2)) => loop x y stack1 stack2
245            | (NONE, NONE)                       => true
246            | _                                  => false end
247
248  fun isSubset ((compare, t1, _), (_, t2, _)) =
249      let fun loop x y stack1 stack2 =
250              case compare(x, y) of
251                  EQUAL =>
252                  (case (get stack1, get stack2) of
253                       (SOME(x, s1), SOME(y, s2)) => loop x y s1 s2
254                     | (NONE, _)                  => true
255                     | _                          => false)
256                | LESS => false
257                | GREATER =>
258                  (case get stack2 of
259                       SOME(y, stack2) => loop x y stack1 stack2
260                     | NONE            => false)
261      in  (* FIXME: here is lots of room for optimizations *)
262          case (get [t1], get [t2]) of
263              (SOME(x, stack1), SOME(y, stack2)) => loop x y stack1 stack2
264            | (NONE, _)                          => true
265            | _                                  => false
266      end
267
268
269  fun intersection (s1 as (compare, t1, n1), s2 as (_, t2, n2)) =
270      let fun loop x y stack1 stack2 res =
271              case compare(x, y) of
272                  EQUAL =>
273                  let val res = insertMax(x, res)
274                  in  case (get stack1, get stack2) of
275                          (SOME(x, s1), SOME(y, s2)) => loop x y s1 s2 res
276                        | _                          => res
277                  end
278                | LESS =>
279                  (case get stack1 of
280                       NONE            => res
281                     | SOME(x, stack1) => loop x y stack1 stack2 res)
282                | GREATER =>
283                  (case get stack2 of
284                       NONE            => res
285                     | SOME(y, stack2) => loop x y stack1 stack2 res)
286      in  (* FIXME: here is lots of room for optimizations *)
287          case (get [t1], get [t2]) of
288              (SOME(x, stack1), SOME(y, stack2)) =>
289              let val digits = loop x y stack1 stack2 ZERO
290              in  (compare, buildAll digits, toInt digits) end
291            | _           => empty compare end
292
293
294  fun difference (s1 as (compare, t1, n1), s2 as (_, t2, n2)) =
295      let fun loop x y stack1 stack2 res =
296              case compare(x, y) of
297                  EQUAL =>
298                  (case (get stack1, get stack2) of
299                       (SOME(x, s1), SOME(y, s2)) => loop x y s1 s2 res
300                     | (SOME _, _)                => insRest stack1 res
301                     | _                          => res)
302                | LESS =>
303                  let val res = insertMax(x, res)
304                  in  case get stack1 of
305                          NONE            => res
306                        | SOME(x, stack1) => loop x y stack1 stack2 res
307                  end
308                | GREATER =>
309                  (case get stack2 of
310                       NONE => insRest stack1 (insertMax(x, res))
311                     | SOME(y, stack2) => loop x y stack1 stack2 res)
312      in  (* FIXME: here is lots of room for optimizations *)
313          case (get [t1], get [t2]) of
314              (SOME(x, stack1), SOME(y, stack2)) =>
315              let val digits = loop x y stack1 stack2 ZERO
316              in  (compare, buildAll digits, toInt digits) end
317            | (_, SOME _) => empty compare
318            | _           => s1 end
319  end
320
321
322  (* Peter Sestoft's convert a sorted list to RB tree *)
323  fun fromSortedList' ls =
324      let val len = length ls
325          fun log2 n =
326              let fun loop k p = if p >= n then k else loop (k+1) (2*p)
327              in loop 0 1 end
328          fun h 0 _ xs = (LEAF, xs)
329            | h n d xs =
330              let val m = n div 2
331                  val (t1, ys) = h m       (d-1) xs
332                  val y = hd ys
333                  and yr = tl ys
334                  val (t2, zs) = h (n-m-1) (d-1) yr
335              in (if d=0 then RED(y, t1, t2) else BLACK(y, t1, t2), zs) end
336      in  ( case #1 (h len (log2 (len + 1) - 1) ls) of
337                RED(x, left, right) => BLACK(x, left, right)
338              | tree                => tree
339          , len) end
340
341
342  exception RedBlackSetError
343
344  (* delete a la Stefan M. Kahrs *)
345
346  fun sub1 (BLACK arg) = RED arg
347    | sub1 _ = raise RedBlackSetError
348
349  fun balleft y (RED(x,a,b)) c             = RED(y, BLACK(x, a, b), c)
350    | balleft x bl (BLACK(y, a, b))        = rbalance x bl (RED(y, a, b))
351    | balleft x bl (RED(z,BLACK(y,a,b),c)) =
352      RED(y, BLACK(x, bl, a), rbalance z b (sub1 c))
353    | balleft _ _ _ = raise RedBlackSetError
354
355  fun balright x a             (RED(y,b,c)) = RED(x, a, BLACK(y, b, c))
356    | balright y (BLACK(x,a,b))          br = lbalance y (RED(x,a,b)) br
357    | balright z (RED(x,a,BLACK(y,b,c))) br =
358      RED(y, lbalance x (sub1 a) b, BLACK(z, c, br))
359    | balright _ _ _ = raise RedBlackSetError
360
361
362  (* [append left right] constructs a new tree t.
363  PRECONDITIONS: RB left /\ RB right
364              /\ !e in left => !x in right e < x
365  POSTCONDITION: not (RB t)
366  *)
367  fun append LEAF right                    = right
368    | append left LEAF                     = left
369    | append (RED(x,a,b)) (RED(y,c,d))     =
370      (case append b c of
371           RED(z, b, c) => RED(z, RED(x, a, b), RED(y, c, d))
372         | bc           => RED(x, a, RED(y, bc, d)))
373    | append a (RED(x,b,c))                = RED(x, append a b, c)
374    | append (RED(x,a,b)) c                = RED(x, a, append b c)
375    | append (BLACK(x,a,b)) (BLACK(y,c,d)) =
376      (case append b c of
377           RED(z, b, c) => RED(z, BLACK(x, a, b), BLACK(y, c, d))
378         | bc           => balleft x a (BLACK(y, bc, d)))
379
380  fun delete (set as (compare, tree, n), x) =
381      let fun delShared y a b =
382              case compare(x,y) of
383                  EQUAL   => append a b
384                | LESS    => (case a of
385                                  BLACK _ => balleft y (del a) b
386                                | _       => RED(y, del a, b))
387                | GREATER => (case b of
388                                  BLACK _ => balright y a (del b)
389                                | _       => RED(y, a, del b))
390          and del LEAF             = raise NotFound
391            | del (RED(y, a, b))   = delShared y a b
392            | del (BLACK(y, a, b)) = delShared y a b
393      in  ( compare
394          , case del tree of
395                RED arg => BLACK arg
396              | tree    => tree
397          , n-1) end
398
399  fun filter p (set as (compare, _, _)) =
400      foldl (fn (e, acc) => if p e then add(acc,e)
401                            else acc)
402         (empty compare)
403         set
404
405end
406