1structure sptreeSyntax :> sptreeSyntax =
2struct
3
4open HolKernel boolLib
5open sptreeTheory
6
7val ERR = Feedback.mk_HOL_ERR "sptreeSyntax"
8
9fun syntax_fns n d m = HolKernel.syntax_fns {n = n, dest = d, make = m} "sptree"
10
11(* ------------------------------------------------------------------------- *)
12
13fun mk_sptree_ty a = Type.mk_thy_type {Tyop = "spt", Thy = "sptree", Args = [a]}
14
15fun dest_sptree_ty ty =
16   case Lib.total Type.dest_thy_type ty of
17      SOME {Tyop = "spt", Thy = "sptree", Args = [a]} => a
18    | _ => raise ERR "dest_sptree_ty" ""
19
20val sptree_ty_of = dest_sptree_ty o Term.type_of
21
22(* ------------------------------------------------------------------------- *)
23
24val s0 =
25   syntax_fns 0
26      (fn tm1 => fn e => fn tm2 =>
27          if Term.same_const tm1 tm2 then sptree_ty_of tm2 else raise e)
28      (fn tm => fn ty => Term.inst [Type.alpha |-> ty] tm)
29
30val (ln_tm, mk_ln, dest_ln, is_ln) = s0 "LN"
31
32(* ------------------------------------------------------------------------- *)
33
34val s1 = HolKernel.syntax_fns1 "sptree"
35val s1' = syntax_fns 2 HolKernel.dest_monop HolKernel.mk_monop
36
37val (domain_tm, mk_domain, dest_domain, is_domain) = s1' "domain"
38val (fromAList_tm, mk_fromAList, dest_fromAList, is_fromAList) = s1 "fromAList"
39val (fromList_tm, mk_fromList, dest_fromList, is_fromList) = s1 "fromList"
40val (lrnext_tm, mk_lrnext, dest_lrnext, is_lrnext) = s1 "lrnext"
41val (ls_tm, mk_ls, dest_ls, is_ls) = s1 "LS"
42val (mk_wf_tm, mk_mk_wf, dest_mk_wf, is_mk_wf) = s1 "mk_wf"
43val (size_tm, mk_size, dest_size, is_size) = s1 "size"
44val (toAList_tm, mk_toAList, dest_toAList, is_toAList) = s1 "toAList"
45val (toList_tm, mk_toList, dest_toList, is_toList) = s1 "toList"
46val (wf_tm, mk_wf, dest_wf, is_wf) = s1 "wf"
47
48(* ------------------------------------------------------------------------- *)
49
50val s2 = HolKernel.syntax_fns2 "sptree"
51
52val (bn_tm, mk_bn, dest_bn, is_bn) = s2 "BN"
53val (delete_tm, mk_delete, dest_delete, is_delete) = s2 "delete"
54val (difference_tm, mk_difference, dest_difference, is_difference) =
55   s2 "difference"
56val (inter_eq_tm, mk_inter_eq, dest_inter_eq, is_inter_eq) = s2 "inter_eq"
57val (inter_tm, mk_inter, dest_inter, is_inter) = s2 "inter"
58val (lookup_tm, mk_lookup, dest_lookup, is_lookup) = s2 "lookup"
59val (mk_bn_tm, mk_mk_bn, dest_mk_bn, is_mk_bn) = s2 "mk_BN"
60val (union_tm, mk_union, dest_union, is_union) = s2 "union"
61
62(* ------------------------------------------------------------------------- *)
63
64val s3 = HolKernel.syntax_fns3 "sptree"
65
66val (bs_tm, mk_bs, dest_bs, is_bs) = s3 "BS"
67val (mk_bs_tm, mk_mk_bs, dest_mk_bs, is_mk_bs) = s3 "mk_BS"
68val (insert_tm, mk_insert, dest_insert, is_insert) = s3 "insert"
69
70(* ------------------------------------------------------------------------- *)
71
72val s4 = HolKernel.syntax_fns4 "sptree"
73
74val (foldi_tm, mk_foldi, dest_foldi, is_foldi) = s4 "foldi"
75
76(* ------------------------------------------------------------------------- *)
77
78(* Pretty-printing support *)
79
80datatype spt = LN | LS of term | BN of spt * spt | BS of spt * term * spt
81
82fun dest_sptree tm =
83   case Lib.total boolSyntax.dest_strip_comb tm of
84      SOME ("sptree$LN", []) => LN
85    | SOME ("sptree$LS", [t]) => LS t
86    | SOME ("sptree$BN", [t1, t2]) => BN (dest_sptree t1, dest_sptree t2)
87    | SOME ("sptree$BS", [t1, v, t2]) => BS (dest_sptree t1, v, dest_sptree t2)
88    | _ => raise ERR "dest_sptree" ""
89
90fun mk_sptree t =
91   case t of
92      LN => mk_ln Type.alpha
93    | LS a => mk_ls a
94    | BN (LN, t2) =>
95         let
96            val tm = mk_sptree t2
97         in
98            mk_bn (mk_ln (sptree_ty_of tm), tm)
99         end
100    | BN (t1, LN) =>
101         let
102            val tm = mk_sptree t1
103         in
104            mk_bn (tm, mk_ln (sptree_ty_of tm))
105         end
106    | BN (t1, t2) => mk_bn (mk_sptree t1, mk_sptree t2)
107    | BS (t1, a, t2) =>
108         let
109            val ln = mk_ln (Term.type_of a)
110            val tm1 = if t1 = LN then ln else mk_sptree t1
111            val tm2 = if t2 = LN then ln else mk_sptree t2
112         in
113            mk_bs (tm1, a, tm2)
114         end
115
116local
117   open Arbnum
118   fun even n = n mod two = zero
119   fun lrnext n =
120      if n = zero
121         then one
122      else times2 (lrnext ((n - (if even n then two else one)) div two))
123   fun foldi f i acc =
124      fn LN => acc
125       | LS a => f i a acc
126       | BN (t1, t2) =>
127           let
128              val inc = lrnext i
129           in
130              foldi f (i + inc) (foldi f (i + two * inc) acc t1) t2
131           end
132       | BS (t1, a, t2) =>
133           let
134              val inc = lrnext i
135           in
136              foldi f (i + inc) (f i a (foldi f (i + two * inc) acc t1)) t2
137           end
138   fun insert k a =
139      fn LN => if k = zero
140                  then LS a
141               else if even k
142                  then BN (insert ((k - one) div two) a LN, LN)
143               else BN (LN, insert ((k - one) div two) a LN)
144       | LS a' =>
145               if k = zero
146                  then LS a
147               else if even k
148                  then BS (insert ((k - one) div two) a LN, a', LN)
149               else BS (LN, a', insert ((k - one) div two) a LN)
150       | BN (t1, t2) =>
151               if k = zero
152                  then BS (t1, a, t2)
153               else if even k
154                  then BN (insert ((k - one) div two) a t1, t2)
155               else BN (t1, insert ((k - one) div two) a t2)
156       | BS (t1, a', t2) =>
157               if k = zero
158                  then BS (t1, a, t2)
159               else if even k
160                  then BS (insert ((k - one) div two) a t1, a', t2)
161               else BS (t1, a', insert ((k - one) div two) a t2)
162in
163   val toAList =
164      Lib.sort (fn (a, _) => fn (b, _) => Arbnum.< (a, b)) o
165      foldi (fn k => fn v => fn a => (k, v) :: a) zero [] o dest_sptree
166   fun fromList l =
167      mk_sptree (snd (List.foldl (fn (a, (i, t)) => (i + one, insert i a t))
168                        (zero, LN) l))
169   fun fromAList l =
170      mk_sptree (List.foldl (fn ((i, a), t) => insert i a t) LN l)
171end
172
173local
174   fun f (k, v) = pairSyntax.mk_pair (numSyntax.mk_numeral k, v)
175in
176   fun sptree_pretty_term tm =
177      let
178         val ty = sptree_ty_of tm
179         val l = toAList tm
180      in
181         if List.null l
182            then raise ERR "sptree_pretty_term" ""
183         else if fst (List.last l) = Arbnum.fromInt (List.length l - 1)
184            then mk_fromList (listSyntax.mk_list (List.map snd l, ty))
185         else let
186                 val nl = List.map f l
187                 val pty = pairSyntax.mk_prod (numSyntax.num, ty)
188              in
189                 mk_fromAList (listSyntax.mk_list (nl, pty))
190              end
191      end
192end
193
194fun sptree_print Gs B syspr ppfns (pg, _, _) d t =
195   let
196      open Portable term_pp_types smpp
197      val {add_string = str, add_break = brk, ublock, ...} =
198         ppfns: term_pp_types.ppstream_funs
199      val t2 = sptree_pretty_term t
200               handle HOL_ERR _ => raise term_pp_types.UserPP_Failed
201      fun sys gs d = syspr {gravs = gs, depth = d, binderp = false}
202      fun delim s =
203         case pg of
204            Prec (j, _) => if 200 <= j then str s else nothing
205          | _ => nothing
206
207   in
208      case Lib.total dest_fromAList t2 of
209         SOME l => ublock INCONSISTENT 0
210                      (delim "("
211                       >> str "sptree$fromAList"
212                       >> brk (1, 2)
213                       >> sys (Top, Top, Top) (d - 1) l
214                       >> delim ")")
215       | NONE =>
216           (case Lib.total dest_fromList t2 of
217               SOME l => ublock INCONSISTENT 0
218                            (delim "("
219                             >> str "sptree$fromList"
220                             >> brk (1, 2)
221                             >> sys (Top, Top, Top) (d - 1) l
222                             >> delim ")")
223             | NONE => raise term_pp_types.UserPP_Failed)
224   end
225
226fun temp_add_sptree_printer () =
227   Parse.temp_add_user_printer ("sptree", ``x: 'a sptree$spt``, sptree_print)
228
229fun remove_sptree_printer () =
230   General.ignore (Parse.remove_user_printer "sptree")
231
232end
233