1(* ========================================================================= *)
2(* FIRST ORDER LOGIC TERMS                                                   *)
3(* Copyright (c) 2001 Joe Hurd, distributed under the BSD License            *)
4(* ========================================================================= *)
5
6structure Term :> Term =
7struct
8
9open Useful;
10
11(* ------------------------------------------------------------------------- *)
12(* A type of first order logic terms.                                        *)
13(* ------------------------------------------------------------------------- *)
14
15type var = Name.name;
16
17type functionName = Name.name;
18
19type function = functionName * int;
20
21type const = functionName;
22
23datatype term =
24    Var of Name.name
25  | Fn of Name.name * term list;
26
27(* ------------------------------------------------------------------------- *)
28(* Constructors and destructors.                                             *)
29(* ------------------------------------------------------------------------- *)
30
31(* Variables *)
32
33fun destVar (Var v) = v
34  | destVar (Fn _) = raise Error "destVar";
35
36val isVar = can destVar;
37
38fun equalVar v (Var v') = Name.equal v v'
39  | equalVar _ _ = false;
40
41(* Functions *)
42
43fun destFn (Fn f) = f
44  | destFn (Var _) = raise Error "destFn";
45
46val isFn = can destFn;
47
48fun fnName tm = fst (destFn tm);
49
50fun fnArguments tm = snd (destFn tm);
51
52fun fnArity tm = length (fnArguments tm);
53
54fun fnFunction tm = (fnName tm, fnArity tm);
55
56local
57  fun func fs [] = fs
58    | func fs (Var _ :: tms) = func fs tms
59    | func fs (Fn (n,l) :: tms) =
60      func (NameAritySet.add fs (n, length l)) (l @ tms);
61in
62  fun functions tm = func NameAritySet.empty [tm];
63end;
64
65local
66  fun func fs [] = fs
67    | func fs (Var _ :: tms) = func fs tms
68    | func fs (Fn (n,l) :: tms) = func (NameSet.add fs n) (l @ tms);
69in
70  fun functionNames tm = func NameSet.empty [tm];
71end;
72
73(* Constants *)
74
75fun mkConst c = (Fn (c, []));
76
77fun destConst (Fn (c, [])) = c
78  | destConst _ = raise Error "destConst";
79
80val isConst = can destConst;
81
82(* Binary functions *)
83
84fun mkBinop f (a,b) = Fn (f,[a,b]);
85
86fun destBinop f (Fn (x,[a,b])) =
87    if Name.equal x f then (a,b) else raise Error "Term.destBinop: wrong binop"
88  | destBinop _ _ = raise Error "Term.destBinop: not a binop";
89
90fun isBinop f = can (destBinop f);
91
92(* ------------------------------------------------------------------------- *)
93(* The size of a term in symbols.                                            *)
94(* ------------------------------------------------------------------------- *)
95
96val VAR_SYMBOLS = 1;
97
98val FN_SYMBOLS = 1;
99
100local
101  fun sz n [] = n
102    | sz n (Var _ :: tms) = sz (n + VAR_SYMBOLS) tms
103    | sz n (Fn (func,args) :: tms) = sz (n + FN_SYMBOLS) (args @ tms);
104in
105  fun symbols tm = sz 0 [tm];
106end;
107
108(* ------------------------------------------------------------------------- *)
109(* A total comparison function for terms.                                    *)
110(* ------------------------------------------------------------------------- *)
111
112local
113  fun cmp [] [] = EQUAL
114    | cmp (tm1 :: tms1) (tm2 :: tms2) =
115      let
116        val tm1_tm2 = (tm1,tm2)
117      in
118        if Portable.pointerEqual tm1_tm2 then cmp tms1 tms2
119        else
120          case tm1_tm2 of
121            (Var v1, Var v2) =>
122            (case Name.compare (v1,v2) of
123               LESS => LESS
124             | EQUAL => cmp tms1 tms2
125             | GREATER => GREATER)
126          | (Var _, Fn _) => LESS
127          | (Fn _, Var _) => GREATER
128          | (Fn (f1,a1), Fn (f2,a2)) =>
129            (case Name.compare (f1,f2) of
130               LESS => LESS
131             | EQUAL =>
132               (case Int.compare (length a1, length a2) of
133                  LESS => LESS
134                | EQUAL => cmp (a1 @ tms1) (a2 @ tms2)
135                | GREATER => GREATER)
136             | GREATER => GREATER)
137      end
138    | cmp _ _ = raise Bug "Term.compare";
139in
140  fun compare (tm1,tm2) = cmp [tm1] [tm2];
141end;
142
143fun equal tm1 tm2 = compare (tm1,tm2) = EQUAL;
144
145(* ------------------------------------------------------------------------- *)
146(* Subterms.                                                                 *)
147(* ------------------------------------------------------------------------- *)
148
149type path = int list;
150
151fun subterm tm [] = tm
152  | subterm (Var _) (_ :: _) = raise Error "Term.subterm: Var"
153  | subterm (Fn (_,tms)) (h :: t) =
154    if h >= length tms then raise Error "Term.replace: Fn"
155    else subterm (List.nth (tms,h)) t;
156
157local
158  fun subtms [] acc = acc
159    | subtms ((path,tm) :: rest) acc =
160      let
161        fun f (n,arg) = (n :: path, arg)
162
163        val acc = (List.rev path, tm) :: acc
164      in
165        case tm of
166          Var _ => subtms rest acc
167        | Fn (_,args) => subtms (List.map f (enumerate args) @ rest) acc
168      end;
169in
170  fun subterms tm = subtms [([],tm)] [];
171end;
172
173fun replace tm ([],res) = if equal res tm then tm else res
174  | replace tm (h :: t, res) =
175    case tm of
176      Var _ => raise Error "Term.replace: Var"
177    | Fn (func,tms) =>
178      if h >= length tms then raise Error "Term.replace: Fn"
179      else
180        let
181          val arg = List.nth (tms,h)
182          val arg' = replace arg (t,res)
183        in
184          if Portable.pointerEqual (arg',arg) then tm
185          else Fn (func, updateNth (h,arg') tms)
186        end;
187
188fun find pred =
189    let
190      fun search [] = NONE
191        | search ((path,tm) :: rest) =
192          if pred tm then SOME (List.rev path)
193          else
194            case tm of
195              Var _ => search rest
196            | Fn (_,a) =>
197              let
198                val subtms = List.map (fn (i,t) => (i :: path, t)) (enumerate a)
199              in
200                search (subtms @ rest)
201              end
202    in
203      fn tm => search [([],tm)]
204    end;
205
206val ppPath = Print.ppList Print.ppInt;
207
208val pathToString = Print.toString ppPath;
209
210(* ------------------------------------------------------------------------- *)
211(* Free variables.                                                           *)
212(* ------------------------------------------------------------------------- *)
213
214local
215  fun free _ [] = false
216    | free v (Var w :: tms) = Name.equal v w orelse free v tms
217    | free v (Fn (_,args) :: tms) = free v (args @ tms);
218in
219  fun freeIn v tm = free v [tm];
220end;
221
222local
223  fun free vs [] = vs
224    | free vs (Var v :: tms) = free (NameSet.add vs v) tms
225    | free vs (Fn (_,args) :: tms) = free vs (args @ tms);
226in
227  val freeVarsList = free NameSet.empty;
228
229  fun freeVars tm = freeVarsList [tm];
230end;
231
232(* ------------------------------------------------------------------------- *)
233(* Fresh variables.                                                          *)
234(* ------------------------------------------------------------------------- *)
235
236fun newVar () = Var (Name.newName ());
237
238fun newVars n = List.map Var (Name.newNames n);
239
240local
241  fun avoid av n = NameSet.member n av;
242in
243  fun variantPrime av = Name.variantPrime {avoid = avoid av};
244
245  fun variantNum av = Name.variantNum {avoid = avoid av};
246end;
247
248(* ------------------------------------------------------------------------- *)
249(* Special support for terms with type annotations.                          *)
250(* ------------------------------------------------------------------------- *)
251
252val hasTypeFunctionName = Name.fromString ":";
253
254val hasTypeFunction = (hasTypeFunctionName,2);
255
256fun destFnHasType ((f,a) : functionName * term list) =
257    if not (Name.equal f hasTypeFunctionName) then
258      raise Error "Term.destFnHasType"
259    else
260      case a of
261        [tm,ty] => (tm,ty)
262      | _ => raise Error "Term.destFnHasType";
263
264val isFnHasType = can destFnHasType;
265
266fun isTypedVar tm =
267    case tm of
268      Var _ => true
269    | Fn func =>
270      case total destFnHasType func of
271        SOME (Var _, _) => true
272      | _ => false;
273
274local
275  fun sz n [] = n
276    | sz n (tm :: tms) =
277      case tm of
278        Var _ => sz (n + 1) tms
279      | Fn func =>
280        case total destFnHasType func of
281          SOME (tm,_) => sz n (tm :: tms)
282        | NONE =>
283          let
284            val (_,a) = func
285          in
286            sz (n + 1) (a @ tms)
287          end;
288in
289  fun typedSymbols tm = sz 0 [tm];
290end;
291
292local
293  fun subtms [] acc = acc
294    | subtms ((path,tm) :: rest) acc =
295      case tm of
296        Var _ => subtms rest acc
297      | Fn func =>
298        case total destFnHasType func of
299          SOME (t,_) =>
300          (case t of
301             Var _ => subtms rest acc
302           | Fn _ =>
303             let
304               val acc = (List.rev path, tm) :: acc
305               val rest = (0 :: path, t) :: rest
306             in
307               subtms rest acc
308             end)
309        | NONE =>
310          let
311            fun f (n,arg) = (n :: path, arg)
312
313            val (_,args) = func
314
315            val acc = (List.rev path, tm) :: acc
316
317            val rest = List.map f (enumerate args) @ rest
318          in
319            subtms rest acc
320          end;
321in
322  fun nonVarTypedSubterms tm = subtms [([],tm)] [];
323end;
324
325(* ------------------------------------------------------------------------- *)
326(* Special support for terms with an explicit function application operator. *)
327(* ------------------------------------------------------------------------- *)
328
329val appName = Name.fromString ".";
330
331fun mkFnApp (fTm,aTm) = (appName, [fTm,aTm]);
332
333fun mkApp f_a = Fn (mkFnApp f_a);
334
335fun destFnApp ((f,a) : Name.name * term list) =
336    if not (Name.equal f appName) then raise Error "Term.destFnApp"
337    else
338      case a of
339        [fTm,aTm] => (fTm,aTm)
340      | _ => raise Error "Term.destFnApp";
341
342val isFnApp = can destFnApp;
343
344fun destApp tm =
345    case tm of
346      Var _ => raise Error "Term.destApp"
347    | Fn func => destFnApp func;
348
349val isApp = can destApp;
350
351fun listMkApp (f,l) = List.foldl mkApp f l;
352
353local
354  fun strip tms tm =
355      case total destApp tm of
356        SOME (f,a) => strip (a :: tms) f
357      | NONE => (tm,tms);
358in
359  fun stripApp tm = strip [] tm;
360end;
361
362(* ------------------------------------------------------------------------- *)
363(* Parsing and pretty printing.                                              *)
364(* ------------------------------------------------------------------------- *)
365
366(* Operators parsed and printed infix *)
367
368val infixes =
369    (ref o Print.Infixes)
370      [(* ML symbols *)
371       {token = "/", precedence = 7, assoc = Print.LeftAssoc},
372       {token = "div", precedence = 7, assoc = Print.LeftAssoc},
373       {token = "mod", precedence = 7, assoc = Print.LeftAssoc},
374       {token = "*", precedence = 7, assoc = Print.LeftAssoc},
375       {token = "+", precedence = 6, assoc = Print.LeftAssoc},
376       {token = "-", precedence = 6, assoc = Print.LeftAssoc},
377       {token = "^", precedence = 6, assoc = Print.LeftAssoc},
378       {token = "@", precedence = 5, assoc = Print.RightAssoc},
379       {token = "::", precedence = 5, assoc = Print.RightAssoc},
380       {token = "=", precedence = 4, assoc = Print.NonAssoc},
381       {token = "<>", precedence = 4, assoc = Print.NonAssoc},
382       {token = "<=", precedence = 4, assoc = Print.NonAssoc},
383       {token = "<", precedence = 4, assoc = Print.NonAssoc},
384       {token = ">=", precedence = 4, assoc = Print.NonAssoc},
385       {token = ">", precedence = 4, assoc = Print.NonAssoc},
386       {token = "o", precedence = 3, assoc = Print.LeftAssoc},
387       {token = "->", precedence = 2, assoc = Print.RightAssoc},
388       {token = ":", precedence = 1, assoc = Print.NonAssoc},
389       {token = ",", precedence = 0, assoc = Print.RightAssoc},
390
391       (* Logical connectives *)
392       {token = "/\\", precedence = ~1, assoc = Print.RightAssoc},
393       {token = "\\/", precedence = ~2, assoc = Print.RightAssoc},
394       {token = "==>", precedence = ~3, assoc = Print.RightAssoc},
395       {token = "<=>", precedence = ~4, assoc = Print.RightAssoc},
396
397       (* Other symbols *)
398       {token = ".", precedence = 9, assoc = Print.LeftAssoc},
399       {token = "**", precedence = 8, assoc = Print.LeftAssoc},
400       {token = "++", precedence = 6, assoc = Print.LeftAssoc},
401       {token = "--", precedence = 6, assoc = Print.LeftAssoc},
402       {token = "==", precedence = 4, assoc = Print.NonAssoc}];
403
404(* The negation symbol *)
405
406val negation : string ref = ref "~";
407
408(* Binder symbols *)
409
410val binders : string list ref = ref ["\\","!","?","?!"];
411
412(* Bracket symbols *)
413
414val brackets : (string * string) list ref = ref [("[","]"),("{","}")];
415
416(* Pretty printing *)
417
418fun pp inputTerm =
419    let
420      val quants = !binders
421      and iOps = !infixes
422      and neg = !negation
423      and bracks = !brackets
424
425      val bMap =
426          let
427            fun f (b1,b2) = (b1 ^ b2, b1, b2)
428          in
429            List.map f bracks
430          end
431
432      val bTokens = op@ (unzip bracks)
433
434      val iTokens = Print.tokensInfixes iOps
435
436      fun destI tm =
437          case tm of
438            Fn (f,[a,b]) =>
439            let
440              val f = Name.toString f
441            in
442              if StringSet.member f iTokens then SOME (f,a,b) else NONE
443            end
444          | _ => NONE
445
446      fun isI tm = Option.isSome (destI tm)
447
448      fun iToken (_,tok) =
449          Print.program
450            [(if tok = "," then Print.skip else Print.ppString " "),
451             Print.ppString tok,
452             Print.break];
453
454      val iPrinter = Print.ppInfixes iOps destI iToken
455
456      val specialTokens =
457          StringSet.addList iTokens (neg :: quants @ ["$","(",")"] @ bTokens)
458
459      fun vName bv s = StringSet.member s bv
460
461      fun checkVarName bv n =
462          let
463            val s = Name.toString n
464          in
465            if vName bv s then s else "$" ^ s
466          end
467
468      fun varName bv = Print.ppMap (checkVarName bv) Print.ppString
469
470      fun checkFunctionName bv n =
471          let
472            val s = Name.toString n
473          in
474            if StringSet.member s specialTokens orelse vName bv s then
475              "(" ^ s ^ ")"
476            else
477              s
478          end
479
480      fun functionName bv = Print.ppMap (checkFunctionName bv) Print.ppString
481
482      fun stripNeg tm =
483          case tm of
484            Fn (f,[a]) =>
485            if Name.toString f <> neg then (0,tm)
486            else let val (n,tm) = stripNeg a in (n + 1, tm) end
487          | _ => (0,tm)
488
489      val destQuant =
490          let
491            fun dest q (Fn (q', [Var v, body])) =
492                if Name.toString q' <> q then NONE
493                else
494                  (case dest q body of
495                     NONE => SOME (q,v,[],body)
496                   | SOME (_,v',vs,body) => SOME (q, v, v' :: vs, body))
497              | dest _ _ = NONE
498          in
499            fn tm => Useful.first (fn q => dest q tm) quants
500          end
501
502      fun isQuant tm = Option.isSome (destQuant tm)
503
504      fun destBrack (Fn (b,[tm])) =
505          let
506            val s = Name.toString b
507          in
508            case List.find (fn (n,_,_) => n = s) bMap of
509              NONE => NONE
510            | SOME (_,b1,b2) => SOME (b1,tm,b2)
511          end
512        | destBrack _ = NONE
513
514      fun isBrack tm = Option.isSome (destBrack tm)
515
516      fun functionArgument bv tm =
517          Print.sequence
518            Print.break
519            (if isBrack tm then customBracket bv tm
520             else if isVar tm orelse isConst tm then basic bv tm
521             else bracket bv tm)
522
523      and basic bv (Var v) = varName bv v
524        | basic bv (Fn (f,args)) =
525          Print.inconsistentBlock 2
526            (functionName bv f :: List.map (functionArgument bv) args)
527
528      and customBracket bv tm =
529          case destBrack tm of
530            SOME (b1,tm,b2) => Print.ppBracket b1 b2 (term bv) tm
531          | NONE => basic bv tm
532
533      and innerQuant bv tm =
534          case destQuant tm of
535            NONE => term bv tm
536          | SOME (q,v,vs,tm) =>
537            let
538              val bv = StringSet.addList bv (List.map Name.toString (v :: vs))
539            in
540              Print.program
541                [Print.ppString q,
542                 varName bv v,
543                 Print.program
544                   (List.map (Print.sequence Print.break o varName bv) vs),
545                 Print.ppString ".",
546                 Print.break,
547                 innerQuant bv tm]
548            end
549
550      and quantifier bv tm =
551          if not (isQuant tm) then customBracket bv tm
552          else Print.inconsistentBlock 2 [innerQuant bv tm]
553
554      and molecule bv (tm,r) =
555          let
556            val (n,tm) = stripNeg tm
557          in
558            Print.inconsistentBlock n
559              [Print.duplicate n (Print.ppString neg),
560               if isI tm orelse (r andalso isQuant tm) then bracket bv tm
561               else quantifier bv tm]
562          end
563
564      and term bv tm = iPrinter (molecule bv) (tm,false)
565
566      and bracket bv tm = Print.ppBracket "(" ")" (term bv) tm
567    in
568      term StringSet.empty
569    end inputTerm;
570
571val toString = Print.toString pp;
572
573(* Parsing *)
574
575local
576  open Parse;
577
578  infixr 9 >>++
579  infixr 8 ++
580  infixr 7 >>
581  infixr 6 ||
582
583  val isAlphaNum =
584      let
585        val alphaNumChars = String.explode "_'"
586      in
587        fn c => mem c alphaNumChars orelse Char.isAlphaNum c
588      end;
589
590  local
591    val alphaNumToken = atLeastOne (some isAlphaNum) >> String.implode;
592
593    val symbolToken =
594        let
595          fun isNeg c = str c = !negation
596
597          val symbolChars = String.explode "<>=-*+/\\?@|!$%&#^:;~"
598
599          fun isSymbol c = mem c symbolChars
600
601          fun isNonNegSymbol c = not (isNeg c) andalso isSymbol c
602        in
603          some isNeg >> str ||
604          (some isNonNegSymbol ++ many (some isSymbol)) >>
605          (String.implode o op::)
606        end;
607
608    val punctToken =
609        let
610          val punctChars = String.explode "()[]{}.,"
611
612          fun isPunct c = mem c punctChars
613        in
614          some isPunct >> str
615        end;
616
617    val lexToken = alphaNumToken || symbolToken || punctToken;
618
619    val space = many (some Char.isSpace);
620  in
621    val lexer = (space ++ lexToken ++ space) >> (fn (_,(tok,_)) => tok);
622  end;
623
624  fun termParser inputStream =
625      let
626        val quants = !binders
627        and iOps = !infixes
628        and neg = !negation
629        and bracks = ("(",")") :: !brackets
630
631        val bracks = List.map (fn (b1,b2) => (b1 ^ b2, b1, b2)) bracks
632
633        val bTokens = List.map #2 bracks @ List.map #3 bracks
634
635        fun possibleVarName "" = false
636          | possibleVarName s = isAlphaNum (String.sub (s,0))
637
638        fun vName bv s = StringSet.member s bv
639
640        val iTokens = Print.tokensInfixes iOps
641
642        fun iMk (f,a,b) = Fn (Name.fromString f, [a,b])
643
644        val iParser = parseInfixes iOps iMk any
645
646        val specialTokens =
647            StringSet.addList iTokens (neg :: quants @ ["$"] @ bTokens)
648
649        fun varName bv =
650            some (vName bv) ||
651            (some (Useful.equal "$") ++ some possibleVarName) >> snd
652
653        fun fName bv s =
654            not (StringSet.member s specialTokens) andalso not (vName bv s)
655
656        fun functionName bv =
657            some (fName bv) ||
658            (some (Useful.equal "(") ++ any ++ some (Useful.equal ")")) >>
659            (fn (_,(s,_)) => s)
660
661        fun basic bv tokens =
662            let
663              val var = varName bv >> (Var o Name.fromString)
664
665              val const =
666                  functionName bv >> (fn f => Fn (Name.fromString f, []))
667
668              fun bracket (ab,a,b) =
669                  (some (Useful.equal a) ++ term bv ++ some (Useful.equal b)) >>
670                  (fn (_,(tm,_)) =>
671                      if ab = "()" then tm else Fn (Name.fromString ab, [tm]))
672
673              fun quantifier q =
674                  let
675                    fun bind (v,t) =
676                        Fn (Name.fromString q, [Var (Name.fromString v), t])
677                  in
678                    (some (Useful.equal q) ++
679                     atLeastOne (some possibleVarName) ++
680                     some (Useful.equal ".")) >>++
681                    (fn (_,(vs,_)) =>
682                        term (StringSet.addList bv vs) >>
683                        (fn body => List.foldr bind body vs))
684                  end
685            in
686              var ||
687              const ||
688              first (List.map bracket bracks) ||
689              first (List.map quantifier quants)
690            end tokens
691
692        and molecule bv tokens =
693            let
694              val negations = many (some (Useful.equal neg)) >> length
695
696              val function =
697                  (functionName bv ++ many (basic bv)) >>
698                  (fn (f,args) => Fn (Name.fromString f, args)) ||
699                  basic bv
700            in
701              (negations ++ function) >>
702              (fn (n,tm) => funpow n (fn t => Fn (Name.fromString neg, [t])) tm)
703            end tokens
704
705        and term bv tokens = iParser (molecule bv) tokens
706      in
707        term StringSet.empty
708      end inputStream;
709in
710  fun fromString input =
711      let
712        val chars = Stream.fromList (String.explode input)
713
714        val tokens = everything (lexer >> singleton) chars
715
716        val terms = everything (termParser >> singleton) tokens
717      in
718        case Stream.toList terms of
719          [tm] => tm
720        | _ => raise Error "Term.fromString"
721      end;
722end;
723
724local
725  val antiquotedTermToString =
726      Print.toString (Print.ppBracket "(" ")" pp);
727in
728  val parse = Parse.parseQuotation antiquotedTermToString fromString;
729end;
730
731end
732
733structure TermOrdered =
734struct type t = Term.term val compare = Term.compare end
735
736structure TermMap = KeyMap (TermOrdered);
737
738structure TermSet = ElementSet (TermMap);
739