1(* ========================================================================= *)
2(* MATCHING AND UNIFICATION FOR SETS OF FIRST ORDER LOGIC TERMS              *)
3(* Copyright (c) 2001 Joe Hurd, distributed under the BSD License            *)
4(* ========================================================================= *)
5
6structure TermNet :> TermNet =
7struct
8
9open Useful;
10
11(* ------------------------------------------------------------------------- *)
12(* Anonymous variables.                                                      *)
13(* ------------------------------------------------------------------------- *)
14
15val anonymousName = Name.fromString "_";
16val anonymousVar = Term.Var anonymousName;
17
18(* ------------------------------------------------------------------------- *)
19(* Quotient terms.                                                           *)
20(* ------------------------------------------------------------------------- *)
21
22datatype qterm =
23    Var
24  | Fn of NameArity.nameArity * qterm list;
25
26local
27  fun cmp [] = EQUAL
28    | cmp (q1_q2 :: qs) =
29      if Portable.pointerEqual q1_q2 then cmp qs
30      else
31        case q1_q2 of
32          (Var,Var) => EQUAL
33        | (Var, Fn _) => LESS
34        | (Fn _, Var) => GREATER
35        | (Fn f1, Fn f2) => fnCmp f1 f2 qs
36
37  and fnCmp (n1,q1) (n2,q2) qs =
38    case NameArity.compare (n1,n2) of
39      LESS => LESS
40    | EQUAL => cmp (zip q1 q2 @ qs)
41    | GREATER => GREATER;
42in
43  fun compareQterm q1_q2 = cmp [q1_q2];
44
45  fun compareFnQterm (f1,f2) = fnCmp f1 f2 [];
46end;
47
48fun equalQterm q1 q2 = compareQterm (q1,q2) = EQUAL;
49
50fun equalFnQterm f1 f2 = compareFnQterm (f1,f2) = EQUAL;
51
52fun termToQterm (Term.Var _) = Var
53  | termToQterm (Term.Fn (f,l)) = Fn ((f, length l), List.map termToQterm l);
54
55local
56  fun qm [] = true
57    | qm ((Var,_) :: rest) = qm rest
58    | qm ((Fn _, Var) :: _) = false
59    | qm ((Fn (f,a), Fn (g,b)) :: rest) =
60      NameArity.equal f g andalso qm (zip a b @ rest);
61in
62  fun matchQtermQterm qtm qtm' = qm [(qtm,qtm')];
63end;
64
65local
66  fun qm [] = true
67    | qm ((Var,_) :: rest) = qm rest
68    | qm ((Fn _, Term.Var _) :: _) = false
69    | qm ((Fn ((f,n),a), Term.Fn (g,b)) :: rest) =
70      Name.equal f g andalso n = length b andalso qm (zip a b @ rest);
71in
72  fun matchQtermTerm qtm tm = qm [(qtm,tm)];
73end;
74
75local
76  fun qn qsub [] = SOME qsub
77    | qn qsub ((Term.Var v, qtm) :: rest) =
78      (case NameMap.peek qsub v of
79         NONE => qn (NameMap.insert qsub (v,qtm)) rest
80       | SOME qtm' => if equalQterm qtm qtm' then qn qsub rest else NONE)
81    | qn _ ((Term.Fn _, Var) :: _) = NONE
82    | qn qsub ((Term.Fn (f,a), Fn ((g,n),b)) :: rest) =
83      if Name.equal f g andalso length a = n then qn qsub (zip a b @ rest)
84      else NONE;
85in
86  fun matchTermQterm qsub tm qtm = qn qsub [(tm,qtm)];
87end;
88
89local
90  fun qv Var x = x
91    | qv x Var = x
92    | qv (Fn (f,a)) (Fn (g,b)) =
93      let
94        val _ = NameArity.equal f g orelse raise Error "TermNet.qv"
95      in
96        Fn (f, zipWith qv a b)
97      end;
98
99  fun qu qsub [] = qsub
100    | qu qsub ((Var, _) :: rest) = qu qsub rest
101    | qu qsub ((qtm, Term.Var v) :: rest) =
102      let
103        val qtm =
104            case NameMap.peek qsub v of NONE => qtm | SOME qtm' => qv qtm qtm'
105      in
106        qu (NameMap.insert qsub (v,qtm)) rest
107      end
108    | qu qsub ((Fn ((f,n),a), Term.Fn (g,b)) :: rest) =
109      if Name.equal f g andalso n = length b then qu qsub (zip a b @ rest)
110      else raise Error "TermNet.qu";
111in
112  fun unifyQtermQterm qtm qtm' = total (qv qtm) qtm';
113
114  fun unifyQtermTerm qsub qtm tm = total (qu qsub) [(qtm,tm)];
115end;
116
117local
118  fun qtermToTerm Var = anonymousVar
119    | qtermToTerm (Fn ((f,_),l)) = Term.Fn (f, List.map qtermToTerm l);
120in
121  val ppQterm = Print.ppMap qtermToTerm Term.pp;
122end;
123
124(* ------------------------------------------------------------------------- *)
125(* A type of term sets that can be efficiently matched and unified.          *)
126(* ------------------------------------------------------------------------- *)
127
128type parameters = {fifo : bool};
129
130datatype 'a net =
131    Result of 'a list
132  | Single of qterm * 'a net
133  | Multiple of 'a net option * 'a net NameArityMap.map;
134
135datatype 'a termNet = Net of parameters * int * (int * (int * 'a) net) option;
136
137(* ------------------------------------------------------------------------- *)
138(* Basic operations.                                                         *)
139(* ------------------------------------------------------------------------- *)
140
141fun new parm = Net (parm,0,NONE);
142
143local
144  fun computeSize (Result l) = length l
145    | computeSize (Single (_,n)) = computeSize n
146    | computeSize (Multiple (vs,fs)) =
147      NameArityMap.foldl
148        (fn (_,n,acc) => acc + computeSize n)
149        (case vs of SOME n => computeSize n | NONE => 0)
150        fs;
151in
152  fun netSize NONE = NONE
153    | netSize (SOME n) = SOME (computeSize n, n);
154end;
155
156fun size (Net (_,_,NONE)) = 0
157  | size (Net (_, _, SOME (i,_))) = i;
158
159fun null net = size net = 0;
160
161fun singles qtms a = List.foldr Single a qtms;
162
163local
164  fun pre NONE = (0,NONE)
165    | pre (SOME (i,n)) = (i, SOME n);
166
167  fun add (Result l) [] (Result l') = Result (l @ l')
168    | add a (input1 as qtm :: qtms) (Single (qtm',n)) =
169      if equalQterm qtm qtm' then Single (qtm, add a qtms n)
170      else add a input1 (add n [qtm'] (Multiple (NONE, NameArityMap.new ())))
171    | add a (Var :: qtms) (Multiple (vs,fs)) =
172      Multiple (SOME (oadd a qtms vs), fs)
173    | add a (Fn (f,l) :: qtms) (Multiple (vs,fs)) =
174      let
175        val n = NameArityMap.peek fs f
176      in
177        Multiple (vs, NameArityMap.insert fs (f, oadd a (l @ qtms) n))
178      end
179    | add _ _ _ = raise Bug "TermNet.insert: Match"
180
181  and oadd a qtms NONE = singles qtms a
182    | oadd a qtms (SOME n) = add a qtms n;
183
184  fun ins a qtm (i,n) = SOME (i + 1, oadd (Result [a]) [qtm] n);
185in
186  fun insert (Net (p,k,n)) (tm,a) =
187      Net (p, k + 1, ins (k,a) (termToQterm tm) (pre n))
188      handle Error _ => raise Bug "TermNet.insert: should never fail";
189end;
190
191fun fromList parm l = List.foldl (fn (tm_a,n) => insert n tm_a) (new parm) l;
192
193fun filter pred =
194    let
195      fun filt (Result l) =
196          (case List.filter (fn (_,a) => pred a) l of
197             [] => NONE
198           | l => SOME (Result l))
199        | filt (Single (qtm,n)) =
200          (case filt n of
201             NONE => NONE
202           | SOME n => SOME (Single (qtm,n)))
203        | filt (Multiple (vs,fs)) =
204          let
205            val vs = Option.mapPartial filt vs
206
207            val fs = NameArityMap.mapPartial (fn (_,n) => filt n) fs
208          in
209            if not (Option.isSome vs) andalso NameArityMap.null fs then NONE
210            else SOME (Multiple (vs,fs))
211          end
212    in
213      fn net as Net (_,_,NONE) => net
214       | Net (p, k, SOME (_,n)) => Net (p, k, netSize (filt n))
215    end
216    handle Error _ => raise Bug "TermNet.filter: should never fail";
217
218fun toString net = "TermNet[" ^ Int.toString (size net) ^ "]";
219
220(* ------------------------------------------------------------------------- *)
221(* Specialized fold operations to support matching and unification.          *)
222(* ------------------------------------------------------------------------- *)
223
224local
225  fun norm (0 :: ks, (f as (_,n)) :: fs, qtms) =
226      let
227        val (a,qtms) = revDivide qtms n
228      in
229        addQterm (Fn (f,a)) (ks,fs,qtms)
230      end
231    | norm stack = stack
232
233  and addQterm qtm (ks,fs,qtms) =
234      let
235        val ks = case ks of [] => [] | k :: ks => (k - 1) :: ks
236      in
237        norm (ks, fs, qtm :: qtms)
238      end
239
240  and addFn (f as (_,n)) (ks,fs,qtms) = norm (n :: ks, f :: fs, qtms);
241in
242  val stackEmpty = ([],[],[]);
243
244  val stackAddQterm = addQterm;
245
246  val stackAddFn = addFn;
247
248  fun stackValue ([],[],[qtm]) = qtm
249    | stackValue _ = raise Bug "TermNet.stackValue";
250end;
251
252local
253  fun fold _ acc [] = acc
254    | fold inc acc ((0,stack,net) :: rest) =
255      fold inc (inc (stackValue stack, net, acc)) rest
256    | fold inc acc ((n, stack, Single (qtm,net)) :: rest) =
257      fold inc acc ((n - 1, stackAddQterm qtm stack, net) :: rest)
258    | fold inc acc ((n, stack, Multiple (v,fns)) :: rest) =
259      let
260        val n = n - 1
261
262        val rest =
263            case v of
264              NONE => rest
265            | SOME net => (n, stackAddQterm Var stack, net) :: rest
266
267        fun getFns (f as (_,k), net, x) =
268            (k + n, stackAddFn f stack, net) :: x
269      in
270        fold inc acc (NameArityMap.foldr getFns rest fns)
271      end
272    | fold _ _ _ = raise Bug "TermNet.foldTerms.fold";
273in
274  fun foldTerms inc acc net = fold inc acc [(1,stackEmpty,net)];
275end;
276
277fun foldEqualTerms pat inc acc =
278    let
279      fun fold ([],net) = inc (pat,net,acc)
280        | fold (pat :: pats, Single (qtm,net)) =
281          if equalQterm pat qtm then fold (pats,net) else acc
282        | fold (Var :: pats, Multiple (v,_)) =
283          (case v of NONE => acc | SOME net => fold (pats,net))
284        | fold (Fn (f,a) :: pats, Multiple (_,fns)) =
285          (case NameArityMap.peek fns f of
286             NONE => acc
287           | SOME net => fold (a @ pats, net))
288        | fold _ = raise Bug "TermNet.foldEqualTerms.fold";
289    in
290      fn net => fold ([pat],net)
291    end;
292
293local
294  fun fold _ acc [] = acc
295    | fold inc acc (([],stack,net) :: rest) =
296      fold inc (inc (stackValue stack, net, acc)) rest
297    | fold inc acc ((Var :: pats, stack, net) :: rest) =
298      let
299        fun harvest (qtm,n,l) = (pats, stackAddQterm qtm stack, n) :: l
300      in
301        fold inc acc (foldTerms harvest rest net)
302      end
303    | fold inc acc ((pat :: pats, stack, Single (qtm,net)) :: rest) =
304      (case unifyQtermQterm pat qtm of
305         NONE => fold inc acc rest
306       | SOME qtm =>
307         fold inc acc ((pats, stackAddQterm qtm stack, net) :: rest))
308    | fold
309        inc acc
310        (((pat as Fn (f,a)) :: pats, stack, Multiple (v,fns)) :: rest) =
311      let
312        val rest =
313            case v of
314              NONE => rest
315            | SOME net => (pats, stackAddQterm pat stack, net) :: rest
316
317        val rest =
318            case NameArityMap.peek fns f of
319              NONE => rest
320            | SOME net => (a @ pats, stackAddFn f stack, net) :: rest
321      in
322        fold inc acc rest
323      end
324    | fold _ _ _ = raise Bug "TermNet.foldUnifiableTerms.fold";
325in
326  fun foldUnifiableTerms pat inc acc net =
327      fold inc acc [([pat],stackEmpty,net)];
328end;
329
330(* ------------------------------------------------------------------------- *)
331(* Matching and unification queries.                                         *)
332(*                                                                           *)
333(* These function return OVER-APPROXIMATIONS!                                *)
334(* Filter afterwards to get the precise set of satisfying values.            *)
335(* ------------------------------------------------------------------------- *)
336
337local
338  fun idwise ((m,_),(n,_)) = Int.compare (m,n);
339
340  fun fifoize ({fifo, ...} : parameters) l = if fifo then sort idwise l else l;
341in
342  fun finally parm l = List.map snd (fifoize parm l);
343end;
344
345local
346  fun mat acc [] = acc
347    | mat acc ((Result l, []) :: rest) = mat (l @ acc) rest
348    | mat acc ((Single (qtm,n), tm :: tms) :: rest) =
349      mat acc (if matchQtermTerm qtm tm then (n,tms) :: rest else rest)
350    | mat acc ((Multiple (vs,fs), tm :: tms) :: rest) =
351      let
352        val rest = case vs of NONE => rest | SOME n => (n,tms) :: rest
353
354        val rest =
355            case tm of
356              Term.Var _ => rest
357            | Term.Fn (f,l) =>
358              case NameArityMap.peek fs (f, length l) of
359                NONE => rest
360              | SOME n => (n, l @ tms) :: rest
361      in
362        mat acc rest
363      end
364    | mat _ _ = raise Bug "TermNet.match: Match";
365in
366  fun match (Net (_,_,NONE)) _ = []
367    | match (Net (p, _, SOME (_,n))) tm =
368      finally p (mat [] [(n,[tm])])
369      handle Error _ => raise Bug "TermNet.match: should never fail";
370end;
371
372local
373  fun unseenInc qsub v tms (qtm,net,rest) =
374      (NameMap.insert qsub (v,qtm), net, tms) :: rest;
375
376  fun seenInc qsub tms (_,net,rest) = (qsub,net,tms) :: rest;
377
378  fun mat acc [] = acc
379    | mat acc ((_, Result l, []) :: rest) = mat (l @ acc) rest
380    | mat acc ((qsub, Single (qtm,net), tm :: tms) :: rest) =
381      (case matchTermQterm qsub tm qtm of
382         NONE => mat acc rest
383       | SOME qsub => mat acc ((qsub,net,tms) :: rest))
384    | mat acc ((qsub, net as Multiple _, Term.Var v :: tms) :: rest) =
385      (case NameMap.peek qsub v of
386         NONE => mat acc (foldTerms (unseenInc qsub v tms) rest net)
387       | SOME qtm => mat acc (foldEqualTerms qtm (seenInc qsub tms) rest net))
388    | mat acc ((qsub, Multiple (_,fns), Term.Fn (f,a) :: tms) :: rest) =
389      let
390        val rest =
391            case NameArityMap.peek fns (f, length a) of
392              NONE => rest
393            | SOME net => (qsub, net, a @ tms) :: rest
394      in
395        mat acc rest
396      end
397    | mat _ _ = raise Bug "TermNet.matched.mat";
398in
399  fun matched (Net (_,_,NONE)) _ = []
400    | matched (Net (parm, _, SOME (_,net))) tm =
401      finally parm (mat [] [(NameMap.new (), net, [tm])])
402      handle Error _ => raise Bug "TermNet.matched: should never fail";
403end;
404
405local
406  fun inc qsub v tms (qtm,net,rest) =
407      (NameMap.insert qsub (v,qtm), net, tms) :: rest;
408
409  fun mat acc [] = acc
410    | mat acc ((_, Result l, []) :: rest) = mat (l @ acc) rest
411    | mat acc ((qsub, Single (qtm,net), tm :: tms) :: rest) =
412      (case unifyQtermTerm qsub qtm tm of
413         NONE => mat acc rest
414       | SOME qsub => mat acc ((qsub,net,tms) :: rest))
415    | mat acc ((qsub, net as Multiple _, Term.Var v :: tms) :: rest) =
416      (case NameMap.peek qsub v of
417         NONE => mat acc (foldTerms (inc qsub v tms) rest net)
418       | SOME qtm => mat acc (foldUnifiableTerms qtm (inc qsub v tms) rest net))
419    | mat acc ((qsub, Multiple (v,fns), Term.Fn (f,a) :: tms) :: rest) =
420      let
421        val rest = case v of NONE => rest | SOME net => (qsub,net,tms) :: rest
422
423        val rest =
424            case NameArityMap.peek fns (f, length a) of
425              NONE => rest
426            | SOME net => (qsub, net, a @ tms) :: rest
427      in
428        mat acc rest
429      end
430    | mat _ _ = raise Bug "TermNet.unify.mat";
431in
432  fun unify (Net (_,_,NONE)) _ = []
433    | unify (Net (parm, _, SOME (_,net))) tm =
434      finally parm (mat [] [(NameMap.new (), net, [tm])])
435      handle Error _ => raise Bug "TermNet.unify: should never fail";
436end;
437
438(* ------------------------------------------------------------------------- *)
439(* Pretty printing.                                                          *)
440(* ------------------------------------------------------------------------- *)
441
442local
443  fun inc (qtm, Result l, acc) =
444      List.foldl (fn ((n,a),acc) => (n,(qtm,a)) :: acc) acc l
445    | inc _ = raise Bug "TermNet.pp.inc";
446
447  fun toList (Net (_,_,NONE)) = []
448    | toList (Net (parm, _, SOME (_,net))) =
449      finally parm (foldTerms inc [] net);
450in
451  fun pp ppA =
452      Print.ppMap toList (Print.ppList (Print.ppOp2 " |->" ppQterm ppA));
453end;
454
455end
456