1(* ========================================================================= *) 2(* FIRST ORDER LOGIC ATOMS *) 3(* Copyright (c) 2001 Joe Hurd, distributed under the BSD License *) 4(* ========================================================================= *) 5 6structure Atom :> Atom = 7struct 8 9open Useful; 10 11(* ------------------------------------------------------------------------- *) 12(* A type for storing first order logic atoms. *) 13(* ------------------------------------------------------------------------- *) 14 15type relationName = Name.name; 16 17type relation = relationName * int; 18 19type atom = relationName * Term.term list; 20 21(* ------------------------------------------------------------------------- *) 22(* Constructors and destructors. *) 23(* ------------------------------------------------------------------------- *) 24 25fun name ((rel,_) : atom) = rel; 26 27fun arguments ((_,args) : atom) = args; 28 29fun arity atm = length (arguments atm); 30 31fun relation atm = (name atm, arity atm); 32 33val functions = 34 let 35 fun f (tm,acc) = NameAritySet.union (Term.functions tm) acc 36 in 37 fn atm => List.foldl f NameAritySet.empty (arguments atm) 38 end; 39 40val functionNames = 41 let 42 fun f (tm,acc) = NameSet.union (Term.functionNames tm) acc 43 in 44 fn atm => List.foldl f NameSet.empty (arguments atm) 45 end; 46 47(* Binary relations *) 48 49fun mkBinop p (a,b) : atom = (p,[a,b]); 50 51fun destBinop p (x,[a,b]) = 52 if Name.equal x p then (a,b) else raise Error "Atom.destBinop: wrong binop" 53 | destBinop _ _ = raise Error "Atom.destBinop: not a binop"; 54 55fun isBinop p = can (destBinop p); 56 57(* ------------------------------------------------------------------------- *) 58(* The size of an atom in symbols. *) 59(* ------------------------------------------------------------------------- *) 60 61fun symbols atm = 62 List.foldl (fn (tm,z) => Term.symbols tm + z) 1 (arguments atm); 63 64(* ------------------------------------------------------------------------- *) 65(* A total comparison function for atoms. *) 66(* ------------------------------------------------------------------------- *) 67 68fun compare ((p1,tms1),(p2,tms2)) = 69 case Name.compare (p1,p2) of 70 LESS => LESS 71 | EQUAL => lexCompare Term.compare (tms1,tms2) 72 | GREATER => GREATER; 73 74fun equal atm1 atm2 = compare (atm1,atm2) = EQUAL; 75 76(* ------------------------------------------------------------------------- *) 77(* Subterms. *) 78(* ------------------------------------------------------------------------- *) 79 80fun subterm _ [] = raise Bug "Atom.subterm: empty path" 81 | subterm ((_,tms) : atom) (h :: t) = 82 if h >= length tms then raise Error "Atom.subterm: bad path" 83 else Term.subterm (List.nth (tms,h)) t; 84 85fun subterms ((_,tms) : atom) = 86 let 87 fun f ((n,tm),l) = List.map (fn (p,s) => (n :: p, s)) (Term.subterms tm) @ l 88 in 89 List.foldl f [] (enumerate tms) 90 end; 91 92fun replace _ ([],_) = raise Bug "Atom.replace: empty path" 93 | replace (atm as (rel,tms)) (h :: t, res) : atom = 94 if h >= length tms then raise Error "Atom.replace: bad path" 95 else 96 let 97 val tm = List.nth (tms,h) 98 val tm' = Term.replace tm (t,res) 99 in 100 if Portable.pointerEqual (tm,tm') then atm 101 else (rel, updateNth (h,tm') tms) 102 end; 103 104fun find pred = 105 let 106 fun f (i,tm) = 107 case Term.find pred tm of 108 SOME path => SOME (i :: path) 109 | NONE => NONE 110 in 111 fn (_,tms) : atom => first f (enumerate tms) 112 end; 113 114(* ------------------------------------------------------------------------- *) 115(* Free variables. *) 116(* ------------------------------------------------------------------------- *) 117 118fun freeIn v atm = List.exists (Term.freeIn v) (arguments atm); 119 120val freeVars = 121 let 122 fun f (tm,acc) = NameSet.union (Term.freeVars tm) acc 123 in 124 fn atm => List.foldl f NameSet.empty (arguments atm) 125 end; 126 127(* ------------------------------------------------------------------------- *) 128(* Substitutions. *) 129(* ------------------------------------------------------------------------- *) 130 131fun subst sub (atm as (p,tms)) : atom = 132 let 133 val tms' = Sharing.map (Subst.subst sub) tms 134 in 135 if Portable.pointerEqual (tms',tms) then atm else (p,tms') 136 end; 137 138(* ------------------------------------------------------------------------- *) 139(* Matching. *) 140(* ------------------------------------------------------------------------- *) 141 142local 143 fun matchArg ((tm1,tm2),sub) = Subst.match sub tm1 tm2; 144in 145 fun match sub (p1,tms1) (p2,tms2) = 146 let 147 val _ = (Name.equal p1 p2 andalso length tms1 = length tms2) orelse 148 raise Error "Atom.match" 149 in 150 List.foldl matchArg sub (zip tms1 tms2) 151 end; 152end; 153 154(* ------------------------------------------------------------------------- *) 155(* Unification. *) 156(* ------------------------------------------------------------------------- *) 157 158local 159 fun unifyArg ((tm1,tm2),sub) = Subst.unify sub tm1 tm2; 160in 161 fun unify sub (p1,tms1) (p2,tms2) = 162 let 163 val _ = (Name.equal p1 p2 andalso length tms1 = length tms2) orelse 164 raise Error "Atom.unify" 165 in 166 List.foldl unifyArg sub (zip tms1 tms2) 167 end; 168end; 169 170(* ------------------------------------------------------------------------- *) 171(* The equality relation. *) 172(* ------------------------------------------------------------------------- *) 173 174val eqRelationName = Name.fromString "="; 175 176val eqRelationArity = 2; 177 178val eqRelation = (eqRelationName,eqRelationArity); 179 180val mkEq = mkBinop eqRelationName; 181 182fun destEq x = destBinop eqRelationName x; 183 184fun isEq x = isBinop eqRelationName x; 185 186fun mkRefl tm = mkEq (tm,tm); 187 188fun destRefl atm = 189 let 190 val (l,r) = destEq atm 191 val _ = Term.equal l r orelse raise Error "Atom.destRefl" 192 in 193 l 194 end; 195 196fun isRefl x = can destRefl x; 197 198fun sym atm = 199 let 200 val (l,r) = destEq atm 201 val _ = not (Term.equal l r) orelse raise Error "Atom.sym: refl" 202 in 203 mkEq (r,l) 204 end; 205 206fun lhs atm = fst (destEq atm); 207 208fun rhs atm = snd (destEq atm); 209 210(* ------------------------------------------------------------------------- *) 211(* Special support for terms with type annotations. *) 212(* ------------------------------------------------------------------------- *) 213 214fun typedSymbols ((_,tms) : atom) = 215 List.foldl (fn (tm,z) => Term.typedSymbols tm + z) 1 tms; 216 217fun nonVarTypedSubterms (_,tms) = 218 let 219 fun addArg ((n,arg),acc) = 220 let 221 fun addTm ((path,tm),acc) = (n :: path, tm) :: acc 222 in 223 List.foldl addTm acc (Term.nonVarTypedSubterms arg) 224 end 225 in 226 List.foldl addArg [] (enumerate tms) 227 end; 228 229(* ------------------------------------------------------------------------- *) 230(* Parsing and pretty printing. *) 231(* ------------------------------------------------------------------------- *) 232 233val pp = Print.ppMap Term.Fn Term.pp; 234 235val toString = Print.toString pp; 236 237fun fromString s = Term.destFn (Term.fromString s); 238 239val parse = Parse.parseQuotation Term.toString fromString; 240 241end 242 243structure AtomOrdered = 244struct type t = Atom.atom val compare = Atom.compare end 245 246structure AtomMap = KeyMap (AtomOrdered); 247 248structure AtomSet = ElementSet (AtomMap); 249