1(* ========================================================================= *) 2(* MATCHING AND UNIFICATION FOR SETS OF FIRST ORDER LOGIC TERMS *) 3(* Copyright (c) 2001 Joe Hurd, distributed under the BSD License *) 4(* ========================================================================= *) 5 6structure TermNet :> TermNet = 7struct 8 9open Useful; 10 11(* ------------------------------------------------------------------------- *) 12(* Anonymous variables. *) 13(* ------------------------------------------------------------------------- *) 14 15val anonymousName = Name.fromString "_"; 16val anonymousVar = Term.Var anonymousName; 17 18(* ------------------------------------------------------------------------- *) 19(* Quotient terms. *) 20(* ------------------------------------------------------------------------- *) 21 22datatype qterm = 23 Var 24 | Fn of NameArity.nameArity * qterm list; 25 26local 27 fun cmp [] = EQUAL 28 | cmp (q1_q2 :: qs) = 29 if Portable.pointerEqual q1_q2 then cmp qs 30 else 31 case q1_q2 of 32 (Var,Var) => EQUAL 33 | (Var, Fn _) => LESS 34 | (Fn _, Var) => GREATER 35 | (Fn f1, Fn f2) => fnCmp f1 f2 qs 36 37 and fnCmp (n1,q1) (n2,q2) qs = 38 case NameArity.compare (n1,n2) of 39 LESS => LESS 40 | EQUAL => cmp (zip q1 q2 @ qs) 41 | GREATER => GREATER; 42in 43 fun compareQterm q1_q2 = cmp [q1_q2]; 44 45 fun compareFnQterm (f1,f2) = fnCmp f1 f2 []; 46end; 47 48fun equalQterm q1 q2 = compareQterm (q1,q2) = EQUAL; 49 50fun equalFnQterm f1 f2 = compareFnQterm (f1,f2) = EQUAL; 51 52fun termToQterm (Term.Var _) = Var 53 | termToQterm (Term.Fn (f,l)) = Fn ((f, length l), List.map termToQterm l); 54 55local 56 fun qm [] = true 57 | qm ((Var,_) :: rest) = qm rest 58 | qm ((Fn _, Var) :: _) = false 59 | qm ((Fn (f,a), Fn (g,b)) :: rest) = 60 NameArity.equal f g andalso qm (zip a b @ rest); 61in 62 fun matchQtermQterm qtm qtm' = qm [(qtm,qtm')]; 63end; 64 65local 66 fun qm [] = true 67 | qm ((Var,_) :: rest) = qm rest 68 | qm ((Fn _, Term.Var _) :: _) = false 69 | qm ((Fn ((f,n),a), Term.Fn (g,b)) :: rest) = 70 Name.equal f g andalso n = length b andalso qm (zip a b @ rest); 71in 72 fun matchQtermTerm qtm tm = qm [(qtm,tm)]; 73end; 74 75local 76 fun qn qsub [] = SOME qsub 77 | qn qsub ((Term.Var v, qtm) :: rest) = 78 (case NameMap.peek qsub v of 79 NONE => qn (NameMap.insert qsub (v,qtm)) rest 80 | SOME qtm' => if equalQterm qtm qtm' then qn qsub rest else NONE) 81 | qn _ ((Term.Fn _, Var) :: _) = NONE 82 | qn qsub ((Term.Fn (f,a), Fn ((g,n),b)) :: rest) = 83 if Name.equal f g andalso length a = n then qn qsub (zip a b @ rest) 84 else NONE; 85in 86 fun matchTermQterm qsub tm qtm = qn qsub [(tm,qtm)]; 87end; 88 89local 90 fun qv Var x = x 91 | qv x Var = x 92 | qv (Fn (f,a)) (Fn (g,b)) = 93 let 94 val _ = NameArity.equal f g orelse raise Error "TermNet.qv" 95 in 96 Fn (f, zipWith qv a b) 97 end; 98 99 fun qu qsub [] = qsub 100 | qu qsub ((Var, _) :: rest) = qu qsub rest 101 | qu qsub ((qtm, Term.Var v) :: rest) = 102 let 103 val qtm = 104 case NameMap.peek qsub v of NONE => qtm | SOME qtm' => qv qtm qtm' 105 in 106 qu (NameMap.insert qsub (v,qtm)) rest 107 end 108 | qu qsub ((Fn ((f,n),a), Term.Fn (g,b)) :: rest) = 109 if Name.equal f g andalso n = length b then qu qsub (zip a b @ rest) 110 else raise Error "TermNet.qu"; 111in 112 fun unifyQtermQterm qtm qtm' = total (qv qtm) qtm'; 113 114 fun unifyQtermTerm qsub qtm tm = total (qu qsub) [(qtm,tm)]; 115end; 116 117local 118 fun qtermToTerm Var = anonymousVar 119 | qtermToTerm (Fn ((f,_),l)) = Term.Fn (f, List.map qtermToTerm l); 120in 121 val ppQterm = Print.ppMap qtermToTerm Term.pp; 122end; 123 124(* ------------------------------------------------------------------------- *) 125(* A type of term sets that can be efficiently matched and unified. *) 126(* ------------------------------------------------------------------------- *) 127 128type parameters = {fifo : bool}; 129 130datatype 'a net = 131 Result of 'a list 132 | Single of qterm * 'a net 133 | Multiple of 'a net option * 'a net NameArityMap.map; 134 135datatype 'a termNet = Net of parameters * int * (int * (int * 'a) net) option; 136 137(* ------------------------------------------------------------------------- *) 138(* Basic operations. *) 139(* ------------------------------------------------------------------------- *) 140 141fun new parm = Net (parm,0,NONE); 142 143local 144 fun computeSize (Result l) = length l 145 | computeSize (Single (_,n)) = computeSize n 146 | computeSize (Multiple (vs,fs)) = 147 NameArityMap.foldl 148 (fn (_,n,acc) => acc + computeSize n) 149 (case vs of SOME n => computeSize n | NONE => 0) 150 fs; 151in 152 fun netSize NONE = NONE 153 | netSize (SOME n) = SOME (computeSize n, n); 154end; 155 156fun size (Net (_,_,NONE)) = 0 157 | size (Net (_, _, SOME (i,_))) = i; 158 159fun null net = size net = 0; 160 161fun singles qtms a = List.foldr Single a qtms; 162 163local 164 fun pre NONE = (0,NONE) 165 | pre (SOME (i,n)) = (i, SOME n); 166 167 fun add (Result l) [] (Result l') = Result (l @ l') 168 | add a (input1 as qtm :: qtms) (Single (qtm',n)) = 169 if equalQterm qtm qtm' then Single (qtm, add a qtms n) 170 else add a input1 (add n [qtm'] (Multiple (NONE, NameArityMap.new ()))) 171 | add a (Var :: qtms) (Multiple (vs,fs)) = 172 Multiple (SOME (oadd a qtms vs), fs) 173 | add a (Fn (f,l) :: qtms) (Multiple (vs,fs)) = 174 let 175 val n = NameArityMap.peek fs f 176 in 177 Multiple (vs, NameArityMap.insert fs (f, oadd a (l @ qtms) n)) 178 end 179 | add _ _ _ = raise Bug "TermNet.insert: Match" 180 181 and oadd a qtms NONE = singles qtms a 182 | oadd a qtms (SOME n) = add a qtms n; 183 184 fun ins a qtm (i,n) = SOME (i + 1, oadd (Result [a]) [qtm] n); 185in 186 fun insert (Net (p,k,n)) (tm,a) = 187 Net (p, k + 1, ins (k,a) (termToQterm tm) (pre n)) 188 handle Error _ => raise Bug "TermNet.insert: should never fail"; 189end; 190 191fun fromList parm l = List.foldl (fn (tm_a,n) => insert n tm_a) (new parm) l; 192 193fun filter pred = 194 let 195 fun filt (Result l) = 196 (case List.filter (fn (_,a) => pred a) l of 197 [] => NONE 198 | l => SOME (Result l)) 199 | filt (Single (qtm,n)) = 200 (case filt n of 201 NONE => NONE 202 | SOME n => SOME (Single (qtm,n))) 203 | filt (Multiple (vs,fs)) = 204 let 205 val vs = Option.mapPartial filt vs 206 207 val fs = NameArityMap.mapPartial (fn (_,n) => filt n) fs 208 in 209 if not (Option.isSome vs) andalso NameArityMap.null fs then NONE 210 else SOME (Multiple (vs,fs)) 211 end 212 in 213 fn net as Net (_,_,NONE) => net 214 | Net (p, k, SOME (_,n)) => Net (p, k, netSize (filt n)) 215 end 216 handle Error _ => raise Bug "TermNet.filter: should never fail"; 217 218fun toString net = "TermNet[" ^ Int.toString (size net) ^ "]"; 219 220(* ------------------------------------------------------------------------- *) 221(* Specialized fold operations to support matching and unification. *) 222(* ------------------------------------------------------------------------- *) 223 224local 225 fun norm (0 :: ks, (f as (_,n)) :: fs, qtms) = 226 let 227 val (a,qtms) = revDivide qtms n 228 in 229 addQterm (Fn (f,a)) (ks,fs,qtms) 230 end 231 | norm stack = stack 232 233 and addQterm qtm (ks,fs,qtms) = 234 let 235 val ks = case ks of [] => [] | k :: ks => (k - 1) :: ks 236 in 237 norm (ks, fs, qtm :: qtms) 238 end 239 240 and addFn (f as (_,n)) (ks,fs,qtms) = norm (n :: ks, f :: fs, qtms); 241in 242 val stackEmpty = ([],[],[]); 243 244 val stackAddQterm = addQterm; 245 246 val stackAddFn = addFn; 247 248 fun stackValue ([],[],[qtm]) = qtm 249 | stackValue _ = raise Bug "TermNet.stackValue"; 250end; 251 252local 253 fun fold _ acc [] = acc 254 | fold inc acc ((0,stack,net) :: rest) = 255 fold inc (inc (stackValue stack, net, acc)) rest 256 | fold inc acc ((n, stack, Single (qtm,net)) :: rest) = 257 fold inc acc ((n - 1, stackAddQterm qtm stack, net) :: rest) 258 | fold inc acc ((n, stack, Multiple (v,fns)) :: rest) = 259 let 260 val n = n - 1 261 262 val rest = 263 case v of 264 NONE => rest 265 | SOME net => (n, stackAddQterm Var stack, net) :: rest 266 267 fun getFns (f as (_,k), net, x) = 268 (k + n, stackAddFn f stack, net) :: x 269 in 270 fold inc acc (NameArityMap.foldr getFns rest fns) 271 end 272 | fold _ _ _ = raise Bug "TermNet.foldTerms.fold"; 273in 274 fun foldTerms inc acc net = fold inc acc [(1,stackEmpty,net)]; 275end; 276 277fun foldEqualTerms pat inc acc = 278 let 279 fun fold ([],net) = inc (pat,net,acc) 280 | fold (pat :: pats, Single (qtm,net)) = 281 if equalQterm pat qtm then fold (pats,net) else acc 282 | fold (Var :: pats, Multiple (v,_)) = 283 (case v of NONE => acc | SOME net => fold (pats,net)) 284 | fold (Fn (f,a) :: pats, Multiple (_,fns)) = 285 (case NameArityMap.peek fns f of 286 NONE => acc 287 | SOME net => fold (a @ pats, net)) 288 | fold _ = raise Bug "TermNet.foldEqualTerms.fold"; 289 in 290 fn net => fold ([pat],net) 291 end; 292 293local 294 fun fold _ acc [] = acc 295 | fold inc acc (([],stack,net) :: rest) = 296 fold inc (inc (stackValue stack, net, acc)) rest 297 | fold inc acc ((Var :: pats, stack, net) :: rest) = 298 let 299 fun harvest (qtm,n,l) = (pats, stackAddQterm qtm stack, n) :: l 300 in 301 fold inc acc (foldTerms harvest rest net) 302 end 303 | fold inc acc ((pat :: pats, stack, Single (qtm,net)) :: rest) = 304 (case unifyQtermQterm pat qtm of 305 NONE => fold inc acc rest 306 | SOME qtm => 307 fold inc acc ((pats, stackAddQterm qtm stack, net) :: rest)) 308 | fold 309 inc acc 310 (((pat as Fn (f,a)) :: pats, stack, Multiple (v,fns)) :: rest) = 311 let 312 val rest = 313 case v of 314 NONE => rest 315 | SOME net => (pats, stackAddQterm pat stack, net) :: rest 316 317 val rest = 318 case NameArityMap.peek fns f of 319 NONE => rest 320 | SOME net => (a @ pats, stackAddFn f stack, net) :: rest 321 in 322 fold inc acc rest 323 end 324 | fold _ _ _ = raise Bug "TermNet.foldUnifiableTerms.fold"; 325in 326 fun foldUnifiableTerms pat inc acc net = 327 fold inc acc [([pat],stackEmpty,net)]; 328end; 329 330(* ------------------------------------------------------------------------- *) 331(* Matching and unification queries. *) 332(* *) 333(* These function return OVER-APPROXIMATIONS! *) 334(* Filter afterwards to get the precise set of satisfying values. *) 335(* ------------------------------------------------------------------------- *) 336 337local 338 fun idwise ((m,_),(n,_)) = Int.compare (m,n); 339 340 fun fifoize ({fifo, ...} : parameters) l = if fifo then sort idwise l else l; 341in 342 fun finally parm l = List.map snd (fifoize parm l); 343end; 344 345local 346 fun mat acc [] = acc 347 | mat acc ((Result l, []) :: rest) = mat (l @ acc) rest 348 | mat acc ((Single (qtm,n), tm :: tms) :: rest) = 349 mat acc (if matchQtermTerm qtm tm then (n,tms) :: rest else rest) 350 | mat acc ((Multiple (vs,fs), tm :: tms) :: rest) = 351 let 352 val rest = case vs of NONE => rest | SOME n => (n,tms) :: rest 353 354 val rest = 355 case tm of 356 Term.Var _ => rest 357 | Term.Fn (f,l) => 358 case NameArityMap.peek fs (f, length l) of 359 NONE => rest 360 | SOME n => (n, l @ tms) :: rest 361 in 362 mat acc rest 363 end 364 | mat _ _ = raise Bug "TermNet.match: Match"; 365in 366 fun match (Net (_,_,NONE)) _ = [] 367 | match (Net (p, _, SOME (_,n))) tm = 368 finally p (mat [] [(n,[tm])]) 369 handle Error _ => raise Bug "TermNet.match: should never fail"; 370end; 371 372local 373 fun unseenInc qsub v tms (qtm,net,rest) = 374 (NameMap.insert qsub (v,qtm), net, tms) :: rest; 375 376 fun seenInc qsub tms (_,net,rest) = (qsub,net,tms) :: rest; 377 378 fun mat acc [] = acc 379 | mat acc ((_, Result l, []) :: rest) = mat (l @ acc) rest 380 | mat acc ((qsub, Single (qtm,net), tm :: tms) :: rest) = 381 (case matchTermQterm qsub tm qtm of 382 NONE => mat acc rest 383 | SOME qsub => mat acc ((qsub,net,tms) :: rest)) 384 | mat acc ((qsub, net as Multiple _, Term.Var v :: tms) :: rest) = 385 (case NameMap.peek qsub v of 386 NONE => mat acc (foldTerms (unseenInc qsub v tms) rest net) 387 | SOME qtm => mat acc (foldEqualTerms qtm (seenInc qsub tms) rest net)) 388 | mat acc ((qsub, Multiple (_,fns), Term.Fn (f,a) :: tms) :: rest) = 389 let 390 val rest = 391 case NameArityMap.peek fns (f, length a) of 392 NONE => rest 393 | SOME net => (qsub, net, a @ tms) :: rest 394 in 395 mat acc rest 396 end 397 | mat _ _ = raise Bug "TermNet.matched.mat"; 398in 399 fun matched (Net (_,_,NONE)) _ = [] 400 | matched (Net (parm, _, SOME (_,net))) tm = 401 finally parm (mat [] [(NameMap.new (), net, [tm])]) 402 handle Error _ => raise Bug "TermNet.matched: should never fail"; 403end; 404 405local 406 fun inc qsub v tms (qtm,net,rest) = 407 (NameMap.insert qsub (v,qtm), net, tms) :: rest; 408 409 fun mat acc [] = acc 410 | mat acc ((_, Result l, []) :: rest) = mat (l @ acc) rest 411 | mat acc ((qsub, Single (qtm,net), tm :: tms) :: rest) = 412 (case unifyQtermTerm qsub qtm tm of 413 NONE => mat acc rest 414 | SOME qsub => mat acc ((qsub,net,tms) :: rest)) 415 | mat acc ((qsub, net as Multiple _, Term.Var v :: tms) :: rest) = 416 (case NameMap.peek qsub v of 417 NONE => mat acc (foldTerms (inc qsub v tms) rest net) 418 | SOME qtm => mat acc (foldUnifiableTerms qtm (inc qsub v tms) rest net)) 419 | mat acc ((qsub, Multiple (v,fns), Term.Fn (f,a) :: tms) :: rest) = 420 let 421 val rest = case v of NONE => rest | SOME net => (qsub,net,tms) :: rest 422 423 val rest = 424 case NameArityMap.peek fns (f, length a) of 425 NONE => rest 426 | SOME net => (qsub, net, a @ tms) :: rest 427 in 428 mat acc rest 429 end 430 | mat _ _ = raise Bug "TermNet.unify.mat"; 431in 432 fun unify (Net (_,_,NONE)) _ = [] 433 | unify (Net (parm, _, SOME (_,net))) tm = 434 finally parm (mat [] [(NameMap.new (), net, [tm])]) 435 handle Error _ => raise Bug "TermNet.unify: should never fail"; 436end; 437 438(* ------------------------------------------------------------------------- *) 439(* Pretty printing. *) 440(* ------------------------------------------------------------------------- *) 441 442local 443 fun inc (qtm, Result l, acc) = 444 List.foldl (fn ((n,a),acc) => (n,(qtm,a)) :: acc) acc l 445 | inc _ = raise Bug "TermNet.pp.inc"; 446 447 fun toList (Net (_,_,NONE)) = [] 448 | toList (Net (parm, _, SOME (_,net))) = 449 finally parm (foldTerms inc [] net); 450in 451 fun pp ppA = 452 Print.ppMap toList (Print.ppList (Print.ppOp2 " |->" ppQterm ppA)); 453end; 454 455end 456