1(*  Title:      Pure/General/table.ML
2    Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
3
4Generic tables.  Efficient purely functional implementation using
5balanced 2-3 trees.
6*)
7
8signature KEY =
9sig
10  type key
11  val ord: key * key -> order
12  val pp : key HOLPP.pprinter
13end;
14
15signature TABLE =
16sig
17  type key
18  type 'a table
19  exception DUP of key
20  exception SAME
21  exception UNDEF of key
22  val empty: 'a table
23  val is_empty: 'a table -> bool
24  val is_single: 'a table -> bool
25  val map: (key -> 'a -> 'b) -> 'a table -> 'b table
26  val fold: (key * 'b -> 'a -> 'a) -> 'b table -> 'a -> 'a
27  val fold_rev: (key * 'b -> 'a -> 'a) -> 'b table -> 'a -> 'a
28  val dest: 'a table -> (key * 'a) list
29  val keys: 'a table -> key list
30  val min: 'a table -> (key * 'a) option
31  val max: 'a table -> (key * 'a) option
32  val get_first: (key * 'a -> 'b option) -> 'a table -> 'b option
33  val exists: (key * 'a -> bool) -> 'a table -> bool
34  val forall: (key * 'a -> bool) -> 'a table -> bool
35  val lookup_key: 'a table -> key -> (key * 'a) option
36  val lookup: 'a table -> key -> 'a option
37  val defined: 'a table -> key -> bool
38  val update: key * 'a -> 'a table -> 'a table
39  val update_new: key * 'a -> 'a table -> 'a table                  (* exn DUP*)
40  val default: key * 'a -> 'a table -> 'a table
41  val map_entry: key -> ('a -> 'a) (*exception SAME*) -> 'a table -> 'a table
42  val map_default: key * 'a -> ('a -> 'a) -> 'a table -> 'a table
43  val make: (key * 'a) list -> 'a table                             (* exn DUP*)
44  val join: (key -> 'a * 'a -> 'a) (*exception SAME*) ->
45            'a table * 'a table -> 'a table                         (* exn DUP*)
46  val merge: ('a -> 'a -> bool) -> 'a table * 'a table -> 'a table  (* exn DUP*)
47  val delete: key -> 'a table -> 'a table                         (* exn UNDEF*)
48  val delete_safe: key -> 'a table -> 'a table
49  val member: ('b -> 'a -> bool) -> 'a table -> key * 'b -> bool
50  val insert: ('a -> 'a -> bool) -> key * 'a -> 'a table -> 'a table(* exn DUP*)
51  val remove: ('b -> 'a -> bool) -> key * 'b -> 'a table -> 'a table
52  val lookup_list: 'a list table -> key -> 'a list
53  val cons_list: key * 'a -> 'a list table -> 'a list table
54  val insert_list: ('a -> 'a -> bool) -> key * 'a -> 'a list table ->
55                   'a list table
56  val remove_list: ('b -> 'a -> bool) -> key * 'b -> 'a list table ->
57                   'a list table
58  val update_list: ('a -> 'a -> bool) -> key * 'a -> 'a list table ->
59                   'a list table
60  val make_list: (key * 'a) list -> 'a list table
61  val dest_list: 'a list table -> (key * 'a) list
62  val merge_list: ('a -> 'a -> bool) -> 'a list table * 'a list table ->
63                  'a list table
64  type set = unit table
65  val insert_set: key -> set -> set
66  val remove_set: key -> set -> set
67  val make_set: key list -> set
68  val pp : 'a HOLPP.pprinter -> 'a table HOLPP.pprinter
69end;
70
71functor Table(Key: KEY) : TABLE =
72struct
73
74open Portable
75(* datatype table *)
76
77type key = Key.key;
78
79datatype 'a table =
80  Empty |
81  Branch2 of 'a table * (key * 'a) * 'a table |
82  Branch3 of 'a table * (key * 'a) * 'a table * (key * 'a) * 'a table;
83
84exception DUP of key;
85
86
87(* empty and single *)
88
89val empty = Empty;
90
91fun is_empty Empty = true
92  | is_empty _ = false;
93
94fun is_single (Branch2 (Empty, _, Empty)) = true
95  | is_single _ = false;
96
97
98(* map and fold combinators *)
99
100fun map_table f =
101  let
102    fun map Empty = Empty
103      | map (Branch2 (left, (k, x), right)) =
104          Branch2 (map left, (k, f k x), map right)
105      | map (Branch3 (left, (k1, x1), mid, (k2, x2), right)) =
106          Branch3 (map left, (k1, f k1 x1), map mid, (k2, f k2 x2), map right);
107  in map end;
108
109fun fold_table f =
110  let
111    fun fold Empty x = x
112      | fold (Branch2 (left, p, right)) x =
113          fold right (f p (fold left x))
114      | fold (Branch3 (left, p1, mid, p2, right)) x =
115          fold right (f p2 (fold mid (f p1 (fold left x))));
116  in fold end;
117
118fun fold_rev_table f =
119  let
120    fun fold Empty x = x
121      | fold (Branch2 (left, p, right)) x =
122          fold left (f p (fold right x))
123      | fold (Branch3 (left, p1, mid, p2, right)) x =
124          fold left (f p1 (fold mid (f p2 (fold right x))));
125  in fold end;
126
127fun dest tab = fold_rev_table cons tab [];
128fun keys tab = fold_rev_table (cons o #1) tab [];
129
130
131(* min/max entries *)
132
133fun min Empty = NONE
134  | min (Branch2 (Empty, p, _)) = SOME p
135  | min (Branch3 (Empty, p, _, _, _)) = SOME p
136  | min (Branch2 (left, _, _)) = min left
137  | min (Branch3 (left, _, _, _, _)) = min left;
138
139fun max Empty = NONE
140  | max (Branch2 (_, p, Empty)) = SOME p
141  | max (Branch3 (_, _, _, p, Empty)) = SOME p
142  | max (Branch2 (_, _, right)) = max right
143  | max (Branch3 (_, _, _, _, right)) = max right;
144
145
146(* get_first *)
147
148fun get_first f =
149  let
150    fun get Empty = NONE
151      | get (Branch2 (left, (k, x), right)) =
152          (case get left of
153            NONE =>
154              (case f (k, x) of
155                NONE => get right
156              | some => some)
157          | some => some)
158      | get (Branch3 (left, (k1, x1), mid, (k2, x2), right)) =
159          (case get left of
160            NONE =>
161              (case f (k1, x1) of
162                NONE =>
163                  (case get mid of
164                    NONE =>
165                      (case f (k2, x2) of
166                        NONE => get right
167                      | some => some)
168                  | some => some)
169              | some => some)
170          | some => some);
171  in get end;
172
173fun exists pred =
174  isSome o get_first (fn entry => if pred entry then SOME () else NONE);
175fun forall pred = not o exists (not o pred);
176
177
178(* lookup *)
179
180fun lookup tab key =
181  let
182    fun look Empty = NONE
183      | look (Branch2 (left, (k, x), right)) =
184          (case Key.ord (key, k) of
185            LESS => look left
186          | EQUAL => SOME x
187          | GREATER => look right)
188      | look (Branch3 (left, (k1, x1), mid, (k2, x2), right)) =
189          (case Key.ord (key, k1) of
190            LESS => look left
191          | EQUAL => SOME x1
192          | GREATER =>
193              (case Key.ord (key, k2) of
194                LESS => look mid
195              | EQUAL => SOME x2
196              | GREATER => look right));
197  in look tab end;
198
199fun lookup_key tab key =
200  let
201    fun look Empty = NONE
202      | look (Branch2 (left, (k, x), right)) =
203          (case Key.ord (key, k) of
204            LESS => look left
205          | EQUAL => SOME (k, x)
206          | GREATER => look right)
207      | look (Branch3 (left, (k1, x1), mid, (k2, x2), right)) =
208          (case Key.ord (key, k1) of
209            LESS => look left
210          | EQUAL => SOME (k1, x1)
211          | GREATER =>
212              (case Key.ord (key, k2) of
213                LESS => look mid
214              | EQUAL => SOME (k2, x2)
215              | GREATER => look right));
216  in look tab end;
217
218fun defined tab key =
219  let
220    fun def Empty = false
221      | def (Branch2 (left, (k, x), right)) =
222          (case Key.ord (key, k) of
223            LESS => def left
224          | EQUAL => true
225          | GREATER => def right)
226      | def (Branch3 (left, (k1, x1), mid, (k2, x2), right)) =
227          (case Key.ord (key, k1) of
228            LESS => def left
229          | EQUAL => true
230          | GREATER =>
231              (case Key.ord (key, k2) of
232                LESS => def mid
233              | EQUAL => true
234              | GREATER => def right));
235  in def tab end;
236
237
238(* modify *)
239
240datatype 'a growth =
241  Stay of 'a table |
242  Sprout of 'a table * (key * 'a) * 'a table;
243
244exception SAME;
245
246fun modify key f tab =
247  let
248    fun modfy Empty = Sprout (Empty, (key, f NONE), Empty)
249      | modfy (Branch2 (left, p as (k, x), right)) =
250          (case Key.ord (key, k) of
251            LESS =>
252              (case modfy left of
253                Stay left' => Stay (Branch2 (left', p, right))
254              | Sprout (left1, q, left2) => Stay (Branch3 (left1, q, left2, p, right)))
255          | EQUAL => Stay (Branch2 (left, (k, f (SOME x)), right))
256          | GREATER =>
257              (case modfy right of
258                Stay right' => Stay (Branch2 (left, p, right'))
259              | Sprout (right1, q, right2) =>
260                  Stay (Branch3 (left, p, right1, q, right2))))
261      | modfy (Branch3 (left, p1 as (k1, x1), mid, p2 as (k2, x2), right)) =
262          (case Key.ord (key, k1) of
263            LESS =>
264              (case modfy left of
265                Stay left' => Stay (Branch3 (left', p1, mid, p2, right))
266              | Sprout (left1, q, left2) =>
267                  Sprout (Branch2 (left1, q, left2), p1, Branch2 (mid, p2, right)))
268          | EQUAL => Stay (Branch3 (left, (k1, f (SOME x1)), mid, p2, right))
269          | GREATER =>
270              (case Key.ord (key, k2) of
271                LESS =>
272                  (case modfy mid of
273                    Stay mid' => Stay (Branch3 (left, p1, mid', p2, right))
274                  | Sprout (mid1, q, mid2) =>
275                      Sprout (Branch2 (left, p1, mid1), q, Branch2 (mid2, p2, right)))
276              | EQUAL => Stay (Branch3 (left, p1, mid, (k2, f (SOME x2)), right))
277              | GREATER =>
278                  (case modfy right of
279                    Stay right' => Stay (Branch3 (left, p1, mid, p2, right'))
280                  | Sprout (right1, q, right2) =>
281                      Sprout (Branch2 (left, p1, mid), p2, Branch2 (right1, q, right2)))));
282
283  in
284    (case modfy tab of
285      Stay tab' => tab'
286    | Sprout br => Branch2 br)
287    handle SAME => tab
288  end;
289
290fun update (key, x) tab = modify key (fn _ => x) tab;
291fun update_new (key, x) tab = modify key (fn NONE => x | SOME _ => raise DUP key) tab;
292fun default (key, x) tab = modify key (fn NONE => x | SOME _ => raise SAME) tab;
293fun map_entry key f = modify key (fn NONE => raise SAME | SOME x => f x);
294fun map_default (key, x) f = modify key (fn NONE => f x | SOME y => f y);
295
296
297(* delete *)
298
299exception UNDEF of key;
300
301local
302
303fun compare NONE (k2, _) = LESS
304  | compare (SOME k1) (k2, _) = Key.ord (k1, k2);
305
306fun if_eq EQUAL x y = x
307  | if_eq _ x y = y;
308
309fun del (SOME k) Empty = raise UNDEF k
310  | del NONE (Branch2 (Empty, p, Empty)) = (p, (true, Empty))
311  | del NONE (Branch3 (Empty, p, Empty, q, Empty)) =
312      (p, (false, Branch2 (Empty, q, Empty)))
313  | del k (Branch2 (Empty, p, Empty)) = (case compare k p of
314      EQUAL => (p, (true, Empty)) | _ => raise UNDEF (valOf k))
315  | del k (Branch3 (Empty, p, Empty, q, Empty)) = (case compare k p of
316      EQUAL => (p, (false, Branch2 (Empty, q, Empty)))
317    | _ => (case compare k q of
318        EQUAL => (q, (false, Branch2 (Empty, p, Empty)))
319      | _ => raise UNDEF (valOf k)))
320  | del k (Branch2 (l, p, r)) = (case compare k p of
321      LESS => (case del k l of
322        (p', (false, l')) => (p', (false, Branch2 (l', p, r)))
323      | (p', (true, l')) => (p', case r of
324          Branch2 (rl, rp, rr) =>
325            (true, Branch3 (l', p, rl, rp, rr))
326        | Branch3 (rl, rp, rm, rq, rr) => (false, Branch2
327            (Branch2 (l', p, rl), rp, Branch2 (rm, rq, rr)))
328        | _ => raise Fail "Impossible case - table.del Branch2-LESS"))
329    | ord => (case del (if_eq ord NONE k) r of
330        (p', (false, r')) => (p', (false, Branch2 (l, if_eq ord p' p, r')))
331      | (p', (true, r')) => (p', case l of
332          Branch2 (ll, lp, lr) =>
333            (true, Branch3 (ll, lp, lr, if_eq ord p' p, r'))
334        | Branch3 (ll, lp, lm, lq, lr) => (false, Branch2
335            (Branch2 (ll, lp, lm), lq, Branch2 (lr, if_eq ord p' p, r')))
336        | _ => raise Fail "Impossible case - table.del Branch2-<any>")))
337  | del k (Branch3 (l, p, m, q, r)) = (case compare k q of
338      LESS => (case compare k p of
339        LESS => (case del k l of
340          (p', (false, l')) => (p', (false, Branch3 (l', p, m, q, r)))
341        | (p', (true, l')) => (p', (false, case (m, r) of
342            (Branch2 (ml, mp, mr), Branch2 _) =>
343              Branch2 (Branch3 (l', p, ml, mp, mr), q, r)
344          | (Branch3 (ml, mp, mm, mq, mr), _) =>
345              Branch3 (Branch2 (l', p, ml), mp, Branch2 (mm, mq, mr), q, r)
346          | (Branch2 (ml, mp, mr), Branch3 (rl, rp, rm, rq, rr)) =>
347              Branch3 (Branch2 (l', p, ml), mp, Branch2 (mr, q, rl), rp,
348                Branch2 (rm, rq, rr))
349          | _ => raise Fail "Impossible case - Table.del LESS-LESS")))
350      | ord => (case del (if_eq ord NONE k) m of
351          (p', (false, m')) =>
352            (p', (false, Branch3 (l, if_eq ord p' p, m', q, r)))
353        | (p', (true, m')) => (p', (false, case (l, r) of
354            (Branch2 (ll, lp, lr), Branch2 _) =>
355              Branch2 (Branch3 (ll, lp, lr, if_eq ord p' p, m'), q, r)
356          | (Branch3 (ll, lp, lm, lq, lr), _) =>
357              Branch3 (Branch2 (ll, lp, lm), lq,
358                Branch2 (lr, if_eq ord p' p, m'), q, r)
359          | (_, Branch3 (rl, rp, rm, rq, rr)) =>
360              Branch3 (l, if_eq ord p' p, Branch2 (m', q, rl), rp,
361                Branch2 (rm, rq, rr))
362          | _ => raise Fail "Impossible case - Table.del LESS-<any>"))))
363    | ord => (case del (if_eq ord NONE k) r of
364        (q', (false, r')) =>
365          (q', (false, Branch3 (l, p, m, if_eq ord q' q, r')))
366      | (q', (true, r')) => (q', (false, case (l, m) of
367          (Branch2 _, Branch2 (ml, mp, mr)) =>
368            Branch2 (l, p, Branch3 (ml, mp, mr, if_eq ord q' q, r'))
369        | (_, Branch3 (ml, mp, mm, mq, mr)) =>
370            Branch3 (l, p, Branch2 (ml, mp, mm), mq,
371              Branch2 (mr, if_eq ord q' q, r'))
372        | (Branch3 (ll, lp, lm, lq, lr), Branch2 (ml, mp, mr)) =>
373            Branch3 (Branch2 (ll, lp, lm), lq, Branch2 (lr, p, ml), mp,
374              Branch2 (mr, if_eq ord q' q, r'))
375        | _ => raise Fail "Impossible case - Table.del <any>"))))
376  | del _ _ = raise Fail "Impossible case - Table.del <topmost defn>";
377
378in
379
380fun delete key tab = snd (snd (del (SOME key) tab));
381fun delete_safe key tab = if defined tab key then delete key tab else tab;
382
383end;
384
385
386(* membership operations *)
387
388fun member eq tab (key, x) =
389  case lookup tab key of NONE => false | SOME y => eq x y
390
391fun insert eq (key, x) =
392  modify key
393         (fn NONE => x
394           | SOME y => if eq x y then raise SAME else raise DUP key)
395
396fun remove eq (key, x) tab =
397  case lookup tab key of
398    NONE => tab
399  | SOME y => if eq x y then delete key tab else tab
400
401
402(* simultaneous modifications *)
403
404fun make entries = Portable.foldl' update_new entries empty;
405
406fun join f (table1, table2) =
407  let
408    fun add (key, y) tab =
409      modify key (fn NONE => y | SOME x => f key (x, y)) tab;
410  in
411    if is_empty table1 then table2
412    else fold_table add table2 table1
413  end
414
415fun merge eq =
416  join (fn key => fn (x,y) => if eq x y then raise SAME else raise DUP key)
417
418
419(* list tables *)
420
421fun lookup_list tab key = these (lookup tab key);
422
423fun cons_list (key, x) tab = modify key (fn NONE => [x] | SOME xs => x :: xs) tab;
424
425fun insert_list eq (key, x) =
426  modify key (fn NONE => [x]
427               | SOME xs => if op_mem eq x xs then raise SAME else x :: xs)
428
429fun remove_list eq (key, x) tab =
430  map_entry key (fn xs => (case op_remove eq x xs of [] => raise UNDEF key | ys => ys)) tab
431  handle UNDEF _ => delete key tab;
432
433fun update_list eq (key, x) =
434  modify key (fn NONE => [x] | SOME [] => [x] | SOME (xs as y :: _) =>
435    if eq x y then raise SAME else op_update eq x xs);
436
437fun make_list args = Portable.foldr' cons_list args empty;
438fun dest_list tab =
439  List.concat (map (fn (key, xs) => map (pair key) xs) (dest tab))
440fun merge_list eq = join (fn _ => uncurry (op_union eq));
441
442
443(* unit tables *)
444
445type set = unit table;
446
447fun insert_set x = default (x, ());
448fun remove_set x : set -> set = delete_safe x;
449fun make_set entries = Portable.foldl' insert_set entries empty;
450
451(* pretty-printing *)
452fun pp vpp tab =
453  let
454    open HOLPP
455    fun ppi (k,v) =
456      block CONSISTENT 0 [Key.pp k, add_string " |->", add_break(1,2), vpp v]
457  in
458    block CONSISTENT 0 [
459      add_string "Table{",
460      block INCONSISTENT 6
461            (pr_list ppi [add_string ",", add_break(1,0)] (dest tab)),
462      add_string "}"
463    ]
464  end
465
466(*final declarations of this structure!*)
467val map = map_table;
468val fold = fold_table;
469val fold_rev = fold_rev_table;
470
471end;
472