1(* ========================================================================= *)
2(* MATCHING AND UNIFICATION FOR SETS OF TERMS                                *)
3(* Copyright (c) 2001-2004 Joe Hurd.                                         *)
4(* ========================================================================= *)
5
6(*
7app load ["mlibUseful", "mlibTerm"];
8*)
9
10(*
11*)
12structure mlibTermnet :> mlibTermnet =
13struct
14
15open mlibUseful mlibTerm;
16
17structure M = Binarymap; local open Binarymap in end;
18
19(* ------------------------------------------------------------------------- *)
20(* Tuning parameters.                                                        *)
21(* ------------------------------------------------------------------------- *)
22
23type parameters = {fifo : bool};
24
25(* ------------------------------------------------------------------------- *)
26(* Helper functions.                                                         *)
27(* ------------------------------------------------------------------------- *)
28
29val flatten = List.concat;
30
31val omap = Option.map;
32
33(* ------------------------------------------------------------------------- *)
34(* Variable and function maps                                                *)
35(* ------------------------------------------------------------------------- *)
36
37type 'a vmap = (string, 'a) M.dict;
38fun empty_vmap () : 'a vmap = M.mkDict String.compare;
39
40type 'a fmap = (string * int, 'a) M.dict;
41fun empty_fmap () : 'a fmap = M.mkDict (lex_order String.compare Int.compare);
42
43(* ------------------------------------------------------------------------- *)
44(* Dealing with the terms that emerge from termnets.                         *)
45(* ------------------------------------------------------------------------- *)
46
47fun fifoize ({fifo, ...} : parameters) l =
48  if fifo then sort (fn ((m,_),(n,_)) => Int.compare (m,n)) l else l;
49
50fun finally parm l = map snd (fifoize parm l);
51
52(* ------------------------------------------------------------------------- *)
53(* Quotient terms                                                            *)
54(* ------------------------------------------------------------------------- *)
55
56datatype qterm = VAR | FN of (string * int) * qterm list;
57
58fun qterm (Var _) = VAR
59  | qterm (Fn (f,l)) = FN ((f, length l), map qterm l);
60
61local
62  fun qm [] = true
63    | qm ((VAR, _) :: rest) = qm rest
64    | qm ((FN _, Var _) :: _) = false
65    | qm ((FN ((f,n),a), Fn (g,b)) :: rest) =
66    f = g andalso n = length b andalso qm (zip a b @ rest);
67in
68  fun qmatch qtm tm = qm [(qtm,tm)];
69end;
70
71local
72  fun update sub v qtm =
73    (case M.peek (sub,v) of NONE => M.insert (sub,v,qtm)
74     | SOME qtm' => if qtm = qtm' then sub else raise Error "matchq: vars");
75
76  fun qn sub [] = sub
77    | qn sub ((Var v, qtm) :: rest) = qn (update sub v qtm) rest
78    | qn _ ((Fn _, VAR) :: _) = raise Error "matchq: match fn var"
79    | qn sub ((Fn (f,a), FN ((g,n),b)) :: rest) =
80    if f = g andalso length a = n then qn sub (zip a b @ rest)
81    else raise Error "matchq: match fn fn";
82in
83  fun matchq sub tm qtm = qn sub [(tm,qtm)];
84end;
85
86local
87  fun qv VAR x = x
88    | qv x VAR = x
89    | qv (FN (f,a)) (FN (g,b)) =
90      let
91        val () = assert (f = g) (Error "qunify: incompatible vars")
92      in
93        FN (f, zipwith qv a b)
94      end;
95
96  fun qu sub [] = sub
97    | qu sub ((VAR, _) :: rest) = qu sub rest
98    | qu sub ((qtm, Var v) :: rest) =
99    let val qtm = case M.peek (sub,v) of NONE => qtm | SOME qtm' => qv qtm qtm'
100    in qu (M.insert (sub, v, qtm)) rest
101    end
102    | qu sub ((FN ((f,n),a), Fn (g,b)) :: rest) =
103    if f = g andalso n = length b then qu sub (zip a b @ rest)
104    else raise Error "unifyq: structurally different";
105in
106  fun qunify qtm qtm' = total (qv qtm) qtm';
107  fun unifyq sub qtm tm = total (qu sub) [(qtm,tm)];
108end;
109
110fun qterm' VAR = Var "_" | qterm' (FN ((f,_),l)) = Fn (f, map qterm' l);
111
112val pp_qterm = pp_map qterm' pp_term;
113
114(* ------------------------------------------------------------------------- *)
115(* mlibTerm discrimination trees are optimized for match queries.                *)
116(* ------------------------------------------------------------------------- *)
117
118datatype 'a net =
119  RESULT of 'a list
120| SINGLE of qterm * 'a net
121| MULTIPLE of 'a net option * 'a net fmap;
122
123datatype 'a termnet = NET of parameters * int * (int * (int * 'a) net) option;
124
125fun empty parm : 'a termnet = NET (parm,0,NONE);
126
127fun size (NET (_,_,NONE)) = 0 | size (NET (_, _, SOME (i,_))) = i;
128
129fun singles tms a = foldr SINGLE a tms;
130
131local
132  fun pre NONE = (0,NONE) | pre (SOME (i,n)) = (i, SOME n);
133  fun add (RESULT l) [] (RESULT l') = RESULT (l @ l')
134    | add a (input1 as tm :: tms) (SINGLE (tm',n)) =
135    if tm = tm' then SINGLE (tm, add a tms n)
136    else add a input1 (add n [tm'] (MULTIPLE (NONE, empty_fmap ())))
137    | add a (VAR::tms) (MULTIPLE (vs,fs)) = MULTIPLE (SOME (oadd a tms vs), fs)
138    | add a (FN (f,l) :: tms) (MULTIPLE (vs,fs)) =
139    MULTIPLE (vs, M.insert (fs, f, oadd a (l @ tms) (M.peek (fs,f))))
140    | add _ _ _ = raise Bug "mlibTermnet.insert: mlibMatch"
141  and oadd a tms NONE = singles tms a
142    | oadd a tms (SOME n) = add a tms n;
143  fun ins a tm (i,n) = SOME (i + 1, oadd (RESULT [a]) [tm] n);
144in
145  fun insert (tm |-> a) (NET (p,k,n)) =
146    NET (p, k + 1, ins (k,a) (qterm tm) (pre n))
147    handle Error _ => raise Bug "mlibTermnet.insert: should never fail";
148end;
149
150local
151  fun mat acc [] = acc
152    | mat acc ((RESULT l, []) :: rest) = mat (l @ acc) rest
153    | mat acc ((SINGLE (tm',n), tm :: tms) :: rest) =
154    mat acc (if qmatch tm' tm then (n,tms) :: rest else rest)
155    | mat acc ((MULTIPLE (vs,fs), tm :: tms) :: rest) =
156    let
157      val rest = case vs of NONE => rest | SOME n => (n,tms) :: rest
158      val rest =
159        case tm of Var _ => rest
160        | Fn (f,l) =>
161          (case M.peek (fs, (f, length l)) of NONE => rest
162           | SOME n => (n, l @ tms) :: rest)
163    in
164      mat acc rest
165    end
166    | mat _ _ = raise Bug "mlibTermnet.match: mlibMatch";
167in
168  fun match (NET (_,_,NONE)) _ = []
169    | match (NET (p, _, SOME (_,n))) tm = finally p (mat [] [(n,[tm])])
170    handle Error _ => raise Bug "mlibTermnet.match: should never fail";
171end;
172
173fun harvest inc =
174  let
175    fun chk [] acc = acc
176      | chk (([],[],[tm],net) :: rest) acc = chk rest (inc tm net acc)
177      | chk ((pl, (f as (_,i), 0) :: fl, sl, n) :: l) acc =
178      let val (a,b) = divide sl i
179      in chk ((pl, fl, FN (f, rev a) :: b, n) :: l) acc
180      end
181      | chk ((pl, (f,j)::fl, sl, n) :: l) acc = get (pl,(f,j-1)::fl,sl,n) l acc
182      | chk _ _ = raise Bug "mlibTermnet.harvest: mlibMatch 1"
183    and get (p :: pl, fl, sl, SINGLE (t,n)) l acc =
184      (case qunify p t of NONE => chk l acc
185       | SOME t => chk ((pl, fl, t :: sl, n) :: l) acc)
186      | get (VAR :: pl, fl, sl, MULTIPLE (vs,fs)) l acc =
187      let
188        fun fget (f as (_,i), n, x) =
189          (funpow i (cons VAR) pl, (f,i) :: fl, sl, n) :: x
190        val l = case vs of NONE => l | SOME n => (pl, fl, VAR :: sl, n) :: l
191      in
192        chk (M.foldr fget l fs) acc
193      end
194      | get (FN (f as (_,i), a) :: pl, fl, sl, MULTIPLE (_,fs)) l acc =
195      (case M.peek (fs,f) of NONE => chk l acc
196       | SOME n => chk ((a @ pl, (f,i) :: fl, sl, n) :: l) acc)
197      | get _ _ _ = raise Bug "mlibTermnet.harvest: mlibMatch 2"
198  in
199    fn pat => fn net => fn acc => get ([pat],[],[],net) [] acc
200  end;
201
202local
203  fun pat sub v = case M.peek (sub,v) of NONE => VAR | SOME qtm => qtm;
204
205  fun inc sub v tms tm net rest = (M.insert (sub,v,tm), net, tms) :: rest;
206
207  fun mat acc [] = acc
208    | mat acc ((_, RESULT l, []) :: rest) = mat (l @ acc) rest
209    | mat acc ((sub, SINGLE (tm', n), tm :: tms) :: rest) =
210    (case unifyq sub tm' tm of NONE => mat acc rest
211     | SOME sub => mat acc ((sub,n,tms) :: rest))
212    | mat acc ((sub, net as MULTIPLE _, Var v :: tms) :: rest) =
213    mat acc (harvest (inc sub v tms) (pat sub v) net rest)
214    | mat acc ((sub, MULTIPLE (vs,fs), Fn (f,l) :: tms) :: rest) =
215    let
216      val rest =
217        (case M.peek (fs, (f, length l)) of NONE => rest
218         | SOME n => (sub, n, l @ tms) :: rest)
219    in
220      mat acc (case vs of NONE => rest | SOME n => (sub,n,tms) :: rest)
221    end
222    | mat _ _ = raise Bug "mlibTermnet.unify: mlibMatch";
223in
224  fun unify (NET (_,_,NONE)) _ = []
225    | unify (NET (p, _, SOME (_,n))) tm =
226    finally p (mat [] [(empty_vmap (), n, [tm])])
227    handle Error _ => raise Bug "mlibTermnet.unify: should never fail";
228end;
229
230local
231  fun pat NONE = VAR | pat (SOME qtm) = qtm;
232
233  fun oeq qtm NONE = true | oeq qtm (SOME qtm') = qtm = qtm';
234
235  fun inc sub v seen tms tm net rest =
236    if oeq tm seen then (M.insert (sub,v,tm), net, tms) :: rest else rest;
237
238  fun mat acc [] = acc
239    | mat acc ((_, RESULT l, []) :: rest) = mat (l @ acc) rest
240    | mat acc ((sub, SINGLE (tm', n), tm :: tms) :: rest) =
241    (case total (matchq sub tm) tm' of NONE => mat acc rest
242     | SOME sub => mat acc ((sub,n,tms) :: rest))
243    | mat acc ((sub, net as MULTIPLE _, Var v :: tms) :: rest) =
244    let val seen = M.peek (sub,v)
245    in mat acc (harvest (inc sub v seen tms) (pat seen) net rest)
246    end
247    | mat acc ((sub, MULTIPLE (_,fs), Fn (f,l) :: tms) :: rest) =
248    mat acc (case M.peek (fs, (f, length l)) of NONE => rest
249             | SOME n => (sub, n, l @ tms) :: rest)
250    | mat _ _ = raise Bug "mlibTermnet.matched: mlibMatch";
251in
252  fun matched (NET (_,_,NONE)) _ = []
253    | matched (NET (p, _, SOME (_,n))) tm =
254    finally p (mat [] [(empty_vmap(),n,[tm])])
255    handle Error _ => raise Bug "mlibTermnet.matched: should never fail";
256end;
257
258fun filter pred =
259  let
260    fun filt (RESULT l) =
261      (case List.filter (pred o snd) l of [] => NONE
262       | l => SOME (length l, RESULT l))
263      | filt (SINGLE (tm,n)) = omap (fn (i,n) => (i, SINGLE (tm,n))) (filt n)
264      | filt (MULTIPLE (vs,fs)) =
265      let
266        fun subfilt (x, n, im as (i,m)) =
267          case filt n of NONE => im | SOME (j,n) => (i + j, M.insert (m,x,n))
268        val (i,vs) =
269          case Option.mapPartial filt vs of NONE => (0,NONE)
270          | SOME (i,n) => (i, SOME n)
271        val (i,fs) = M.foldl subfilt (i, empty_fmap ()) fs
272      in
273        if i = 0 then NONE else SOME (i, MULTIPLE (vs,fs))
274      end
275  in
276    fn net as NET (_,_,NONE) => net
277     | NET (p, k, SOME (_,n)) => NET (p, k, filt n)
278  end
279  handle Error _ => raise Bug "mlibTermnet.filter: should never fail";
280
281fun from_maplets p l = foldl (uncurry insert) (empty p) l;
282
283local
284  fun inc tm (RESULT l) acc = foldl (fn (x,y) => (qterm' tm |-> x) :: y) acc l
285    | inc _ _ _ = raise Bug "mlibTermnet.to_maplets: mlibMatch";
286  fun fin (tm |-> (n,a)) = (n, tm |-> a);
287in
288  fun to_maplets (NET (_,_,NONE)) = []
289    | to_maplets (NET (p, _, SOME (_,n))) =
290    finally p (map fin (harvest inc VAR n []));
291end;
292
293fun pp_termnet pp_a = pp_map to_maplets (pp_list (pp_maplet pp_term pp_a));
294
295end
296