1structure sumSyntax :> sumSyntax =
2struct
3
4local open sumTheory in end;
5
6open HolKernel Abbrev;
7
8val ERR = mk_HOL_ERR "sumSyntax"
9
10fun mk_sum (ty1, ty2) =
11   mk_thy_type {Tyop = "sum", Thy = "sum", Args = [ty1, ty2]}
12
13fun dest_sum ty =
14   case total dest_thy_type ty of
15      SOME {Tyop = "sum", Thy = "sum", Args = [ty1, ty2]} => (ty1, ty2)
16    | other => raise ERR "dest_sum" "not a sum type"
17
18val strip_sum = strip_binop dest_sum
19val spine_sum = spine_binop (total dest_sum)
20val list_mk_sum = end_itlist (curry mk_sum)
21
22val sum_case_tm =
23   mk_thy_const
24      {Name = "sum_CASE",
25       Thy = "sum",
26       Ty = mk_sum (beta, gamma) --> (beta --> alpha) -->
27            (gamma --> alpha) --> alpha}
28
29fun mk_sum_case (f, g, s) =
30   let
31      val (df, r) = dom_rng (type_of f)
32      val (dg, _) = dom_rng (type_of g)
33   in
34      list_mk_comb
35         (inst [alpha |-> r, beta |-> df, gamma |-> dg] sum_case_tm, [s, f, g])
36   end
37
38val monop = HolKernel.syntax_fns1 "sum"
39
40val (isl_tm, mk_isl, dest_isl, is_isl) = monop "ISL"
41val (isr_tm, mk_isr, dest_isr, is_isr) = monop "ISR"
42val (outl_tm, mk_outl, dest_outl, is_outl) = monop "OUTL"
43val (outr_tm, mk_outr, dest_outr, is_outr) = monop "OUTR"
44
45val (inl_tm, mk_inl, dest_inl, is_inl) =
46   HolKernel.syntax_fns
47     {n = 1,
48      dest = fn tm1 => fn e => fn t =>
49               (HolKernel.dest_monop tm1 e t, snd (dest_sum (type_of t))),
50      make = fn tm => fn (t, ty) =>
51               Term.mk_comb
52                 (Term.inst [Type.alpha |-> type_of t, Type.beta |-> ty] tm, t)}
53     "sum" "INL"
54
55val (inr_tm, mk_inr, dest_inr, is_inr) =
56   HolKernel.syntax_fns
57     {n = 1,
58      dest = fn tm1 => fn e => fn t =>
59               (HolKernel.dest_monop tm1 e t, fst (dest_sum (type_of t))),
60      make = fn tm => fn (t, ty) =>
61               Term.mk_comb
62                 (Term.inst [Type.alpha |-> ty, Type.beta |-> type_of t] tm, t)}
63     "sum" "INR"
64
65(*---------------------------------------------------------------------------*)
66(* Lifting sums                                                              *)
67(*---------------------------------------------------------------------------*)
68
69datatype ('a,'b) sum = INL of 'a | INR of 'b
70
71fun lift_sum ty =
72   let
73      val inl = TypeBasePure.cinst ty inl_tm
74      val inr = TypeBasePure.cinst ty inr_tm
75      fun lift f g (INL x) = mk_comb(inl, f x)
76        | lift f g (INR y) = mk_comb(inr, g y)
77   in
78      lift
79   end
80
81end
82