1(* ========================================================================= *)
2(* FINITE MAPS IMPLEMENTED WITH RANDOMLY BALANCED TREES                      *)
3(* Copyright (c) 2004 Joe Hurd, distributed under the BSD License            *)
4(* ========================================================================= *)
5
6structure Map :> Map =
7struct
8
9(* ------------------------------------------------------------------------- *)
10(* Importing useful functionality.                                           *)
11(* ------------------------------------------------------------------------- *)
12
13exception Bug = Useful.Bug;
14
15exception Error = Useful.Error;
16
17val pointerEqual = Portable.pointerEqual;
18
19val K = Useful.K;
20
21val randomInt = Portable.randomInt;
22
23val randomWord = Portable.randomWord;
24
25(* ------------------------------------------------------------------------- *)
26(* Converting a comparison function to an equality function.                 *)
27(* ------------------------------------------------------------------------- *)
28
29fun equalKey compareKey key1 key2 = compareKey (key1,key2) = EQUAL;
30
31(* ------------------------------------------------------------------------- *)
32(* Priorities.                                                               *)
33(* ------------------------------------------------------------------------- *)
34
35type priority = Word.word;
36
37val randomPriority = randomWord;
38
39val comparePriority = Word.compare;
40
41(* ------------------------------------------------------------------------- *)
42(* Priority search trees.                                                    *)
43(* ------------------------------------------------------------------------- *)
44
45datatype ('key,'value) tree =
46    E
47  | T of ('key,'value) node
48
49and ('key,'value) node =
50    Node of
51      {size : int,
52       priority : priority,
53       left : ('key,'value) tree,
54       key : 'key,
55       value : 'value,
56       right : ('key,'value) tree};
57
58fun lowerPriorityNode node1 node2 =
59    let
60      val Node {priority = p1, ...} = node1
61      and Node {priority = p2, ...} = node2
62    in
63      comparePriority (p1,p2) = LESS
64    end;
65
66(* ------------------------------------------------------------------------- *)
67(* Tree debugging functions.                                                 *)
68(* ------------------------------------------------------------------------- *)
69
70(*BasicDebug
71local
72  fun checkSizes tree =
73      case tree of
74        E => 0
75      | T (Node {size,left,right,...}) =>
76        let
77          val l = checkSizes left
78          and r = checkSizes right
79
80          val () = if l + 1 + r = size then () else raise Bug "wrong size"
81        in
82          size
83        end;
84
85  fun checkSorted compareKey x tree =
86      case tree of
87        E => x
88      | T (Node {left,key,right,...}) =>
89        let
90          val x = checkSorted compareKey x left
91
92          val () =
93              case x of
94                NONE => ()
95              | SOME k =>
96                case compareKey (k,key) of
97                  LESS => ()
98                | EQUAL => raise Bug "duplicate keys"
99                | GREATER => raise Bug "unsorted"
100
101          val x = SOME key
102        in
103          checkSorted compareKey x right
104        end;
105
106  fun checkPriorities compareKey tree =
107      case tree of
108        E => NONE
109      | T node =>
110        let
111          val Node {left,right,...} = node
112
113          val () =
114              case checkPriorities compareKey left of
115                NONE => ()
116              | SOME lnode =>
117                if not (lowerPriorityNode node lnode) then ()
118                else raise Bug "left child has greater priority"
119
120          val () =
121              case checkPriorities compareKey right of
122                NONE => ()
123              | SOME rnode =>
124                if not (lowerPriorityNode node rnode) then ()
125                else raise Bug "right child has greater priority"
126        in
127          SOME node
128        end;
129in
130  fun treeCheckInvariants compareKey tree =
131      let
132        val _ = checkSizes tree
133
134        val _ = checkSorted compareKey NONE tree
135
136        val _ = checkPriorities compareKey tree
137      in
138        tree
139      end
140      handle Error err => raise Bug err;
141end;
142*)
143
144(* ------------------------------------------------------------------------- *)
145(* Tree operations.                                                          *)
146(* ------------------------------------------------------------------------- *)
147
148fun treeNew () = E;
149
150fun nodeSize (Node {size = x, ...}) = x;
151
152fun treeSize tree =
153    case tree of
154      E => 0
155    | T x => nodeSize x;
156
157fun mkNode priority left key value right =
158    let
159      val size = treeSize left + 1 + treeSize right
160    in
161      Node
162        {size = size,
163         priority = priority,
164         left = left,
165         key = key,
166         value = value,
167         right = right}
168    end;
169
170fun mkTree priority left key value right =
171    let
172      val node = mkNode priority left key value right
173    in
174      T node
175    end;
176
177(* ------------------------------------------------------------------------- *)
178(* Extracting the left and right spines of a tree.                           *)
179(* ------------------------------------------------------------------------- *)
180
181fun treeLeftSpine acc tree =
182    case tree of
183      E => acc
184    | T node => nodeLeftSpine acc node
185
186and nodeLeftSpine acc node =
187    let
188      val Node {left,...} = node
189    in
190      treeLeftSpine (node :: acc) left
191    end;
192
193fun treeRightSpine acc tree =
194    case tree of
195      E => acc
196    | T node => nodeRightSpine acc node
197
198and nodeRightSpine acc node =
199    let
200      val Node {right,...} = node
201    in
202      treeRightSpine (node :: acc) right
203    end;
204
205(* ------------------------------------------------------------------------- *)
206(* Singleton trees.                                                          *)
207(* ------------------------------------------------------------------------- *)
208
209fun mkNodeSingleton priority key value =
210    let
211      val size = 1
212      and left = E
213      and right = E
214    in
215      Node
216        {size = size,
217         priority = priority,
218         left = left,
219         key = key,
220         value = value,
221         right = right}
222    end;
223
224fun nodeSingleton (key,value) =
225    let
226      val priority = randomPriority ()
227    in
228      mkNodeSingleton priority key value
229    end;
230
231fun treeSingleton key_value =
232    let
233      val node = nodeSingleton key_value
234    in
235      T node
236    end;
237
238(* ------------------------------------------------------------------------- *)
239(* Appending two trees, where every element of the first tree is less than   *)
240(* every element of the second tree.                                         *)
241(* ------------------------------------------------------------------------- *)
242
243fun treeAppend tree1 tree2 =
244    case tree1 of
245      E => tree2
246    | T node1 =>
247      case tree2 of
248        E => tree1
249      | T node2 =>
250        if lowerPriorityNode node1 node2 then
251          let
252            val Node {priority,left,key,value,right,...} = node2
253
254            val left = treeAppend tree1 left
255          in
256            mkTree priority left key value right
257          end
258        else
259          let
260            val Node {priority,left,key,value,right,...} = node1
261
262            val right = treeAppend right tree2
263          in
264            mkTree priority left key value right
265          end;
266
267(* ------------------------------------------------------------------------- *)
268(* Appending two trees and a node, where every element of the first tree is  *)
269(* less than the node, which in turn is less than every element of the       *)
270(* second tree.                                                              *)
271(* ------------------------------------------------------------------------- *)
272
273fun treeCombine left node right =
274    let
275      val left_node = treeAppend left (T node)
276    in
277      treeAppend left_node right
278    end;
279
280(* ------------------------------------------------------------------------- *)
281(* Searching a tree for a value.                                             *)
282(* ------------------------------------------------------------------------- *)
283
284fun treePeek compareKey pkey tree =
285    case tree of
286      E => NONE
287    | T node => nodePeek compareKey pkey node
288
289and nodePeek compareKey pkey node =
290    let
291      val Node {left,key,value,right,...} = node
292    in
293      case compareKey (pkey,key) of
294        LESS => treePeek compareKey pkey left
295      | EQUAL => SOME value
296      | GREATER => treePeek compareKey pkey right
297    end;
298
299(* ------------------------------------------------------------------------- *)
300(* Tree paths.                                                               *)
301(* ------------------------------------------------------------------------- *)
302
303(* Generating a path by searching a tree for a key/value pair *)
304
305fun treePeekPath compareKey pkey path tree =
306    case tree of
307      E => (path,NONE)
308    | T node => nodePeekPath compareKey pkey path node
309
310and nodePeekPath compareKey pkey path node =
311    let
312      val Node {left,key,right,...} = node
313    in
314      case compareKey (pkey,key) of
315        LESS => treePeekPath compareKey pkey ((true,node) :: path) left
316      | EQUAL => (path, SOME node)
317      | GREATER => treePeekPath compareKey pkey ((false,node) :: path) right
318    end;
319
320(* A path splits a tree into left/right components *)
321
322fun addSidePath ((wentLeft,node),(leftTree,rightTree)) =
323    let
324      val Node {priority,left,key,value,right,...} = node
325    in
326      if wentLeft then (leftTree, mkTree priority rightTree key value right)
327      else (mkTree priority left key value leftTree, rightTree)
328    end;
329
330fun addSidesPath left_right = List.foldl addSidePath left_right;
331
332fun mkSidesPath path = addSidesPath (E,E) path;
333
334(* Updating the subtree at a path *)
335
336local
337  fun updateTree ((wentLeft,node),tree) =
338      let
339        val Node {priority,left,key,value,right,...} = node
340      in
341        if wentLeft then mkTree priority tree key value right
342        else mkTree priority left key value tree
343      end;
344in
345  fun updateTreePath tree = List.foldl updateTree tree;
346end;
347
348(* Inserting a new node at a path position *)
349
350fun insertNodePath node =
351    let
352      fun insert left_right path =
353          case path of
354            [] =>
355            let
356              val (left,right) = left_right
357            in
358              treeCombine left node right
359            end
360          | (step as (_,snode)) :: rest =>
361            if lowerPriorityNode snode node then
362              let
363                val left_right = addSidePath (step,left_right)
364              in
365                insert left_right rest
366              end
367            else
368              let
369                val (left,right) = left_right
370
371                val tree = treeCombine left node right
372              in
373                updateTreePath tree path
374              end
375    in
376      insert (E,E)
377    end;
378
379(* ------------------------------------------------------------------------- *)
380(* Using a key to split a node into three components: the keys comparing     *)
381(* less than the supplied key, an optional equal key, and the keys comparing *)
382(* greater.                                                                  *)
383(* ------------------------------------------------------------------------- *)
384
385fun nodePartition compareKey pkey node =
386    let
387      val (path,pnode) = nodePeekPath compareKey pkey [] node
388    in
389      case pnode of
390        NONE =>
391        let
392          val (left,right) = mkSidesPath path
393        in
394          (left,NONE,right)
395        end
396      | SOME node =>
397        let
398          val Node {left,key,value,right,...} = node
399
400          val (left,right) = addSidesPath (left,right) path
401        in
402          (left, SOME (key,value), right)
403        end
404    end;
405
406(* ------------------------------------------------------------------------- *)
407(* Searching a tree for a key/value pair.                                    *)
408(* ------------------------------------------------------------------------- *)
409
410fun treePeekKey compareKey pkey tree =
411    case tree of
412      E => NONE
413    | T node => nodePeekKey compareKey pkey node
414
415and nodePeekKey compareKey pkey node =
416    let
417      val Node {left,key,value,right,...} = node
418    in
419      case compareKey (pkey,key) of
420        LESS => treePeekKey compareKey pkey left
421      | EQUAL => SOME (key,value)
422      | GREATER => treePeekKey compareKey pkey right
423    end;
424
425(* ------------------------------------------------------------------------- *)
426(* Inserting new key/values into the tree.                                   *)
427(* ------------------------------------------------------------------------- *)
428
429fun treeInsert compareKey key_value tree =
430    let
431      val (key,value) = key_value
432
433      val (path,inode) = treePeekPath compareKey key [] tree
434    in
435      case inode of
436        NONE =>
437        let
438          val node = nodeSingleton (key,value)
439        in
440          insertNodePath node path
441        end
442      | SOME node =>
443        let
444          val Node {size,priority,left,right,...} = node
445
446          val node =
447              Node
448                {size = size,
449                 priority = priority,
450                 left = left,
451                 key = key,
452                 value = value,
453                 right = right}
454        in
455          updateTreePath (T node) path
456        end
457    end;
458
459(* ------------------------------------------------------------------------- *)
460(* Deleting key/value pairs: it raises an exception if the supplied key is   *)
461(* not present.                                                              *)
462(* ------------------------------------------------------------------------- *)
463
464fun treeDelete compareKey dkey tree =
465    case tree of
466      E => raise Bug "Map.delete: element not found"
467    | T node => nodeDelete compareKey dkey node
468
469and nodeDelete compareKey dkey node =
470    let
471      val Node {size,priority,left,key,value,right} = node
472    in
473      case compareKey (dkey,key) of
474        LESS =>
475        let
476          val size = size - 1
477          and left = treeDelete compareKey dkey left
478
479          val node =
480              Node
481                {size = size,
482                 priority = priority,
483                 left = left,
484                 key = key,
485                 value = value,
486                 right = right}
487        in
488          T node
489        end
490      | EQUAL => treeAppend left right
491      | GREATER =>
492        let
493          val size = size - 1
494          and right = treeDelete compareKey dkey right
495
496          val node =
497              Node
498                {size = size,
499                 priority = priority,
500                 left = left,
501                 key = key,
502                 value = value,
503                 right = right}
504        in
505          T node
506        end
507    end;
508
509(* ------------------------------------------------------------------------- *)
510(* Partial map is the basic operation for preserving tree structure.         *)
511(* It applies its argument function to the elements *in order*.              *)
512(* ------------------------------------------------------------------------- *)
513
514fun treeMapPartial f tree =
515    case tree of
516      E => E
517    | T node => nodeMapPartial f node
518
519and nodeMapPartial f (Node {priority,left,key,value,right,...}) =
520    let
521      val left = treeMapPartial f left
522      and vo = f (key,value)
523      and right = treeMapPartial f right
524    in
525      case vo of
526        NONE => treeAppend left right
527      | SOME value => mkTree priority left key value right
528    end;
529
530(* ------------------------------------------------------------------------- *)
531(* Mapping tree values.                                                      *)
532(* ------------------------------------------------------------------------- *)
533
534fun treeMap f tree =
535    case tree of
536      E => E
537    | T node => T (nodeMap f node)
538
539and nodeMap f node =
540    let
541      val Node {size,priority,left,key,value,right} = node
542
543      val left = treeMap f left
544      and value = f (key,value)
545      and right = treeMap f right
546    in
547      Node
548        {size = size,
549         priority = priority,
550         left = left,
551         key = key,
552         value = value,
553         right = right}
554    end;
555
556(* ------------------------------------------------------------------------- *)
557(* Merge is the basic operation for joining two trees. Note that the merged  *)
558(* key is always the one from the second map.                                *)
559(* ------------------------------------------------------------------------- *)
560
561fun treeMerge compareKey f1 f2 fb tree1 tree2 =
562    case tree1 of
563      E => treeMapPartial f2 tree2
564    | T node1 =>
565      case tree2 of
566        E => treeMapPartial f1 tree1
567      | T node2 => nodeMerge compareKey f1 f2 fb node1 node2
568
569and nodeMerge compareKey f1 f2 fb node1 node2 =
570    let
571      val Node {priority,left,key,value,right,...} = node2
572
573      val (l,kvo,r) = nodePartition compareKey key node1
574
575      val left = treeMerge compareKey f1 f2 fb l left
576      and right = treeMerge compareKey f1 f2 fb r right
577
578      val vo =
579          case kvo of
580            NONE => f2 (key,value)
581          | SOME kv => fb (kv,(key,value))
582    in
583      case vo of
584        NONE => treeAppend left right
585      | SOME value =>
586        let
587          val node = mkNodeSingleton priority key value
588        in
589          treeCombine left node right
590        end
591    end;
592
593(* ------------------------------------------------------------------------- *)
594(* A union operation on trees.                                               *)
595(* ------------------------------------------------------------------------- *)
596
597fun treeUnion compareKey f f2 tree1 tree2 =
598    case tree1 of
599      E => tree2
600    | T node1 =>
601      case tree2 of
602        E => tree1
603      | T node2 => nodeUnion compareKey f f2 node1 node2
604
605and nodeUnion compareKey f f2 node1 node2 =
606    if pointerEqual (node1,node2) then nodeMapPartial f2 node1
607    else
608      let
609        val Node {priority,left,key,value,right,...} = node2
610
611        val (l,kvo,r) = nodePartition compareKey key node1
612
613        val left = treeUnion compareKey f f2 l left
614        and right = treeUnion compareKey f f2 r right
615
616        val vo =
617            case kvo of
618              NONE => SOME value
619            | SOME kv => f (kv,(key,value))
620      in
621        case vo of
622          NONE => treeAppend left right
623        | SOME value =>
624          let
625            val node = mkNodeSingleton priority key value
626          in
627            treeCombine left node right
628          end
629      end;
630
631(* ------------------------------------------------------------------------- *)
632(* An intersect operation on trees.                                          *)
633(* ------------------------------------------------------------------------- *)
634
635fun treeIntersect compareKey f t1 t2 =
636    case t1 of
637      E => E
638    | T n1 =>
639      case t2 of
640        E => E
641      | T n2 => nodeIntersect compareKey f n1 n2
642
643and nodeIntersect compareKey f n1 n2 =
644    let
645      val Node {priority,left,key,value,right,...} = n2
646
647      val (l,kvo,r) = nodePartition compareKey key n1
648
649      val left = treeIntersect compareKey f l left
650      and right = treeIntersect compareKey f r right
651
652      val vo =
653          case kvo of
654            NONE => NONE
655          | SOME kv => f (kv,(key,value))
656    in
657      case vo of
658        NONE => treeAppend left right
659      | SOME value => mkTree priority left key value right
660    end;
661
662(* ------------------------------------------------------------------------- *)
663(* A union operation on trees which simply chooses the second value.         *)
664(* ------------------------------------------------------------------------- *)
665
666fun treeUnionDomain compareKey tree1 tree2 =
667    case tree1 of
668      E => tree2
669    | T node1 =>
670      case tree2 of
671        E => tree1
672      | T node2 =>
673        if pointerEqual (node1,node2) then tree2
674        else nodeUnionDomain compareKey node1 node2
675
676and nodeUnionDomain compareKey node1 node2 =
677    let
678      val Node {priority,left,key,value,right,...} = node2
679
680      val (l,_,r) = nodePartition compareKey key node1
681
682      val left = treeUnionDomain compareKey l left
683      and right = treeUnionDomain compareKey r right
684
685      val node = mkNodeSingleton priority key value
686    in
687      treeCombine left node right
688    end;
689
690(* ------------------------------------------------------------------------- *)
691(* An intersect operation on trees which simply chooses the second value.    *)
692(* ------------------------------------------------------------------------- *)
693
694fun treeIntersectDomain compareKey tree1 tree2 =
695    case tree1 of
696      E => E
697    | T node1 =>
698      case tree2 of
699        E => E
700      | T node2 =>
701        if pointerEqual (node1,node2) then tree2
702        else nodeIntersectDomain compareKey node1 node2
703
704and nodeIntersectDomain compareKey node1 node2 =
705    let
706      val Node {priority,left,key,value,right,...} = node2
707
708      val (l,kvo,r) = nodePartition compareKey key node1
709
710      val left = treeIntersectDomain compareKey l left
711      and right = treeIntersectDomain compareKey r right
712    in
713      if Option.isSome kvo then mkTree priority left key value right
714      else treeAppend left right
715    end;
716
717(* ------------------------------------------------------------------------- *)
718(* A difference operation on trees.                                          *)
719(* ------------------------------------------------------------------------- *)
720
721fun treeDifferenceDomain compareKey t1 t2 =
722    case t1 of
723      E => E
724    | T n1 =>
725      case t2 of
726        E => t1
727      | T n2 => nodeDifferenceDomain compareKey n1 n2
728
729and nodeDifferenceDomain compareKey n1 n2 =
730    if pointerEqual (n1,n2) then E
731    else
732      let
733        val Node {priority,left,key,value,right,...} = n1
734
735        val (l,kvo,r) = nodePartition compareKey key n2
736
737        val left = treeDifferenceDomain compareKey left l
738        and right = treeDifferenceDomain compareKey right r
739      in
740        if Option.isSome kvo then treeAppend left right
741        else mkTree priority left key value right
742      end;
743
744(* ------------------------------------------------------------------------- *)
745(* A subset operation on trees.                                              *)
746(* ------------------------------------------------------------------------- *)
747
748fun treeSubsetDomain compareKey tree1 tree2 =
749    case tree1 of
750      E => true
751    | T node1 =>
752      case tree2 of
753        E => false
754      | T node2 => nodeSubsetDomain compareKey node1 node2
755
756and nodeSubsetDomain compareKey node1 node2 =
757    pointerEqual (node1,node2) orelse
758    let
759      val Node {size,left,key,right,...} = node1
760    in
761      size <= nodeSize node2 andalso
762      let
763        val (l,kvo,r) = nodePartition compareKey key node2
764      in
765        Option.isSome kvo andalso
766        treeSubsetDomain compareKey left l andalso
767        treeSubsetDomain compareKey right r
768      end
769    end;
770
771(* ------------------------------------------------------------------------- *)
772(* Picking an arbitrary key/value pair from a tree.                          *)
773(* ------------------------------------------------------------------------- *)
774
775fun nodePick node =
776    let
777      val Node {key,value,...} = node
778    in
779      (key,value)
780    end;
781
782fun treePick tree =
783    case tree of
784      E => raise Bug "Map.treePick"
785    | T node => nodePick node;
786
787(* ------------------------------------------------------------------------- *)
788(* Removing an arbitrary key/value pair from a tree.                         *)
789(* ------------------------------------------------------------------------- *)
790
791fun nodeDeletePick node =
792    let
793      val Node {left,key,value,right,...} = node
794    in
795      ((key,value), treeAppend left right)
796    end;
797
798fun treeDeletePick tree =
799    case tree of
800      E => raise Bug "Map.treeDeletePick"
801    | T node => nodeDeletePick node;
802
803(* ------------------------------------------------------------------------- *)
804(* Finding the nth smallest key/value (counting from 0).                     *)
805(* ------------------------------------------------------------------------- *)
806
807fun treeNth n tree =
808    case tree of
809      E => raise Bug "Map.treeNth"
810    | T node => nodeNth n node
811
812and nodeNth n node =
813    let
814      val Node {left,key,value,right,...} = node
815
816      val k = treeSize left
817    in
818      if n = k then (key,value)
819      else if n < k then treeNth n left
820      else treeNth (n - (k + 1)) right
821    end;
822
823(* ------------------------------------------------------------------------- *)
824(* Removing the nth smallest key/value (counting from 0).                    *)
825(* ------------------------------------------------------------------------- *)
826
827fun treeDeleteNth n tree =
828    case tree of
829      E => raise Bug "Map.treeDeleteNth"
830    | T node => nodeDeleteNth n node
831
832and nodeDeleteNth n node =
833    let
834      val Node {size,priority,left,key,value,right} = node
835
836      val k = treeSize left
837    in
838      if n = k then ((key,value), treeAppend left right)
839      else if n < k then
840        let
841          val (key_value,left) = treeDeleteNth n left
842
843          val size = size - 1
844
845          val node =
846              Node
847                {size = size,
848                 priority = priority,
849                 left = left,
850                 key = key,
851                 value = value,
852                 right = right}
853        in
854          (key_value, T node)
855        end
856      else
857        let
858          val n = n - (k + 1)
859
860          val (key_value,right) = treeDeleteNth n right
861
862          val size = size - 1
863
864          val node =
865              Node
866                {size = size,
867                 priority = priority,
868                 left = left,
869                 key = key,
870                 value = value,
871                 right = right}
872        in
873          (key_value, T node)
874        end
875    end;
876
877(* ------------------------------------------------------------------------- *)
878(* Iterators.                                                                *)
879(* ------------------------------------------------------------------------- *)
880
881datatype ('key,'value) iterator =
882    LeftToRightIterator of
883      ('key * 'value) * ('key,'value) tree * ('key,'value) node list
884  | RightToLeftIterator of
885      ('key * 'value) * ('key,'value) tree * ('key,'value) node list;
886
887fun fromSpineLeftToRightIterator nodes =
888    case nodes of
889      [] => NONE
890    | Node {key,value,right,...} :: nodes =>
891      SOME (LeftToRightIterator ((key,value),right,nodes));
892
893fun fromSpineRightToLeftIterator nodes =
894    case nodes of
895      [] => NONE
896    | Node {key,value,left,...} :: nodes =>
897      SOME (RightToLeftIterator ((key,value),left,nodes));
898
899fun addLeftToRightIterator nodes tree = fromSpineLeftToRightIterator (treeLeftSpine nodes tree);
900
901fun addRightToLeftIterator nodes tree = fromSpineRightToLeftIterator (treeRightSpine nodes tree);
902
903fun treeMkIterator tree = addLeftToRightIterator [] tree;
904
905fun treeMkRevIterator tree = addRightToLeftIterator [] tree;
906
907fun readIterator iter =
908    case iter of
909      LeftToRightIterator (key_value,_,_) => key_value
910    | RightToLeftIterator (key_value,_,_) => key_value;
911
912fun advanceIterator iter =
913    case iter of
914      LeftToRightIterator (_,tree,nodes) => addLeftToRightIterator nodes tree
915    | RightToLeftIterator (_,tree,nodes) => addRightToLeftIterator nodes tree;
916
917fun foldIterator f acc io =
918    case io of
919      NONE => acc
920    | SOME iter =>
921      let
922        val (key,value) = readIterator iter
923      in
924        foldIterator f (f (key,value,acc)) (advanceIterator iter)
925      end;
926
927fun findIterator pred io =
928    case io of
929      NONE => NONE
930    | SOME iter =>
931      let
932        val key_value = readIterator iter
933      in
934        if pred key_value then SOME key_value
935        else findIterator pred (advanceIterator iter)
936      end;
937
938fun firstIterator f io =
939    case io of
940      NONE => NONE
941    | SOME iter =>
942      let
943        val key_value = readIterator iter
944      in
945        case f key_value of
946          NONE => firstIterator f (advanceIterator iter)
947        | s => s
948      end;
949
950fun compareIterator compareKey compareValue io1 io2 =
951    case (io1,io2) of
952      (NONE,NONE) => EQUAL
953    | (NONE, SOME _) => LESS
954    | (SOME _, NONE) => GREATER
955    | (SOME i1, SOME i2) =>
956      let
957        val (k1,v1) = readIterator i1
958        and (k2,v2) = readIterator i2
959      in
960        case compareKey (k1,k2) of
961          LESS => LESS
962        | EQUAL =>
963          (case compareValue (v1,v2) of
964             LESS => LESS
965           | EQUAL =>
966             let
967               val io1 = advanceIterator i1
968               and io2 = advanceIterator i2
969             in
970               compareIterator compareKey compareValue io1 io2
971             end
972           | GREATER => GREATER)
973        | GREATER => GREATER
974      end;
975
976fun equalIterator equalKey equalValue io1 io2 =
977    case (io1,io2) of
978      (NONE,NONE) => true
979    | (NONE, SOME _) => false
980    | (SOME _, NONE) => false
981    | (SOME i1, SOME i2) =>
982      let
983        val (k1,v1) = readIterator i1
984        and (k2,v2) = readIterator i2
985      in
986        equalKey k1 k2 andalso
987        equalValue v1 v2 andalso
988        let
989          val io1 = advanceIterator i1
990          and io2 = advanceIterator i2
991        in
992          equalIterator equalKey equalValue io1 io2
993        end
994      end;
995
996(* ------------------------------------------------------------------------- *)
997(* A type of finite maps.                                                    *)
998(* ------------------------------------------------------------------------- *)
999
1000datatype ('key,'value) map =
1001    Map of ('key * 'key -> order) * ('key,'value) tree;
1002
1003(* ------------------------------------------------------------------------- *)
1004(* Map debugging functions.                                                  *)
1005(* ------------------------------------------------------------------------- *)
1006
1007(*BasicDebug
1008fun checkInvariants s m =
1009    let
1010      val Map (compareKey,tree) = m
1011
1012      val _ = treeCheckInvariants compareKey tree
1013    in
1014      m
1015    end
1016    handle Bug bug => raise Bug (s ^ "\n" ^ "Map.checkInvariants: " ^ bug);
1017*)
1018
1019(* ------------------------------------------------------------------------- *)
1020(* Constructors.                                                             *)
1021(* ------------------------------------------------------------------------- *)
1022
1023fun new compareKey =
1024    let
1025      val tree = treeNew ()
1026    in
1027      Map (compareKey,tree)
1028    end;
1029
1030fun singleton compareKey key_value =
1031    let
1032      val tree = treeSingleton key_value
1033    in
1034      Map (compareKey,tree)
1035    end;
1036
1037(* ------------------------------------------------------------------------- *)
1038(* Map size.                                                                 *)
1039(* ------------------------------------------------------------------------- *)
1040
1041fun size (Map (_,tree)) = treeSize tree;
1042
1043fun null m = size m = 0;
1044
1045(* ------------------------------------------------------------------------- *)
1046(* Querying.                                                                 *)
1047(* ------------------------------------------------------------------------- *)
1048
1049fun peekKey (Map (compareKey,tree)) key = treePeekKey compareKey key tree;
1050
1051fun peek (Map (compareKey,tree)) key = treePeek compareKey key tree;
1052
1053fun inDomain key m = Option.isSome (peek m key);
1054
1055fun get m key =
1056    case peek m key of
1057      NONE => raise Error "Map.get: element not found"
1058    | SOME value => value;
1059
1060fun pick (Map (_,tree)) = treePick tree;
1061
1062fun nth (Map (_,tree)) n = treeNth n tree;
1063
1064fun random m =
1065    let
1066      val n = size m
1067    in
1068      if n = 0 then raise Bug "Map.random: empty"
1069      else nth m (randomInt n)
1070    end;
1071
1072(* ------------------------------------------------------------------------- *)
1073(* Adding.                                                                   *)
1074(* ------------------------------------------------------------------------- *)
1075
1076fun insert (Map (compareKey,tree)) key_value =
1077    let
1078      val tree = treeInsert compareKey key_value tree
1079    in
1080      Map (compareKey,tree)
1081    end;
1082
1083(*BasicDebug
1084val insert = fn m => fn kv =>
1085    checkInvariants "Map.insert: result"
1086      (insert (checkInvariants "Map.insert: input" m) kv);
1087*)
1088
1089fun insertList m =
1090    let
1091      fun ins (key_value,acc) = insert acc key_value
1092    in
1093      List.foldl ins m
1094    end;
1095
1096(* ------------------------------------------------------------------------- *)
1097(* Removing.                                                                 *)
1098(* ------------------------------------------------------------------------- *)
1099
1100fun delete (Map (compareKey,tree)) dkey =
1101    let
1102      val tree = treeDelete compareKey dkey tree
1103    in
1104      Map (compareKey,tree)
1105    end;
1106
1107(*BasicDebug
1108val delete = fn m => fn k =>
1109    checkInvariants "Map.delete: result"
1110      (delete (checkInvariants "Map.delete: input" m) k);
1111*)
1112
1113fun remove m key = if inDomain key m then delete m key else m;
1114
1115fun deletePick (Map (compareKey,tree)) =
1116    let
1117      val (key_value,tree) = treeDeletePick tree
1118    in
1119      (key_value, Map (compareKey,tree))
1120    end;
1121
1122(*BasicDebug
1123val deletePick = fn m =>
1124    let
1125      val (kv,m) = deletePick (checkInvariants "Map.deletePick: input" m)
1126    in
1127      (kv, checkInvariants "Map.deletePick: result" m)
1128    end;
1129*)
1130
1131fun deleteNth (Map (compareKey,tree)) n =
1132    let
1133      val (key_value,tree) = treeDeleteNth n tree
1134    in
1135      (key_value, Map (compareKey,tree))
1136    end;
1137
1138(*BasicDebug
1139val deleteNth = fn m => fn n =>
1140    let
1141      val (kv,m) = deleteNth (checkInvariants "Map.deleteNth: input" m) n
1142    in
1143      (kv, checkInvariants "Map.deleteNth: result" m)
1144    end;
1145*)
1146
1147fun deleteRandom m =
1148    let
1149      val n = size m
1150    in
1151      if n = 0 then raise Bug "Map.deleteRandom: empty"
1152      else deleteNth m (randomInt n)
1153    end;
1154
1155(* ------------------------------------------------------------------------- *)
1156(* Joining (all join operations prefer keys in the second map).              *)
1157(* ------------------------------------------------------------------------- *)
1158
1159fun merge {first,second,both} (Map (compareKey,tree1)) (Map (_,tree2)) =
1160    let
1161      val tree = treeMerge compareKey first second both tree1 tree2
1162    in
1163      Map (compareKey,tree)
1164    end;
1165
1166(*BasicDebug
1167val merge = fn f => fn m1 => fn m2 =>
1168    checkInvariants "Map.merge: result"
1169      (merge f
1170         (checkInvariants "Map.merge: input 1" m1)
1171         (checkInvariants "Map.merge: input 2" m2));
1172*)
1173
1174fun union f (Map (compareKey,tree1)) (Map (_,tree2)) =
1175    let
1176      fun f2 kv = f (kv,kv)
1177
1178      val tree = treeUnion compareKey f f2 tree1 tree2
1179    in
1180      Map (compareKey,tree)
1181    end;
1182
1183(*BasicDebug
1184val union = fn f => fn m1 => fn m2 =>
1185    checkInvariants "Map.union: result"
1186      (union f
1187         (checkInvariants "Map.union: input 1" m1)
1188         (checkInvariants "Map.union: input 2" m2));
1189*)
1190
1191fun intersect f (Map (compareKey,tree1)) (Map (_,tree2)) =
1192    let
1193      val tree = treeIntersect compareKey f tree1 tree2
1194    in
1195      Map (compareKey,tree)
1196    end;
1197
1198(*BasicDebug
1199val intersect = fn f => fn m1 => fn m2 =>
1200    checkInvariants "Map.intersect: result"
1201      (intersect f
1202         (checkInvariants "Map.intersect: input 1" m1)
1203         (checkInvariants "Map.intersect: input 2" m2));
1204*)
1205
1206(* ------------------------------------------------------------------------- *)
1207(* Iterators over maps.                                                      *)
1208(* ------------------------------------------------------------------------- *)
1209
1210fun mkIterator (Map (_,tree)) = treeMkIterator tree;
1211
1212fun mkRevIterator (Map (_,tree)) = treeMkRevIterator tree;
1213
1214(* ------------------------------------------------------------------------- *)
1215(* Mapping and folding.                                                      *)
1216(* ------------------------------------------------------------------------- *)
1217
1218fun mapPartial f (Map (compareKey,tree)) =
1219    let
1220      val tree = treeMapPartial f tree
1221    in
1222      Map (compareKey,tree)
1223    end;
1224
1225(*BasicDebug
1226val mapPartial = fn f => fn m =>
1227    checkInvariants "Map.mapPartial: result"
1228      (mapPartial f (checkInvariants "Map.mapPartial: input" m));
1229*)
1230
1231fun map f (Map (compareKey,tree)) =
1232    let
1233      val tree = treeMap f tree
1234    in
1235      Map (compareKey,tree)
1236    end;
1237
1238(*BasicDebug
1239val map = fn f => fn m =>
1240    checkInvariants "Map.map: result"
1241      (map f (checkInvariants "Map.map: input" m));
1242*)
1243
1244fun transform f = map (fn (_,value) => f value);
1245
1246fun filter pred =
1247    let
1248      fun f (key_value as (_,value)) =
1249          if pred key_value then SOME value else NONE
1250    in
1251      mapPartial f
1252    end;
1253
1254fun partition p =
1255    let
1256      fun np x = not (p x)
1257    in
1258      fn m => (filter p m, filter np m)
1259    end;
1260
1261fun foldl f b m = foldIterator f b (mkIterator m);
1262
1263fun foldr f b m = foldIterator f b (mkRevIterator m);
1264
1265fun app f m = foldl (fn (key,value,()) => f (key,value)) () m;
1266
1267(* ------------------------------------------------------------------------- *)
1268(* Searching.                                                                *)
1269(* ------------------------------------------------------------------------- *)
1270
1271fun findl p m = findIterator p (mkIterator m);
1272
1273fun findr p m = findIterator p (mkRevIterator m);
1274
1275fun firstl f m = firstIterator f (mkIterator m);
1276
1277fun firstr f m = firstIterator f (mkRevIterator m);
1278
1279fun exists p m = Option.isSome (findl p m);
1280
1281fun all p =
1282    let
1283      fun np x = not (p x)
1284    in
1285      fn m => not (exists np m)
1286    end;
1287
1288fun count pred =
1289    let
1290      fun f (k,v,acc) = if pred (k,v) then acc + 1 else acc
1291    in
1292      foldl f 0
1293    end;
1294
1295(* ------------------------------------------------------------------------- *)
1296(* Comparing.                                                                *)
1297(* ------------------------------------------------------------------------- *)
1298
1299fun compare compareValue (m1,m2) =
1300    if pointerEqual (m1,m2) then EQUAL
1301    else
1302      case Int.compare (size m1, size m2) of
1303        LESS => LESS
1304      | EQUAL =>
1305        let
1306          val Map (compareKey,_) = m1
1307
1308          val io1 = mkIterator m1
1309          and io2 = mkIterator m2
1310        in
1311          compareIterator compareKey compareValue io1 io2
1312        end
1313      | GREATER => GREATER;
1314
1315fun equal equalValue m1 m2 =
1316    pointerEqual (m1,m2) orelse
1317    (size m1 = size m2 andalso
1318     let
1319       val Map (compareKey,_) = m1
1320
1321       val io1 = mkIterator m1
1322       and io2 = mkIterator m2
1323     in
1324       equalIterator (equalKey compareKey) equalValue io1 io2
1325     end);
1326
1327(* ------------------------------------------------------------------------- *)
1328(* Set operations on the domain.                                             *)
1329(* ------------------------------------------------------------------------- *)
1330
1331fun unionDomain (Map (compareKey,tree1)) (Map (_,tree2)) =
1332    let
1333      val tree = treeUnionDomain compareKey tree1 tree2
1334    in
1335      Map (compareKey,tree)
1336    end;
1337
1338(*BasicDebug
1339val unionDomain = fn m1 => fn m2 =>
1340    checkInvariants "Map.unionDomain: result"
1341      (unionDomain
1342         (checkInvariants "Map.unionDomain: input 1" m1)
1343         (checkInvariants "Map.unionDomain: input 2" m2));
1344*)
1345
1346local
1347  fun uncurriedUnionDomain (m,acc) = unionDomain acc m;
1348in
1349  fun unionListDomain ms =
1350      case ms of
1351        [] => raise Bug "Map.unionListDomain: no sets"
1352      | m :: ms => List.foldl uncurriedUnionDomain m ms;
1353end;
1354
1355fun intersectDomain (Map (compareKey,tree1)) (Map (_,tree2)) =
1356    let
1357      val tree = treeIntersectDomain compareKey tree1 tree2
1358    in
1359      Map (compareKey,tree)
1360    end;
1361
1362(*BasicDebug
1363val intersectDomain = fn m1 => fn m2 =>
1364    checkInvariants "Map.intersectDomain: result"
1365      (intersectDomain
1366         (checkInvariants "Map.intersectDomain: input 1" m1)
1367         (checkInvariants "Map.intersectDomain: input 2" m2));
1368*)
1369
1370local
1371  fun uncurriedIntersectDomain (m,acc) = intersectDomain acc m;
1372in
1373  fun intersectListDomain ms =
1374      case ms of
1375        [] => raise Bug "Map.intersectListDomain: no sets"
1376      | m :: ms => List.foldl uncurriedIntersectDomain m ms;
1377end;
1378
1379fun differenceDomain (Map (compareKey,tree1)) (Map (_,tree2)) =
1380    let
1381      val tree = treeDifferenceDomain compareKey tree1 tree2
1382    in
1383      Map (compareKey,tree)
1384    end;
1385
1386(*BasicDebug
1387val differenceDomain = fn m1 => fn m2 =>
1388    checkInvariants "Map.differenceDomain: result"
1389      (differenceDomain
1390         (checkInvariants "Map.differenceDomain: input 1" m1)
1391         (checkInvariants "Map.differenceDomain: input 2" m2));
1392*)
1393
1394fun symmetricDifferenceDomain m1 m2 =
1395    unionDomain (differenceDomain m1 m2) (differenceDomain m2 m1);
1396
1397fun equalDomain m1 m2 = equal (K (K true)) m1 m2;
1398
1399fun subsetDomain (Map (compareKey,tree1)) (Map (_,tree2)) =
1400    treeSubsetDomain compareKey tree1 tree2;
1401
1402fun disjointDomain m1 m2 = null (intersectDomain m1 m2);
1403
1404(* ------------------------------------------------------------------------- *)
1405(* Converting to and from lists.                                             *)
1406(* ------------------------------------------------------------------------- *)
1407
1408fun keys m = foldr (fn (key,_,l) => key :: l) [] m;
1409
1410fun values m = foldr (fn (_,value,l) => value :: l) [] m;
1411
1412fun toList m = foldr (fn (key,value,l) => (key,value) :: l) [] m;
1413
1414fun fromList compareKey l =
1415    let
1416      val m = new compareKey
1417    in
1418      insertList m l
1419    end;
1420
1421(* ------------------------------------------------------------------------- *)
1422(* Pretty-printing.                                                          *)
1423(* ------------------------------------------------------------------------- *)
1424
1425fun toString m = "<" ^ (if null m then "" else Int.toString (size m)) ^ ">";
1426
1427end
1428