1signature LVSET =
2sig
3  type value
4  type t
5  val empty : t
6  val insert : t * value -> t
7  val fold : (value * 'a -> 'a) -> 'a -> t -> 'a
8  val listItems : t -> value list
9end
10
11signature LV_TERM_NET =
12sig
13
14  (* signature names modelled on Binarymap's *)
15  type lvtermnet
16  type term = Term.term
17  type key = Term.term list * Term.term
18  type value
19  structure Set : LVSET where type value = value
20
21  val empty : lvtermnet
22  val insert : (lvtermnet * key * value) -> lvtermnet
23  val find : lvtermnet * key -> Set.t
24  val peek : lvtermnet * key -> Set.t
25  val match : lvtermnet * term -> (key * value) list
26
27  val delete : lvtermnet * key -> lvtermnet * Set.t
28  val numItems : lvtermnet -> int
29  val listItems : lvtermnet -> (key * value) list
30  val app : (key * Set.t -> unit) -> lvtermnet -> unit
31  val fold : (key * value * 'b -> 'b) -> 'b -> lvtermnet -> 'b
32
33end
34
35functor LVTermNetFunctor(S : LVSET) : LV_TERM_NET =
36struct
37
38open HolKernel
39
40datatype label = V of int
41               | C of {Name : string, Thy : string} * int
42               | Lam of int
43               | LV of string * int
44type key = term list * term
45type value = S.value
46structure Set = S
47
48val tlt_compare = pair_compare (list_compare Term.compare, Term.compare)
49
50fun labcmp (p as (l1, l2)) =
51    case p of
52      (V n1, V n2) => Int.compare(n1, n2)
53    | (V _, _) => LESS
54    | (_, V _) => GREATER
55    | (C ({Name=nm1,Thy=th1}, n1), C ({Name=nm2,Thy=th2}, n2)) =>
56         pair_compare (Int.compare,
57                       pair_compare (String.compare, String.compare))
58                      ((n1, (th1, nm1)), (n2, (th2, nm2)))
59    | (C _, _) => LESS
60    | (_, C _) => GREATER
61    | (Lam n1, Lam n2) => Int.compare (n1, n2)
62    | (Lam _, _) => LESS
63    | (_, Lam _) => GREATER
64    | (LV p1, LV p2) => pair_compare (String.compare, Int.compare) (p1, p2)
65
66datatype N = LF of (key,S.t) Binarymap.dict
67           | ND of (label,N) Binarymap.dict
68
69type lvtermnet = N * int
70
71val empty_node = ND (Binarymap.mkDict labcmp)
72val empty = (empty_node, 0)
73
74fun ndest_term (fvs, tm) = let
75  val (f, args) = strip_comb tm
76  val args' = map (fn t => (fvs, t)) args
77in
78  case dest_term f of
79    VAR(s, ty) => if op_mem aconv f fvs then (LV (s, length args), args')
80                  else (V (length args), args')
81  | LAMB(bv, bod) =>
82      (Lam (length args), (op_set_diff aconv fvs [bv], bod) :: args')
83  | CONST{Name,Thy,Ty} => (C ({Name=Name,Thy=Thy}, length args), args')
84  | COMB _ => raise Fail "impossible"
85end
86
87fun cons_insert (bmap, k, i) =
88    case Binarymap.peek(bmap,k) of
89      NONE => Binarymap.insert(bmap,k,S.insert(S.empty,i))
90    | SOME items => Binarymap.insert(bmap,k,S.insert(items, i))
91
92fun insert ((net,sz), k, item) = let
93  fun newnode labs =
94      case labs of
95        [] => LF (Binarymap.mkDict tlt_compare)
96      | _ => empty_node
97  fun trav (net, tms) =
98      case (net, tms) of
99        (LF d, []) => LF (cons_insert(d,k,item))
100      | (ND d, k::ks0) => let
101          val (lab, rest) = ndest_term k
102          val ks = rest @ ks0
103          val n' =
104              case Binarymap.peek(d,lab) of
105                NONE => trav(newnode ks, ks)
106              | SOME n => trav(n, ks)
107        in
108          ND (Binarymap.insert(d, lab, n'))
109        end
110      | _ => raise Fail "LVTermNet.insert: catastrophic invariant failure"
111in
112  (trav(net,[k]), sz + 1)
113end
114
115fun listItems (net, sz) = let
116  fun cons'(k,vs,acc) = S.fold (fn (v,acc) => (k,v)::acc) acc vs
117  fun trav (net, acc) =
118      case net of
119        LF d => Binarymap.foldl cons' acc d
120      | ND d => let
121          fun foldthis (k,v,acc) = trav(v,acc)
122        in
123          Binarymap.foldl foldthis acc d
124        end
125in
126  trav(net, [])
127end
128
129fun numItems (net, sz) = sz
130
131fun peek ((net,sz), k) = let
132  fun trav (net, tms) =
133      case (net, tms) of
134        (LF d, []) => (valOf (Binarymap.peek(d, k)) handle Option => S.empty)
135      | (ND d, k::ks) => let
136          val (lab, rest) = ndest_term k
137        in
138          case Binarymap.peek(d, lab) of
139            NONE => S.empty
140          | SOME n => trav(n, rest @ ks)
141        end
142      | _ => raise Fail "LVTermNet.peek: catastrophic invariant failure"
143in
144  trav(net, [k])
145end
146
147val find = peek
148
149fun lookup_label tm = let
150  val (f, args) = strip_comb tm
151in
152  case dest_term f of
153    CONST{Name, Thy, ...} => (C ({Name=Name,Thy=Thy}, length args), args)
154  | LAMB(Bvar, Body) => (Lam (length args), Body::args)
155  | VAR (s, _) => (LV (s, length args), args)
156  | _ => raise Fail "LVTermNet.lookup_label: catastrophic invariant failure"
157end
158
159fun conslistItems (d, acc) = let
160  fun listfold k (v,acc) = (k,v)::acc
161  fun mapfold (k,vs,acc) = S.fold (listfold k) acc vs
162in
163  Binarymap.foldl mapfold acc d
164end
165
166fun match ((net,sz), tm) = let
167  fun trav acc (net, ks) =
168      case (net, ks) of
169        (LF d, []) => conslistItems (d, acc)
170      | (ND d, k::ks0) => let
171          val varresult = case Binarymap.peek(d, V 0) of
172                            NONE => acc
173                          | SOME n => trav acc (n, ks0)
174          val (lab, rest) = lookup_label k
175          val restn = length rest
176          val varhead_results = let
177            fun recurse acc n =
178              if n = 0 then acc
179              else
180                case Binarymap.peek (d, V n) of
181                    NONE => recurse acc (n - 1)
182                  | SOME m => recurse
183                                (trav acc (m, List.drop(rest, restn - n) @ ks0))
184                                (n - 1)
185          in
186            recurse varresult (length (#2 (strip_comb k)))
187          end
188        in
189          case Binarymap.peek (d, lab) of
190            NONE => varhead_results
191          | SOME n => trav varhead_results (n, rest @ ks0)
192        end
193      | _ => raise Fail "LVTermNet.match: catastrophic invariant failure"
194in
195  trav [] (net, [tm])
196end
197
198fun delete ((net,sz), k) = let
199  fun trav (p as (net, ks)) =
200      case p of
201        (LF d, []) => let
202          val (d',removed) = Binarymap.remove(d, k)
203        in
204          if Binarymap.numItems d' = 0 then (NONE, removed)
205          else (SOME (LF d'), removed)
206        end
207      | (ND d, k::ks) => let
208          val (lab, rest) = ndest_term k
209        in
210          case Binarymap.peek(d, lab) of
211            NONE => raise Binarymap.NotFound
212          | SOME n => let
213            in
214              case trav (n, rest @ ks) of
215                (NONE, removed) => let
216                  val (d',_) = Binarymap.remove(d, lab)
217                in
218                  if Binarymap.numItems d' = 0 then (NONE, removed)
219                  else (SOME (ND d'), removed)
220                end
221              | (SOME n', removed) => (SOME (ND (Binarymap.insert(d,lab,n'))),
222                                       removed)
223            end
224        end
225      | _ => raise Fail "LVTermNet.delete: catastrophic invariant failure"
226in
227  case trav (net, [k]) of
228    (NONE, removed) => (empty, removed)
229  | (SOME n, removed) =>  ((n,sz-1), removed)
230end
231
232fun app f (net, sz) = let
233  fun trav n =
234      case n of
235        LF d => Binarymap.app f d
236      | ND d => Binarymap.app (fn (lab, n) => trav n) d
237in
238  trav net
239end
240
241fun consfoldl f acc d = let
242  fun setfold k (d, acc) = f (k, d, acc)
243  fun mapfold (k,vs,acc) = S.fold (setfold k) acc vs
244in
245  Binarymap.foldl mapfold acc d
246end
247
248fun fold f acc (net, sz) = let
249  fun trav acc n =
250      case n of
251        LF d => consfoldl f acc d
252      | ND d => Binarymap.foldl (fn (lab,n',acc) => trav acc n') acc d
253in
254  trav acc net
255end
256
257end (* struct *)
258