1(* non-interactive mode
2*)
3structure skiTools :> skiTools =
4struct
5open HolKernel Parse boolLib;
6
7(* interactive mode
8val () = loadPath := union ["..", "../finished"] (!loadPath);
9val () = app load
10  ["HurdUseful",
11   "ho_basicTools",
12   "unifyTools",
13   "skiTheory"];
14val () = show_assums := true;
15*)
16
17open HurdUseful ho_basicTools unifyTools skiTheory;
18
19infixr 0 oo ## ++ << || THENC ORELSEC THENR ORELSER;
20infix 1 >> |->;
21
22val op++ = op THEN;
23val op<< = op THENL;
24val op|| = op ORELSE;
25val op>> = op THEN1;
26val !! = REPEAT;
27
28(* non-interactive mode
29*)
30fun trace _ _ = ();
31fun printVal _ = ();
32
33(* ------------------------------------------------------------------------- *)
34(* Type/term substitutions                                                   *)
35(* ------------------------------------------------------------------------- *)
36
37val empty_raw_subst : raw_substitution = (([], empty_tmset), ([], []));
38
39fun raw_match' tm1 tm2 ((tmS, tmIds), (tyS, tyIds)) =
40  raw_match tyIds tmIds tm1 tm2 (tmS, tyS);
41
42fun type_raw_match ty1 ty2 (sub : raw_substitution) =
43  let
44    val tm1 = mk_const ("NIL", mk_type ("list", [ty1]))
45    val tm2 = mk_const ("NIL", mk_type ("list", [ty2]))
46  in
47    raw_match' tm1 tm2 sub
48  end;
49
50val finalize_subst : raw_substitution -> substitution = norm_subst;
51
52(* ------------------------------------------------------------------------- *)
53(* Conversion to combinators {S,K,I}.                                        *)
54(* ------------------------------------------------------------------------- *)
55
56fun SKI_CONV tm =
57  (case dest_term tm of
58     CONST _ => ALL_CONV
59   | VAR _ => ALL_CONV
60   | COMB _ => RAND_CONV SKI_CONV THENC RATOR_CONV SKI_CONV
61   | LAMB _
62     => ABS_CONV SKI_CONV THENC
63        (ho_REWR_CONV MK_K ORELSEC
64         ho_REWR_CONV MK_I ORELSEC
65         ho_REWR_CONV MK_S) THENC
66        SKI_CONV) tm;
67
68(*
69try SKI_CONV ``?x. !y. x y = y + 1``;
70SKI_CONV ``\x. f x o g``;
71SKI_CONV ``\x y. f x y``;
72SKI_CONV ``$? = \P. P ($@ P)``;
73SKI_CONV ``$==> = \a b. ~a \/ b``;
74SKI_CONV ``$! = \P. K T = P``;
75SKI_CONV ``!x y. P x y``;
76SKI_CONV ``!x y. P y x``;
77SKI_CONV ``(P = Q) = (!x. P x = Q x)``;
78*)
79
80(* ------------------------------------------------------------------------- *)
81(* A combinator {S,K,I} unify function.                                      *)
82(* ------------------------------------------------------------------------- *)
83
84local
85  fun occurs v tm = free_in v tm
86
87  fun solve _ sub [] = sub
88    | solve vars sub (current :: next) =
89    solve' vars sub (Df (pinst sub) current) next
90  and solve' vars sub (tm1, tm2) rest =
91    if is_tmvar vars tm1 then
92      if tm1 = tm2 then solve vars sub rest
93      else if occurs tm1 tm2 then raise ERR "ski_unify" "occurs check"
94      else
95        (case total (find_redex tm1) (fst sub) of SOME {residue, ...}
96           => solve' vars sub (tm2, residue) rest
97         | NONE =>
98           let
99             val (ty1, ty2) = Df type_of (tm1, tm2)
100             val sub_extra = var_type_unify (snd vars) ty1 ty2
101             val (tm1', tm2') = Df (inst_ty sub_extra) (tm1, tm2)
102             val sub' = refine_subst sub ([tm1' |-> tm2'], sub_extra)
103             val vars' = (map (inst_ty sub_extra) ## I) vars
104           in
105             solve vars' sub' rest
106           end)
107    else if is_tmvar vars tm2 then solve' vars sub (tm2, tm1) rest
108    else
109      (case Df dest_term (tm1, tm2) of
110         (COMB (Rator, Rand), COMB (Rator', Rand'))
111         => solve' vars sub (Rator, Rator') ((Rand, Rand') :: rest)
112       | (VAR (Name, Ty), VAR (Name', Ty')) =>
113         let
114           val _ = assert (Name = Name') (ERR "ski_unify" "different vars")
115           val _ = assert (Ty = Ty')
116             (BUG "ski_unify" "same var, different types?")
117         in
118           solve vars sub rest
119         end
120       | (CONST {Name, Thy, Ty}, CONST {Name = Name', Thy = Thy', Ty = Ty'}) =>
121         let
122           val _ =
123             assert (Name = Name' andalso Thy = Thy')
124             (ERR "ski_unify" "different vars")
125           val sub_extra = var_type_unify (snd vars) Ty Ty'
126           val sub' = refine_subst sub ([], sub_extra)
127           val vars' = (map (inst_ty sub_extra) ## I) vars
128         in
129           solve vars' sub' rest
130         end
131       | _ => raise ERR "ski_unify" "terms fundamentally different")
132in
133  fun ski_unifyl vars work = solve vars empty_subst work;
134  fun ski_unify vars tm1 tm2 = solve' vars empty_subst (tm1, tm2) [];
135end;
136
137(* ------------------------------------------------------------------------- *)
138(* A combinator {S,K,I} term discriminator.                                  *)
139(* ------------------------------------------------------------------------- *)
140
141datatype ski_pattern
142  = SKI_COMB_BEGIN
143  | SKI_COMB_END
144  | SKI_CONST of term
145  | SKI_VAR of term;
146
147datatype 'a ski_discrim =
148  SKI_DISCRIM of int * (ski_pattern, vars * 'a) tree list;
149
150val empty_ski_discrim = SKI_DISCRIM (0, []);
151fun ski_discrim_size (SKI_DISCRIM (i, _)) = i;
152
153fun ski_pattern_term_build SKI_COMB_BEGIN stack =
154  (LEFT SKI_COMB_BEGIN :: stack)
155  | ski_pattern_term_build SKI_COMB_END
156  (RIGHT rand :: RIGHT rator :: LEFT SKI_COMB_BEGIN :: stack) =
157  (RIGHT (mk_comb (rator, rand)) :: stack)
158  | ski_pattern_term_build (SKI_CONST c) stack = RIGHT c :: stack
159  | ski_pattern_term_build (SKI_VAR v) stack = RIGHT v :: stack
160  | ski_pattern_term_build SKI_COMB_END _ =
161  raise BUG "ski_pattern_term_build" "badly formed list";
162
163local
164  fun final [RIGHT tm] = tm
165    | final _ = raise ERR "ski_patterns_to_term" "not a complete pattern list"
166
167  fun final_a (vars, a) stack = ((vars, final stack), a)
168in
169  fun ski_patterns_to_term pats =
170    final (trans ski_pattern_term_build [] pats)
171
172  fun dest_ski_discrim (SKI_DISCRIM (_, d)) =
173    flatten (map (tree_trans ski_pattern_term_build final_a []) d)
174end;
175
176fun ski_pattern_term_break ((tm_vars, _) : vars) (RIGHT tm :: rest) =
177  (case dest_term tm of COMB (Rator, Rand) =>
178     LEFT SKI_COMB_BEGIN :: RIGHT Rator :: RIGHT Rand ::
179     LEFT SKI_COMB_END :: rest
180   | LAMB _ => raise BUG "ski_pattern_term_break" "can't break a lambda"
181   | _ => LEFT (if mem tm tm_vars then SKI_VAR tm else SKI_CONST tm) :: rest)
182  | ski_pattern_term_break _ [] =
183  raise BUG "ski_pattern_term_break" "nothing to break"
184  | ski_pattern_term_break _ (LEFT _ :: _) =
185  raise BUG "ski_pattern_term_break" "can't break a ski_pattern";
186
187fun vterm_to_ski_patterns (vars, tm) =
188  let
189    fun break res [] = rev res
190      | break res (LEFT p :: rest) = break (p :: res) rest
191      | break res ripe = break res (ski_pattern_term_break vars ripe)
192    val res = break [] [RIGHT tm]
193    val _ =
194      trace "vterm_to_ski_patterns: ((vars, tm), res)"
195      (fn () => printVal ((vars, tm), res))
196  in
197    res
198  end;
199
200local
201  fun add a [] leaves = LEAF a :: leaves
202    | add a (pat :: next) [] = [BRANCH (pat, add a next [])]
203    | add a (pats as pat :: next) ((b as BRANCH (pat', trees)) :: branches) =
204    if pat = pat' then BRANCH (pat', add a next trees) :: branches
205    else b :: add a pats branches
206    | add _ (_::_) (LEAF _::_) =
207    raise BUG "discrim_add" "expected a branch, got a leaf"
208in
209  fun ski_discrim_add ((vars, tm), a) (SKI_DISCRIM (i, d)) =
210    SKI_DISCRIM (i + 1, add (vars, a) (vterm_to_ski_patterns (vars, tm)) d)
211end;
212
213fun ski_discrim_addl l d = trans ski_discrim_add d l;
214
215fun mk_ski_discrim l = ski_discrim_addl l empty_ski_discrim;
216
217fun ski_pattern_reduce (SKI_VAR _) _ _ =
218  raise BUG "ski_pattern_reduce" "can't reduce variables"
219  | ski_pattern_reduce SKI_COMB_BEGIN SKI_COMB_BEGIN sub = sub
220  | ski_pattern_reduce SKI_COMB_END SKI_COMB_END sub = sub
221  | ski_pattern_reduce (SKI_CONST c) (SKI_CONST c') sub =
222  raw_match' c c' sub
223  | ski_pattern_reduce _ _ _ =
224  raise ERR "ski_pattern_reduce" "patterns fundamentally different";
225
226local
227  fun advance (SKI_VAR v) (RIGHT tm :: rest, sub) = (rest, raw_match' v tm sub)
228    | advance pat (state as RIGHT _ :: _, sub) =
229    advance pat (ski_pattern_term_break empty_vars state, sub)
230    | advance pat (LEFT pat' :: rest, sub) =
231    (rest, ski_pattern_reduce pat pat' sub)
232    | advance _ ([], _) =
233    raise BUG "ski_discrim_match" "no patterns left in list"
234
235  fun finally (_, a) ([], sub) = SOME (finalize_subst sub, a)
236    | finally _ _ = raise BUG "ski_discrim_match" "patterns left at end"
237
238  fun tree_match tm =
239    tree_partial_trans (total o advance) finally ([RIGHT tm], empty_raw_subst)
240in
241  fun ski_discrim_match (SKI_DISCRIM (_, d)) tm =
242    (flatten o map (tree_match tm)) d
243  handle HOL_ERR _ => raise BUG "ski_discrim_match" "should never fail";
244end;
245
246local
247  fun advance _ pat (SOME (v, lstate), rstate, work) =
248    (case ski_pattern_term_build pat lstate of [RIGHT tm]
249       => (NONE, rstate, (v, tm) :: work)
250     | lstate' => (SOME (v, lstate'), rstate, work))
251    | advance _ (SKI_VAR v) (NONE, RIGHT tm :: rest, work) =
252    (NONE, rest, (v, tm) :: work)
253    | advance vars pat (NONE, rstate as RIGHT _ :: _, work) =
254    advance vars pat (NONE, ski_pattern_term_break vars rstate, work)
255    | advance vars pat (NONE, LEFT (SKI_VAR v) :: rest, work) =
256    advance vars pat (SOME (v, []), rest, work)
257    | advance _ (SKI_CONST c) (NONE, LEFT (SKI_CONST c') :: rest, work) =
258    (NONE, rest, (c, c') :: work)
259    | advance _ pat (NONE, LEFT pat' :: rest, work) =
260    if pat = pat' then (NONE, rest, work)
261    else raise ERR "ski_discrim_unify" "terms fundamentally different"
262    | advance _ _ (NONE, [], _) =
263    raise BUG "ski_discrim_match" "no patterns left in list";
264
265  fun finally (vars, a) (NONE, [], work) = SOME ((vars, work), a)
266    | finally _ _ = raise BUG "ski_discrim_unify" "patterns left at end";
267
268  fun tree_search vars tm =
269    tree_partial_trans (total o advance vars) finally (NONE, [RIGHT tm], []);
270in
271  fun ski_discrim_unify (sd as SKI_DISCRIM (_, d)) (vars, tm) =
272    let
273      val shortlist = (flatten o map (tree_search vars tm)) d
274      fun select ((vars', work), a) = (ski_unifyl (union2 vars vars') work, a)
275      val res = partial_map (total select) shortlist
276      val _ = trace
277        "ski_discrim_unify: ((vars, tm), map fst res)"
278        (fn () => printVal ((vars, tm), map fst res))
279    in
280      res
281    end
282end;
283
284(* non-interactive mode
285*)
286end;
287