1(* ========================================================================= *)
2(* FINITE SETS WITH A FIXED ELEMENT TYPE                                     *)
3(* Copyright (c) 2004 Joe Hurd, distributed under the BSD License            *)
4(* ========================================================================= *)
5
6functor ElementSet (
7  KM : KeyMap
8) :> ElementSet
9where type element = KM.key
10and type 'a map = 'a KM.map =
11struct
12
13(* ------------------------------------------------------------------------- *)
14(* A type of set elements.                                                   *)
15(* ------------------------------------------------------------------------- *)
16
17type element = KM.key;
18
19val compareElement = KM.compareKey;
20
21val equalElement = KM.equalKey;
22
23(* ------------------------------------------------------------------------- *)
24(* A type of finite sets.                                                    *)
25(* ------------------------------------------------------------------------- *)
26
27type 'a map = 'a KM.map;
28
29datatype set = Set of unit map;
30
31(* ------------------------------------------------------------------------- *)
32(* Converting to and from maps.                                              *)
33(* ------------------------------------------------------------------------- *)
34
35fun dest (Set m) = m;
36
37fun mapPartial f =
38    let
39      fun mf (elt,()) = f elt
40    in
41      fn Set m => KM.mapPartial mf m
42    end;
43
44fun map f =
45    let
46      fun mf (elt,()) = f elt
47    in
48      fn Set m => KM.map mf m
49    end;
50
51fun domain m = Set (KM.transform (fn _ => ()) m);
52
53(* ------------------------------------------------------------------------- *)
54(* Constructors.                                                             *)
55(* ------------------------------------------------------------------------- *)
56
57val empty = Set (KM.new ());
58
59fun singleton elt = Set (KM.singleton (elt,()));
60
61(* ------------------------------------------------------------------------- *)
62(* Set size.                                                                 *)
63(* ------------------------------------------------------------------------- *)
64
65fun null (Set m) = KM.null m;
66
67fun size (Set m) = KM.size m;
68
69(* ------------------------------------------------------------------------- *)
70(* Querying.                                                                 *)
71(* ------------------------------------------------------------------------- *)
72
73fun peek (Set m) elt =
74    case KM.peekKey m elt of
75      SOME (elt,()) => SOME elt
76    | NONE => NONE;
77
78fun member elt (Set m) = KM.inDomain elt m;
79
80fun pick (Set m) =
81    let
82      val (elt,_) = KM.pick m
83    in
84      elt
85    end;
86
87fun nth (Set m) n =
88    let
89      val (elt,_) = KM.nth m n
90    in
91      elt
92    end;
93
94fun random (Set m) =
95    let
96      val (elt,_) = KM.random m
97    in
98      elt
99    end;
100
101(* ------------------------------------------------------------------------- *)
102(* Adding.                                                                   *)
103(* ------------------------------------------------------------------------- *)
104
105fun add (Set m) elt =
106    let
107      val m = KM.insert m (elt,())
108    in
109      Set m
110    end;
111
112local
113  fun uncurriedAdd (elt,set) = add set elt;
114in
115  fun addList set = List.foldl uncurriedAdd set;
116end;
117
118(* ------------------------------------------------------------------------- *)
119(* Removing.                                                                 *)
120(* ------------------------------------------------------------------------- *)
121
122fun delete (Set m) elt =
123    let
124      val m = KM.delete m elt
125    in
126      Set m
127    end;
128
129fun remove (Set m) elt =
130    let
131      val m = KM.remove m elt
132    in
133      Set m
134    end;
135
136fun deletePick (Set m) =
137    let
138      val ((elt,()),m) = KM.deletePick m
139    in
140      (elt, Set m)
141    end;
142
143fun deleteNth (Set m) n =
144    let
145      val ((elt,()),m) = KM.deleteNth m n
146    in
147      (elt, Set m)
148    end;
149
150fun deleteRandom (Set m) =
151    let
152      val ((elt,()),m) = KM.deleteRandom m
153    in
154      (elt, Set m)
155    end;
156
157(* ------------------------------------------------------------------------- *)
158(* Joining.                                                                  *)
159(* ------------------------------------------------------------------------- *)
160
161fun union (Set m1) (Set m2) = Set (KM.unionDomain m1 m2);
162
163fun unionList sets =
164    let
165      val ms = List.map dest sets
166    in
167      Set (KM.unionListDomain ms)
168    end;
169
170fun intersect (Set m1) (Set m2) = Set (KM.intersectDomain m1 m2);
171
172fun intersectList sets =
173    let
174      val ms = List.map dest sets
175    in
176      Set (KM.intersectListDomain ms)
177    end;
178
179fun difference (Set m1) (Set m2) =
180    Set (KM.differenceDomain m1 m2);
181
182fun symmetricDifference (Set m1) (Set m2) =
183    Set (KM.symmetricDifferenceDomain m1 m2);
184
185(* ------------------------------------------------------------------------- *)
186(* Mapping and folding.                                                      *)
187(* ------------------------------------------------------------------------- *)
188
189fun filter pred =
190    let
191      fun mpred (elt,()) = pred elt
192    in
193      fn Set m => Set (KM.filter mpred m)
194    end;
195
196fun partition pred =
197    let
198      fun mpred (elt,()) = pred elt
199    in
200      fn Set m =>
201         let
202           val (m1,m2) = KM.partition mpred m
203         in
204           (Set m1, Set m2)
205         end
206    end;
207
208fun app f =
209    let
210      fun mf (elt,()) = f elt
211    in
212      fn Set m => KM.app mf m
213    end;
214
215fun foldl f =
216    let
217      fun mf (elt,(),acc) = f (elt,acc)
218    in
219      fn acc => fn Set m => KM.foldl mf acc m
220    end;
221
222fun foldr f =
223    let
224      fun mf (elt,(),acc) = f (elt,acc)
225    in
226      fn acc => fn Set m => KM.foldr mf acc m
227    end;
228
229(* ------------------------------------------------------------------------- *)
230(* Searching.                                                                *)
231(* ------------------------------------------------------------------------- *)
232
233fun findl p =
234    let
235      fun mp (elt,()) = p elt
236    in
237      fn Set m =>
238         case KM.findl mp m of
239           SOME (elt,()) => SOME elt
240         | NONE => NONE
241    end;
242
243fun findr p =
244    let
245      fun mp (elt,()) = p elt
246    in
247      fn Set m =>
248         case KM.findr mp m of
249           SOME (elt,()) => SOME elt
250         | NONE => NONE
251    end;
252
253fun firstl f =
254    let
255      fun mf (elt,()) = f elt
256    in
257      fn Set m => KM.firstl mf m
258    end;
259
260fun firstr f =
261    let
262      fun mf (elt,()) = f elt
263    in
264      fn Set m => KM.firstr mf m
265    end;
266
267fun exists p =
268    let
269      fun mp (elt,()) = p elt
270    in
271      fn Set m => KM.exists mp m
272    end;
273
274fun all p =
275    let
276      fun mp (elt,()) = p elt
277    in
278      fn Set m => KM.all mp m
279    end;
280
281fun count p =
282    let
283      fun mp (elt,()) = p elt
284    in
285      fn Set m => KM.count mp m
286    end;
287
288(* ------------------------------------------------------------------------- *)
289(* Comparing.                                                                *)
290(* ------------------------------------------------------------------------- *)
291
292fun compareValue ((),()) = EQUAL;
293
294fun equalValue () () = true;
295
296fun compare (Set m1, Set m2) = KM.compare compareValue (m1,m2);
297
298fun equal (Set m1) (Set m2) = KM.equal equalValue m1 m2;
299
300fun subset (Set m1) (Set m2) = KM.subsetDomain m1 m2;
301
302fun disjoint (Set m1) (Set m2) = KM.disjointDomain m1 m2;
303
304(* ------------------------------------------------------------------------- *)
305(* Closing under an operation.                                               *)
306(* ------------------------------------------------------------------------- *)
307
308fun closedAdd f =
309    let
310      fun adds acc set = foldl check acc set
311
312      and check (elt,acc) =
313          if member elt acc then acc
314          else expand (add acc elt) elt
315
316      and expand acc elt = adds acc (f elt)
317    in
318      adds
319    end;
320
321fun close f = closedAdd f empty;
322
323(* ------------------------------------------------------------------------- *)
324(* Converting to and from lists.                                             *)
325(* ------------------------------------------------------------------------- *)
326
327fun transform f =
328    let
329      fun inc (x,l) = f x :: l
330    in
331      foldr inc []
332    end;
333
334fun toList (Set m) = KM.keys m;
335
336fun fromList elts = addList empty elts;
337
338(* ------------------------------------------------------------------------- *)
339(* Pretty-printing.                                                          *)
340(* ------------------------------------------------------------------------- *)
341
342fun toString set =
343    "{" ^ (if null set then "" else Int.toString (size set)) ^ "}";
344
345(* ------------------------------------------------------------------------- *)
346(* Iterators over sets                                                       *)
347(* ------------------------------------------------------------------------- *)
348
349type iterator = unit KM.iterator;
350
351fun mkIterator (Set m) = KM.mkIterator m;
352
353fun mkRevIterator (Set m) = KM.mkRevIterator m;
354
355fun readIterator iter =
356    let
357      val (elt,()) = KM.readIterator iter
358    in
359      elt
360    end;
361
362fun advanceIterator iter = KM.advanceIterator iter;
363
364end
365
366structure IntSet =
367ElementSet (IntMap);
368
369structure IntPairSet =
370ElementSet (IntPairMap);
371
372structure StringSet =
373ElementSet (StringMap);
374