1(* ========================================================================= *)
2(* FIRST ORDER LOGIC ATOMS                                                   *)
3(* Copyright (c) 2001 Joe Hurd, distributed under the BSD License            *)
4(* ========================================================================= *)
5
6structure Atom :> Atom =
7struct
8
9open Useful;
10
11(* ------------------------------------------------------------------------- *)
12(* A type for storing first order logic atoms.                               *)
13(* ------------------------------------------------------------------------- *)
14
15type relationName = Name.name;
16
17type relation = relationName * int;
18
19type atom = relationName * Term.term list;
20
21(* ------------------------------------------------------------------------- *)
22(* Constructors and destructors.                                             *)
23(* ------------------------------------------------------------------------- *)
24
25fun name ((rel,_) : atom) = rel;
26
27fun arguments ((_,args) : atom) = args;
28
29fun arity atm = length (arguments atm);
30
31fun relation atm = (name atm, arity atm);
32
33val functions =
34    let
35      fun f (tm,acc) = NameAritySet.union (Term.functions tm) acc
36    in
37      fn atm => List.foldl f NameAritySet.empty (arguments atm)
38    end;
39
40val functionNames =
41    let
42      fun f (tm,acc) = NameSet.union (Term.functionNames tm) acc
43    in
44      fn atm => List.foldl f NameSet.empty (arguments atm)
45    end;
46
47(* Binary relations *)
48
49fun mkBinop p (a,b) : atom = (p,[a,b]);
50
51fun destBinop p (x,[a,b]) =
52    if Name.equal x p then (a,b) else raise Error "Atom.destBinop: wrong binop"
53  | destBinop _ _ = raise Error "Atom.destBinop: not a binop";
54
55fun isBinop p = can (destBinop p);
56
57(* ------------------------------------------------------------------------- *)
58(* The size of an atom in symbols.                                           *)
59(* ------------------------------------------------------------------------- *)
60
61fun symbols atm =
62    List.foldl (fn (tm,z) => Term.symbols tm + z) 1 (arguments atm);
63
64(* ------------------------------------------------------------------------- *)
65(* A total comparison function for atoms.                                    *)
66(* ------------------------------------------------------------------------- *)
67
68fun compare ((p1,tms1),(p2,tms2)) =
69    case Name.compare (p1,p2) of
70      LESS => LESS
71    | EQUAL => lexCompare Term.compare (tms1,tms2)
72    | GREATER => GREATER;
73
74fun equal atm1 atm2 = compare (atm1,atm2) = EQUAL;
75
76(* ------------------------------------------------------------------------- *)
77(* Subterms.                                                                 *)
78(* ------------------------------------------------------------------------- *)
79
80fun subterm _ [] = raise Bug "Atom.subterm: empty path"
81  | subterm ((_,tms) : atom) (h :: t) =
82    if h >= length tms then raise Error "Atom.subterm: bad path"
83    else Term.subterm (List.nth (tms,h)) t;
84
85fun subterms ((_,tms) : atom) =
86    let
87      fun f ((n,tm),l) = List.map (fn (p,s) => (n :: p, s)) (Term.subterms tm) @ l
88    in
89      List.foldl f [] (enumerate tms)
90    end;
91
92fun replace _ ([],_) = raise Bug "Atom.replace: empty path"
93  | replace (atm as (rel,tms)) (h :: t, res) : atom =
94    if h >= length tms then raise Error "Atom.replace: bad path"
95    else
96      let
97        val tm = List.nth (tms,h)
98        val tm' = Term.replace tm (t,res)
99      in
100        if Portable.pointerEqual (tm,tm') then atm
101        else (rel, updateNth (h,tm') tms)
102      end;
103
104fun find pred =
105    let
106      fun f (i,tm) =
107          case Term.find pred tm of
108            SOME path => SOME (i :: path)
109          | NONE => NONE
110    in
111      fn (_,tms) : atom => first f (enumerate tms)
112    end;
113
114(* ------------------------------------------------------------------------- *)
115(* Free variables.                                                           *)
116(* ------------------------------------------------------------------------- *)
117
118fun freeIn v atm = List.exists (Term.freeIn v) (arguments atm);
119
120val freeVars =
121    let
122      fun f (tm,acc) = NameSet.union (Term.freeVars tm) acc
123    in
124      fn atm => List.foldl f NameSet.empty (arguments atm)
125    end;
126
127(* ------------------------------------------------------------------------- *)
128(* Substitutions.                                                            *)
129(* ------------------------------------------------------------------------- *)
130
131fun subst sub (atm as (p,tms)) : atom =
132    let
133      val tms' = Sharing.map (Subst.subst sub) tms
134    in
135      if Portable.pointerEqual (tms',tms) then atm else (p,tms')
136    end;
137
138(* ------------------------------------------------------------------------- *)
139(* Matching.                                                                 *)
140(* ------------------------------------------------------------------------- *)
141
142local
143  fun matchArg ((tm1,tm2),sub) = Subst.match sub tm1 tm2;
144in
145  fun match sub (p1,tms1) (p2,tms2) =
146      let
147        val _ = (Name.equal p1 p2 andalso length tms1 = length tms2) orelse
148                raise Error "Atom.match"
149      in
150        List.foldl matchArg sub (zip tms1 tms2)
151      end;
152end;
153
154(* ------------------------------------------------------------------------- *)
155(* Unification.                                                              *)
156(* ------------------------------------------------------------------------- *)
157
158local
159  fun unifyArg ((tm1,tm2),sub) = Subst.unify sub tm1 tm2;
160in
161  fun unify sub (p1,tms1) (p2,tms2) =
162      let
163        val _ = (Name.equal p1 p2 andalso length tms1 = length tms2) orelse
164                raise Error "Atom.unify"
165      in
166        List.foldl unifyArg sub (zip tms1 tms2)
167      end;
168end;
169
170(* ------------------------------------------------------------------------- *)
171(* The equality relation.                                                    *)
172(* ------------------------------------------------------------------------- *)
173
174val eqRelationName = Name.fromString "=";
175
176val eqRelationArity = 2;
177
178val eqRelation = (eqRelationName,eqRelationArity);
179
180val mkEq = mkBinop eqRelationName;
181
182fun destEq x = destBinop eqRelationName x;
183
184fun isEq x = isBinop eqRelationName x;
185
186fun mkRefl tm = mkEq (tm,tm);
187
188fun destRefl atm =
189    let
190      val (l,r) = destEq atm
191      val _ = Term.equal l r orelse raise Error "Atom.destRefl"
192    in
193      l
194    end;
195
196fun isRefl x = can destRefl x;
197
198fun sym atm =
199    let
200      val (l,r) = destEq atm
201      val _ = not (Term.equal l r) orelse raise Error "Atom.sym: refl"
202    in
203      mkEq (r,l)
204    end;
205
206fun lhs atm = fst (destEq atm);
207
208fun rhs atm = snd (destEq atm);
209
210(* ------------------------------------------------------------------------- *)
211(* Special support for terms with type annotations.                          *)
212(* ------------------------------------------------------------------------- *)
213
214fun typedSymbols ((_,tms) : atom) =
215    List.foldl (fn (tm,z) => Term.typedSymbols tm + z) 1 tms;
216
217fun nonVarTypedSubterms (_,tms) =
218    let
219      fun addArg ((n,arg),acc) =
220          let
221            fun addTm ((path,tm),acc) = (n :: path, tm) :: acc
222          in
223            List.foldl addTm acc (Term.nonVarTypedSubterms arg)
224          end
225    in
226      List.foldl addArg [] (enumerate tms)
227    end;
228
229(* ------------------------------------------------------------------------- *)
230(* Parsing and pretty printing.                                              *)
231(* ------------------------------------------------------------------------- *)
232
233val pp = Print.ppMap Term.Fn Term.pp;
234
235val toString = Print.toString pp;
236
237fun fromString s = Term.destFn (Term.fromString s);
238
239val parse = Parse.parseQuotation Term.toString fromString;
240
241end
242
243structure AtomOrdered =
244struct type t = Atom.atom val compare = Atom.compare end
245
246structure AtomMap = KeyMap (AtomOrdered);
247
248structure AtomSet = ElementSet (AtomMap);
249