1(* non-interactive mode
2*)
3structure ho_discrimTools :> ho_discrimTools =
4struct
5open HolKernel Parse boolLib;
6
7(* interactive mode
8val () = loadPath := union ["..", "../finished"] (!loadPath);
9val () = app load
10  ["bossLib",
11   "realLib",
12   "rich_listTheory",
13   "arithmeticTheory",
14   "numTheory",
15   "pred_setTheory",
16   "pairTheory",
17   "combinTheory",
18   "listTheory",
19   "dividesTheory",
20   "primeTheory",
21   "gcdTheory",
22   "probLib",
23   "HurdUseful"];
24val () = show_assums := true;
25*)
26
27open hurdUtils ho_basicTools;
28
29infixr 0 oo ## ++ << || THENC ORELSEC THENR ORELSER;
30infix 1 >>;
31
32val op++ = op THEN;
33val op<< = op THENL;
34val op|| = op ORELSE;
35val op>> = op THEN1;
36val !! = REPEAT;
37
38(* ------------------------------------------------------------------------- *)
39(* Type/term substitutions                                                   *)
40(* ------------------------------------------------------------------------- *)
41
42val empty_raw_subst : raw_substitution = (([], empty_tmset), ([], []));
43
44fun raw_match' tm1 tm2 ((tmS, tmIds), (tyS, tyIds)) =
45  raw_match tyIds tmIds tm1 tm2 (tmS, tyS);
46
47fun type_raw_match ty1 ty2 (sub : raw_substitution) =
48  let
49    val tm1 = mk_const ("NIL", mk_type ("list", [ty1]))
50    val tm2 = mk_const ("NIL", mk_type ("list", [ty2]))
51  in
52    raw_match' tm1 tm2 sub
53  end;
54
55val finalize_subst : raw_substitution -> substitution = norm_subst;
56
57(* ------------------------------------------------------------------------- *)
58(* A term discriminator.                                                     *)
59(* ------------------------------------------------------------------------- *)
60
61datatype pattern
62  = COMB_BEGIN
63  | COMB_END
64  | ABS_BEGIN of hol_type
65  | ABS_END
66  | CONSTANT of term
67  | BVAR of int
68  | FVAR of term * int list;
69
70fun pat_eq p1 p2 =
71  case (p1,p2) of
72      (COMB_BEGIN, COMB_BEGIN) => true
73    | (COMB_END, COMB_END) => true
74    | (ABS_BEGIN ty1, ABS_BEGIN ty2) => Type.compare(ty1,ty2) = EQUAL
75    | (ABS_END, ABS_END) => true
76    | (CONSTANT t1, CONSTANT t2) => aconv t1 t2
77    | (BVAR i1, BVAR i2) => i1 = i2
78    | (FVAR p1, FVAR p2) => pair_eq aconv equal p1 p2
79    | _ => false
80
81datatype 'a discrim = DISCRIM of int * (pattern, 'a) tree list;
82
83val empty_discrim = DISCRIM (0, []);
84fun discrim_size (DISCRIM (i, _)) = i;
85
86local
87  val bv_prefix = "bv";
88
89  fun advance COMB_BEGIN state = state
90    | advance COMB_END (bvs, s1 :: s2 :: srest) =
91    (bvs, mk_comb (s2, s1) :: srest)
92    | advance (ABS_BEGIN ty) (bvs, stack) =
93    let
94      val bv = mk_var (mk_string_fn bv_prefix [int_to_string (length bvs)], ty)
95    in
96      (bv :: bvs, stack)
97    end
98    | advance ABS_END (bv :: bvs, s1 :: stack) = (bvs, mk_abs (bv, s1) :: stack)
99    | advance (CONSTANT tm) (bvs, stack) = (bvs, tm :: stack)
100    | advance (BVAR n) (bvs, stack) = (bvs, mk_bv bvs n :: stack)
101    | advance (FVAR fvc) (bvs, stack) = (bvs, mk_ho_pat bvs fvc :: stack)
102    | advance _ _ = raise BUG "pats_to_term" "not a valid pattern list"
103
104  fun final (_, [tm]) = tm
105    | final _ = raise BUG "pats_to_term" "not a complete pattern list"
106
107  fun final_a a state = (final state, a)
108in
109  fun pats_to_term_bvs (bvs, pats) = final (trans advance (bvs, []) pats)
110  fun pats_to_term pats = pats_to_term_bvs ([], pats)
111  fun dest_discrim (DISCRIM (_, d)) =
112    flatten (map (tree_trans advance final_a ([], [])) d)
113end;
114
115fun tm_pat_break (bvs, RIGHT tm :: rest) =
116  if is_bv bvs tm then
117    (bvs, LEFT (BVAR (dest_bv bvs tm)) :: rest)
118  else if is_ho_pat bvs tm then
119    (bvs, LEFT (FVAR (dest_ho_pat bvs tm)) :: rest)
120  else
121    (case dest_term tm of CONST _
122       => (bvs, LEFT (CONSTANT tm) :: rest)
123     | COMB (Rator, Rand)
124       => (bvs,
125           LEFT COMB_BEGIN :: RIGHT Rator :: RIGHT Rand ::
126           LEFT COMB_END :: rest)
127     | LAMB (Bvar, Body)
128       => (Bvar :: bvs,
129           LEFT (ABS_BEGIN (type_of Bvar)) :: RIGHT Body ::
130           LEFT ABS_END :: rest)
131     | VAR _ => raise BUG "tm_pat_break" "shouldn't be a var")
132  | tm_pat_break (_, []) =
133  raise BUG "tm_pat_break" "nothing to break"
134  | tm_pat_break (_, LEFT _::_) =
135  raise BUG "tm_pat_break" "can't break apart a pattern!";
136
137fun tm_pat_correct ABS_END ((_ : term) :: bvs) = bvs
138  | tm_pat_correct ABS_END [] =
139  raise BUG "tm_pat_correct" "no bvs left at an ABS_END"
140  | tm_pat_correct _ bvs = bvs;
141
142local
143  fun tm_pats res (_, []) = res
144    | tm_pats res (bvs, LEFT p :: rest) =
145    tm_pats (p :: res) (tm_pat_correct p bvs, rest)
146    | tm_pats res unbroken =
147    tm_pats res (tm_pat_break unbroken)
148in
149  fun term_to_pats_bvs (bvs, tm) = rev (tm_pats [] (bvs, [RIGHT tm]))
150  fun term_to_pats tm = term_to_pats_bvs ([], tm)
151end;
152
153local
154  fun add a ts [] = LEAF a :: ts
155    | add a [] (pat :: next) = [BRANCH (pat, add a [] next)]
156    | add a ((b as BRANCH (pat', ts')) :: rest) (ps as pat :: next) =
157        if pat_eq pat pat' then BRANCH (pat', add a ts' next) :: rest
158        else b :: add a rest ps
159    | add _ (LEAF _::_) (_::_) =
160    raise BUG "discrim_add" "expected a branch, got a leaf"
161in
162  fun discrim_add (tm, a) (DISCRIM (i, d)) =
163    DISCRIM (i + 1, add a d (term_to_pats tm));
164end;
165
166fun mk_discrim l = trans discrim_add empty_discrim l;
167
168fun pat_ho_match _ (FVAR _) _ =
169  raise BUG "pat_ho_match" "can't match variables"
170  | pat_ho_match sub_thks COMB_BEGIN (_, COMB_BEGIN) = sub_thks
171  | pat_ho_match (sub, thks) (ABS_BEGIN ty) (_, ABS_BEGIN ty') =
172  (type_raw_match ty ty' sub, thks)
173  | pat_ho_match (sub, th1 :: th2 :: thks) COMB_END (_, COMB_END) =
174  (sub, (fn () => MK_COMB (th2 (), th1 ())) :: thks)
175  | pat_ho_match (sub, th :: thks) ABS_END (bv :: _, ABS_END) =
176  (sub, (fn () => MK_ABS (GEN bv (th ()))) :: thks)
177  | pat_ho_match (sub, thks) (BVAR n) (bvs, BVAR n') =
178  if n = n' then (sub, (fn () => REFL (mk_bv bvs n)) :: thks)
179  else raise ERR "pat_ho_match" "different bound vars"
180  | pat_ho_match (sub, thks) (CONSTANT c) (_, CONSTANT c') =
181  (raw_match' c c' sub, (fn () => REFL c') :: thks)
182  | pat_ho_match _ _ _ =
183  raise ERR "pat_ho_match" "pats fundamentally different";
184
185local
186  fun advance (FVAR (fv, fbs)) ((bvs, RIGHT tm::rest), (sub, thks)) =
187    let
188      val (sub', thk') = ho_raw_match (fv, fbs) (bvs, tm) sub
189    in
190      ((bvs, rest), (sub', thk' :: thks))
191    end
192    | advance pat (state as (_, RIGHT _::_), ho_sub) =
193    advance pat (tm_pat_break state, ho_sub)
194    | advance pat ((bvs, LEFT pat' :: rest), ho_sub) =
195    let
196      val ho_sub' = pat_ho_match ho_sub pat (bvs, pat')
197    in
198      ((tm_pat_correct pat' bvs, rest), ho_sub')
199    end
200    | advance _ ((_, []), _) =
201    raise BUG "discrim_ho_match" "no pats left in list"
202
203  fun finally a ((_, []), (sub, [thk])) =
204    SOME ((finalize_subst sub, fn () => SYM (thk ())), a)
205    | finally _ _ = raise BUG "discrim_ho_match" "weird state at end"
206
207  fun tree_match tm =
208    tree_partial_trans
209    (fn x => total (advance x))
210    finally
211    (([], [RIGHT tm]), (empty_raw_subst, []))
212in
213  fun discrim_ho_match (DISCRIM (_, d)) tm =
214    (flatten o map (tree_match tm)) d
215  handle (h as HOL_ERR _) => raise err_BUG "discrim_ho_match" h;
216end;
217
218fun pat_fo_match _ (FVAR _) _ =
219  raise BUG "pat_fo_match" "can't match variables"
220  | pat_fo_match sub COMB_BEGIN COMB_BEGIN = sub
221  | pat_fo_match sub (ABS_BEGIN ty) (ABS_BEGIN ty') = type_raw_match ty ty' sub
222  | pat_fo_match sub COMB_END COMB_END = sub
223  | pat_fo_match sub ABS_END ABS_END = sub
224  | pat_fo_match sub (BVAR n) (BVAR n') =
225  if n = n' then sub else raise ERR "pat_fo_match" "different bound vars"
226  | pat_fo_match sub (CONSTANT c) (CONSTANT c') = raw_match' c c' sub
227  | pat_fo_match _ _ _ =
228  raise ERR "pat_fo_match" "pats fundamentally different";
229
230local
231  fun advance (FVAR (fv, fbs)) ((bvs, RIGHT tm :: rest), sub) =
232    let
233      val sub' = fo_raw_match (fv, fbs) (bvs, tm) sub
234    in
235      ((bvs, rest), sub')
236    end
237    | advance pat (state as (_, RIGHT _::_), sub) =
238    advance pat (tm_pat_break state, sub)
239    | advance pat ((bvs, LEFT pat' :: rest), sub) =
240    let
241      val sub' = pat_fo_match sub pat pat'
242    in
243      ((tm_pat_correct pat' bvs, rest), sub')
244    end
245    | advance _ ((_, []), _) =
246    raise BUG "discrim_fo_match" "no pats left in list"
247
248  fun finally a ((_, []), sub) =
249    SOME (finalize_subst sub, a)
250    | finally _ _ = raise BUG "discrim_fo_match" "weird state at end"
251
252  fun tree_match tm =
253    tree_partial_trans (fn x => total (advance x))
254    finally (([], [RIGHT tm]), empty_raw_subst)
255in
256  fun discrim_fo_match (DISCRIM (_, d)) tm =
257    (flatten o map (tree_match tm)) d
258  handle HOL_ERR _ => raise BUG "discrim_fo_match" "should never fail";
259end;
260
261(* ------------------------------------------------------------------------- *)
262(* Variable Term Discriminators                                              *)
263(* Terms come with a list of variables that may be instantiated, the rest    *)
264(* should be treated as constants.                                           *)
265(* ------------------------------------------------------------------------- *)
266
267type 'a vdiscrim = (vars * 'a) discrim;
268
269val empty_vdiscrim : 'a vdiscrim = empty_discrim;
270
271val vdiscrim_size : 'a vdiscrim -> int = discrim_size;
272
273local
274  fun dest (tm : term, (vars : vars, a)) = ((vars, tm), a)
275in
276  fun dest_vdiscrim d = (map dest o dest_discrim) d;
277end;
278
279local
280  fun prepare (vars : vars) (tm, a) = (tm, (vars, a))
281in
282  fun vdiscrim_add ((vars, tm), a) = discrim_add (prepare vars (tm, a));
283end;
284
285fun mk_vdiscrim l = trans vdiscrim_add empty_discrim l;
286
287local
288  fun vars_check (ho_sub as (sub, _) : ho_substitution, (vars, a)) =
289    if subst_vars_in_set sub vars then SOME (ho_sub, a) else NONE
290  fun check_ho_match ml = partial_map vars_check ml
291in
292  fun vdiscrim_ho_match d tm = check_ho_match (discrim_ho_match d tm);
293end;
294
295local
296  fun vars_check (sub : substitution, (vars, a)) =
297    if subst_vars_in_set sub vars then SOME (sub, a) else NONE
298  fun check_fo_match ml = partial_map vars_check ml
299in
300  fun vdiscrim_fo_match d tm = check_fo_match (discrim_fo_match d tm);
301end;
302
303(* ------------------------------------------------------------------------- *)
304(* Ordered (Variable) Term Discriminators                                    *)
305(* Entries are returned in the order they arrived (latest first).            *)
306(* ------------------------------------------------------------------------- *)
307
308type 'a ovdiscrim = (int * 'a) vdiscrim;
309
310val empty_ovdiscrim : 'a ovdiscrim = empty_vdiscrim;
311
312val ovdiscrim_size : 'a ovdiscrim -> int = vdiscrim_size;
313
314local
315  fun transfer (a, (n : int, b)) = (n, (a, b));
316  fun order (m, _) (n, _) = m > n;
317  fun dest dl = (map snd o sort order o map transfer) dl;
318in
319  fun dest_ovdiscrim d = (dest o dest_vdiscrim) d;
320
321  fun ovdiscrim_ho_match ml d = dest (vdiscrim_ho_match ml d)
322  handle (h as HOL_ERR _) => raise err_BUG "ovdiscrim_ho_match" h;
323
324  fun ovdiscrim_fo_match ml d = dest (vdiscrim_fo_match ml d)
325  handle (h as HOL_ERR _) => raise err_BUG "ovdiscrim_fo_match" h;
326end;
327
328fun ovdiscrim_add (tm, a) d = vdiscrim_add (tm, (vdiscrim_size d, a)) d;
329
330fun mk_ovdiscrim l = trans ovdiscrim_add empty_discrim l;
331
332(* non-interactive mode
333*)
334end;
335