1(* non-interactive mode 2*) 3structure skiTools :> skiTools = 4struct 5open HolKernel Parse boolLib; 6 7(* interactive mode 8val () = loadPath := union ["..", "../finished"] (!loadPath); 9val () = app load 10 ["HurdUseful", 11 "ho_basicTools", 12 "unifyTools", 13 "skiTheory"]; 14val () = show_assums := true; 15*) 16 17open hurdUtils ho_basicTools unifyTools skiTheory; 18 19infixr 0 oo ## ++ << || THENC ORELSEC THENR ORELSER; 20infix 1 >> |->; 21 22val op++ = op THEN; 23val op<< = op THENL; 24val op|| = op ORELSE; 25val op>> = op THEN1; 26val !! = REPEAT; 27 28(* non-interactive mode 29*) 30fun trace _ _ = (); 31fun printVal _ = (); 32 33(* ------------------------------------------------------------------------- *) 34(* Type/term substitutions *) 35(* ------------------------------------------------------------------------- *) 36 37val empty_raw_subst : raw_substitution = (([], empty_tmset), ([], [])); 38 39fun raw_match' tm1 tm2 ((tmS, tmIds), (tyS, tyIds)) = 40 raw_match tyIds tmIds tm1 tm2 (tmS, tyS); 41 42fun type_raw_match ty1 ty2 (sub : raw_substitution) = 43 let 44 val tm1 = mk_const ("NIL", mk_type ("list", [ty1])) 45 val tm2 = mk_const ("NIL", mk_type ("list", [ty2])) 46 in 47 raw_match' tm1 tm2 sub 48 end; 49 50val finalize_subst : raw_substitution -> substitution = norm_subst; 51 52(* ------------------------------------------------------------------------- *) 53(* Conversion to combinators {S,K,I}. *) 54(* ------------------------------------------------------------------------- *) 55 56fun SKI_CONV tm = 57 (case dest_term tm of 58 CONST _ => ALL_CONV 59 | VAR _ => ALL_CONV 60 | COMB _ => RAND_CONV SKI_CONV THENC RATOR_CONV SKI_CONV 61 | LAMB _ 62 => ABS_CONV SKI_CONV THENC 63 (ho_REWR_CONV MK_K ORELSEC 64 ho_REWR_CONV MK_I ORELSEC 65 ho_REWR_CONV MK_S) THENC 66 SKI_CONV) tm; 67 68(* 69try SKI_CONV ``?x. !y. x y = y + 1``; 70SKI_CONV ``\x. f x o g``; 71SKI_CONV ``\x y. f x y``; 72SKI_CONV ``$? = \P. P ($@ P)``; 73SKI_CONV ``$==> = \a b. ~a \/ b``; 74SKI_CONV ``$! = \P. K T = P``; 75SKI_CONV ``!x y. P x y``; 76SKI_CONV ``!x y. P y x``; 77SKI_CONV ``(P = Q) = (!x. P x = Q x)``; 78*) 79 80(* ------------------------------------------------------------------------- *) 81(* A combinator {S,K,I} unify function. *) 82(* ------------------------------------------------------------------------- *) 83 84local 85 fun occurs v tm = free_in v tm 86 87 fun solve _ sub [] = sub 88 | solve vars sub (current :: next) = 89 solve' vars sub (Df (pinst sub) current) next 90 and solve' vars sub (tm1, tm2) rest = 91 if is_tmvar vars tm1 then 92 if tm1 ~~ tm2 then solve vars sub rest 93 else if occurs tm1 tm2 then raise ERR "ski_unify" "occurs check" 94 else 95 (case total (tfind_redex tm1) (fst sub) of SOME {residue, ...} 96 => solve' vars sub (tm2, residue) rest 97 | NONE => 98 let 99 val (ty1, ty2) = Df type_of (tm1, tm2) 100 val sub_extra = var_type_unify (snd vars) ty1 ty2 101 val (tm1', tm2') = Df (inst_ty sub_extra) (tm1, tm2) 102 val sub' = refine_subst sub ([tm1' |-> tm2'], sub_extra) 103 val vars' = (map (inst_ty sub_extra) ## I) vars 104 in 105 solve vars' sub' rest 106 end) 107 else if is_tmvar vars tm2 then solve' vars sub (tm2, tm1) rest 108 else 109 (case Df dest_term (tm1, tm2) of 110 (COMB (Rator, Rand), COMB (Rator', Rand')) 111 => solve' vars sub (Rator, Rator') ((Rand, Rand') :: rest) 112 | (VAR (Name, Ty), VAR (Name', Ty')) => 113 let 114 val _ = assert (Name = Name') (ERR "ski_unify" "different vars") 115 val _ = assert (Ty = Ty') 116 (BUG "ski_unify" "same var, different types?") 117 in 118 solve vars sub rest 119 end 120 | (CONST {Name, Thy, Ty}, CONST {Name = Name', Thy = Thy', Ty = Ty'}) => 121 let 122 val _ = 123 assert (Name = Name' andalso Thy = Thy') 124 (ERR "ski_unify" "different vars") 125 val sub_extra = var_type_unify (snd vars) Ty Ty' 126 val sub' = refine_subst sub ([], sub_extra) 127 val vars' = (map (inst_ty sub_extra) ## I) vars 128 in 129 solve vars' sub' rest 130 end 131 | _ => raise ERR "ski_unify" "terms fundamentally different") 132in 133 fun ski_unifyl vars work = solve vars empty_subst work; 134 fun ski_unify vars tm1 tm2 = solve' vars empty_subst (tm1, tm2) []; 135end; 136 137(* ------------------------------------------------------------------------- *) 138(* A combinator {S,K,I} term discriminator. *) 139(* ------------------------------------------------------------------------- *) 140 141datatype ski_pattern 142 = SKI_COMB_BEGIN 143 | SKI_COMB_END 144 | SKI_CONST of term 145 | SKI_VAR of term; 146 147datatype 'a ski_discrim = 148 SKI_DISCRIM of int * (ski_pattern, vars * 'a) tree list; 149 150fun skipat_eq p1 p2 = 151 case (p1, p2) of 152 (SKI_COMB_BEGIN, SKI_COMB_BEGIN) => true 153 | (SKI_COMB_END, SKI_COMB_END) => true 154 | (SKI_CONST t1, SKI_CONST t2) => aconv t1 t2 155 | (SKI_VAR t1, SKI_VAR t2) => aconv t1 t2 156 | _ => false 157 158val empty_ski_discrim = SKI_DISCRIM (0, []); 159fun ski_discrim_size (SKI_DISCRIM (i, _)) = i; 160 161fun ski_pattern_term_build SKI_COMB_BEGIN stack = 162 (LEFT SKI_COMB_BEGIN :: stack) 163 | ski_pattern_term_build SKI_COMB_END 164 (RIGHT rand :: RIGHT rator :: LEFT SKI_COMB_BEGIN :: stack) = 165 (RIGHT (mk_comb (rator, rand)) :: stack) 166 | ski_pattern_term_build (SKI_CONST c) stack = RIGHT c :: stack 167 | ski_pattern_term_build (SKI_VAR v) stack = RIGHT v :: stack 168 | ski_pattern_term_build SKI_COMB_END _ = 169 raise BUG "ski_pattern_term_build" "badly formed list"; 170 171local 172 fun final [RIGHT tm] = tm 173 | final _ = raise ERR "ski_patterns_to_term" "not a complete pattern list" 174 175 fun final_a (vars, a) stack = ((vars, final stack), a) 176in 177 fun ski_patterns_to_term pats = 178 final (trans ski_pattern_term_build [] pats) 179 180 fun dest_ski_discrim (SKI_DISCRIM (_, d)) = 181 flatten (map (tree_trans ski_pattern_term_build final_a []) d) 182end; 183 184fun ski_pattern_term_break ((tm_vars, _) : vars) (RIGHT tm :: rest) = 185 (case dest_term tm of COMB (Rator, Rand) => 186 LEFT SKI_COMB_BEGIN :: RIGHT Rator :: RIGHT Rand :: 187 LEFT SKI_COMB_END :: rest 188 | LAMB _ => raise BUG "ski_pattern_term_break" "can't break a lambda" 189 | _ => LEFT (if tmem tm tm_vars then SKI_VAR tm else SKI_CONST tm) :: rest) 190 | ski_pattern_term_break _ [] = 191 raise BUG "ski_pattern_term_break" "nothing to break" 192 | ski_pattern_term_break _ (LEFT _ :: _) = 193 raise BUG "ski_pattern_term_break" "can't break a ski_pattern"; 194 195fun vterm_to_ski_patterns (vars, tm) = 196 let 197 fun break res [] = rev res 198 | break res (LEFT p :: rest) = break (p :: res) rest 199 | break res ripe = break res (ski_pattern_term_break vars ripe) 200 val res = break [] [RIGHT tm] 201 val _ = 202 trace "vterm_to_ski_patterns: ((vars, tm), res)" 203 (fn () => printVal ((vars, tm), res)) 204 in 205 res 206 end; 207 208local 209 fun add a [] leaves = LEAF a :: leaves 210 | add a (pat :: next) [] = [BRANCH (pat, add a next [])] 211 | add a (pats as pat :: next) ((b as BRANCH (pat', trees)) :: branches) = 212 if skipat_eq pat pat' then BRANCH (pat', add a next trees) :: branches 213 else b :: add a pats branches 214 | add _ (_::_) (LEAF _::_) = 215 raise BUG "discrim_add" "expected a branch, got a leaf" 216in 217 fun ski_discrim_add ((vars, tm), a) (SKI_DISCRIM (i, d)) = 218 SKI_DISCRIM (i + 1, add (vars, a) (vterm_to_ski_patterns (vars, tm)) d) 219end; 220 221fun ski_discrim_addl l d = trans ski_discrim_add d l; 222 223fun mk_ski_discrim l = ski_discrim_addl l empty_ski_discrim; 224 225fun ski_pattern_reduce (SKI_VAR _) _ _ = 226 raise BUG "ski_pattern_reduce" "can't reduce variables" 227 | ski_pattern_reduce SKI_COMB_BEGIN SKI_COMB_BEGIN sub = sub 228 | ski_pattern_reduce SKI_COMB_END SKI_COMB_END sub = sub 229 | ski_pattern_reduce (SKI_CONST c) (SKI_CONST c') sub = 230 raw_match' c c' sub 231 | ski_pattern_reduce _ _ _ = 232 raise ERR "ski_pattern_reduce" "patterns fundamentally different"; 233 234local 235 fun advance (SKI_VAR v) (RIGHT tm :: rest, sub) = (rest, raw_match' v tm sub) 236 | advance pat (state as RIGHT _ :: _, sub) = 237 advance pat (ski_pattern_term_break empty_vars state, sub) 238 | advance pat (LEFT pat' :: rest, sub) = 239 (rest, ski_pattern_reduce pat pat' sub) 240 | advance _ ([], _) = 241 raise BUG "ski_discrim_match" "no patterns left in list" 242 243 fun finally (_, a) ([], sub) = SOME (finalize_subst sub, a) 244 | finally _ _ = raise BUG "ski_discrim_match" "patterns left at end" 245 246 fun tree_match tm = 247 tree_partial_trans (total o advance) finally ([RIGHT tm], empty_raw_subst) 248in 249 fun ski_discrim_match (SKI_DISCRIM (_, d)) tm = 250 (flatten o map (tree_match tm)) d 251 handle HOL_ERR _ => raise BUG "ski_discrim_match" "should never fail"; 252end; 253 254local 255 fun advance _ pat (SOME (v, lstate), rstate, work) = 256 (case ski_pattern_term_build pat lstate of 257 [RIGHT tm] => (NONE, rstate, (v, tm) :: work) 258 | lstate' => (SOME (v, lstate'), rstate, work)) 259 | advance _ (SKI_VAR v) (NONE, RIGHT tm :: rest, work) = 260 (NONE, rest, (v, tm) :: work) 261 | advance vars pat (NONE, rstate as RIGHT _ :: _, work) = 262 advance vars pat (NONE, ski_pattern_term_break vars rstate, work) 263 | advance vars pat (NONE, LEFT (SKI_VAR v) :: rest, work) = 264 advance vars pat (SOME (v, []), rest, work) 265 | advance _ (SKI_CONST c) (NONE, LEFT (SKI_CONST c') :: rest, work) = 266 (NONE, rest, (c, c') :: work) 267 | advance _ pat (NONE, LEFT pat' :: rest, work) = 268 if skipat_eq pat pat' then (NONE, rest, work) 269 else raise ERR "ski_discrim_unify" "terms fundamentally different" 270 | advance _ _ (NONE, [], _) = 271 raise BUG "ski_discrim_match" "no patterns left in list"; 272 273 fun finally (vars, a) (NONE, [], work) = SOME ((vars, work), a) 274 | finally _ _ = raise BUG "ski_discrim_unify" "patterns left at end"; 275 276 fun tree_search vars tm = 277 tree_partial_trans (total o advance vars) finally (NONE, [RIGHT tm], []); 278in 279 fun ski_discrim_unify (sd as SKI_DISCRIM (_, d)) (vars, tm) = 280 let 281 val shortlist = (flatten o map (tree_search vars tm)) d 282 fun select ((vars', work), a) = 283 (ski_unifyl (vars_union vars vars') work, a) 284 val res = partial_map (total select) shortlist 285 val _ = trace 286 "ski_discrim_unify: ((vars, tm), map fst res)" 287 (fn () => printVal ((vars, tm), map fst res)) 288 in 289 res 290 end 291end; 292 293(* non-interactive mode 294*) 295end; 296