1structure binderLib :> binderLib =
2struct
3
4open HolKernel Parse boolLib
5open BasicProvers simpLib
6
7open nomsetTheory
8
9open NEWLib
10structure Parse = struct
11  open Parse
12  val (Type,Term) = parse_from_grammars nomsetTheory.nomset_grammars
13end
14open Parse
15
16val tmToString =
17    term_to_string
18        |> with_flag (Parse.current_backend, PPBackEnd.raw_terminal)
19        |> trace ("Unicode", 0)
20
21
22fun ERR f msg = raise (HOL_ERR {origin_function = f,
23                                origin_structure = "binderLib",
24                                message = msg})
25
26datatype nominaltype_info =
27         NTI of { recursion_thm : thm option,
28                  (* recursion theorems are stored in SPEC_ALL form, with
29                     preconditions as one big set of conjuncts (rather than
30                     iterated implications) *)
31                  nullfv : term,
32                  pm_constant : term,
33                  pm_rewrites : thm list,
34                  fv_rewrites : thm list,
35                  binders : (term * int * thm) list }
36
37fun nti_null_fv (NTI{nullfv, ...}) = nullfv
38fun nti_perm_t (NTI{pm_constant,...}) = pm_constant
39fun nti_recursion (NTI{recursion_thm,...}) = recursion_thm
40
41
42
43
44(* ----------------------------------------------------------------------
45    prove_recursive_term_function_exists tm
46
47    for the 'a nc type tm would be roughly of the form
48
49       (!s. f (VAR s) = var_rhs s) /\
50       (!k. f (CON k) = con_rhs k) /\
51       (!t1 t2. f (t1 @@ t2) = app_rhs t1 t2 (f t1) (f t2)) /\
52       (!u t. f (LAM u t) = lam_rhs t (f t))
53
54    though not necessarily with the conjuncts in that particular
55    order, or with all of them present.  The universal quantifications
56    can be omitted for any or all of the conjuncts. This code attempts
57    to prove that such a function actually exists.  The RHS of any
58    LAMBDA clause comes in for particular attention because this might
59    attempt a recursion that is not be justified.  This function
60    attempts to use the simplifier and the stateful simpset (srw_ss()) to
61    prove that the side-conditions on the recursion theorem for the given
62    type actually hold.
63
64   ---------------------------------------------------------------------- *)
65
66val type_db =
67    ref (Binarymap.mkDict KernelSig.name_compare :
68         (KernelSig.kernelname,nominaltype_info) Binarymap.dict)
69
70val string_ty = stringSyntax.string_ty
71val stringset_ty = pred_setSyntax.mk_set_type string_ty
72
73fun mk_icomb(f,x) = let
74  val (dom,rng) = dom_rng (type_of f)
75  val i = match_type dom (type_of x)
76in
77  mk_comb(Term.inst i f, x)
78end
79
80fun findmap_terms f t = let
81  fun recurse acc tlist =
82      case tlist of
83        [] => acc
84      | t::ts => let
85          val acc' = case f t of
86                       NONE => acc
87                     | SOME result => result::acc
88        in
89          case dest_term t of
90            COMB(f,x) => recurse acc' (f::x::ts)
91          | LAMB(_,b) => recurse acc' (b::ts)
92          | _ => recurse acc' ts
93        end
94in
95  recurse [] [t]
96end
97
98fun isdbty ty = let
99  val {Thy,Tyop,...} = dest_thy_type ty
100in
101  isSome (Binarymap.peek(!type_db, {Thy=Thy,Name=Tyop}))
102end handle HOL_ERR _ => false
103
104fun tylookup ty = let
105  val {Thy,Tyop,...} = dest_thy_type ty
106in
107  Binarymap.peek(!type_db, {Thy=Thy,Name=Tyop})
108end handle HOL_ERR _ => NONE
109
110fun get_pm_rewrites ty =
111    case tylookup ty of
112      NONE => []
113    | SOME (NTI i) => #pm_rewrites i
114
115fun findfirst f t = let
116  fun recurse tlist =
117      case tlist of
118        [] => NONE
119      | t::ts => let
120        in
121          case f t of
122            NONE => let
123            in
124              case dest_term t of
125                COMB(f,x) => recurse (f::x::ts)
126              | LAMB(_,b) => recurse (b::ts)
127              | _ => recurse ts
128            end
129          | x => x
130        end
131in
132  recurse [t]
133end
134
135fun find_top_terms P t = let
136  fun recurse acc tlist =
137      case tlist of
138        [] => acc
139      | t::ts => let
140        in
141          if P t then recurse (t::acc) ts
142          else
143            case dest_term t of
144              COMB(f,x) => recurse acc (f::x::ts)
145            | LAMB(_,b) => recurse acc (b::ts)
146            | _ => recurse acc ts
147        end
148in
149  recurse [] [t]
150end
151
152val supp_t = prim_mk_const{Name = "supp", Thy = "nomset"}
153fun find_avoids (t, (strs, sets)) = let
154  open stringSyntax
155  fun isdbty t = let
156    val ty = type_of t
157    val {Thy,Tyop,...} = dest_thy_type ty
158  in
159    if is_var t then
160      case Binarymap.peek(!type_db, {Thy=Thy,Name=Tyop}) of
161        NONE => NONE
162      | SOME (NTI data) => SOME (mk_icomb(mk_icomb(supp_t, #pm_constant data),
163                                          t))
164    else NONE
165  end handle HOL_ERR _ => NONE
166  fun eqty ty t = type_of t = ty
167  val strings = List.filter (fn t => is_var t orelse is_string_literal t)
168                            (find_top_terms (eqty string_ty) t)
169  val stringsets = List.filter is_var (find_terms (eqty stringset_ty) t)
170  val fv_terms = findmap_terms isdbty t
171  open HOLset
172in
173  (addList(strs,strings), addList(addList(sets, fv_terms), stringsets))
174end
175
176fun goalnames (asl, w) = let
177  val fvs = FVL (w::asl) empty_tmset
178  fun foldthis(t,acc) = HOLset.add(acc, #1 (dest_var t))
179in
180  HOLset.listItems (HOLset.foldl foldthis (HOLset.empty String.compare) fvs)
181end
182
183fun FRESH_TAC (g as (asl, w)) = let
184  val (strs, sets) = List.foldl find_avoids (empty_tmset, empty_tmset) (w::asl)
185  val finite_sets = List.mapPartial (total pred_setSyntax.dest_finite) asl
186  fun filterthis t = not (is_var t) orelse mem t finite_sets
187  val sets' = List.filter filterthis (HOLset.listItems sets)
188  fun get_binder already_done t =
189      case tylookup (type_of t) of
190        NONE => NONE
191      | SOME (NTI info) => let
192          val bs = #binders info
193          fun checkthis (_,i,th) = let
194            val l = lhs (#2 (strip_forall
195                               (#2 (dest_imp (#2 (strip_forall (concl th)))))))
196            val argi = List.nth(#2 (strip_comb t), i)
197          in
198            if can (match_term l) t andalso
199               not (HOLset.member(already_done, argi))
200            then SOME (t, i, th)
201            else NONE
202          end handle Subscript => NONE
203        in
204          get_first checkthis bs
205        end
206  fun do_one used strs sets' (g as (asl, w)) =
207      case get_first (findfirst (get_binder used)) (w::asl) of
208        NONE => raise ERR "FRESH_TAC" "No binders present in goal"
209      | SOME (t, i, th) => let
210          val old = List.nth (#2 (strip_comb t), i)
211          val (pre,l,r) = let
212            val (_, b) = strip_forall (concl th)
213            val (pre, eq) = dest_imp b
214            val (l,r) = dest_eq (#2 (strip_forall eq))
215          in
216            (pre,l,r)
217          end
218          val base = case total Literal.dest_string_lit old of
219                       NONE => #1 (dest_var old)
220                     | SOME s => s
221          val newname = Lexis.gen_variant Lexis.tmvar_vary (goalnames g) base
222          val new_in_thm = List.nth (#2 (strip_comb r), i)
223          val new_t = mk_var(newname, type_of new_in_thm)
224          val th' = PART_MATCH (lhs o #2 o strip_forall o #2 o dest_imp)
225                               th
226                               t
227          val th' = INST [new_in_thm |-> new_t] th'
228          val avoid_t = let
229            open pred_setSyntax
230          in
231            List.foldl mk_union (mk_set (HOLset.listItems strs)) sets'
232          end
233          fun freshcont freshthm = let
234            val th = MP th' freshthm
235          in
236            SUBST_ALL_TAC th THEN
237            ASM_SIMP_TAC (srw_ss()) (basic_swapTheory.swapstr_def ::
238                                     get_pm_rewrites (type_of t))
239          end
240
241          val tac =
242              NEW_TAC newname avoid_t THEN
243              SUBGOAL_THEN (#1 (dest_imp (concl th'))) freshcont THENL [
244                ASM_SIMP_TAC (srw_ss()) [],
245                TRY (do_one (HOLset.add(used,new_t))
246                            (HOLset.add(strs,new_t))
247                            sets')
248              ]
249        in
250          tac g
251        end
252in
253  do_one empty_tmset strs sets' g
254end
255
256fun recthm_for_type ty = let
257  val {Tyop,Thy,...} = dest_thy_type ty
258  val NTI {recursion_thm,...} = Binarymap.find(!type_db, {Name=Tyop,Thy=Thy})
259in
260  recursion_thm
261end handle HOL_ERR _ => NONE
262         | Binarymap.NotFound => NONE
263
264fun find_constructors recthm = let
265  val (_, c) = dest_imp (concl recthm)
266  val (homvar, body) = dest_exists c
267  val eqns = let val (c1, c2) = dest_conj body
268             in
269               if is_conj c1 then c1
270               else body
271             end
272  fun dest_eqn t = let
273    val eqn_proper = #2 (strip_imp (#2 (strip_forall t)))
274  in
275    (#1 (strip_comb (rand (lhs eqn_proper))), #1 (strip_comb (rhs eqn_proper)))
276  end
277in
278  map dest_eqn (strip_conj eqns)
279end
280
281fun check_for_errors tm = let
282  val conjs = map (#2 o strip_forall) (strip_conj tm)
283  val _ = List.all is_eq conjs orelse
284          ERR "prove_recursive_term_function_exists"
285              "All conjuncts must be equations"
286  val f = rator (lhs (hd conjs))
287  val _ = List.all (fn t => rator (lhs t) = f) conjs orelse
288          ERR "prove_recursive_term_function_exists"
289              "Must define same constant in all equations"
290  val _ = List.all (fn t => length (#2 (strip_comb (lhs t))) = 1) conjs orelse
291          ERR "prove_recursive_term_function_exists"
292              "Function being defined must be applied to one argument"
293  val dom_ty = #1 (dom_rng (type_of f))
294  val recthm = valOf (recthm_for_type dom_ty)
295               handle Option => ERR "prove_recursive_term_function_exists"
296                                    ("No recursion theorem for type "^
297                                     type_to_string dom_ty)
298  val constructors = map #1 (find_constructors recthm)
299  val () =
300      case List.find
301           (fn t => List.all
302                      (fn c => not
303                                 (same_const c
304                                             (#1 (strip_comb (rand (lhs t))))))
305                      constructors) conjs of
306        NONE => ()
307      | SOME t => ERR "prove_recursive_term_function_exists"
308                      ("Unknown constructor "^
309                       tmToString (#1 (strip_comb (rand (lhs t)))))
310  val () =
311      case get_first
312           (fn t => let val (c, args) = strip_comb (rand (lhs t))
313                    in
314                      case List.find (not o is_var) args of
315                        NONE => NONE
316                      | SOME v => SOME (v, c)
317                    end) conjs of
318        NONE => ()
319      | SOME (v, c) =>
320        ERR "prove_recursive_term_function_exists"
321            (#1 (dest_const c)^"^'s argument "^tmToString v^
322             " is not a variable")
323in
324  (f, conjs)
325end
326
327val null_apm = ``nomset$mk_pmact (combin$K combin$I) : 'a nomset$pmact``
328val nameless_nti =
329    NTI { nullfv = mk_arb alpha,
330          pm_constant = null_apm,
331          pm_rewrites = [nomsetTheory.discretepm_raw, combinTheory.K_THM,
332                         combinTheory.I_THM],
333          fv_rewrites = [],
334          recursion_thm = NONE,
335          binders = []}
336
337
338val range_database = ref (Binarymap.mkDict String.compare)
339
340fun force_inst {src, to_munge} = let
341  val inst = match_type (type_of to_munge) (type_of src)
342in
343  Term.inst inst to_munge
344end
345
346val cpm_ty = let val sty = stringSyntax.string_ty
347             in
348               listSyntax.mk_list_type (pairSyntax.mk_prod (sty, sty))
349             end
350
351fun mk_perm_ty ty = mk_thy_type {Tyop = "pmact", Thy = "nomset", Args = [ty]}
352
353exception InfoProofFailed of term
354fun with_info_prove_recn_exists f finisher dom_ty rng_ty lookup info = let
355  val NTI {nullfv, pm_rewrites, fv_rewrites, pm_constant, ...} = info
356  fun mk_simple_abstraction (c, (cargs, r)) = list_mk_abs(cargs, r)
357  val recthm = valOf (recthm_for_type dom_ty)
358  val hom_t = #1 (dest_exists (#2 (dest_imp (concl recthm))))
359  val base_inst = match_type (type_of hom_t) (type_of f)
360  val nullfv = let
361    val i = match_type (type_of nullfv) rng_ty
362  in
363    Term.inst i nullfv
364  end
365  val constructors = find_constructors recthm
366  fun do_a_constructor (c, fnterm) =
367      case lookup c of
368        SOME (user_c, (args, r_term)) => let
369          fun hasdom_ty t = Type.compare(type_of t, dom_ty) = EQUAL
370          val rec_args = filter hasdom_ty args
371          val f_applied =
372              map (fn t => mk_comb(f, t) |-> genvar rng_ty) rec_args
373          val new_body = Term.subst f_applied r_term
374          val base_abs = list_mk_abs (args, new_body)
375          val with_reccalls = list_mk_abs (map #residue f_applied, base_abs)
376        in
377          force_inst {src = with_reccalls, to_munge = fnterm} |-> with_reccalls
378        end
379      | NONE => let
380          val fnterm' = Term.inst base_inst fnterm
381          fun build_abs ty = let
382            val (d,r) = dom_rng ty
383          in
384            mk_abs(genvar (Type.type_subst base_inst d), build_abs r)
385          end handle HOL_ERR _ => nullfv
386        in
387          fnterm' |-> build_abs (type_of fnterm)
388        end
389  val recursion_exists0 =
390      INST (map do_a_constructor constructors) (INST_TYPE base_inst recthm)
391  val other_var_inst = let
392    val apm = mk_var("apm", mk_perm_ty rng_ty)
393    val pm' = force_inst {src = apm, to_munge = pm_constant}
394  in
395    [apm |-> pm']
396  end
397  val recursion_exists1 = INST other_var_inst recursion_exists0
398  val recursion_exists =
399      CONV_RULE (RAND_CONV
400                   (BINDER_CONV
401                      (EVERY_CONJ_CONV
402                         (STRIP_QUANT_CONV (RAND_CONV LIST_BETA_CONV))) THENC
403                      RAND_CONV (ALPHA_CONV f)))
404                   recursion_exists1
405  val rewrites = pm_rewrites @ fv_rewrites
406  val precondition_discharged =
407      CONV_RULE
408        (LAND_CONV (SIMP_CONV (srw_ss()) rewrites THENC
409                    SIMP_CONV (srw_ss()) [pred_setTheory.SUBSET_DEF] THENC
410                    finisher))
411        recursion_exists
412in
413  MP precondition_discharged TRUTH
414     handle HOL_ERR _ =>
415            raise InfoProofFailed (#1 (dest_imp
416                                         (concl precondition_discharged)))
417end;
418
419
420
421fun prove_recursive_term_function_exists0 fin tm = let
422  val (f, conjs) = check_for_errors tm
423
424  val (dom_ty, rng_ty) = dom_rng (type_of f)
425
426  fun insert (x as (c, rhs)) alist =
427      case alist of
428        [] => [(c,rhs)]
429      | (h as (c',rhs')) :: t => if same_const c c' then
430                                   ERR "prove_recursive_term_function_exists"
431                                       ("Two equations for constructor " ^
432                                        #1 (dest_const c))
433                                 else h :: insert x t
434  fun lookup c alist =
435      case alist of
436        [] => NONE
437      | (h as (c',rhs)) :: t => if same_const c c' then SOME h else lookup c t
438  fun insert_eqn (eqn, alist) = let
439    val (c, cargs) = strip_comb (rand (lhs eqn))
440  in
441    insert (c, (cargs, rhs eqn)) alist
442  end
443  (* order of keys is order of clauses in original definition request *)
444  val alist = List.foldl insert_eqn [] conjs
445  fun find_eqn c = lookup c alist
446  val callthis = with_info_prove_recn_exists f fin dom_ty rng_ty find_eqn
447in
448  case Lib.total dest_thy_type rng_ty of
449    SOME {Tyop, Thy, ...} => let
450    in
451      case Binarymap.peek(!type_db, {Name=Tyop,Thy=Thy}) of
452        NONE => callthis nameless_nti |> REWRITE_RULE [discretepm_thm]
453      | SOME i => callthis i
454        handle InfoProofFailed tm =>
455               (HOL_WARNING
456                  "binderLib"
457                  "prove_recursive_term_function_exists"
458                  ("Couldn't prove function with swap over range - \n\
459                   \goal was "^tmToString tm^"\n\
460                   \trying null range assumption");
461                callthis nameless_nti |> REWRITE_RULE [discretepm_thm])
462    end
463  | NONE => callthis nameless_nti |> REWRITE_RULE [discretepm_thm]
464end handle InfoProofFailed tm =>
465           raise ERR "prove_recursive_term_function_exists"
466                     ("Couldn't prove function with swap over range - \n\
467                      \goal was "^tmToString tm)
468
469fun strip_tyannote acc (Absyn.TYPED(_, a, ty)) = strip_tyannote (ty::acc) a
470  | strip_tyannote acc x = (List.rev acc, x)
471
472fun head_sym a = let
473  val conjs = Absyn.strip_conj a
474  val firstbody = #2 (Absyn.strip_forall (hd conjs))
475  val (eqlhs, _) = Absyn.dest_app firstbody
476  val (_, lhs) = Absyn.dest_app eqlhs
477  val (_, fargs)  = strip_tyannote [] lhs
478  val (f0, _) = Absyn.strip_app fargs
479in
480  #2 (strip_tyannote [] f0)
481end
482
483fun prove_recursive_term_function_exists tm = let
484  val f_thm = prove_recursive_term_function_exists0 ALL_CONV tm
485  val (f_v, f_body) = dest_exists (concl f_thm)
486  val defining_body = CONJUNCT1 (ASSUME f_body)
487  val result = EQT_ELIM (SIMP_CONV bool_ss [defining_body] tm)
488in
489  CHOOSE (f_v, f_thm) (EXISTS (mk_exists(f_v, tm), f_v) result)
490end handle InfoProofFailed tm =>
491           raise ERR "prove_recursive_term_function_exists"
492                     ("Couldn't prove function with swap over range - \n\
493                      \goal was "^tmToString tm)
494
495fun prove_recursive_term_function_exists' fin tm = let
496  val f_thm = prove_recursive_term_function_exists0 fin tm
497  val (f_v, f_body) = dest_exists (concl f_thm)
498  val defining_body = CONJUNCT1 (ASSUME f_body)
499  val result = EQT_ELIM (SIMP_CONV bool_ss [defining_body] tm)
500in
501  CHOOSE (f_v, f_thm) (EXISTS (mk_exists(f_v, tm), f_v) result)
502end handle InfoProofFailed tm =>
503           raise ERR "prove_recursive_term_function_exists"
504                     ("Couldn't prove function with swap over range - \n\
505                      \goal was "^tmToString tm)
506
507
508fun define_wrapper worker q = let
509  val a = Absyn q
510  val f = head_sym a
511  val fstr = case f of
512               Absyn.IDENT(_, s) => s
513             | x => ERR "define_recursive_term_function" "invalid head symbol"
514  val restore_this = hide fstr
515  fun restore() = Parse.update_overload_maps fstr restore_this
516  val tm = Parse.absyn_to_term (Parse.term_grammar()) a
517           handle e => (restore(); raise e)
518  val _ = restore()
519  val f_thm0 = BETA_RULE (worker tm)
520  val (f_t0, th_body0) = dest_exists (concl f_thm0)
521  val f_t = mk_var(fstr, type_of f_t0)
522  val th_body = subst [f_t0 |-> f_t] th_body0
523  fun defining_conj c = let
524    val fvs = List.filter (fn v => #1 (dest_var v) <> fstr) (free_vars c)
525  in
526    list_mk_forall(fvs, c)
527  end
528  val defining_term0 = list_mk_conj(map defining_conj (strip_conj tm))
529  val defining_term = mk_conj(defining_term0, rand th_body)
530  val base_definition = let
531    (* this checks that the user-provided term is a consequence of the
532       existential theorem that prove_recursive_term_function_exists0
533       returns.  It might go wrong if extra implications appear as
534       side-conditions to equations on binder-constructors.  This can't
535       happen in this version of the code yet (because we don't handle
536       parameters or side sets *)
537    val definition_ok0 = default_prover(mk_imp(th_body, defining_term),
538                                        SIMP_TAC bool_ss [])
539  in
540    CHOOSE (f_t, f_thm0)
541           (EXISTS(mk_exists(f_t, defining_term), f_t)
542                  (UNDISCH definition_ok0))
543  end handle HOL_ERR _ => f_thm0
544  (* feel that having the equation the other way (non-GSYMed) 'round may be
545     better in many theorem-proving circumstances.  But for the moment,
546     this way round provides backwards compatibility. *)
547  val base_definition =
548      CONV_RULE (QUANT_CONV (RAND_CONV (ONCE_REWRITE_CONV [EQ_SYM_EQ])))
549                base_definition
550
551  val f_def =
552      new_specification (fstr^"_def", [fstr], base_definition)
553  val _ = add_const fstr
554  val f_const = prim_mk_const {Name = fstr, Thy = current_theory()}
555  val f_thm = save_thm(fstr^"_thm",
556                       default_prover(subst [f_t |-> f_const] tm,
557                                      SIMP_TAC bool_ss [f_def])
558                       handle HOL_ERR _ => CONJUNCT1 f_def)
559  val f_invariants = let
560    val interesting_bit =
561        f_def |> CONJUNCT2
562              |> PURE_REWRITE_RULE [combinTheory.K_THM, combinTheory.I_THM]
563    val (l,r) =
564        interesting_bit |> concl |> strip_forall |> #2
565                        |> strip_imp |> #2
566                        |> dest_eq
567    val (lf,_) = strip_comb l
568    val (rf, _) = strip_comb r
569    val nm = fstr^"_equivariant"
570  in
571    save_thm(nm, interesting_bit) before export_rewrites [nm]
572  end
573in
574  (f_thm, f_invariants)
575end
576
577
578fun define_recursive_term_function q =
579    define_wrapper (prove_recursive_term_function_exists0 ALL_CONV) q
580
581fun define_recursive_term_function' fin q =
582    define_wrapper (prove_recursive_term_function_exists0 fin) q
583
584
585val recursive_term_function_existence =
586    prove_recursive_term_function_exists0 ALL_CONV
587
588
589
590
591
592end (* struct *)
593