1(* non-interactive mode
2*)
3structure skiTools :> skiTools =
4struct
5open HolKernel Parse boolLib;
6
7(* interactive mode
8val () = loadPath := union ["..", "../finished"] (!loadPath);
9val () = app load
10  ["HurdUseful",
11   "ho_basicTools",
12   "unifyTools",
13   "skiTheory"];
14val () = show_assums := true;
15*)
16
17open hurdUtils ho_basicTools unifyTools skiTheory;
18
19infixr 0 oo ## ++ << || THENC ORELSEC THENR ORELSER;
20infix 1 >> |->;
21
22val op++ = op THEN;
23val op<< = op THENL;
24val op|| = op ORELSE;
25val op>> = op THEN1;
26val !! = REPEAT;
27
28(* non-interactive mode
29*)
30fun trace _ _ = ();
31fun printVal _ = ();
32
33(* ------------------------------------------------------------------------- *)
34(* Type/term substitutions                                                   *)
35(* ------------------------------------------------------------------------- *)
36
37val empty_raw_subst : raw_substitution = (([], empty_tmset), ([], []));
38
39fun raw_match' tm1 tm2 ((tmS, tmIds), (tyS, tyIds)) =
40  raw_match tyIds tmIds tm1 tm2 (tmS, tyS);
41
42fun type_raw_match ty1 ty2 (sub : raw_substitution) =
43  let
44    val tm1 = mk_const ("NIL", mk_type ("list", [ty1]))
45    val tm2 = mk_const ("NIL", mk_type ("list", [ty2]))
46  in
47    raw_match' tm1 tm2 sub
48  end;
49
50val finalize_subst : raw_substitution -> substitution = norm_subst;
51
52(* ------------------------------------------------------------------------- *)
53(* Conversion to combinators {S,K,I}.                                        *)
54(* ------------------------------------------------------------------------- *)
55
56fun SKI_CONV tm =
57  (case dest_term tm of
58     CONST _ => ALL_CONV
59   | VAR _ => ALL_CONV
60   | COMB _ => RAND_CONV SKI_CONV THENC RATOR_CONV SKI_CONV
61   | LAMB _
62     => ABS_CONV SKI_CONV THENC
63        (ho_REWR_CONV MK_K ORELSEC
64         ho_REWR_CONV MK_I ORELSEC
65         ho_REWR_CONV MK_S) THENC
66        SKI_CONV) tm;
67
68(*
69try SKI_CONV ``?x. !y. x y = y + 1``;
70SKI_CONV ``\x. f x o g``;
71SKI_CONV ``\x y. f x y``;
72SKI_CONV ``$? = \P. P ($@ P)``;
73SKI_CONV ``$==> = \a b. ~a \/ b``;
74SKI_CONV ``$! = \P. K T = P``;
75SKI_CONV ``!x y. P x y``;
76SKI_CONV ``!x y. P y x``;
77SKI_CONV ``(P = Q) = (!x. P x = Q x)``;
78*)
79
80(* ------------------------------------------------------------------------- *)
81(* A combinator {S,K,I} unify function.                                      *)
82(* ------------------------------------------------------------------------- *)
83
84local
85  fun occurs v tm = free_in v tm
86
87  fun solve _ sub [] = sub
88    | solve vars sub (current :: next) =
89    solve' vars sub (Df (pinst sub) current) next
90  and solve' vars sub (tm1, tm2) rest =
91    if is_tmvar vars tm1 then
92      if tm1 ~~ tm2 then solve vars sub rest
93      else if occurs tm1 tm2 then raise ERR "ski_unify" "occurs check"
94      else
95        (case total (tfind_redex tm1) (fst sub) of SOME {residue, ...}
96           => solve' vars sub (tm2, residue) rest
97         | NONE =>
98           let
99             val (ty1, ty2) = Df type_of (tm1, tm2)
100             val sub_extra = var_type_unify (snd vars) ty1 ty2
101             val (tm1', tm2') = Df (inst_ty sub_extra) (tm1, tm2)
102             val sub' = refine_subst sub ([tm1' |-> tm2'], sub_extra)
103             val vars' = (map (inst_ty sub_extra) ## I) vars
104           in
105             solve vars' sub' rest
106           end)
107    else if is_tmvar vars tm2 then solve' vars sub (tm2, tm1) rest
108    else
109      (case Df dest_term (tm1, tm2) of
110         (COMB (Rator, Rand), COMB (Rator', Rand'))
111         => solve' vars sub (Rator, Rator') ((Rand, Rand') :: rest)
112       | (VAR (Name, Ty), VAR (Name', Ty')) =>
113         let
114           val _ = assert (Name = Name') (ERR "ski_unify" "different vars")
115           val _ = assert (Ty = Ty')
116             (BUG "ski_unify" "same var, different types?")
117         in
118           solve vars sub rest
119         end
120       | (CONST {Name, Thy, Ty}, CONST {Name = Name', Thy = Thy', Ty = Ty'}) =>
121         let
122           val _ =
123             assert (Name = Name' andalso Thy = Thy')
124             (ERR "ski_unify" "different vars")
125           val sub_extra = var_type_unify (snd vars) Ty Ty'
126           val sub' = refine_subst sub ([], sub_extra)
127           val vars' = (map (inst_ty sub_extra) ## I) vars
128         in
129           solve vars' sub' rest
130         end
131       | _ => raise ERR "ski_unify" "terms fundamentally different")
132in
133  fun ski_unifyl vars work = solve vars empty_subst work;
134  fun ski_unify vars tm1 tm2 = solve' vars empty_subst (tm1, tm2) [];
135end;
136
137(* ------------------------------------------------------------------------- *)
138(* A combinator {S,K,I} term discriminator.                                  *)
139(* ------------------------------------------------------------------------- *)
140
141datatype ski_pattern
142  = SKI_COMB_BEGIN
143  | SKI_COMB_END
144  | SKI_CONST of term
145  | SKI_VAR of term;
146
147datatype 'a ski_discrim =
148  SKI_DISCRIM of int * (ski_pattern, vars * 'a) tree list;
149
150fun skipat_eq p1 p2 =
151  case (p1, p2) of
152      (SKI_COMB_BEGIN, SKI_COMB_BEGIN) => true
153    | (SKI_COMB_END, SKI_COMB_END) => true
154    | (SKI_CONST t1, SKI_CONST t2) => aconv t1 t2
155    | (SKI_VAR t1, SKI_VAR t2) => aconv t1 t2
156    | _ => false
157
158val empty_ski_discrim = SKI_DISCRIM (0, []);
159fun ski_discrim_size (SKI_DISCRIM (i, _)) = i;
160
161fun ski_pattern_term_build SKI_COMB_BEGIN stack =
162  (LEFT SKI_COMB_BEGIN :: stack)
163  | ski_pattern_term_build SKI_COMB_END
164  (RIGHT rand :: RIGHT rator :: LEFT SKI_COMB_BEGIN :: stack) =
165  (RIGHT (mk_comb (rator, rand)) :: stack)
166  | ski_pattern_term_build (SKI_CONST c) stack = RIGHT c :: stack
167  | ski_pattern_term_build (SKI_VAR v) stack = RIGHT v :: stack
168  | ski_pattern_term_build SKI_COMB_END _ =
169  raise BUG "ski_pattern_term_build" "badly formed list";
170
171local
172  fun final [RIGHT tm] = tm
173    | final _ = raise ERR "ski_patterns_to_term" "not a complete pattern list"
174
175  fun final_a (vars, a) stack = ((vars, final stack), a)
176in
177  fun ski_patterns_to_term pats =
178    final (trans ski_pattern_term_build [] pats)
179
180  fun dest_ski_discrim (SKI_DISCRIM (_, d)) =
181    flatten (map (tree_trans ski_pattern_term_build final_a []) d)
182end;
183
184fun ski_pattern_term_break ((tm_vars, _) : vars) (RIGHT tm :: rest) =
185  (case dest_term tm of COMB (Rator, Rand) =>
186     LEFT SKI_COMB_BEGIN :: RIGHT Rator :: RIGHT Rand ::
187     LEFT SKI_COMB_END :: rest
188   | LAMB _ => raise BUG "ski_pattern_term_break" "can't break a lambda"
189   | _ => LEFT (if tmem tm tm_vars then SKI_VAR tm else SKI_CONST tm) :: rest)
190  | ski_pattern_term_break _ [] =
191  raise BUG "ski_pattern_term_break" "nothing to break"
192  | ski_pattern_term_break _ (LEFT _ :: _) =
193  raise BUG "ski_pattern_term_break" "can't break a ski_pattern";
194
195fun vterm_to_ski_patterns (vars, tm) =
196  let
197    fun break res [] = rev res
198      | break res (LEFT p :: rest) = break (p :: res) rest
199      | break res ripe = break res (ski_pattern_term_break vars ripe)
200    val res = break [] [RIGHT tm]
201    val _ =
202      trace "vterm_to_ski_patterns: ((vars, tm), res)"
203      (fn () => printVal ((vars, tm), res))
204  in
205    res
206  end;
207
208local
209  fun add a [] leaves = LEAF a :: leaves
210    | add a (pat :: next) [] = [BRANCH (pat, add a next [])]
211    | add a (pats as pat :: next) ((b as BRANCH (pat', trees)) :: branches) =
212    if skipat_eq pat pat' then BRANCH (pat', add a next trees) :: branches
213    else b :: add a pats branches
214    | add _ (_::_) (LEAF _::_) =
215    raise BUG "discrim_add" "expected a branch, got a leaf"
216in
217  fun ski_discrim_add ((vars, tm), a) (SKI_DISCRIM (i, d)) =
218    SKI_DISCRIM (i + 1, add (vars, a) (vterm_to_ski_patterns (vars, tm)) d)
219end;
220
221fun ski_discrim_addl l d = trans ski_discrim_add d l;
222
223fun mk_ski_discrim l = ski_discrim_addl l empty_ski_discrim;
224
225fun ski_pattern_reduce (SKI_VAR _) _ _ =
226  raise BUG "ski_pattern_reduce" "can't reduce variables"
227  | ski_pattern_reduce SKI_COMB_BEGIN SKI_COMB_BEGIN sub = sub
228  | ski_pattern_reduce SKI_COMB_END SKI_COMB_END sub = sub
229  | ski_pattern_reduce (SKI_CONST c) (SKI_CONST c') sub =
230  raw_match' c c' sub
231  | ski_pattern_reduce _ _ _ =
232  raise ERR "ski_pattern_reduce" "patterns fundamentally different";
233
234local
235  fun advance (SKI_VAR v) (RIGHT tm :: rest, sub) = (rest, raw_match' v tm sub)
236    | advance pat (state as RIGHT _ :: _, sub) =
237    advance pat (ski_pattern_term_break empty_vars state, sub)
238    | advance pat (LEFT pat' :: rest, sub) =
239    (rest, ski_pattern_reduce pat pat' sub)
240    | advance _ ([], _) =
241    raise BUG "ski_discrim_match" "no patterns left in list"
242
243  fun finally (_, a) ([], sub) = SOME (finalize_subst sub, a)
244    | finally _ _ = raise BUG "ski_discrim_match" "patterns left at end"
245
246  fun tree_match tm =
247    tree_partial_trans (total o advance) finally ([RIGHT tm], empty_raw_subst)
248in
249  fun ski_discrim_match (SKI_DISCRIM (_, d)) tm =
250    (flatten o map (tree_match tm)) d
251  handle HOL_ERR _ => raise BUG "ski_discrim_match" "should never fail";
252end;
253
254local
255  fun advance _ pat (SOME (v, lstate), rstate, work) =
256    (case ski_pattern_term_build pat lstate of
257         [RIGHT tm] => (NONE, rstate, (v, tm) :: work)
258       | lstate' => (SOME (v, lstate'), rstate, work))
259    | advance _ (SKI_VAR v) (NONE, RIGHT tm :: rest, work) =
260        (NONE, rest, (v, tm) :: work)
261    | advance vars pat (NONE, rstate as RIGHT _ :: _, work) =
262        advance vars pat (NONE, ski_pattern_term_break vars rstate, work)
263    | advance vars pat (NONE, LEFT (SKI_VAR v) :: rest, work) =
264        advance vars pat (SOME (v, []), rest, work)
265    | advance _ (SKI_CONST c) (NONE, LEFT (SKI_CONST c') :: rest, work) =
266        (NONE, rest, (c, c') :: work)
267    | advance _ pat (NONE, LEFT pat' :: rest, work) =
268        if skipat_eq pat pat' then (NONE, rest, work)
269        else raise ERR "ski_discrim_unify" "terms fundamentally different"
270    | advance _ _ (NONE, [], _) =
271        raise BUG "ski_discrim_match" "no patterns left in list";
272
273  fun finally (vars, a) (NONE, [], work) = SOME ((vars, work), a)
274    | finally _ _ = raise BUG "ski_discrim_unify" "patterns left at end";
275
276  fun tree_search vars tm =
277    tree_partial_trans (total o advance vars) finally (NONE, [RIGHT tm], []);
278in
279  fun ski_discrim_unify (sd as SKI_DISCRIM (_, d)) (vars, tm) =
280    let
281      val shortlist = (flatten o map (tree_search vars tm)) d
282      fun select ((vars', work), a) =
283        (ski_unifyl (vars_union vars vars') work, a)
284      val res = partial_map (total select) shortlist
285      val _ = trace
286        "ski_discrim_unify: ((vars, tm), map fst res)"
287        (fn () => printVal ((vars, tm), map fst res))
288    in
289      res
290    end
291end;
292
293(* non-interactive mode
294*)
295end;
296