1(* non-interactive mode
2*)
3structure ho_discrimTools :> ho_discrimTools =
4struct
5open HolKernel Parse boolLib;
6
7(* interactive mode
8val () = loadPath := union ["..", "../finished"] (!loadPath);
9val () = app load
10  ["bossLib",
11   "realLib",
12   "rich_listTheory",
13   "arithmeticTheory",
14   "numTheory",
15   "pred_setTheory",
16   "pairTheory",
17   "combinTheory",
18   "listTheory",
19   "dividesTheory",
20   "primeTheory",
21   "gcdTheory",
22   "probLib",
23   "HurdUseful"];
24val () = show_assums := true;
25*)
26
27open HurdUseful ho_basicTools;
28
29infixr 0 oo ## ++ << || THENC ORELSEC THENR ORELSER;
30infix 1 >>;
31
32val op++ = op THEN;
33val op<< = op THENL;
34val op|| = op ORELSE;
35val op>> = op THEN1;
36val !! = REPEAT;
37
38(* ------------------------------------------------------------------------- *)
39(* Type/term substitutions                                                   *)
40(* ------------------------------------------------------------------------- *)
41
42val empty_raw_subst : raw_substitution = (([], empty_tmset), ([], []));
43
44fun raw_match' tm1 tm2 ((tmS, tmIds), (tyS, tyIds)) =
45  raw_match tyIds tmIds tm1 tm2 (tmS, tyS);
46
47fun type_raw_match ty1 ty2 (sub : raw_substitution) =
48  let
49    val tm1 = mk_const ("NIL", mk_type ("list", [ty1]))
50    val tm2 = mk_const ("NIL", mk_type ("list", [ty2]))
51  in
52    raw_match' tm1 tm2 sub
53  end;
54
55val finalize_subst : raw_substitution -> substitution = norm_subst;
56
57(* ------------------------------------------------------------------------- *)
58(* A term discriminator.                                                     *)
59(* ------------------------------------------------------------------------- *)
60
61datatype pattern
62  = COMB_BEGIN
63  | COMB_END
64  | ABS_BEGIN of hol_type
65  | ABS_END
66  | CONSTANT of term
67  | BVAR of int
68  | FVAR of term * int list;
69
70datatype 'a discrim = DISCRIM of int * (pattern, 'a) tree list;
71
72val empty_discrim = DISCRIM (0, []);
73fun discrim_size (DISCRIM (i, _)) = i;
74
75local
76  val bv_prefix = "bv";
77
78  fun advance COMB_BEGIN state = state
79    | advance COMB_END (bvs, s1 :: s2 :: srest) =
80    (bvs, mk_comb (s2, s1) :: srest)
81    | advance (ABS_BEGIN ty) (bvs, stack) =
82    let
83      val bv = mk_var (mk_string_fn bv_prefix [int_to_string (length bvs)], ty)
84    in
85      (bv :: bvs, stack)
86    end
87    | advance ABS_END (bv :: bvs, s1 :: stack) = (bvs, mk_abs (bv, s1) :: stack)
88    | advance (CONSTANT tm) (bvs, stack) = (bvs, tm :: stack)
89    | advance (BVAR n) (bvs, stack) = (bvs, mk_bv bvs n :: stack)
90    | advance (FVAR fvc) (bvs, stack) = (bvs, mk_ho_pat bvs fvc :: stack)
91    | advance _ _ = raise BUG "pats_to_term" "not a valid pattern list"
92
93  fun final (_, [tm]) = tm
94    | final _ = raise BUG "pats_to_term" "not a complete pattern list"
95
96  fun final_a a state = (final state, a)
97in
98  fun pats_to_term_bvs (bvs, pats) = final (trans advance (bvs, []) pats)
99  fun pats_to_term pats = pats_to_term_bvs ([], pats)
100  fun dest_discrim (DISCRIM (_, d)) =
101    flatten (map (tree_trans advance final_a ([], [])) d)
102end;
103
104fun tm_pat_break (bvs, RIGHT tm :: rest) =
105  if is_bv bvs tm then
106    (bvs, LEFT (BVAR (dest_bv bvs tm)) :: rest)
107  else if is_ho_pat bvs tm then
108    (bvs, LEFT (FVAR (dest_ho_pat bvs tm)) :: rest)
109  else
110    (case dest_term tm of CONST _
111       => (bvs, LEFT (CONSTANT tm) :: rest)
112     | COMB (Rator, Rand)
113       => (bvs,
114           LEFT COMB_BEGIN :: RIGHT Rator :: RIGHT Rand ::
115           LEFT COMB_END :: rest)
116     | LAMB (Bvar, Body)
117       => (Bvar :: bvs,
118           LEFT (ABS_BEGIN (type_of Bvar)) :: RIGHT Body ::
119           LEFT ABS_END :: rest)
120     | VAR _ => raise BUG "tm_pat_break" "shouldn't be a var")
121  | tm_pat_break (_, []) =
122  raise BUG "tm_pat_break" "nothing to break"
123  | tm_pat_break (_, LEFT _::_) =
124  raise BUG "tm_pat_break" "can't break apart a pattern!";
125
126fun tm_pat_correct ABS_END ((_ : term) :: bvs) = bvs
127  | tm_pat_correct ABS_END [] =
128  raise BUG "tm_pat_correct" "no bvs left at an ABS_END"
129  | tm_pat_correct _ bvs = bvs;
130
131local
132  fun tm_pats res (_, []) = res
133    | tm_pats res (bvs, LEFT p :: rest) =
134    tm_pats (p :: res) (tm_pat_correct p bvs, rest)
135    | tm_pats res unbroken =
136    tm_pats res (tm_pat_break unbroken)
137in
138  fun term_to_pats_bvs (bvs, tm) = rev (tm_pats [] (bvs, [RIGHT tm]))
139  fun term_to_pats tm = term_to_pats_bvs ([], tm)
140end;
141
142local
143  fun add a ts [] = LEAF a :: ts
144    | add a [] (pat :: next) = [BRANCH (pat, add a [] next)]
145    | add a ((b as BRANCH (pat', ts')) :: rest) (ps as pat :: next) =
146    if pat = pat' then BRANCH (pat', add a ts' next) :: rest
147    else b :: add a rest ps
148    | add _ (LEAF _::_) (_::_) =
149    raise BUG "discrim_add" "expected a branch, got a leaf"
150in
151  fun discrim_add (tm, a) (DISCRIM (i, d)) =
152    DISCRIM (i + 1, add a d (term_to_pats tm));
153end;
154
155fun mk_discrim l = trans discrim_add empty_discrim l;
156
157fun pat_ho_match _ (FVAR _) _ =
158  raise BUG "pat_ho_match" "can't match variables"
159  | pat_ho_match sub_thks COMB_BEGIN (_, COMB_BEGIN) = sub_thks
160  | pat_ho_match (sub, thks) (ABS_BEGIN ty) (_, ABS_BEGIN ty') =
161  (type_raw_match ty ty' sub, thks)
162  | pat_ho_match (sub, th1 :: th2 :: thks) COMB_END (_, COMB_END) =
163  (sub, (fn () => MK_COMB (th2 (), th1 ())) :: thks)
164  | pat_ho_match (sub, th :: thks) ABS_END (bv :: _, ABS_END) =
165  (sub, (fn () => MK_ABS (GEN bv (th ()))) :: thks)
166  | pat_ho_match (sub, thks) (BVAR n) (bvs, BVAR n') =
167  if n = n' then (sub, (fn () => REFL (mk_bv bvs n)) :: thks)
168  else raise ERR "pat_ho_match" "different bound vars"
169  | pat_ho_match (sub, thks) (CONSTANT c) (_, CONSTANT c') =
170  (raw_match' c c' sub, (fn () => REFL c') :: thks)
171  | pat_ho_match _ _ _ =
172  raise ERR "pat_ho_match" "pats fundamentally different";
173
174local
175  fun advance (FVAR (fv, fbs)) ((bvs, RIGHT tm::rest), (sub, thks)) =
176    let
177      val (sub', thk') = ho_raw_match (fv, fbs) (bvs, tm) sub
178    in
179      ((bvs, rest), (sub', thk' :: thks))
180    end
181    | advance pat (state as (_, RIGHT _::_), ho_sub) =
182    advance pat (tm_pat_break state, ho_sub)
183    | advance pat ((bvs, LEFT pat' :: rest), ho_sub) =
184    let
185      val ho_sub' = pat_ho_match ho_sub pat (bvs, pat')
186    in
187      ((tm_pat_correct pat' bvs, rest), ho_sub')
188    end
189    | advance _ ((_, []), _) =
190    raise BUG "discrim_ho_match" "no pats left in list"
191
192  fun finally a ((_, []), (sub, [thk])) =
193    SOME ((finalize_subst sub, fn () => SYM (thk ())), a)
194    | finally _ _ = raise BUG "discrim_ho_match" "weird state at end"
195
196  fun tree_match tm =
197    tree_partial_trans
198    (fn x => total (advance x))
199    finally
200    (([], [RIGHT tm]), (empty_raw_subst, []))
201in
202  fun discrim_ho_match (DISCRIM (_, d)) tm =
203    (flatten o map (tree_match tm)) d
204  handle (h as HOL_ERR _) => raise err_BUG "discrim_ho_match" h;
205end;
206
207fun pat_fo_match _ (FVAR _) _ =
208  raise BUG "pat_fo_match" "can't match variables"
209  | pat_fo_match sub COMB_BEGIN COMB_BEGIN = sub
210  | pat_fo_match sub (ABS_BEGIN ty) (ABS_BEGIN ty') = type_raw_match ty ty' sub
211  | pat_fo_match sub COMB_END COMB_END = sub
212  | pat_fo_match sub ABS_END ABS_END = sub
213  | pat_fo_match sub (BVAR n) (BVAR n') =
214  if n = n' then sub else raise ERR "pat_fo_match" "different bound vars"
215  | pat_fo_match sub (CONSTANT c) (CONSTANT c') = raw_match' c c' sub
216  | pat_fo_match _ _ _ =
217  raise ERR "pat_fo_match" "pats fundamentally different";
218
219local
220  fun advance (FVAR (fv, fbs)) ((bvs, RIGHT tm :: rest), sub) =
221    let
222      val sub' = fo_raw_match (fv, fbs) (bvs, tm) sub
223    in
224      ((bvs, rest), sub')
225    end
226    | advance pat (state as (_, RIGHT _::_), sub) =
227    advance pat (tm_pat_break state, sub)
228    | advance pat ((bvs, LEFT pat' :: rest), sub) =
229    let
230      val sub' = pat_fo_match sub pat pat'
231    in
232      ((tm_pat_correct pat' bvs, rest), sub')
233    end
234    | advance _ ((_, []), _) =
235    raise BUG "discrim_fo_match" "no pats left in list"
236
237  fun finally a ((_, []), sub) =
238    SOME (finalize_subst sub, a)
239    | finally _ _ = raise BUG "discrim_fo_match" "weird state at end"
240
241  fun tree_match tm =
242    tree_partial_trans (fn x => total (advance x))
243    finally (([], [RIGHT tm]), empty_raw_subst)
244in
245  fun discrim_fo_match (DISCRIM (_, d)) tm =
246    (flatten o map (tree_match tm)) d
247  handle HOL_ERR _ => raise BUG "discrim_fo_match" "should never fail";
248end;
249
250(* ------------------------------------------------------------------------- *)
251(* Variable Term Discriminators                                              *)
252(* Terms come with a list of variables that may be instantiated, the rest    *)
253(* should be treated as constants.                                           *)
254(* ------------------------------------------------------------------------- *)
255
256type 'a vdiscrim = (vars * 'a) discrim;
257
258val empty_vdiscrim : 'a vdiscrim = empty_discrim;
259
260val vdiscrim_size : 'a vdiscrim -> int = discrim_size;
261
262local
263  fun dest (tm : term, (vars : vars, a)) = ((vars, tm), a)
264in
265  fun dest_vdiscrim d = (map dest o dest_discrim) d;
266end;
267
268local
269  fun prepare (vars : vars) (tm, a) = (tm, (vars, a))
270in
271  fun vdiscrim_add ((vars, tm), a) = discrim_add (prepare vars (tm, a));
272end;
273
274fun mk_vdiscrim l = trans vdiscrim_add empty_discrim l;
275
276local
277  fun vars_check (ho_sub as (sub, _) : ho_substitution, (vars, a)) =
278    if subst_vars_in_set sub vars then SOME (ho_sub, a) else NONE
279  fun check_ho_match ml = partial_map vars_check ml
280in
281  fun vdiscrim_ho_match d tm = check_ho_match (discrim_ho_match d tm);
282end;
283
284local
285  fun vars_check (sub : substitution, (vars, a)) =
286    if subst_vars_in_set sub vars then SOME (sub, a) else NONE
287  fun check_fo_match ml = partial_map vars_check ml
288in
289  fun vdiscrim_fo_match d tm = check_fo_match (discrim_fo_match d tm);
290end;
291
292(* ------------------------------------------------------------------------- *)
293(* Ordered (Variable) Term Discriminators                                    *)
294(* Entries are returned in the order they arrived (latest first).            *)
295(* ------------------------------------------------------------------------- *)
296
297type 'a ovdiscrim = (int * 'a) vdiscrim;
298
299val empty_ovdiscrim : 'a ovdiscrim = empty_vdiscrim;
300
301val ovdiscrim_size : 'a ovdiscrim -> int = vdiscrim_size;
302
303local
304  fun transfer (a, (n : int, b)) = (n, (a, b));
305  fun order (m, _) (n, _) = m > n;
306  fun dest dl = (map snd o sort order o map transfer) dl;
307in
308  fun dest_ovdiscrim d = (dest o dest_vdiscrim) d;
309
310  fun ovdiscrim_ho_match ml d = dest (vdiscrim_ho_match ml d)
311  handle (h as HOL_ERR _) => raise err_BUG "ovdiscrim_ho_match" h;
312
313  fun ovdiscrim_fo_match ml d = dest (vdiscrim_fo_match ml d)
314  handle (h as HOL_ERR _) => raise err_BUG "ovdiscrim_fo_match" h;
315end;
316
317fun ovdiscrim_add (tm, a) d = vdiscrim_add (tm, (vdiscrim_size d, a)) d;
318
319fun mk_ovdiscrim l = trans ovdiscrim_add empty_discrim l;
320
321(* non-interactive mode
322*)
323end;
324
325
326
327
328
329
330