1structure Preterm :> Preterm =
2struct
3
4open Feedback Lib GrammarSpecials;
5open errormonad typecheck_error
6
7val ERR = mk_HOL_ERR "Preterm"
8val ERRloc = mk_HOL_ERRloc "Preterm"
9
10type pretype = Pretype.pretype
11type hol_type = Type.hol_type
12type term = Term.term
13type overinfo = {Name:string, Ty:pretype,
14                 Info:Overload.overloaded_op_info, Locn:locn.locn}
15fun tmlist_tyvs tlist =
16  List.foldl (fn (t,acc) => Lib.union (Term.type_vars_in_term t) acc) [] tlist
17
18type 'a in_env = 'a Pretype.in_env
19
20val show_typecheck_errors = ref true
21val _ = register_btrace ("show_typecheck_errors", show_typecheck_errors)
22fun tcheck_say s = if !show_typecheck_errors then Lib.say s else ()
23
24val last_tcerror : error option ref = ref NONE
25
26type 'a errM = (Pretype.Env.t,'a,tcheck_error * locn.locn) errormonad.t
27type 'a seqM = (Pretype.Env.t,'a) seqmonad.seqmonad
28
29fun smash errM env =
30  case errM env of
31      Error e => raise mkExn e
32    | Some(_, v) => v
33
34open Preterm_dtype
35
36fun pdest_eq pt =
37    case pt of
38        Comb{Rator = Comb{Rator = Const {Name = "=", Thy = "min", ...},
39                          Rand = l, ...},
40             Rand = r, ...} => (l,r)
41      | Constrained{Ptm,...} => pdest_eq Ptm
42      | _ => raise mk_HOL_ERR "Preterm" "pdest_eq" "Preterm is not an equality"
43
44val lhs = #1 o pdest_eq
45
46fun strip_pforall pt = let
47  fun recurse acc pt =
48    case pt of
49        Comb{Rator = Const{Name = "!", Thy = "bool", ...},
50             Rand = Abs{Bvar,Body,...}, ...} => recurse (Bvar::acc) Body
51      | Constrained{Ptm,...} => recurse acc Ptm
52      | _ => (List.rev acc, pt)
53in
54  recurse [] pt
55end
56
57fun head_var pt = let
58  fun err s = mk_HOL_ERR "Preterm" "head_var" s
59in
60  case pt of
61      Var _ => pt
62    | Const _ => raise err "Head is a constant"
63    | Overloaded _ => raise err "Head is an Overloaded"
64    | Comb{Rator,...} => head_var Rator
65    | Abs _ => raise err "Head is an abstraction"
66    | Constrained{Ptm,...} => head_var Ptm
67    | Antiq{Tm,...} =>
68      let
69        val (nm,ty) = Term.dest_var Tm
70         handle HOL_ERR _ => raise err "Head is an antiquoted non-var"
71      in
72        Var{Name=nm,Ty=Pretype.fromType ty,Locn=locn.Loc_None}
73      end
74    | Pattern _ => raise err "Head is a Pattern"
75end
76
77
78val op--> = Pretype.mk_fun_ty
79fun ptype_of (Var{Ty, ...}) = return Ty
80  | ptype_of (Const{Ty, ...}) = return Ty
81  | ptype_of (Comb{Rator, ...}) = ptype_of Rator >- Pretype.chase
82  | ptype_of (Abs{Bvar,Body,...}) =
83      lift2 (fn ty1 => fn ty2 => ty1 --> ty2) (ptype_of Bvar) (ptype_of Body)
84  | ptype_of (Constrained{Ty,...}) = return Ty
85  | ptype_of (Antiq{Tm,...}) = return (Pretype.fromType (Term.type_of Tm))
86  | ptype_of (Overloaded {Ty,...}) = return Ty
87  | ptype_of (Pattern{Ptm,...}) = ptype_of Ptm
88
89fun dest_ptvar pt =
90    case pt of
91        Var{Name,Locn,Ty} => (Name,Ty,Locn)
92      | _ => raise mk_HOL_ERR "Preterm" "dest_ptvar" "Preterm is not a variable"
93
94fun plist_mk_rbinop opn pts =
95    case pts of
96        [] => raise mk_HOL_ERR "Preterm" "list_mk_rbinop" "Empty list"
97      | _ =>
98        let
99          val pts' = List.rev pts
100          fun foldthis (pt, acc) = Comb{Rator = Comb{Rator = opn, Rand = pt,
101                                                     Locn = locn.Loc_None},
102                                        Rand = acc, Locn = locn.Loc_None}
103        in
104          List.foldl foldthis (hd pts') (tl pts')
105        end
106
107val bogus = locn.Loc_None
108fun term_to_preterm avds t = let
109  fun gen ty = Pretype.rename_tv avds (Pretype.fromType ty)
110  open HolKernel
111  fun recurse t =
112      case dest_term t of
113        VAR(n,ty) => gen ty >- (fn pty =>
114                     return (Var{Name = n, Locn = bogus, Ty = pty}))
115      | CONST{Ty,Thy,Name} => gen Ty >- (fn pty =>
116                              return (Const{Ty = pty, Name = Name,
117                                            Thy = Thy, Locn = bogus}))
118      | COMB(f,x) => recurse f >- (fn f' =>
119                     recurse x >- (fn x' =>
120                     return (Comb{Rand = x', Rator = f', Locn = bogus})))
121      | LAMB(v,bod) => recurse v >- (fn v' =>
122                       recurse bod >- (fn bod' =>
123                       return (Abs{Body = bod', Bvar = v', Locn = bogus})))
124in
125  lift #2 (addState [] (recurse t))
126end
127
128
129
130(*---------------------------------------------------------------------------
131     Read the location from a preterm.
132 ---------------------------------------------------------------------------*)
133
134fun locn (Var{Locn,...})         = Locn
135  | locn (Const{Locn,...})       = Locn
136  | locn (Overloaded{Locn,...})  = Locn
137  | locn (Comb{Locn,...})        = Locn
138  | locn (Abs{Locn,...})         = Locn
139  | locn (Constrained{Locn,...}) = Locn
140  | locn (Antiq{Locn,...})       = Locn
141  | locn (Pattern{Locn,...})     = Locn
142
143(*---------------------------------------------------------------------------
144     Location-ignoring equality for preterms.
145 ---------------------------------------------------------------------------*)
146fun infoeq {base_type=bt1,actual_ops=ops1,tyavoids=tya1}
147           {base_type=bt2,actual_ops=ops2,tyavoids=tya2} =
148   bt1 = bt2 andalso tya1 = tya2 andalso
149   ListPair.allEq (fn (t1,t2) => Term.aconv t1 t2) (ops1, ops2)
150
151fun eq (Var{Name=Name,Ty=Ty,...}) (Var{Name=Name',Ty=Ty',...}) =
152     Name=Name' andalso Ty=Ty'
153  | eq (Const{Name=Name,Thy=Thy,Ty=Ty,...})
154       (Const{Name=Name',Thy=Thy',Ty=Ty',...}) =
155     Name=Name' andalso Thy=Thy' andalso Ty=Ty'
156  | eq (Overloaded{Name=Name,Ty=Ty,Info=Info,...})
157       (Overloaded{Name=Name',Ty=Ty',Info=Info',...}) =
158     Name=Name' andalso Ty=Ty' andalso infoeq Info Info'
159  | eq (Comb{Rator=Rator,Rand=Rand,...})           (Comb{Rator=Rator',Rand=Rand',...})            = eq Rator Rator' andalso eq Rand Rand'
160  | eq (Abs{Bvar=Bvar,Body=Body,...})              (Abs{Bvar=Bvar',Body=Body',...})               = eq Bvar Bvar' andalso eq Body Body'
161  | eq (Constrained{Ptm=Ptm,Ty=Ty,...})            (Constrained{Ptm=Ptm',Ty=Ty',...})             = eq Ptm Ptm' andalso Ty=Ty'
162  | eq (Antiq{Tm=Tm,...})                          (Antiq{Tm=Tm',...})                            = Term.aconv Tm Tm'
163  | eq (Pattern{Ptm,...})                           (Pattern{Ptm=Ptm',...})
164                  = eq Ptm Ptm'
165  | eq  _                                           _                                             = false
166
167fun isolate_var ptv =
168  case ptv of
169      Constrained{Ptm,...} => isolate_var Ptm
170    | Var _ => ptv
171    | _ => raise ERR "ptfvs" "Mal-formed abstraction"
172
173fun ptfvs pt =
174    case pt of
175        Var _ => [pt]
176      | Comb{Rator,Rand=r,...} =>
177        let
178        in
179          case Rator of
180              Comb{Rator=Const{Name,...}, Rand = l, ...} =>
181              if Name = GrammarSpecials.case_arrow_special then
182                op_set_diff eq (ptfvs r) (ptfvs l)
183              else
184                op_union eq (ptfvs Rator) (ptfvs r)
185            | _ => op_union eq (ptfvs Rator) (ptfvs r)
186        end
187      | Abs{Bvar,Body,...} => op_set_diff eq (ptfvs Body) [isolate_var Bvar]
188      | Constrained{Ptm,...} => ptfvs Ptm
189      | _ => []
190
191fun strip_pcomb pt = let
192  fun recurse acc t =
193      case t of
194        Comb{Rator, Rand, ...} => recurse (Rand::acc) Rator
195      | _ => (t, acc)
196in
197  recurse [] pt
198end
199
200
201(* ----------------------------------------------------------------------
202
203     Simple map from a preterm to a term. The argument "shr" maps from
204     pretypes to types. Overloaded nodes cause failure if one is
205     encountered, as Overloaded nodes should be gone by the time clean
206     is called.
207
208     shr takes a location for now, until Preterm has a location built-in.
209
210     Handles the beta-conversion that occurs into Pattern terms.
211
212   ---------------------------------------------------------------------- *)
213
214
215
216fun clean shr = let
217  open Term
218  fun cl t =
219      case t of
220        Var{Name,Ty,Locn}            => mk_var(Name, shr Locn Ty)
221      | Const{Name,Thy,Ty,Locn}      => mk_thy_const{Name=Name,
222                                                     Thy=Thy,
223                                                     Ty=shr Locn Ty}
224      | Comb{Rator,Rand,...}         => let
225          val (f, args0) = strip_pcomb t
226          val args = map cl args0
227        in
228          case f of
229            Pattern{Ptm,...} => let
230              val t = cl Ptm
231              val (bvs, _) = strip_abs t
232              val inst = ListPair.map (fn (p,a) => p |-> a) (bvs, args)
233              val result0 = funpow (length inst) (#2 o dest_abs) t
234            in
235              list_mk_comb(Term.subst inst result0,
236                           List.drop(args, length inst))
237            end
238          | _ => list_mk_comb(cl f, args)
239        end
240      | Abs{Bvar,Body,...}           => mk_abs(cl Bvar, cl Body)
241      | Antiq{Tm,...}                => Tm
242      | Constrained{Ptm,...}         => cl Ptm
243      | Overloaded{Name,Ty,Locn,...} =>
244          raise ERRloc "clean" Locn "Overload term remains"
245      | Pattern {Ptm,...}             => cl Ptm
246 in
247  cl
248 end
249
250val has_free_uvar = Pretype.has_unbound_uvar
251
252fun tyVars ptm =  (* the pretype variables in a preterm *)
253  case ptm of
254    Var{Ty,...}             => Pretype.tyvars Ty
255  | Const{Ty,...}           => Pretype.tyvars Ty
256  | Comb{Rator,Rand,...}    => lift2 Lib.union (tyVars Rator) (tyVars Rand)
257  | Abs{Bvar,Body,...}      => lift2 Lib.union (tyVars Bvar) (tyVars Body)
258  | Antiq{Tm,...}           =>
259      return (map Type.dest_vartype (Term.type_vars_in_term Tm))
260  | Constrained{Ptm,Ty,...} => lift2 Lib.union (tyVars Ptm) (Pretype.tyvars Ty)
261  | Pattern{Ptm,...}        => tyVars Ptm
262  | Overloaded _            => raise Fail "Preterm.tyVars: applied to \
263                                          \Overloaded";
264
265
266(*---------------------------------------------------------------------------
267    Translate a preterm to a term. Will "guess type variables"
268    (assign names to type variables created during type inference),
269    if a flag is set. No "Overloaded" nodes are allowed in the preterm:
270    overloading resolution should already have gotten rid of them.
271 ---------------------------------------------------------------------------*)
272
273val _ =
274    register_btrace ("notify type variable guesses",
275                     Globals.notify_on_tyvar_guess)
276
277val _ =
278    register_btrace ("guess overloads", Globals.guessing_overloads)
279
280fun to_term (tm : preterm) : term in_env =
281    if !Globals.guessing_tyvars then
282      let
283        fun cleanup tm = let
284          infix >> >-
285          fun usedLift m (E,used) =
286            case m E of
287                Error e => Error e
288              | Some (E', result) => Some ((E',used), result)
289          fun clean0 pty = lift Pretype.clean (Pretype.remove_made_links pty)
290          val clean = usedLift o clean0
291        in
292          case tm of
293            Var{Name,Ty,...} => lift (fn ty => Term.mk_var(Name, ty))
294                                     (Pretype.replace_null_links Ty >- clean)
295          | Const{Name,Thy,Ty,...} =>
296              lift (fn ty => Term.mk_thy_const{Name=Name,Thy=Thy,Ty=ty})
297                   (Pretype.replace_null_links Ty >- clean)
298          | Comb{Rator, Rand,...} => let
299              val (f, args) = strip_pcomb tm
300              open Term
301            in
302              case f of
303                Pattern{Ptm,...} => let
304                  fun doit f_t args = let
305                    val (bvs, _) = strip_abs f_t
306                    val inst = ListPair.map Lib.|-> (bvs, args)
307                    val res0 = funpow (length inst) (#2 o dest_abs) f_t
308                  in
309                    list_mk_comb(Term.subst inst res0,
310                                 List.drop(args, length inst))
311                  end
312                in
313                  cleanup Ptm >- (fn f =>
314                  mmap cleanup args >- (fn args' =>
315                  return (doit f args')))
316                end
317              | _ => cleanup f >- (fn f_t =>
318                     mmap cleanup args >- (fn args' =>
319                     return (list_mk_comb(f_t, args'))))
320            end
321          | Abs{Bvar, Body,...} => cleanup Bvar >- (fn Bvar'
322                                => cleanup Body >- (fn Body'
323                                => return (Term.mk_abs(Bvar', Body'))))
324          | Antiq{Tm,...} => return Tm
325          | Constrained{Ptm,...} => cleanup Ptm
326          | Overloaded _ => raise ERRloc "to_term" (locn tm)
327                                         "applied to Overloaded"
328          | Pattern{Ptm,...} => cleanup Ptm
329        end
330        fun addV m vars e =
331          case m (e,vars) of
332              Error e => Error e
333            | Some ((e',v'), r) => Some (e', (r,v'))
334        val V = tyVars tm >-
335                (fn vs => lift (fn x => (vs,x)) (addV (cleanup tm) vs))
336      in
337        fn e =>
338           case V e of
339               Error e => Error e
340             | Some (e', (vs0, (tm, vs))) =>
341               let
342                 val guessed_vars = List.take(vs, length vs - length vs0)
343                 val _ =
344                     if not (null guessed_vars) andalso
345                        !Globals.notify_on_tyvar_guess andalso
346                        !Globals.interactive
347                     then
348                       Feedback.HOL_MESG
349                         (String.concat
350                            ("inventing new type variable names: "
351                             :: Lib.commafy (List.rev guessed_vars)))
352                     else ()
353               in
354                 Some (e', tm)
355               end
356      end
357    else
358      let
359        fun shr env l ty =
360            if smash (has_free_uvar ty) env then
361              raise ERRloc "typecheck.to_term" l
362                           "Unconstrained type variable (and Globals.\
363                           \guessing_tyvars is false)"
364            else smash (lift Pretype.clean (Pretype.remove_made_links ty))
365                       env
366      in
367        (fn e => Some (e, clean (shr e) tm))
368      end
369
370
371
372
373(*---------------------------------------------------------------------------*
374 *                                                                           *
375 * Overloading removal.  Th function "remove_overloading_phase1" will        *
376 * replace Overloaded _ nodes with Consts where it can be shown that only    *
377 * one of the possible constants has a type compatible with the type of the  *
378 * term as it has been inferred during the previous phase of type inference. *
379 * This may in turn constrain other overloaded terms elsewhere in the tree.  *
380 *                                                                           *
381 *---------------------------------------------------------------------------*)
382
383(* In earlier stages, the base_type of any overloaded preterms will have been
384   become more instantiated through the process of type inference.  This
385   first phase of resolving overloading removes those operators that are
386   no longer compatible with this type.  If this results in no operators,
387   this is an error.  If it results in one operator, this can be chosen
388   as the result.  If there are more than one, this is passed on so that
389   later phases can figure out which are possible given all the other
390   overloaded sub-terms in the term. *)
391local
392  open errormonad
393  infix >~
394  val op>~ = optmonad.>-
395in
396fun filterM PM l =
397  case l of
398      [] => return l
399    | h::t => PM h >- (fn b => if b then lift (cons h) (filterM PM t)
400                               else filterM PM t)
401
402fun remove_overloading_phase1 ptm =
403  case ptm of
404    Comb{Rator, Rand, Locn} =>
405      lift2 (fn t1 => fn t2 => Comb{Rator = t1, Rand = t2, Locn = Locn})
406            (remove_overloading_phase1 Rator)
407            (remove_overloading_phase1 Rand)
408  | Abs{Bvar, Body, Locn} =>
409      lift2 (fn t1 => fn t2 => Abs{Bvar = t1, Body = t2, Locn = Locn})
410            (remove_overloading_phase1 Bvar)
411            (remove_overloading_phase1 Body)
412  | Constrained{Ptm, Ty, Locn} =>
413      lift (fn t => Constrained{Ptm = t, Ty = Ty, Locn = Locn})
414           (remove_overloading_phase1 Ptm)
415  | Overloaded{Name,Ty,Info,Locn} => let
416      fun testfn t = let
417        open Term
418        val possty = type_of t
419        val avds = map Type.dest_vartype (tmlist_tyvs (free_vars t))
420        val pty0 = Pretype.fromType possty
421      in
422        Pretype.rename_typevars avds pty0 >- Pretype.can_unify Ty
423      end
424      fun after_filter possible_ops =
425        case possible_ops of
426            [] => error (OvlNoType(Name,Pretype.toType Ty), Locn)
427          | [t] =>
428            let
429              open Term
430            in
431              if is_const t then
432                let
433                  val {Ty = ty,Name,Thy} = dest_thy_const t
434                  val ptyM = Pretype.rename_typevars [] (Pretype.fromType ty)
435                in
436                  ptyM >- Pretype.unify Ty >>
437                  return (Const{Name=Name, Thy=Thy, Ty=Ty, Locn=Locn})
438            end
439              else
440                let
441                  val avds = map Type.dest_vartype (tmlist_tyvs (free_vars t))
442                in
443                  term_to_preterm avds t >- (fn ptm =>
444                  ptype_of ptm >- (fn pty =>
445                  Pretype.unify Ty pty >>
446                  return (Pattern{Ptm = ptm, Locn = Locn})))
447                end
448            end
449          | _ =>
450            return
451              (Overloaded{Name=Name, Ty=Ty,
452                          Info=Overload.fupd_actual_ops (fn _ => possible_ops)
453                                                        Info,
454                          Locn=Locn})
455    in
456      filterM testfn (#actual_ops Info) >- after_filter
457  end
458  | _ => return ptm
459
460end (* local *)
461
462
463val remove_overloading : preterm -> preterm seqM = let
464  open seqmonad Term
465  infix >- >> ++
466  fun unify t1 t2 = fromErr (Pretype.unify t1 t2)
467
468  fun recurse ptm =
469    case ptm of
470        Overloaded {Name,Ty,Info,Locn} =>
471        let
472          val actual_ops = #actual_ops Info
473          fun try t =
474            if is_const t then
475              let
476                val {Ty=ty,Name=nm,Thy=thy} = Term.dest_thy_const t
477                val pty0 = Pretype.fromType ty
478              in
479                fromErr (Pretype.rename_typevars [] pty0) >- unify Ty >>
480                return (Const{Name=nm, Ty=Ty, Thy=thy, Locn=Locn})
481              end
482            else
483              let
484                val avds = map Type.dest_vartype (tmlist_tyvs (free_vars t))
485              in
486                fromErr (term_to_preterm avds t) >- (fn ptm =>
487                fromErr (ptype_of ptm) >- (fn pty =>
488                unify Ty pty >>
489                return (Pattern{Ptm = ptm, Locn = Locn})))
490              end
491        in
492          tryall try actual_ops
493        end
494      | Comb{Rator, Rand, Locn} =>
495          lift2 (fn t1 => fn t2 => Comb{Rator=t1,Rand=t2,Locn=Locn})
496                (recurse Rator) (recurse Rand)
497      | Abs{Bvar, Body, Locn} =>
498          lift (fn t => Abs{Bvar=Bvar, Body=t, Locn=Locn}) (recurse Body)
499      | Constrained{Ptm,Ty,Locn} =>
500          lift (fn t => Constrained{Ptm=t, Ty=Ty, Locn=Locn}) (recurse Ptm)
501      | _ => return ptm
502
503(*
504  val overloads = overloaded_subterms [] ptm
505  val _ = if length overloads >= 30
506          then HOL_WARNING "Preterm" "remove_overloading"
507                           "many overloaded symbols in term: \
508                           \overloading resolution might take a long time."
509          else ()
510*)
511in
512  recurse
513end
514
515(* this version loses the sequence/lazy-list backtracking of the parse *)
516fun do_overloading_removal ptm =
517  let
518    open errormonad
519  in
520    remove_overloading_phase1 ptm >-
521    (seqmonad.toError (OvlFail, locn.Loc_Unknown) o remove_overloading)
522  end
523
524fun report_ovl_ambiguity b env =
525  (* b is true if multiple resolutions weren't possible *)
526  if not b andalso
527     (not (!Globals.guessing_overloads) orelse !Globals.notify_on_tyvar_guess)
528  then
529    if not (!Globals.guessing_overloads) then
530      error (OvlTooMany, locn.Loc_None) env
531    else if !Globals.interactive then
532      (Feedback.HOL_MESG "more than one resolution of overloading was possible";
533       ok env)
534    else
535      ok env
536  else ok env
537
538fun remove_elim_magics ptm =
539  case ptm of
540    Var _ => ptm
541  | Const _ => ptm
542  | Antiq _ => ptm
543  | Comb{Rator = (rator as Const{Name, ...}), Rand = ptm1, Locn} =>
544      if Name = nat_elim_term then remove_elim_magics ptm1
545      else if Name = string_elim_term then remove_elim_magics ptm1
546      else Comb{Rator = rator, Rand = remove_elim_magics ptm1, Locn = Locn}
547  | Comb{Rator, Rand, Locn} => Comb{Rator = remove_elim_magics Rator,
548                                    Rand = remove_elim_magics Rand, Locn = Locn}
549  | Abs{Bvar, Body, Locn} => Abs{Bvar = remove_elim_magics Bvar,
550                                 Body = remove_elim_magics Body, Locn = Locn}
551  | Constrained{Ptm, Ty, Locn} => Constrained{Ptm = remove_elim_magics Ptm,
552                                              Ty = Ty, Locn = Locn}
553  | Overloaded _ => raise Fail "Preterm.remove_elim_magics on Overloaded"
554  | Pattern _ => ptm
555
556
557fun overloading_resolution (ptm : preterm) : (preterm * bool) errM =
558  errormonad.lift
559    (fn (t,b) => (remove_elim_magics t, b))
560    (do_overloading_removal ptm)
561
562fun overloading_resolutionS ptm =
563  let
564    open seqmonad
565  in
566    lift
567      remove_elim_magics
568      (fromErr (remove_overloading_phase1 ptm) >- remove_overloading)
569  end
570
571(*---------------------------------------------------------------------------
572 * Type inference for HOL terms. Looks ugly because of error messages, but is
573 * actually very simple, given side-effecting unification.
574 *---------------------------------------------------------------------------*)
575
576fun isnumrator_name nm =
577  nm = "BIT1" orelse nm = "BIT2" orelse nm = "NUMERAL" orelse
578  nm = fromNum_str orelse nm = nat_elim_term
579
580fun isnumrator (Const{Name,...}) = isnumrator_name Name
581  | isnumrator (Overloaded{Name,...}) = isnumrator_name Name
582  | isnumrator _ = false
583
584fun isnum (Const {Name,...}) = Name = "0" orelse Name = "ZERO"
585  | isnum (Overloaded{Name,...}) = Name = "0" orelse Name = "ZERO"
586  | isnum (Comb{Rator,Rand,...}) = isnumrator Rator andalso isnum Rand
587  | isnum _ = false
588
589fun is_atom (Var _) = true
590  | is_atom (Const _) = true
591  | is_atom (Constrained{Ptm,...}) = is_atom Ptm
592  | is_atom (Overloaded _) = true
593  | is_atom (t as Comb{Rator,Rand,...}) = isnum t
594  | is_atom t = false
595
596
597local
598  fun default_typrinter x = "<hol_type>"
599  fun default_tmprinter x = "<term>"
600  open errormonad
601  infix ++?
602  fun smashTm ptm =
603    Lib.with_flag (Globals.notify_on_tyvar_guess, false)
604                  (smash (overloading_resolution ptm >- (to_term o #1)))
605in
606fun typecheck_phase1 printers = let
607  val (ptm, pty) =
608      case printers of
609        SOME (x,y) => let
610          val typrint = y
611          fun tmprint tm =
612              if Term.is_const tm then x tm ^ " " ^ y (Term.type_of tm)
613              else x tm
614        in
615          (tmprint, typrint)
616        end
617      | NONE => (default_tmprinter, default_typrinter)
618  fun check(Comb{Rator, Rand, Locn}) =
619    check Rator >> check Rand >>
620    ptype_of Rator >- (fn rator_ty =>
621    ptype_of Rand >- (fn rand_ty =>
622    Pretype.new_uvar >- (fn range_var =>
623    (Pretype.unify rator_ty (rand_ty --> range_var)) ++?
624     (fn unify_error => fn env =>
625          let val tmp = !Globals.show_types
626              val _   = Globals.show_types := true
627              val Rator' = smashTm Rator env
628                handle e => (Globals.show_types := tmp; raise e)
629              val Rand'  = smashTm Rand env
630                handle e => (Globals.show_types := tmp; raise e)
631              val message =
632                  String.concat
633                      [
634                       "\nType inference failure: unable to infer a type \
635                       \for the application of\n\n",
636                       ptm Rator',
637                       "\n\n"^locn.toString (locn Rator)^"\n\n",
638                       if is_atom Rator then ""
639                       else ("which has type\n\n" ^
640                             pty(Term.type_of Rator') ^ "\n\n"),
641
642                       "to\n\n",
643                       ptm Rand',
644                       "\n\n"^locn.toString (locn Rand)^"\n\n",
645
646                       if is_atom Rand then ""
647                       else ("which has type\n\n" ^
648                             pty(Term.type_of Rand') ^ "\n\n"),
649
650                       "unification failure message: " ^
651                       errorMsg (#1 unify_error) ^ "\n"]
652          in
653            Globals.show_types := tmp;
654            tcheck_say message;
655            Error (AppFail(Rator',Rand',message), locn Rand)
656          end))))
657    | check (Abs{Bvar, Body, Locn}) = (check Bvar >> check Body)
658    | check (Constrained{Ptm,Ty,Locn}) =
659        check Ptm >> ptype_of Ptm >- (fn ptyp =>
660        (Pretype.unify ptyp Ty ++
661         (fn env =>
662             let val tmp = !Globals.show_types
663                 val _ = Globals.show_types := true
664                 val real_term = smashTm Ptm env
665                   handle e => (Globals.show_types := tmp; raise e)
666                 val real_type = Pretype.toType Ty
667                   handle e => (Globals.show_types := tmp; raise e)
668                 val message =
669                  String.concat
670                      [
671                       "\nType inference failure: the term\n\n",
672                       ptm real_term,
673                       "\n\n", locn.toString (locn Ptm), "\n\n",
674                       if is_atom Ptm then ""
675                       else("which has type\n\n" ^
676                            pty(Term.type_of real_term) ^ "\n\n"),
677                       "can not be constrained to be of type\n\n",
678                       pty real_type,
679                       "\n\nunification failure message: ???\n"]
680             in
681               Globals.show_types := tmp;
682               tcheck_say message;
683               Error(ConstrainFail(real_term, real_type, message), Locn)
684             end)))
685    | check _ = ok
686in
687  check
688end
689end (* local *)
690
691val TC = typecheck_phase1
692
693(* ---------------------------------------------------------------------- *)
694(* function to do the equivalent of strip_conj, but where the "conj" is   *)
695(* the magic binary operator bool$<GrammarSpecials.case_split_special     *)
696(* ---------------------------------------------------------------------- *)
697
698open HolKernel
699fun dest_binop n c t = let
700  val (f,args) = strip_comb t
701  val {Name,Thy,...} = dest_thy_const f
702      handle HOL_ERR _ =>
703             raise ERR ("dest_case"^n) ("Not a "^n^" term")
704  val _ = (Name = c andalso Thy = "bool") orelse
705          raise ERR ("dest_case"^n) ("Not a "^n^" term")
706  val _ = length args = 2 orelse
707          raise ERR ("dest_case_"^n) ("case "^n^" 'op' with bad # of args")
708in
709  (hd args, hd (tl args))
710end
711
712val dest_case_split = dest_binop "split" case_split_special
713val dest_case_arrow = dest_binop "arrow" case_arrow_special
714
715fun strip_splits t0 = let
716  fun trav acc t = let
717    val (l,r) = dest_case_split t
718  in
719    trav (trav acc r) l
720  end handle HOL_ERR _ => t::acc
721in
722  trav [] t0
723end
724
725fun mk_conj(t1, t2) = let
726  val c = mk_thy_const{Name = "/\\", Thy = "bool",
727                       Ty = Type.bool --> Type.bool --> Type.bool}
728in
729  mk_comb(mk_comb(c,t1), t2)
730end
731
732fun list_mk_conj [] = raise ERR "list_mk_conj" "empty list"
733  | list_mk_conj [h] = h
734  | list_mk_conj (h::t) = mk_conj(h, list_mk_conj t)
735fun mk_eq(t1, t2) = let
736  val ty = type_of t1
737  val c = mk_thy_const{Name = "=", Thy = "min", Ty = ty --> ty --> Type.bool}
738in
739  mk_comb(mk_comb(c,t1),t2)
740end
741
742datatype rcm_action = Input of term
743                    | Ab of term * term
744                    | Cmb of int * term
745datatype rcm_out = Ch of term | Unch of term
746fun is_unch (Unch _) = true | is_unch _ = false
747fun dest_out (Ch t) = t | dest_out (Unch t) = t
748fun Pprefix P list = let
749  fun recurse pfx rest =
750      case rest of
751        [] => (list, [])
752      | h::t => if P h then recurse (h::pfx) t
753                else (List.rev pfx, rest)
754in
755  recurse [] list
756end
757
758fun recomb (outf, outargs, orig) = let
759  fun lmk(base, args) = List.foldl (fn (out,t) => mk_comb(t,dest_out out))
760                                   base args
761in
762  case outf of
763    Ch f => Ch (lmk(f, outargs))
764  | Unch f => let
765      val (_, badargs) = Pprefix is_unch outargs
766    in
767      if null badargs then Unch orig
768      else Ch (lmk(funpow (length badargs) rator orig, badargs))
769    end
770end
771
772fun remove_case_magic0 tm0 = let
773  fun traverse acc actions =
774      case actions of
775        [] => dest_out (hd acc)
776      | act :: rest => let
777        in
778          case act of
779            Input t => let
780            in
781              if is_abs t then let
782                  val (v,body) = dest_abs t
783                in
784                  traverse acc (Input body :: Ab (v,t) :: rest)
785                end
786              else if is_comb t then let
787                  val (f,args) = strip_comb t
788                  val in_args = map Input args
789                in
790                  traverse acc (in_args @
791                                [Input f, Cmb(length args, t)] @ rest)
792                end
793              else
794                traverse (Unch t::acc) rest
795            end
796          | Ab (v,orig) => let
797            in
798              case acc of
799                Ch bod' :: acc0 => traverse (Ch (mk_abs(v,bod'))::acc0)
800                                            rest
801              | Unch _ :: acc0 => traverse (Unch orig :: acc0) rest
802              | [] => raise Fail "Preterm.rcm: inv failed!"
803            end
804          | Cmb(arglen, orig) => let
805              val out_f = hd acc
806              val f = dest_out out_f
807              val acc0 = tl acc
808              val acc_base = List.drop(acc0, arglen)
809              val out_args = List.rev (List.take(acc0, arglen))
810              val args = map dest_out out_args
811              val newt = let
812                val {Name,Thy,Ty} = dest_thy_const f
813                    handle HOL_ERR _ => {Name = "", Thy = "", Ty = alpha}
814              in
815                if Name = core_case_special andalso Thy = "bool" then let
816                    val _ = length args >= 2 orelse
817                            raise ERR "remove_case_magic"
818                                      "case constant has wrong # of args"
819                    val split_on_t = hd args
820                    val cases = strip_splits (hd (tl args))
821                    val patbody_pairs = map dest_case_arrow cases
822                        handle HOL_ERR _ =>
823                               raise ERR "remove_case_magic"
824                                         ("Case expression has invalid syntax \
825                                          \where there should be arrows")
826                    val split_on_t_ty = type_of split_on_t
827                    val result_ty =
828                        type_of (list_mk_comb(f, List.take(args,2)))
829                    val fakef = genvar (split_on_t_ty --> result_ty)
830                    val fake_eqns =
831                        list_mk_conj(map (fn (l,r) =>
832                                             mk_eq(mk_comb(fakef, l), r))
833                                         patbody_pairs)
834                    val functional =
835                        GrammarSpecials.compile_pattern_match fake_eqns
836                    val func = snd (dest_abs functional)
837                    val (v,case_t0) = dest_abs func
838                    val case_t = subst [v |-> split_on_t] case_t0
839                  in
840                    Ch (list_mk_comb(case_t, tl (tl args)))
841                  end
842                else
843                  recomb(out_f, out_args, orig)
844              end (* newt *)
845            in
846              traverse (newt::acc_base) rest
847            end (* Cmb *)
848        end (* act :: rest *) (* end traverse *)
849in
850  traverse [] [Input tm0]
851end
852
853fun remove_case_magic tm =
854    if GrammarSpecials.case_initialised() then remove_case_magic0 tm
855    else tm
856
857val post_process_term = ref (I : term -> term);
858
859fun typecheck pfns ptm0 =
860  let
861    open errormonad
862  in
863    lift remove_case_magic
864         (TC pfns ptm0 >>
865          overloading_resolution ptm0 >-                     (fn (ptm,b) =>
866          report_ovl_ambiguity b >> to_term ptm)) >-         (fn t =>
867         fn e => errormonad.Some(e, !post_process_term t))
868  end
869
870fun typecheckS ptm =
871  let
872    open seqmonad
873    val TC' = errormonad.with_flagM (show_typecheck_errors, false) (TC NONE ptm)
874  in
875    lift (!post_process_term o remove_case_magic)
876         (fromErr TC' >> overloading_resolutionS ptm >-
877          (fn ptm' => fromErr (to_term ptm')))
878  end
879
880
881end; (* Preterm *)
882