1structure monadsyntax :> monadsyntax =
2struct
3
4open HolKernel Parse boolLib
5local open optionTheory in end
6
7val monadseq_special = "__monad_sequence"
8val monad_emptyseq_special = "__monad_emptyseq"
9val monadassign_special = "__monad_assign"
10val monad_unitbind = "monad_unitbind"
11val monad_bind = "monad_bind"
12
13fun ERR f msg = mk_HOL_ERR "monadsyntax" f msg
14
15type monadinfo = { bind : term,
16                   ignorebind : term option,
17                   unit : term,
18                   fail : term option,
19                   choice : term option,
20                   guard : term option }
21
22structure MonadInfo =
23struct
24  open ThyDataSexp
25  type t = monadinfo
26  fun toSexp {bind,ignorebind,unit,fail,choice,guard} =
27      List [Term bind, Option (Option.map Term ignorebind),
28            Term unit, Option (Option.map Term fail),
29            Option (Option.map Term choice),
30            Option (Option.map Term guard)]
31  fun determOpt NONE = NONE
32    | determOpt (SOME (Term t)) = SOME t
33    | determOpt _ = raise ERR "fromSexp" "Expected term option"
34  fun fromSexp s =
35    case s of
36        List [Term bind, Option ign_opt, Term unit, Option failopt,
37              Option choiceopt, Option guardopt] =>
38          {bind = bind, ignorebind = determOpt ign_opt, unit = unit,
39           fail = determOpt failopt, guard = determOpt guardopt,
40           choice = determOpt choiceopt}
41      | _ => raise ERR "fromSexp" "bad format - not a list of 5 elements"
42end
43
44val monadDB =
45    ref (Binarymap.mkDict String.compare : (string,MonadInfo.t) Binarymap.dict)
46
47fun write_keyval (nm, mi) =
48  let
49    open ThyDataSexp
50  in
51    List [List [String nm, MonadInfo.toSexp mi]]
52  end
53
54fun load_from_disk {thyname, data} =
55  let
56    open ThyDataSexp
57    fun dest_keyval (s : ThyDataSexp.t) : string * MonadInfo.t =
58      case s of
59          List [String key, mi_sexp] =>
60            let
61              val mit = MonadInfo.fromSexp mi_sexp
62            in
63              (key, mit)
64            end
65        | _ => raise ERR "load_from_disk" "keyval pair data looks bad"
66  in
67    case data of
68        List keyvals =>
69          monadDB := List.foldl (fn ((k,v), acc) => Binarymap.insert(acc,k,v))
70                                (!monadDB)
71                                (map dest_keyval keyvals)
72      | _ => raise ERR "load_from_disk" "data looks bad"
73  end
74
75fun getMITname s =
76  let
77    open ThyDataSexp
78  in
79    case s of
80        List [String k, _] => k
81      | _ => raise ERR "getMITname" "Shouldn't happen"
82  end
83
84fun uptodate_check t =
85  case t of
86      ThyDataSexp.List tyis =>
87      let
88        val (good, bad) = partition ThyDataSexp.uptodate tyis
89      in
90        case bad of
91            [] => t
92          | _ =>
93            let
94              val tyinames = map getMITname bad
95            in
96              HOL_WARNING "monadsyntax" "uptodate_check"
97                          ("Monad information for: " ^
98                           String.concatWith ", " tyinames ^ " discarded");
99              ThyDataSexp.List good
100            end
101      end
102    | _ => raise Fail "TypeBase.uptodate_check : shouldn't happen"
103
104
105fun check_thydelta (t, tdelta) =
106  let
107    open TheoryDelta
108  in
109    case tdelta of
110        NewConstant _ => uptodate_check t
111      | NewTypeOp _ => uptodate_check t
112      | DelConstant _ => uptodate_check t
113      | DelTypeOp _ => uptodate_check t
114      | _ => t
115  end
116
117val {export = export_minfo, ...} = ThyDataSexp.new{
118      thydataty = "MonadInfoDB",
119      load = load_from_disk, other_tds = check_thydelta,
120      merge = ThyDataSexp.alist_merge
121    }
122
123fun predeclare (nm, t) = monadDB := Binarymap.insert(!monadDB, nm, t)
124fun declare_monad p = (predeclare p; export_minfo (write_keyval p))
125
126fun all_monads () = Binarymap.listItems (!monadDB)
127
128
129fun to_vstruct a = let
130  open Absyn
131in
132  case a of
133    AQ x => VAQ x
134  | IDENT x => VIDENT x
135  | QIDENT(loc,_,_) =>
136    raise mk_HOL_ERRloc "Absyn" "to_vstruct" loc
137                        "Qualified identifiers can't be varstructs"
138  | APP(loc1, APP(loc2, IDENT (loc3, ","), arg1), arg2) =>
139      VPAIR(loc1, to_vstruct arg1, to_vstruct arg2)
140  | TYPED (loc, a0, pty) => VTYPED(loc, to_vstruct a0, pty)
141  | _ => raise mk_HOL_ERRloc "Absyn" "to_vstruct" (locn_of_absyn a)
142                             "Bad form of varstruct"
143end
144
145fun clean_action a = let
146  open Absyn
147in
148  case a of
149    APP(loc1, APP(loc2, IDENT(loc3, s), arg1), arg2) => let
150    in
151      if s = monadassign_special then
152        (SOME (to_vstruct arg1), arg2)
153      else (NONE, a)
154    end
155  | _ => (NONE, a)
156end
157
158fun cleanseq a = let
159  open Absyn
160in
161  case a of
162    APP(loc1, APP(loc2, IDENT(loc3, s), arg1), arg2) => let
163    in
164      if s = monadseq_special then let
165          val (bv, arg1') = clean_action (clean_do true arg1)
166          val arg2' = clean_actions arg2
167        in
168          case arg2' of
169            NONE => SOME arg1'
170          | SOME a => let
171            in
172              case bv of
173                NONE => SOME (APP(loc1,
174                                  APP(loc2,
175                                      IDENT(loc3, monad_unitbind),
176                                      arg1'),
177                                  a))
178              | SOME b => SOME (APP(loc1,
179                                    APP(loc2,
180                                        IDENT(loc3, monad_bind),
181                                        arg1'),
182                                    LAM(locn_of_absyn a, b, a)))
183            end
184        end
185      else NONE
186    end
187  | _ => NONE
188end
189and clean_do indo a = let
190  open Absyn
191  val clean_do = clean_do indo
192in
193  case cleanseq a of
194    SOME a => a
195  | NONE => let
196    in
197      case a of
198        APP(l,arg1 as APP(_,IDENT(_,s),_),arg2) =>
199        if s = monadassign_special andalso not indo then
200          raise mk_HOL_ERRloc "monadsyntax" "clean_do" l
201                              "Bare monad assign arrow illegal"
202        else APP(l,clean_do arg1,clean_do arg2)
203      | APP(l,a1,a2) => APP(l,clean_do a1, clean_do a2)
204      | LAM(l,v,a) => LAM(l,v,clean_do a)
205      | IDENT(loc,s) => if s = monad_emptyseq_special then
206                          raise mk_HOL_ERRloc "monadsyntax" "clean_do" loc
207                                              "Empty do-od pair illegal"
208                        else a
209      | TYPED(l,a,pty) => TYPED(l,clean_do a, pty)
210      | _ => a
211    end
212end
213and clean_actions a = let
214  open Absyn
215in
216  case cleanseq a of
217    SOME a => SOME a
218  | NONE => let
219    in
220      case a of
221        IDENT(loc,s) => if s = monad_emptyseq_special then NONE
222                        else SOME a
223      | a => SOME a
224    end
225end
226
227fun transform_absyn G a = clean_do false a
228
229
230
231fun dest_bind G t = let
232  open term_pp_types
233  val oinfo = term_grammar.overload_info G
234  val (f, args) = valOf (Overload.oi_strip_comb oinfo t)
235                  handle Option => raise UserPP_Failed
236  val (x,y) =
237      case args of
238          [x,y] => (x,y)
239        | _ => raise UserPP_Failed
240  val prname =
241      f |> dest_var |> #1 |> GrammarSpecials.dest_fakeconst_name
242        |> valOf |> #fake
243        handle HOL_ERR _ => raise UserPP_Failed
244             | Option => raise UserPP_Failed
245  val _ = prname = monad_unitbind orelse
246          (prname = monad_bind andalso pairSyntax.is_pabs y) orelse
247           raise UserPP_Failed
248in
249  SOME (prname, x, y)
250end handle HOL_ERR _ => NONE
251         | term_pp_types.UserPP_Failed =>  NONE
252
253
254fun print_monads (tyg, tmg) backend sysprinter ppfns (p,l,r) depth t = let
255  open term_pp_types term_grammar smpp term_pp_utils
256  infix >>
257  val ppfns = ppfns : ppstream_funs
258  val {add_string=strn,add_break=brk,ublock,...} = ppfns
259  val (prname, arg1, arg2) = valOf (dest_bind tmg t)
260                             handle Option => raise UserPP_Failed
261  val minprint = ppstring (#2 (print_from_grammars min_grammars))
262  fun syspr bp gravs t =
263    sysprinter {gravs = gravs, binderp = bp, depth = depth - 1} t
264  fun pr_action (v, action) =
265      case v of
266        NONE => syspr false (Top,Top,Top) action
267      | SOME v => let
268          val bvars = free_vars v
269        in
270          addbvs bvars >>
271          ublock PP.INCONSISTENT 0
272            (syspr true (Top,Top,Prec(100, "monad_assign")) v >>
273             strn " " >> strn "<-" >> brk(1,2) >>
274             syspr false (Top,Prec(100, "monad_assign"),Top) action)
275        end
276  fun brk_bind binder arg1 arg2 =
277      if binder = monad_bind then let
278              val (v,body) = (SOME ## I) (pairSyntax.dest_pabs arg2)
279                             handle HOL_ERR _ => (NONE, arg2)
280        in
281          ((v, arg1), body)
282        end
283      else ((NONE, arg1), arg2)
284  fun strip acc t =
285      case dest_bind tmg t of
286        NONE => List.rev ((NONE, t) :: acc)
287      | SOME (prname, arg1, arg2) => let
288          val (arg1', arg2') = brk_bind prname arg1 arg2
289        in
290          strip (arg1'::acc) arg2'
291        end
292  val (arg1',arg2') = brk_bind prname arg1 arg2
293  val actions = strip [arg1'] arg2'
294in
295  ublock PP.CONSISTENT 0
296    (strn "do" >> brk(1,2) >>
297     getbvs >- (fn oldbvs =>
298     pr_list pr_action (strn ";" >> brk(1,2)) actions >>
299     brk(1,0) >>
300     strn "od" >> setbvs oldbvs))
301end
302
303val _ = term_grammar.userSyntaxFns.register_userPP {
304          name = "monadsyntax.print_monads",
305          code = print_monads
306    }
307val _ = term_grammar.userSyntaxFns.register_absynPostProcessor {
308          name = "monadsyntax.transform_absyn",
309          code = transform_absyn
310    }
311
312fun syntax_actions al ar aup app =
313  (al {block_info = (PP.CONSISTENT,0),
314       cons = monadseq_special,
315       nilstr = monad_emptyseq_special,
316       leftdelim = [TOK "do", BreakSpace(1,2)],
317       rightdelim = [TOK "od"],
318       separator = [TOK ";", BreakSpace(1,0)]};
319   ar {block_style = (AroundEachPhrase, (PP.INCONSISTENT, 2)),
320       fixity = Infix(NONASSOC, 100),
321       paren_style = OnlyIfNecessary,
322       pp_elements = [BreakSpace(1,0), TOK "<-", HardSpace 1],
323       term_name = monadassign_special};
324   aup ("monadsyntax.print_monads", ``x:'a``, print_monads);
325   app ("monadsyntax.transform_absyn", transform_absyn))
326
327fun temp_add_monadsyntax () =
328    syntax_actions temp_add_listform temp_add_rule temp_add_user_printer
329                   temp_add_absyn_postprocessor
330
331val monad_lform_name =
332    GrammarSpecials.mk_lform_name {
333      cons = monadseq_special,
334      nilstr = monad_emptyseq_special
335    }
336
337fun mk_unitbind mbind =
338  let
339    val (m1ty, rng) = dom_rng (type_of mbind)
340    val (fm2ty, m2ty) = dom_rng rng
341    val (argty, _) = dom_rng fm2ty
342    val m1 = mk_var("m1", m1ty)
343    val m2 = mk_var("m2", m2ty)
344    val K  = combinSyntax.K_tm |> inst [alpha |-> type_of m2, beta |-> argty]
345    val Km2= mk_comb(K, m2)
346  in
347    list_mk_abs([m1,m2], list_mk_comb(mbind, [m1, Km2]))
348  end
349
350fun getMI fname s =
351  case Binarymap.peek(!monadDB, s) of
352      NONE => raise ERR fname ("No such monad defined: "^s)
353    | SOME mi => mi
354
355(* iovl is used so that fail and return don't contaminate normal uses *)
356fun gen_enable_monad fname iovl ovl s =
357  let
358    val {bind,ignorebind,unit,fail,choice,guard} = getMI fname s
359  in
360    ovl ("monad_bind", bind) ;
361    ovl ("monad_unitbind",
362         case ignorebind of NONE => mk_unitbind bind | SOME ib => ib);
363    iovl ("return", unit) ;
364    Option.app (fn f => iovl("fail", f)) fail;
365    Option.app (fn c => ovl("++", c)) choice;
366    Option.app (fn g => ovl("assert", g)) guard
367  end
368
369fun gen_disable_monad fname rmovl s =
370  let
371    val {bind,ignorebind,unit,fail,choice,guard} = getMI fname s
372  in
373    rmovl "monad_bind" bind;
374    rmovl "monad_unitbind"
375          (case ignorebind of NONE => mk_unitbind bind | SOME ib => ib);
376    rmovl "return" unit;
377    Option.app (rmovl "fail") fail;
378    Option.app (rmovl "++") choice;
379    Option.app (rmovl "assert") guard
380  end
381
382fun gen_inferior_overload_on raw (s, t) =
383  (* want to have t still print with monad syntax, just don't want this to be
384     preferred target when parsing.  So, have to make sure that this ranks
385     higher than the raw name of the constant *)
386  (raw (s, t);
387   if is_const t then
388     let
389       val G = term_grammar()
390       val {Name,...} = dest_thy_const t
391       val oinfo = term_grammar.overload_info G
392       val ms = Overload.PrintMap.match (Overload.raw_print_map oinfo, t)
393     in
394       if List.exists (fn (_, (_, s, _)) => s = Name) ms then
395         raw (Name, t)
396       else ()
397     end
398   else ())
399
400val enable_monad =
401    gen_enable_monad "enable_monad" inferior_overload_on overload_on
402val weak_enable_monad =
403    gen_enable_monad "weak_enable_monad"
404                     inferior_overload_on
405                     (gen_inferior_overload_on inferior_overload_on)
406val disable_monad = gen_disable_monad "disable_monad" gen_remove_ovl_mapping
407val temp_weak_enable_monad =
408    gen_enable_monad "temp_weak_enable_monad"
409                     temp_inferior_overload_on
410                     (gen_inferior_overload_on temp_inferior_overload_on)
411val temp_enable_monad =
412    gen_enable_monad "temp_enable_monad"
413                     temp_inferior_overload_on
414                     temp_overload_on
415val temp_disable_monad =
416    gen_disable_monad "temp_disable_monad" temp_gen_remove_ovl_mapping
417
418fun gen_disable_syntax rr rup rpp =
419  (rr {term_name = monad_lform_name, tok = "do"};
420   rr {term_name = monadassign_special, tok = "<-"};
421   rup "monadsyntax.print_monads";
422   rpp "monadsyntax.transform_absyn")
423
424fun disable_monadsyntax () =
425  gen_disable_syntax remove_termtok remove_user_printer (fn s => ())
426fun temp_disable_monadsyntax () =
427  gen_disable_syntax temp_remove_termtok temp_remove_user_printer (fn s => ())
428
429fun aup (s, pat, code) = (add_ML_dependency "monadsyntax";
430                          add_user_printer (s, pat))
431
432fun aap (s, code) = (add_ML_dependency "monadsyntax";
433                     add_absyn_postprocessor s)
434
435fun add_monadsyntax () = syntax_actions add_listform add_rule aup aap
436
437val enable_monadsyntax = add_monadsyntax
438val temp_enable_monadsyntax = temp_add_monadsyntax
439
440val _ = TexTokenMap.temp_TeX_notation
441            {hol = "<-", TeX = ("\\HOLTokenLeftmap{}", 1)}
442val _ = TexTokenMap.temp_TeX_notation {hol = "do", TeX = ("\\HOLKeyword{do}", 2)}
443val _ = TexTokenMap.temp_TeX_notation {hol = "od", TeX = ("\\HOLKeyword{od}", 2)}
444
445val _ = predeclare (
446      "option",
447      { bind = prim_mk_const {Name = "OPTION_BIND", Thy = "option"},
448        ignorebind = SOME (prim_mk_const{
449                              Name = "OPTION_IGNORE_BIND", Thy = "option"}),
450        unit = prim_mk_const {Name = "SOME", Thy = "option"},
451        fail = SOME (prim_mk_const {Name = "NONE", Thy = "option"}),
452        guard = SOME (prim_mk_const {Name = "OPTION_GUARD", Thy = "option"}),
453        choice = SOME (prim_mk_const {Name = "OPTION_CHOICE", Thy = "option"})
454      });
455
456end (* struct *)
457