1(*---------------------------------------------------------------------------*
2 * Syntax support for the theory of pairs. When possible, functions          *
3 * below deal with both paired and unpaired input.                           *
4 *---------------------------------------------------------------------------*)
5
6structure pairSyntax :> pairSyntax =
7struct
8
9open HolKernel boolTheory pairTheory boolSyntax Abbrev;
10
11val ERR = mk_HOL_ERR "pairSyntax"
12
13(*---------------------------------------------------------------------------
14             Operations on product types
15 ---------------------------------------------------------------------------*)
16
17fun mk_prod (ty1, ty2) =
18   Type.mk_thy_type {Tyop = "prod", Thy = "pair", Args = [ty1, ty2]}
19
20fun dest_prod ty =
21   case total dest_thy_type ty of
22      SOME{Tyop = "prod", Thy = "pair", Args = [ty1, ty2]} => (ty1, ty2)
23    | other => raise ERR "dest_prod" "not a product type"
24
25val spine_prod = spine_binop (total dest_prod)
26val strip_prod = strip_binop dest_prod
27val list_mk_prod = end_itlist (curry mk_prod)
28
29(*---------------------------------------------------------------------------
30         Useful constants in the theory of pairs
31 ---------------------------------------------------------------------------*)
32
33val uncurry_tm  = pairTheory.uncurry_tm
34val comma_tm    = pairTheory.comma_tm
35val fst_tm      = prim_mk_const {Name="FST",   Thy="pair"}
36val snd_tm      = prim_mk_const {Name="SND",   Thy="pair"}
37val curry_tm    = prim_mk_const {Name="CURRY", Thy="pair"}
38val pair_map_tm = prim_mk_const {Name="##",    Thy="pair"}
39val lex_tm      = prim_mk_const {Name="LEX",   Thy="pair"}
40val swap_tm     = prim_mk_const {Name="SWAP",  Thy="pair"}
41val pair_case_tm = prim_mk_const {Name = "pair_CASE", Thy = "pair"}
42
43(*---------------------------------------------------------------------------
44     Make a pair from two components, or a tuple from a list of components
45 ---------------------------------------------------------------------------*)
46
47fun mk_pair (fst, snd) =
48   let
49      val ty1 = type_of fst
50      and ty2 = type_of snd
51   in
52      list_mk_comb (inst [alpha |-> ty1, beta |-> ty2] comma_tm, [fst, snd])
53   end
54
55val list_mk_pair = end_itlist (curry mk_pair)
56
57(*---------------------------------------------------------------------------
58      Take a pair apart, once, and repeatedly. The atoms appear
59      in left-to-right order.
60 ---------------------------------------------------------------------------*)
61
62val dest_pair = pairTheory.dest_pair
63val strip_pair = pairTheory.strip_pair
64val spine_pair = pairTheory.spine_pair
65
66(*---------------------------------------------------------------------------
67    Inverse of strip_pair ... returns unconsumed elements in input list.
68    This is so that it can be used easily over lists of things to be
69    unstripped.
70 ---------------------------------------------------------------------------*)
71
72local
73   fun break [] = raise ERR "unstrip_pair" "unable"
74     | break (h::t) = (h, t)
75in
76   fun unstrip_pair ty V =
77      if is_vartype ty then break V
78      else case total dest_prod ty
79            of SOME (ty1, ty2) =>
80                let
81                   val (ltm, vs1) = unstrip_pair ty1 V
82                   val (rtm, vs2) = unstrip_pair ty2 vs1
83                in
84                   (mk_pair (ltm, rtm), vs2)
85                end
86             | NONE => break V
87end
88
89(*---------------------------------------------------------------------------
90       Is it a pair?
91 ---------------------------------------------------------------------------*)
92
93val is_pair = can dest_pair
94
95(*---------------------------------------------------------------------------
96      Making applications of FST and SND
97 ---------------------------------------------------------------------------*)
98
99fun mk_fst tm =
100   let
101      val (ty1, ty2) = dest_prod (type_of tm)
102   in
103      mk_comb (inst [alpha |-> ty1, beta |-> ty2] fst_tm, tm)
104   end
105   handle HOL_ERR _ => raise ERR "mk_fst" ""
106
107fun mk_snd tm =
108   let
109      val (ty1, ty2) = dest_prod (type_of tm)
110   in
111      mk_comb (inst [alpha |-> ty1, beta |-> ty2] snd_tm, tm)
112   end
113   handle HOL_ERR _ => raise ERR "mk_snd" ""
114
115fun mk_uncurry_tm (xt, yt, zt) =
116   inst [alpha |-> xt, beta |-> yt, gamma |-> zt] uncurry_tm
117
118fun mk_curry (f, x, y) =
119   let
120      val (pty, rty) = dom_rng (type_of f)
121      val (aty, bty) = dest_prod pty
122   in
123      list_mk_comb
124        (inst [alpha |-> aty, beta |-> bty, gamma |-> rty] curry_tm, [f, x, y])
125   end
126   handle HOL_ERR _ => raise ERR "mk_curry" ""
127
128fun mk_uncurry (f, x) =
129  case strip_fun (type_of f) of
130     ([a, b], c) => mk_comb (mk_comb (mk_uncurry_tm (a, b, c), f), x)
131   | _ => raise ERR "mk_uncurry" ""
132
133fun mk_pair_map (f, g) =
134   let
135      val (df, rf) = dom_rng (type_of f)
136      val (dg, rg) = dom_rng (type_of g)
137   in
138      list_mk_comb (inst [alpha |-> df,
139                          beta  |-> dg,
140                          gamma |-> rf,
141                          delta |-> rg] pair_map_tm, [f, g])
142   end
143
144fun mk_lex (r1, r2) =
145   let
146      val (dr1, _) = dom_rng (type_of r1)
147      val (dr2, _) = dom_rng (type_of r2)
148   in
149      list_mk_comb (inst [alpha |-> dr1, beta |-> dr2] lex_tm, [r1, r2])
150   end
151
152fun mk_swap t =
153   let
154      val (aty, bty) = dest_prod (Term.type_of t)
155   in
156      Term.mk_comb
157         (Term.inst [Type.alpha |-> aty, Type.beta |-> bty] swap_tm, t)
158   end
159
160fun mk_pair_case {pairtm, ftm} = list_mk_icomb(pair_case_tm, [pairtm, ftm])
161
162fun dest_pair_case tm = let
163  val (f, args) = strip_comb tm
164in
165  if same_const pair_case_tm f andalso length args = 2 then
166    {pairtm = el 1 args, ftm = el 2 args}
167  else
168    raise ERR "dest_pair_case" "Term not a pair_CASE"
169end
170
171val is_pair_case = can dest_pair_case
172
173
174val dest_fst = dest_monop fst_tm (ERR "dest_fst" "")
175val dest_snd = dest_monop snd_tm (ERR "dest_snd" "")
176
177fun dest_curry tm =
178   let
179      val (M, y) = with_exn dest_comb tm (ERR "dest_curry" "")
180      val (f, x) = dest_binop curry_tm (ERR "dest_curry" "") M
181   in
182      (f, x, y)
183   end
184
185val dest_pair_map = dest_binop pair_map_tm (ERR "dest_pair_map" "")
186
187val dest_lex = dest_binop lex_tm (ERR "dest_lex" "")
188
189val dest_swap = dest_monop swap_tm (ERR "dest_swap" "")
190
191val is_fst = can dest_fst
192val is_snd = can dest_snd
193val is_curry = can dest_curry
194val is_pair_map = can dest_pair_map
195val is_lex = can dest_lex
196val is_swap = can dest_swap
197
198(*---------------------------------------------------------------------------*)
199(* Constructor, destructor and discriminator functions for paired            *)
200(* abstractions and ordinary abstractions.                                   *)
201(* [JRH 91.07.17]                                                            *)
202(*---------------------------------------------------------------------------*)
203
204val mk_pabs = pairTheory.mk_pabs
205
206fun mk_plet (p as (vstruct, rhs, body)) =
207   mk_let (mk_pabs (vstruct, body), rhs)
208   handle HOL_ERR _ => raise ERR "mk_plet" ""
209
210fun mk_pforall (p as (vstruct, _)) =
211   mk_comb (inst [alpha |-> type_of vstruct] universal, mk_pabs p)
212   handle HOL_ERR _ => raise ERR "mk_pforall" ""
213
214fun mk_pexists (p as (vstruct, _)) =
215   mk_comb (inst [alpha |-> type_of vstruct] existential, mk_pabs p)
216   handle HOL_ERR _ => raise ERR "mk_pexists" ""
217
218fun mk_pexists1 (p as (vstruct, _)) =
219   mk_comb (inst [alpha |-> type_of vstruct] exists1, mk_pabs p)
220   handle HOL_ERR _ => raise ERR "mk_pexists1" ""
221
222fun mk_pselect (p as (vstruct, body)) =
223   mk_comb (inst [alpha |-> type_of vstruct] select, mk_pabs p)
224   handle HOL_ERR _ => raise ERR "mk_pselect" ""
225
226fun dest_pabs tm =
227   Term.dest_abs tm
228   handle HOL_ERR _ =>
229     let
230        val (Rator, Rand) = with_exn dest_comb tm (ERR "dest_pabs" "")
231     in
232        if same_const uncurry_tm Rator
233           then let
234                   val (lv, body) = dest_pabs Rand
235                   val (rv, body) = dest_pabs body
236                in
237                   (mk_pair (lv, rv), body)
238                end
239        else raise ERR "dest_pabs" ""
240     end
241
242fun pbvar tm = fst (dest_pabs tm) handle HOL_ERR _ => failwith "pbvar"
243and pbody tm = snd (dest_pabs tm) handle HOL_ERR _ => failwith "pbody"
244
245fun dest_plet M =
246   let
247      val (f, rhs) = dest_let M
248      val (vstruct, body) = dest_pabs f
249   in
250      (vstruct, rhs, body)
251   end
252   handle _ => raise ERR "dest_plet" "not a (possibly paired) \"let\""
253
254(*---------------------------------------------------------------------------*)
255(* Paired binders                                                            *)
256(*---------------------------------------------------------------------------*)
257
258local
259   val FORALL_ERR  = ERR "dest_pforall"  "not a (possibly paired) \"!\""
260   val EXISTS_ERR  = ERR "dest_pexists"  "not a (possibly paired) \"?\""
261   val EXISTS1_ERR = ERR "dest_pexists1" "not a (possibly paired) \"?!\""
262   val SELECT_ERR  = ERR "dest_pselect"  "not a (possibly paired) \"@\""
263in
264   fun dest_pbinder c e M =
265      let
266         val (Rator, Rand) = with_exn dest_comb M e
267      in
268         if same_const c Rator then with_exn dest_pabs Rand e else raise e
269      end
270   val dest_pforall  = dest_pbinder universal   FORALL_ERR
271   val dest_pexists  = dest_pbinder existential EXISTS_ERR
272   val dest_pexists1 = dest_pbinder exists1     EXISTS1_ERR
273   val dest_pselect  = dest_pbinder select      SELECT_ERR
274end
275
276val dest_uncurry = dest_pabs o dest_monop uncurry_tm (ERR "dest_uncurry" "")
277val is_uncurry = can dest_uncurry
278
279val is_pabs     = can dest_pabs
280val is_plet     = can dest_plet
281val is_pforall  = can dest_pforall
282val is_pexists  = can dest_pexists
283val is_pexists1 = can dest_pexists1
284val is_pselect  = can dest_pselect
285
286fun list_mk_pabs (V, M)    = itlist (curry mk_pabs) V M
287fun list_mk_pforall (V, M) = itlist (curry mk_pforall) V M
288fun list_mk_pexists (V, M) = itlist (curry mk_pexists) V M
289
290(*---------------------------------------------------------------------------*)
291
292fun strip dest =
293   let
294      fun decomp M =
295         case dest M of
296            NONE => ([], M)
297          | SOME (vstruct, body) =>
298               let val (V, kern) = strip dest body in (vstruct::V, kern) end
299   in
300      decomp
301   end
302
303val strip_pabs    = strip (total dest_pabs)
304val strip_pforall = strip (total dest_pforall)
305val strip_pexists = strip (total dest_pexists)
306
307(*---------------------------------------------------------------------------*)
308(* Support for dealing with the syntax of any kind of let.                   *)
309(*---------------------------------------------------------------------------*)
310
311local
312    fun dest_plet' tm = let val (a, b, c) = dest_plet tm in ([(a, b)], c) end
313    fun dest_simple_let tm =
314       let
315          val (f, x) = dest_let tm
316          val (v, M) = dest_abs f
317       in
318          ([(v, x)], M)
319       end
320    fun dest_and_let tm acc =
321       let
322          val (f, x) = boolSyntax.dest_let tm
323       in
324          if is_let f
325             then dest_and_let f (x::acc)
326          else let val (blist, M) = strip_pabs f in (zip blist (x::acc), M) end
327       end
328    fun fixup (l, r) =
329       let
330          val (vstructs, M) = strip_pabs r
331       in
332          (list_mk_comb (l, vstructs), M)
333       end
334in
335   fun dest_anylet tm =
336      let
337         val (blist, M) = dest_simple_let tm handle HOL_ERR _ =>
338                          dest_plet' tm      handle HOL_ERR _ =>
339                          dest_and_let tm []
340      in
341         (map fixup blist, M)
342      end
343      handle HOL_ERR _ => raise ERR "dest_anylet" "not a \"let\"-term"
344end
345
346local
347   fun abstr (l, r) =
348      if is_pair l orelse is_var l
349         then (l, r)
350      else let val (f, args) = strip_comb l in (f, list_mk_pabs (args, r)) end
351   fun reorg (l, r, M) = let val (a, b) = abstr (l, r) in (a, b, M) end
352in
353   fun mk_anylet ([], M) = raise ERR "mk_anylet" "no binding"
354     | mk_anylet ([(l, r)], M) = mk_plet (reorg (l, r, M))
355     | mk_anylet (blist, M) =
356         let
357            val (L, R) = unzip (map abstr blist)
358            val abstr = list_mk_pabs (L, M)
359         in
360            rev_itlist (fn r => fn tm => mk_let (tm, r)) R abstr
361         end
362end
363
364fun strip_anylet tm =
365   case total dest_anylet tm of
366      SOME (blist, M) => let val (L, N) = strip_anylet M in (blist::L, N) end
367    | NONE => ([], tm)
368
369fun list_mk_anylet (L, M) = itlist (curry mk_anylet) L M
370
371(* Examples
372  val tm1 = Term `let x = M in N x`;
373  val tm2 = Term `let (x,y,z) = M in N x y z`;
374  val tm3 = Term `let x = M and y = N in P x y`;
375  val tm4 = Term `let (x,y) = M and z = N in P x y z`;
376  val tm5 = Term `let (x,y) = M and z = N in let u = x in P x y z u`;
377  val tm6 = Term `let f(x,y) = M
378                  and g z = N
379                  in let u = x in P (g(f(x,u)))`;
380  val tm7 = Term `let f x = M in P (f y)`;
381  val tm8 = Term `let g x = A in
382                  let v = g x y in
383                  let f x y (a,b) = g a
384                  and foo = M
385                  in f x foo v`;
386*)
387
388
389(*---------------------------------------------------------------------------*)
390(* A "vstruct" is a tuple of variables, possibly nested, with no duplicate   *)
391(* occurrences.                                                              *)
392(*---------------------------------------------------------------------------*)
393
394val is_vstruct = pairTheory.is_vstruct
395
396(* ===================================================================== *)
397(* Generates a pair structure of variable with the same structure as     *)
398(* its parameter.                                                        *)
399(* ===================================================================== *)
400
401fun genvarstruct ty =
402   case total dest_prod ty of
403      SOME (ty1, ty2) => mk_pair (genvarstruct ty1, genvarstruct ty2)
404    | NONE => genvar ty
405
406(*---------------------------------------------------------------------------*)
407(* Lift from ML pairs to HOL pairs                                           *)
408(*---------------------------------------------------------------------------*)
409
410fun lift_prod ty =
411   let
412      val comma = TypeBasePure.cinst ty comma_tm
413   in
414      fn f => fn g => fn (x, y) => list_mk_comb (comma, [f x, g y])
415   end
416
417end
418