1structure Preterm :> Preterm =
2struct
3
4open Feedback Lib GrammarSpecials;
5open errormonad typecheck_error
6
7val ERR = mk_HOL_ERR "Preterm"
8val ERRloc = mk_HOL_ERRloc "Preterm"
9
10type pretype = Pretype.pretype
11type hol_type = Type.hol_type
12type term = Term.term
13type overinfo = {Name:string, Ty:pretype,
14                 Info:Overload.overloaded_op_info, Locn:locn.locn}
15fun tmlist_tyvs tlist =
16  List.foldl (fn (t,acc) => Lib.union (Term.type_vars_in_term t) acc) [] tlist
17
18type 'a in_env = 'a Pretype.in_env
19
20val show_typecheck_errors = ref true
21val _ = register_btrace ("show_typecheck_errors", show_typecheck_errors)
22fun tcheck_say s = if !show_typecheck_errors then Lib.say s else ()
23
24val last_tcerror : error option ref = ref NONE
25
26type 'a errM = (Pretype.Env.t,'a,tcheck_error * locn.locn) errormonad.t
27type 'a seqM = (Pretype.Env.t,'a) seqmonad.seqmonad
28
29fun smash errM env =
30  case errM env of
31      Error e => raise mkExn e
32    | Some(_, v) => v
33
34open Preterm_dtype
35
36fun pdest_eq pt =
37    case pt of
38        Comb{Rator = Comb{Rator = Const {Name = "=", Thy = "min", ...},
39                          Rand = l, ...},
40             Rand = r, ...} => (l,r)
41      | Constrained{Ptm,...} => pdest_eq Ptm
42      | _ => raise mk_HOL_ERR "Preterm" "pdest_eq" "Preterm is not an equality"
43
44val lhs = #1 o pdest_eq
45
46fun strip_pforall pt = let
47  fun recurse acc pt =
48    case pt of
49        Comb{Rator = Const{Name = "!", Thy = "bool", ...},
50             Rand = Abs{Bvar,Body,...}, ...} => recurse (Bvar::acc) Body
51      | Constrained{Ptm,...} => recurse acc Ptm
52      | _ => (List.rev acc, pt)
53in
54  recurse [] pt
55end
56
57fun head_var pt = let
58  fun err s = mk_HOL_ERR "Preterm" "head_var" s
59in
60  case pt of
61      Var _ => pt
62    | Const _ => raise err "Head is a constant"
63    | Overloaded _ => raise err "Head is an Overloaded"
64    | Comb{Rator,...} => head_var Rator
65    | Abs _ => raise err "Head is an abstraction"
66    | Constrained{Ptm,...} => head_var Ptm
67    | Antiq{Tm,...} =>
68      let
69        val (nm,ty) = Term.dest_var Tm
70         handle HOL_ERR _ => raise err "Head is an antiquoted non-var"
71      in
72        Var{Name=nm,Ty=Pretype.fromType ty,Locn=locn.Loc_None}
73      end
74    | Pattern _ => raise err "Head is a Pattern"
75end
76
77
78val op--> = Pretype.mk_fun_ty
79fun ptype_of (Var{Ty, ...}) = return Ty
80  | ptype_of (Const{Ty, ...}) = return Ty
81  | ptype_of (Comb{Rator, ...}) = ptype_of Rator >- Pretype.chase
82  | ptype_of (Abs{Bvar,Body,...}) =
83      lift2 (fn ty1 => fn ty2 => ty1 --> ty2) (ptype_of Bvar) (ptype_of Body)
84  | ptype_of (Constrained{Ty,...}) = return Ty
85  | ptype_of (Antiq{Tm,...}) = return (Pretype.fromType (Term.type_of Tm))
86  | ptype_of (Overloaded {Ty,...}) = return Ty
87  | ptype_of (Pattern{Ptm,...}) = ptype_of Ptm
88
89fun dest_ptvar pt =
90    case pt of
91        Var{Name,Locn,Ty} => (Name,Ty,Locn)
92      | _ => raise mk_HOL_ERR "Preterm" "dest_ptvar" "Preterm is not a variable"
93
94fun plist_mk_rbinop opn pts =
95    case pts of
96        [] => raise mk_HOL_ERR "Preterm" "list_mk_rbinop" "Empty list"
97      | _ =>
98        let
99          val pts' = List.rev pts
100          fun foldthis (pt, acc) = Comb{Rator = Comb{Rator = opn, Rand = pt,
101                                                     Locn = locn.Loc_None},
102                                        Rand = acc, Locn = locn.Loc_None}
103        in
104          List.foldl foldthis (hd pts') (tl pts')
105        end
106
107val bogus = locn.Loc_None
108fun term_to_preterm avds t = let
109  fun gen ty = Pretype.rename_tv avds (Pretype.fromType ty)
110  open HolKernel
111  fun recurse t =
112      case dest_term t of
113        VAR(n,ty) => gen ty >- (fn pty =>
114                     return (Var{Name = n, Locn = bogus, Ty = pty}))
115      | CONST{Ty,Thy,Name} => gen Ty >- (fn pty =>
116                              return (Const{Ty = pty, Name = Name,
117                                            Thy = Thy, Locn = bogus}))
118      | COMB(f,x) => recurse f >- (fn f' =>
119                     recurse x >- (fn x' =>
120                     return (Comb{Rand = x', Rator = f', Locn = bogus})))
121      | LAMB(v,bod) => recurse v >- (fn v' =>
122                       recurse bod >- (fn bod' =>
123                       return (Abs{Body = bod', Bvar = v', Locn = bogus})))
124in
125  lift #2 (addState [] (recurse t))
126end
127
128
129
130(*---------------------------------------------------------------------------
131     Read the location from a preterm.
132 ---------------------------------------------------------------------------*)
133
134fun locn (Var{Locn,...})         = Locn
135  | locn (Const{Locn,...})       = Locn
136  | locn (Overloaded{Locn,...})  = Locn
137  | locn (Comb{Locn,...})        = Locn
138  | locn (Abs{Locn,...})         = Locn
139  | locn (Constrained{Locn,...}) = Locn
140  | locn (Antiq{Locn,...})       = Locn
141  | locn (Pattern{Locn,...})     = Locn
142
143(*---------------------------------------------------------------------------
144     Location-ignoring equality for preterms.
145 ---------------------------------------------------------------------------*)
146fun infoeq {base_type=bt1,actual_ops=ops1,tyavoids=tya1}
147           {base_type=bt2,actual_ops=ops2,tyavoids=tya2} =
148   bt1 = bt2 andalso tya1 = tya2 andalso
149   ListPair.allEq (fn (t1,t2) => Term.aconv t1 t2) (ops1, ops2)
150
151fun eq (Var{Name=Name,Ty=Ty,...}) (Var{Name=Name',Ty=Ty',...}) =
152     Name=Name' andalso Ty=Ty'
153  | eq (Const{Name=Name,Thy=Thy,Ty=Ty,...})
154       (Const{Name=Name',Thy=Thy',Ty=Ty',...}) =
155     Name=Name' andalso Thy=Thy' andalso Ty=Ty'
156  | eq (Overloaded{Name=Name,Ty=Ty,Info=Info,...})
157       (Overloaded{Name=Name',Ty=Ty',Info=Info',...}) =
158     Name=Name' andalso Ty=Ty' andalso infoeq Info Info'
159  | eq (Comb{Rator=Rator,Rand=Rand,...})           (Comb{Rator=Rator',Rand=Rand',...})            = eq Rator Rator' andalso eq Rand Rand'
160  | eq (Abs{Bvar=Bvar,Body=Body,...})              (Abs{Bvar=Bvar',Body=Body',...})               = eq Bvar Bvar' andalso eq Body Body'
161  | eq (Constrained{Ptm=Ptm,Ty=Ty,...})            (Constrained{Ptm=Ptm',Ty=Ty',...})             = eq Ptm Ptm' andalso Ty=Ty'
162  | eq (Antiq{Tm=Tm,...})                          (Antiq{Tm=Tm',...})                            = Term.aconv Tm Tm'
163  | eq (Pattern{Ptm,...})                           (Pattern{Ptm=Ptm',...})
164                  = eq Ptm Ptm'
165  | eq  _                                           _                                             = false
166
167fun isolate_var ptv =
168  case ptv of
169      Constrained{Ptm,...} => isolate_var Ptm
170    | Var _ => ptv
171    | _ => raise ERR "ptfvs" "Mal-formed abstraction"
172
173fun ptfvs pt =
174    case pt of
175        Var _ => [pt]
176      | Comb{Rator,Rand=r,...} =>
177        let
178        in
179          case Rator of
180              Comb{Rator=Const{Name,...}, Rand = l, ...} =>
181              if Name = GrammarSpecials.case_arrow_special then
182                op_set_diff eq (ptfvs r) (ptfvs l)
183              else
184                op_union eq (ptfvs Rator) (ptfvs r)
185            | _ => op_union eq (ptfvs Rator) (ptfvs r)
186        end
187      | Abs{Bvar,Body,...} => op_set_diff eq (ptfvs Body) [isolate_var Bvar]
188      | Constrained{Ptm,...} => ptfvs Ptm
189      | _ => []
190
191fun strip_pcomb pt = let
192  fun recurse acc t =
193      case t of
194        Comb{Rator, Rand, ...} => recurse (Rand::acc) Rator
195      | _ => (t, acc)
196in
197  recurse [] pt
198end
199
200
201(* ----------------------------------------------------------------------
202
203     Simple map from a preterm to a term. The argument "shr" maps from
204     pretypes to types. Overloaded nodes cause failure if one is
205     encountered, as Overloaded nodes should be gone by the time clean
206     is called.
207
208     shr takes a location for now, until Preterm has a location built-in.
209
210     Handles the beta-conversion that occurs into Pattern terms.
211
212   ---------------------------------------------------------------------- *)
213
214
215
216fun clean shr = let
217  open Term
218  fun cl t =
219      case t of
220        Var{Name,Ty,Locn}            => mk_var(Name, shr Locn Ty)
221      | Const{Name,Thy,Ty,Locn}      => mk_thy_const{Name=Name,
222                                                     Thy=Thy,
223                                                     Ty=shr Locn Ty}
224      | Comb{Rator,Rand,...}         => let
225          val (f, args0) = strip_pcomb t
226          val args = map cl args0
227        in
228          case f of
229            Pattern{Ptm,...} => let
230              val t = cl Ptm
231              val (bvs, _) = strip_abs t
232              val inst = ListPair.map (fn (p,a) => p |-> a) (bvs, args)
233              val result0 = funpow (length inst) (#2 o dest_abs) t
234            in
235              list_mk_comb(Term.subst inst result0,
236                           List.drop(args, length inst))
237            end
238          | _ => list_mk_comb(cl f, args)
239        end
240      | Abs{Bvar,Body,...}           => mk_abs(cl Bvar, cl Body)
241      | Antiq{Tm,...}                => Tm
242      | Constrained{Ptm,...}         => cl Ptm
243      | Overloaded{Name,Ty,Locn,...} =>
244          raise ERRloc "clean" Locn "Overload term remains"
245      | Pattern {Ptm,...}             => cl Ptm
246 in
247  cl
248 end
249
250val has_free_uvar = Pretype.has_unbound_uvar
251
252fun tyVars ptm =  (* the pretype variables in a preterm *)
253  case ptm of
254    Var{Ty,...}             => Pretype.tyvars Ty
255  | Const{Ty,...}           => Pretype.tyvars Ty
256  | Comb{Rator,Rand,...}    => lift2 Lib.union (tyVars Rator) (tyVars Rand)
257  | Abs{Bvar,Body,...}      => lift2 Lib.union (tyVars Bvar) (tyVars Body)
258  | Antiq{Tm,...}           =>
259      return (map Type.dest_vartype (Term.type_vars_in_term Tm))
260  | Constrained{Ptm,Ty,...} => lift2 Lib.union (tyVars Ptm) (Pretype.tyvars Ty)
261  | Pattern{Ptm,...}        => tyVars Ptm
262  | Overloaded _            => raise Fail "Preterm.tyVars: applied to \
263                                          \Overloaded";
264
265
266(*---------------------------------------------------------------------------
267    Translate a preterm to a term. Will "guess type variables"
268    (assign names to type variables created during type inference),
269    if a flag is set. No "Overloaded" nodes are allowed in the preterm:
270    overloading resolution should already have gotten rid of them.
271 ---------------------------------------------------------------------------*)
272
273val _ =
274    register_btrace ("notify type variable guesses",
275                     Globals.notify_on_tyvar_guess)
276
277fun to_term (tm : preterm) : term in_env =
278    if !Globals.guessing_tyvars then
279      let
280        fun cleanup tm = let
281          infix >> >-
282          fun usedLift m (E,used) =
283            case m E of
284                Error e => Error e
285              | Some (E', result) => Some ((E',used), result)
286          fun clean0 pty = lift Pretype.clean (Pretype.remove_made_links pty)
287          val clean = usedLift o clean0
288        in
289          case tm of
290            Var{Name,Ty,...} => lift (fn ty => Term.mk_var(Name, ty))
291                                     (Pretype.replace_null_links Ty >- clean)
292          | Const{Name,Thy,Ty,...} =>
293              lift (fn ty => Term.mk_thy_const{Name=Name,Thy=Thy,Ty=ty})
294                   (Pretype.replace_null_links Ty >- clean)
295          | Comb{Rator, Rand,...} => let
296              val (f, args) = strip_pcomb tm
297              open Term
298            in
299              case f of
300                Pattern{Ptm,...} => let
301                  fun doit f_t args = let
302                    val (bvs, _) = strip_abs f_t
303                    val inst = ListPair.map Lib.|-> (bvs, args)
304                    val res0 = funpow (length inst) (#2 o dest_abs) f_t
305                  in
306                    list_mk_comb(Term.subst inst res0,
307                                 List.drop(args, length inst))
308                  end
309                in
310                  cleanup Ptm >- (fn f =>
311                  mmap cleanup args >- (fn args' =>
312                  return (doit f args')))
313                end
314              | _ => cleanup f >- (fn f_t =>
315                     mmap cleanup args >- (fn args' =>
316                     return (list_mk_comb(f_t, args'))))
317            end
318          | Abs{Bvar, Body,...} => cleanup Bvar >- (fn Bvar'
319                                => cleanup Body >- (fn Body'
320                                => return (Term.mk_abs(Bvar', Body'))))
321          | Antiq{Tm,...} => return Tm
322          | Constrained{Ptm,...} => cleanup Ptm
323          | Overloaded _ => raise ERRloc "to_term" (locn tm)
324                                         "applied to Overloaded"
325          | Pattern{Ptm,...} => cleanup Ptm
326        end
327        fun addV m vars e =
328          case m (e,vars) of
329              Error e => Error e
330            | Some ((e',v'), r) => Some (e', (r,v'))
331        val V = tyVars tm >-
332                (fn vs => lift (fn x => (vs,x)) (addV (cleanup tm) vs))
333      in
334        fn e =>
335           case V e of
336               Error e => Error e
337             | Some (e', (vs0, (tm, vs))) =>
338               let
339                 val guessed_vars = List.take(vs, length vs - length vs0)
340                 val _ =
341                     if not (null guessed_vars) andalso
342                        !Globals.notify_on_tyvar_guess andalso
343                        !Globals.interactive
344                     then
345                       Feedback.HOL_MESG
346                         (String.concat
347                            ("inventing new type variable names: "
348                             :: Lib.commafy (List.rev guessed_vars)))
349                     else ()
350               in
351                 Some (e', tm)
352               end
353      end
354    else
355      let
356        fun shr env l ty =
357            if smash (has_free_uvar ty) env then
358              raise ERRloc "typecheck.to_term" l
359                           "Unconstrained type variable (and Globals.\
360                           \guessing_tyvars is false)"
361            else smash (lift Pretype.clean (Pretype.remove_made_links ty))
362                       env
363      in
364        (fn e => Some (e, clean (shr e) tm))
365      end
366
367
368
369
370(*---------------------------------------------------------------------------*
371 *                                                                           *
372 * Overloading removal.  Th function "remove_overloading_phase1" will        *
373 * replace Overloaded _ nodes with Consts where it can be shown that only    *
374 * one of the possible constants has a type compatible with the type of the  *
375 * term as it has been inferred during the previous phase of type inference. *
376 * This may in turn constrain other overloaded terms elsewhere in the tree.  *
377 *                                                                           *
378 *---------------------------------------------------------------------------*)
379
380(* In earlier stages, the base_type of any overloaded preterms will have been
381   become more instantiated through the process of type inference.  This
382   first phase of resolving overloading removes those operators that are
383   no longer compatible with this type.  If this results in no operators,
384   this is an error.  If it results in one operator, this can be chosen
385   as the result.  If there are more than one, this is passed on so that
386   later phases can figure out which are possible given all the other
387   overloaded sub-terms in the term. *)
388local
389  open errormonad
390  infix >~
391  val op>~ = optmonad.>-
392in
393fun filterM PM l =
394  case l of
395      [] => return l
396    | h::t => PM h >- (fn b => if b then lift (cons h) (filterM PM t)
397                               else filterM PM t)
398
399fun remove_overloading_phase1 ptm =
400  case ptm of
401    Comb{Rator, Rand, Locn} =>
402      lift2 (fn t1 => fn t2 => Comb{Rator = t1, Rand = t2, Locn = Locn})
403            (remove_overloading_phase1 Rator)
404            (remove_overloading_phase1 Rand)
405  | Abs{Bvar, Body, Locn} =>
406      lift2 (fn t1 => fn t2 => Abs{Bvar = t1, Body = t2, Locn = Locn})
407            (remove_overloading_phase1 Bvar)
408            (remove_overloading_phase1 Body)
409  | Constrained{Ptm, Ty, Locn} =>
410      lift (fn t => Constrained{Ptm = t, Ty = Ty, Locn = Locn})
411           (remove_overloading_phase1 Ptm)
412  | Overloaded{Name,Ty,Info,Locn} => let
413      fun testfn t = let
414        open Term
415        val possty = type_of t
416        val avds = map Type.dest_vartype (tmlist_tyvs (free_vars t))
417        val pty0 = Pretype.fromType possty
418      in
419        Pretype.rename_typevars avds pty0 >- Pretype.can_unify Ty
420      end
421      fun after_filter possible_ops =
422        case possible_ops of
423            [] => error (OvlNoType(Name,Pretype.toType Ty), Locn)
424          | [t] =>
425            let
426              open Term
427            in
428              if is_const t then
429                let
430                  val {Ty = ty,Name,Thy} = dest_thy_const t
431                  val ptyM = Pretype.rename_typevars [] (Pretype.fromType ty)
432                in
433                  ptyM >- Pretype.unify Ty >>
434                  return (Const{Name=Name, Thy=Thy, Ty=Ty, Locn=Locn})
435            end
436              else
437                let
438                  val avds = map Type.dest_vartype (tmlist_tyvs (free_vars t))
439                in
440                  term_to_preterm avds t >- (fn ptm =>
441                  ptype_of ptm >- (fn pty =>
442                  Pretype.unify Ty pty >>
443                  return (Pattern{Ptm = ptm, Locn = Locn})))
444                end
445            end
446          | _ =>
447            return
448              (Overloaded{Name=Name, Ty=Ty,
449                          Info=Overload.fupd_actual_ops (fn _ => possible_ops)
450                                                        Info,
451                          Locn=Locn})
452    in
453      filterM testfn (#actual_ops Info) >- after_filter
454  end
455  | _ => return ptm
456
457end (* local *)
458
459
460val remove_overloading : preterm -> preterm seqM = let
461  open seqmonad Term
462  infix >- >> ++
463  fun unify t1 t2 = fromErr (Pretype.unify t1 t2)
464
465  fun recurse ptm =
466    case ptm of
467        Overloaded {Name,Ty,Info,Locn} =>
468        let
469          val actual_ops = #actual_ops Info
470          fun try t =
471            if is_const t then
472              let
473                val {Ty=ty,Name=nm,Thy=thy} = Term.dest_thy_const t
474                val pty0 = Pretype.fromType ty
475              in
476                fromErr (Pretype.rename_typevars [] pty0) >- unify Ty >>
477                return (Const{Name=nm, Ty=Ty, Thy=thy, Locn=Locn})
478              end
479            else
480              let
481                val avds = map Type.dest_vartype (tmlist_tyvs (free_vars t))
482              in
483                fromErr (term_to_preterm avds t) >- (fn ptm =>
484                fromErr (ptype_of ptm) >- (fn pty =>
485                unify Ty pty >>
486                return (Pattern{Ptm = ptm, Locn = Locn})))
487              end
488        in
489          tryall try actual_ops
490        end
491      | Comb{Rator, Rand, Locn} =>
492          lift2 (fn t1 => fn t2 => Comb{Rator=t1,Rand=t2,Locn=Locn})
493                (recurse Rator) (recurse Rand)
494      | Abs{Bvar, Body, Locn} =>
495          lift (fn t => Abs{Bvar=Bvar, Body=t, Locn=Locn}) (recurse Body)
496      | Constrained{Ptm,Ty,Locn} =>
497          lift (fn t => Constrained{Ptm=t, Ty=Ty, Locn=Locn}) (recurse Ptm)
498      | _ => return ptm
499
500(*
501  val overloads = overloaded_subterms [] ptm
502  val _ = if length overloads >= 30
503          then HOL_WARNING "Preterm" "remove_overloading"
504                           "many overloaded symbols in term: \
505                           \overloading resolution might take a long time."
506          else ()
507*)
508in
509  recurse
510end
511
512(* this version loses the sequence/lazy-list backtracking of the parse *)
513fun do_overloading_removal ptm =
514  let
515    open errormonad
516  in
517    remove_overloading_phase1 ptm >-
518    (seqmonad.toError (OvlFail, locn.Loc_Unknown) o remove_overloading)
519  end
520
521fun report_ovl_ambiguity b env =
522  (* b is true if multiple resolutions weren't possible *)
523  if not b andalso
524     (not (!Globals.guessing_overloads) orelse !Globals.notify_on_tyvar_guess)
525  then
526    if not (!Globals.guessing_overloads) then
527      error (OvlTooMany, locn.Loc_None) env
528    else if !Globals.interactive then
529      (Feedback.HOL_MESG "more than one resolution of overloading was possible";
530       ok env)
531    else
532      ok env
533  else ok env
534
535fun remove_elim_magics ptm =
536  case ptm of
537    Var _ => ptm
538  | Const _ => ptm
539  | Antiq _ => ptm
540  | Comb{Rator = (rator as Const{Name, ...}), Rand = ptm1, Locn} =>
541      if Name = nat_elim_term then remove_elim_magics ptm1
542      else Comb{Rator = rator, Rand = remove_elim_magics ptm1, Locn = Locn}
543  | Comb{Rator, Rand, Locn} => Comb{Rator = remove_elim_magics Rator,
544                                    Rand = remove_elim_magics Rand, Locn = Locn}
545  | Abs{Bvar, Body, Locn} => Abs{Bvar = remove_elim_magics Bvar,
546                                 Body = remove_elim_magics Body, Locn = Locn}
547  | Constrained{Ptm, Ty, Locn} => Constrained{Ptm = remove_elim_magics Ptm,
548                                              Ty = Ty, Locn = Locn}
549  | Overloaded _ => raise Fail "Preterm.remove_elim_magics on Overloaded"
550  | Pattern _ => ptm
551
552
553fun overloading_resolution (ptm : preterm) : (preterm * bool) errM =
554  errormonad.lift
555    (fn (t,b) => (remove_elim_magics t, b))
556    (do_overloading_removal ptm)
557
558fun overloading_resolutionS ptm =
559  let
560    open seqmonad
561  in
562    lift
563      remove_elim_magics
564      (fromErr (remove_overloading_phase1 ptm) >- remove_overloading)
565  end
566
567(*---------------------------------------------------------------------------
568 * Type inference for HOL terms. Looks ugly because of error messages, but is
569 * actually very simple, given side-effecting unification.
570 *---------------------------------------------------------------------------*)
571
572fun isnumrator_name nm =
573  nm = "BIT1" orelse nm = "BIT2" orelse nm = "NUMERAL" orelse
574  nm = fromNum_str orelse nm = nat_elim_term
575
576fun isnumrator (Const{Name,...}) = isnumrator_name Name
577  | isnumrator (Overloaded{Name,...}) = isnumrator_name Name
578  | isnumrator _ = false
579
580fun isnum (Const {Name,...}) = Name = "0" orelse Name = "ZERO"
581  | isnum (Overloaded{Name,...}) = Name = "0" orelse Name = "ZERO"
582  | isnum (Comb{Rator,Rand,...}) = isnumrator Rator andalso isnum Rand
583  | isnum _ = false
584
585fun is_atom (Var _) = true
586  | is_atom (Const _) = true
587  | is_atom (Constrained{Ptm,...}) = is_atom Ptm
588  | is_atom (Overloaded _) = true
589  | is_atom (t as Comb{Rator,Rand,...}) = isnum t
590  | is_atom t = false
591
592
593local
594  fun default_typrinter x = "<hol_type>"
595  fun default_tmprinter x = "<term>"
596  open errormonad
597  infix ++?
598  fun smashTm ptm =
599    Lib.with_flag (Globals.notify_on_tyvar_guess, false)
600                  (smash (overloading_resolution ptm >- (to_term o #1)))
601in
602fun typecheck_phase1 printers = let
603  val (ptm, pty) =
604      case printers of
605        SOME (x,y) => let
606          val typrint = y
607          fun tmprint tm =
608              if Term.is_const tm then x tm ^ " " ^ y (Term.type_of tm)
609              else x tm
610        in
611          (tmprint, typrint)
612        end
613      | NONE => (default_tmprinter, default_typrinter)
614  fun check(Comb{Rator, Rand, Locn}) =
615    check Rator >> check Rand >>
616    ptype_of Rator >- (fn rator_ty =>
617    ptype_of Rand >- (fn rand_ty =>
618    Pretype.new_uvar >- (fn range_var =>
619    (Pretype.unify rator_ty (rand_ty --> range_var)) ++?
620     (fn unify_error => fn env =>
621          let val tmp = !Globals.show_types
622              val _   = Globals.show_types := true
623              val Rator' = smashTm Rator env
624                handle e => (Globals.show_types := tmp; raise e)
625              val Rand'  = smashTm Rand env
626                handle e => (Globals.show_types := tmp; raise e)
627              val message =
628                  String.concat
629                      [
630                       "\nType inference failure: unable to infer a type \
631                       \for the application of\n\n",
632                       ptm Rator',
633                       "\n\n"^locn.toString (locn Rator)^"\n\n",
634                       if is_atom Rator then ""
635                       else ("which has type\n\n" ^
636                             pty(Term.type_of Rator') ^ "\n\n"),
637
638                       "to\n\n",
639                       ptm Rand',
640                       "\n\n"^locn.toString (locn Rand)^"\n\n",
641
642                       if is_atom Rand then ""
643                       else ("which has type\n\n" ^
644                             pty(Term.type_of Rand') ^ "\n\n"),
645
646                       "unification failure message: " ^
647                       errorMsg (#1 unify_error) ^ "\n"]
648          in
649            Globals.show_types := tmp;
650            tcheck_say message;
651            Error (AppFail(Rator',Rand',message), locn Rand)
652          end))))
653    | check (Abs{Bvar, Body, Locn}) = (check Bvar >> check Body)
654    | check (Constrained{Ptm,Ty,Locn}) =
655        check Ptm >> ptype_of Ptm >- (fn ptyp =>
656        (Pretype.unify ptyp Ty ++
657         (fn env =>
658             let val tmp = !Globals.show_types
659                 val _ = Globals.show_types := true
660                 val real_term = smashTm Ptm env
661                   handle e => (Globals.show_types := tmp; raise e)
662                 val real_type = Pretype.toType Ty
663                   handle e => (Globals.show_types := tmp; raise e)
664                 val message =
665                  String.concat
666                      [
667                       "\nType inference failure: the term\n\n",
668                       ptm real_term,
669                       "\n\n", locn.toString (locn Ptm), "\n\n",
670                       if is_atom Ptm then ""
671                       else("which has type\n\n" ^
672                            pty(Term.type_of real_term) ^ "\n\n"),
673                       "can not be constrained to be of type\n\n",
674                       pty real_type,
675                       "\n\nunification failure message: ???\n"]
676             in
677               Globals.show_types := tmp;
678               tcheck_say message;
679               Error(ConstrainFail(real_term, real_type, message), Locn)
680             end)))
681    | check _ = ok
682in
683  check
684end
685end (* local *)
686
687val TC = typecheck_phase1
688
689(* ---------------------------------------------------------------------- *)
690(* function to do the equivalent of strip_conj, but where the "conj" is   *)
691(* the magic binary operator bool$<GrammarSpecials.case_split_special     *)
692(* ---------------------------------------------------------------------- *)
693
694open HolKernel
695fun dest_binop n c t = let
696  val (f,args) = strip_comb t
697  val {Name,Thy,...} = dest_thy_const f
698      handle HOL_ERR _ =>
699             raise ERR ("dest_case"^n) ("Not a "^n^" term")
700  val _ = (Name = c andalso Thy = "bool") orelse
701          raise ERR ("dest_case"^n) ("Not a "^n^" term")
702  val _ = length args = 2 orelse
703          raise ERR ("dest_case_"^n) ("case "^n^" 'op' with bad # of args")
704in
705  (hd args, hd (tl args))
706end
707
708val dest_case_split = dest_binop "split" case_split_special
709val dest_case_arrow = dest_binop "arrow" case_arrow_special
710
711fun strip_splits t0 = let
712  fun trav acc t = let
713    val (l,r) = dest_case_split t
714  in
715    trav (trav acc r) l
716  end handle HOL_ERR _ => t::acc
717in
718  trav [] t0
719end
720
721fun mk_conj(t1, t2) = let
722  val c = mk_thy_const{Name = "/\\", Thy = "bool",
723                       Ty = Type.bool --> Type.bool --> Type.bool}
724in
725  mk_comb(mk_comb(c,t1), t2)
726end
727
728fun list_mk_conj [] = raise ERR "list_mk_conj" "empty list"
729  | list_mk_conj [h] = h
730  | list_mk_conj (h::t) = mk_conj(h, list_mk_conj t)
731fun mk_eq(t1, t2) = let
732  val ty = type_of t1
733  val c = mk_thy_const{Name = "=", Thy = "min", Ty = ty --> ty --> Type.bool}
734in
735  mk_comb(mk_comb(c,t1),t2)
736end
737
738datatype rcm_action = Input of term
739                    | Ab of term * term
740                    | Cmb of int * term
741datatype rcm_out = Ch of term | Unch of term
742fun is_unch (Unch _) = true | is_unch _ = false
743fun dest_out (Ch t) = t | dest_out (Unch t) = t
744fun Pprefix P list = let
745  fun recurse pfx rest =
746      case rest of
747        [] => (list, [])
748      | h::t => if P h then recurse (h::pfx) t
749                else (List.rev pfx, rest)
750in
751  recurse [] list
752end
753
754fun recomb (outf, outargs, orig) = let
755  fun lmk(base, args) = List.foldl (fn (out,t) => mk_comb(t,dest_out out))
756                                   base args
757in
758  case outf of
759    Ch f => Ch (lmk(f, outargs))
760  | Unch f => let
761      val (_, badargs) = Pprefix is_unch outargs
762    in
763      if null badargs then Unch orig
764      else Ch (lmk(funpow (length badargs) rator orig, badargs))
765    end
766end
767
768fun remove_case_magic0 tm0 = let
769  fun traverse acc actions =
770      case actions of
771        [] => dest_out (hd acc)
772      | act :: rest => let
773        in
774          case act of
775            Input t => let
776            in
777              if is_abs t then let
778                  val (v,body) = dest_abs t
779                in
780                  traverse acc (Input body :: Ab (v,t) :: rest)
781                end
782              else if is_comb t then let
783                  val (f,args) = strip_comb t
784                  val in_args = map Input args
785                in
786                  traverse acc (in_args @
787                                [Input f, Cmb(length args, t)] @ rest)
788                end
789              else
790                traverse (Unch t::acc) rest
791            end
792          | Ab (v,orig) => let
793            in
794              case acc of
795                Ch bod' :: acc0 => traverse (Ch (mk_abs(v,bod'))::acc0)
796                                            rest
797              | Unch _ :: acc0 => traverse (Unch orig :: acc0) rest
798              | [] => raise Fail "Preterm.rcm: inv failed!"
799            end
800          | Cmb(arglen, orig) => let
801              val out_f = hd acc
802              val f = dest_out out_f
803              val acc0 = tl acc
804              val acc_base = List.drop(acc0, arglen)
805              val out_args = List.rev (List.take(acc0, arglen))
806              val args = map dest_out out_args
807              val newt = let
808                val {Name,Thy,Ty} = dest_thy_const f
809                    handle HOL_ERR _ => {Name = "", Thy = "", Ty = alpha}
810              in
811                if Name = core_case_special andalso Thy = "bool" then let
812                    val _ = length args >= 2 orelse
813                            raise ERR "remove_case_magic"
814                                      "case constant has wrong # of args"
815                    val split_on_t = hd args
816                    val cases = strip_splits (hd (tl args))
817                    val patbody_pairs = map dest_case_arrow cases
818                        handle HOL_ERR _ =>
819                               raise ERR "remove_case_magic"
820                                         ("Case expression has invalid syntax \
821                                          \where there should be arrows")
822                    val split_on_t_ty = type_of split_on_t
823                    val result_ty =
824                        type_of (list_mk_comb(f, List.take(args,2)))
825                    val fakef = genvar (split_on_t_ty --> result_ty)
826                    val fake_eqns =
827                        list_mk_conj(map (fn (l,r) =>
828                                             mk_eq(mk_comb(fakef, l), r))
829                                         patbody_pairs)
830                    val functional =
831                        GrammarSpecials.compile_pattern_match fake_eqns
832                    val func = snd (dest_abs functional)
833                    val (v,case_t0) = dest_abs func
834                    val case_t = subst [v |-> split_on_t] case_t0
835                  in
836                    Ch (list_mk_comb(case_t, tl (tl args)))
837                  end
838                else
839                  recomb(out_f, out_args, orig)
840              end (* newt *)
841            in
842              traverse (newt::acc_base) rest
843            end (* Cmb *)
844        end (* act :: rest *) (* end traverse *)
845in
846  traverse [] [Input tm0]
847end
848
849fun remove_case_magic tm =
850    if GrammarSpecials.case_initialised() then remove_case_magic0 tm
851    else tm
852
853val post_process_term = ref (I : term -> term);
854
855fun typecheck pfns ptm0 =
856  let
857    open errormonad
858  in
859    lift remove_case_magic
860         (TC pfns ptm0 >>
861          overloading_resolution ptm0 >-                     (fn (ptm,b) =>
862          report_ovl_ambiguity b >> to_term ptm)) >-         (fn t =>
863         fn e => errormonad.Some(e, !post_process_term t))
864  end
865
866fun typecheckS ptm =
867  let
868    open seqmonad
869    val TC' = errormonad.with_flagM (show_typecheck_errors, false) (TC NONE ptm)
870  in
871    lift (!post_process_term o remove_case_magic)
872         (fromErr TC' >> overloading_resolutionS ptm >-
873          (fn ptm' => fromErr (to_term ptm')))
874  end
875
876
877end; (* Preterm *)
878