1open HolKernel Parse boolLib pairTheory pairSyntax combinTheory listTheory;
2
3val _ = new_theory "state_transformer"
4
5infixr 0 ||
6infix 1 >>;
7
8val DEF = Lib.with_flag (boolLib.def_suffix, "_DEF") TotalDefn.Define
9
10(* ------------------------------------------------------------------------- *)
11(* Definitions.                                                              *)
12(* ------------------------------------------------------------------------- *)
13
14Type M[local] = ���:'state -> 'a # 'state���
15
16(* identity of the Kleisli category *)
17val UNIT_DEF = DEF `UNIT (x:'b) = \(s:'a). (x, s)`;
18
19val BIND_DEF = DEF `BIND (g: ('b, 'a) M) (f: 'b -> ('c, 'a) M) = UNCURRY f o g`;
20
21val IGNORE_BIND_DEF = DEF `IGNORE_BIND f g = BIND f (\x. g)`;
22
23val _ =
24    monadsyntax.declare_monad (
25      "state",
26      { bind = ���BIND���, ignorebind = SOME ���IGNORE_BIND���, unit = ���UNIT���,
27        choice = NONE, fail = NONE, guard = NONE
28      }
29    )
30val _ = monadsyntax.add_monadsyntax()
31val _ = monadsyntax.enable_monad "state"
32
33val MMAP_DEF = DEF `MMAP (f: 'c -> 'b) (m: ('c, 'a) M) = BIND m (UNIT o f)`;
34
35val JOIN_DEF = DEF `JOIN (z: (('b, 'a) M, 'a) M) = BIND z I`;
36
37(* functor (on arrows) from the Kleisli category *)
38val EXT_DEF = DEF `EXT (f: 'b -> ('c, 's) M) (m: ('b, 's) M) = UNCURRY f o m`;
39
40(* composition in the Kleisli category *)
41val MCOMP_DEF =
42  DEF `MCOMP (g: 'b -> ('c, 's) M) (f: 'a -> ('b, 's) M) = EXT g o f` ;
43
44val FOR_def = TotalDefn.tDefine "FOR"
45 `(FOR : num # num # (num -> (unit, 'state) M) -> (unit, 'state) M) (i, j, a) =
46     if i = j then
47        a i
48     else
49        BIND (a i) (\u. FOR (if i < j then i + 1 else i - 1, j, a))`
50  (TotalDefn.WF_REL_TAC `measure (\(i, j, a). if i < j then j - i else i - j)`)
51
52val FOREACH_def = TotalDefn.Define`
53   ((FOREACH : 'a list # ('a -> (unit, 'state) M) -> (unit, 'state) M) ([], a) =
54       UNIT ()) /\
55   (FOREACH (h :: t, a) = BIND (a h) (\u. FOREACH (t, a)))`
56
57val READ_def = TotalDefn.Define`
58   (READ : ('state -> 'a) -> ('a, 'state) M) f = \s. (f s, s)`;
59
60val WRITE_def = TotalDefn.Define`
61   (WRITE : ('state -> 'state) -> (unit, 'state) M) f = \s. ((), f s)`;
62
63val NARROW_def = TotalDefn.Define`
64   (NARROW : 'b -> ('a, 'b # 'state) M -> ('a, 'state) M) v f =
65   \s. let (r, s1) = f (v, s) in (r, SND s1)`
66
67val WIDEN_def = TotalDefn.Define`
68   (WIDEN : ('a, 'state) M -> ('a, 'b # 'state) M) f =
69   \(s1, s2). let (r, s3) = f s2 in (r, (s1, s3))`
70
71val sequence_def = TotalDefn.Define`
72   sequence = FOLDR (\m ms. BIND m (\x. BIND ms (\xs. UNIT (x::xs)))) (UNIT [])`
73
74val mapM_def = TotalDefn.Define`
75   mapM f = sequence o MAP f`
76
77open simpLib BasicProvers boolSimps metisLib
78
79val mwhile_exists = prove(
80  ``!g b. ?f.
81      f = BIND g (\gv. if gv then IGNORE_BIND b f else UNIT ())``,
82  MAP_EVERY Q.X_GEN_TAC [`g`, `b`] THEN
83  Q.EXISTS_TAC
84    `\s0. if ?n. ~FST (g (FUNPOW (SND o b o SND o g) n s0)) then
85            let n = LEAST n. ~FST (g (FUNPOW (SND o b o SND o g) n s0))
86            in
87              ((), SND (g (FUNPOW (SND o b o SND o g) n s0)))
88          else ARB` THEN
89  SIMP_TAC (srw_ss()) [FUN_EQ_THM] THEN Q.X_GEN_TAC `s` THEN
90  COND_CASES_TAC THENL [
91    POP_ASSUM (Q.X_CHOOSE_THEN `n0` ASSUME_TAC) THEN
92    SIMP_TAC (srw_ss()) [SimpLHS, LET_THM] THEN
93    numLib.LEAST_ELIM_TAC THEN CONJ_TAC THEN1 METIS_TAC[] THEN
94    Q.X_GEN_TAC `n` THEN SIMP_TAC (srw_ss()) [] THEN STRIP_TAC THEN
95    SIMP_TAC (srw_ss()) [BIND_DEF] THEN
96    Q.SPEC_THEN `g s` (Q.X_CHOOSE_THEN `gv1`
97                                       (Q.X_CHOOSE_THEN `s1` ASSUME_TAC))
98                pairTheory.pair_CASES THEN
99    ASM_SIMP_TAC (srw_ss()) [] THEN REVERSE (Cases_on `gv1`)
100    THEN1 (`n = 0`
101             by (SPOSE_NOT_THEN ASSUME_TAC THEN
102                 `0 < n` by SRW_TAC [numSimps.ARITH_ss][] THEN
103                 FIRST_X_ASSUM (Q.SPEC_THEN `0` MP_TAC) THEN
104                 SRW_TAC [][]) THEN
105           SRW_TAC [][UNIT_DEF]) THEN
106    ASM_SIMP_TAC (srw_ss()) [IGNORE_BIND_DEF, BIND_DEF] THEN
107    Q.SPEC_THEN `b s1` (Q.X_CHOOSE_THEN `bv1`
108                                        (Q.X_CHOOSE_THEN `s2` ASSUME_TAC))
109                pairTheory.pair_CASES THEN
110    ASM_SIMP_TAC (srw_ss()) [] THEN
111    `?m. n = SUC m`
112      by (Cases_on `n` THEN FULL_SIMP_TAC (srw_ss()) []) THEN
113    Q.SUBGOAL_THEN `?n. ~FST (g (FUNPOW (SND o b o SND o g) n s2))`
114      ASSUME_TAC
115    THEN1 (Q.EXISTS_TAC `m` THEN
116           FULL_SIMP_TAC (srw_ss()) [arithmeticTheory.FUNPOW]) THEN
117    ASM_SIMP_TAC (srw_ss()) [arithmeticTheory.FUNPOW] THEN
118    Q_TAC SUFF_TAC
119       `(LEAST n. ~FST (g (FUNPOW (SND o b o SND o g) n s2))) = m`
120       THEN1 SRW_TAC [][] THEN
121    numLib.LEAST_ELIM_TAC THEN CONJ_TAC THEN1 SRW_TAC [][] THEN
122    Q.X_GEN_TAC `p` THEN SRW_TAC [][] THEN
123    Q_TAC SUFF_TAC `~(m < p) /\ ~(p < m)` THEN1 numLib.ARITH_TAC THEN
124    REPEAT STRIP_TAC THENL [
125      `FST (g (FUNPOW (SND o b o SND o g) m s2))` by METIS_TAC[] THEN
126      `FST (g (FUNPOW (SND o b o SND o g) (SUC m) s))`
127         by (SIMP_TAC (srw_ss())[arithmeticTheory.FUNPOW] THEN
128             SRW_TAC [][]),
129      `SUC p < SUC m` by SRW_TAC [numSimps.ARITH_ss][] THEN
130      RES_THEN MP_TAC THEN
131      SIMP_TAC (srw_ss()) [arithmeticTheory.FUNPOW] THEN
132      SRW_TAC [][]
133    ],
134    FULL_SIMP_TAC (srw_ss()) [BIND_DEF] THEN
135    Q.SPEC_THEN `g s` (Q.X_CHOOSE_THEN `gv1`
136                                       (Q.X_CHOOSE_THEN `s1` ASSUME_TAC))
137                pairTheory.pair_CASES THEN
138    REVERSE (SRW_TAC [][])
139      THEN1(FIRST_X_ASSUM (Q.SPEC_THEN `0` MP_TAC) THEN SRW_TAC [][]) THEN
140    SRW_TAC [][IGNORE_BIND_DEF, BIND_DEF] THEN
141    Q.SPEC_THEN `b s1` (Q.X_CHOOSE_THEN `bv1`
142                                        (Q.X_CHOOSE_THEN `s2` ASSUME_TAC))
143                pairTheory.pair_CASES THEN
144    SRW_TAC [][] THEN
145    FIRST_X_ASSUM (Q.SPEC_THEN `SUC m` (MP_TAC o Q.GEN `m`)) THEN
146    SRW_TAC [][arithmeticTheory.FUNPOW]
147  ])
148
149val MWHILE_DEF = new_specification(
150  "MWHILE_DEF", ["MWHILE"],
151  mwhile_exists |> SIMP_RULE bool_ss [SKOLEM_THM]);
152
153(* ------------------------------------------------------------------------- *)
154(* Theorems.                                                                 *)
155(* ------------------------------------------------------------------------- *)
156
157val Suff = Q_TAC SUFF_TAC
158val Know = Q_TAC KNOW_TAC
159val FUN_EQ_TAC = CONV_TAC (ONCE_DEPTH_CONV FUN_EQ_CONV)
160
161(* UNIT and MCOMP are identity and composition of the Kleisli category *)
162val UNIT_CURRY = store_thm
163  ("UNIT_CURRY",
164   ``UNIT = CURRY I``,
165   REWRITE_TAC [CURRY_DEF, UNIT_DEF, FUN_EQ_THM, combinTheory.I_THM]
166    >> BETA_TAC >> REWRITE_TAC []) ;
167
168val MCOMP_ALT = store_thm
169  ("MCOMP_ALT",
170  ``MCOMP g f = CURRY (UNCURRY g o UNCURRY f)``,
171  REWRITE_TAC [MCOMP_DEF, CURRY_DEF, FUN_EQ_THM, o_THM, UNCURRY_DEF, EXT_DEF]);
172
173val MCOMP_ID = store_thm
174  ("MCOMP_ID",
175   ``(MCOMP g UNIT = g) /\ (MCOMP UNIT f = f)``,
176  REWRITE_TAC [MCOMP_ALT, UNIT_CURRY,
177    UNCURRY_CURRY_THM, CURRY_UNCURRY_THM, I_o_ID]);
178
179val MCOMP_ASSOC = store_thm
180  ("MCOMP_ASSOC",
181   ``MCOMP f (MCOMP g h) = MCOMP (MCOMP f g) h``,
182  REWRITE_TAC [MCOMP_ALT, o_ASSOC, UNCURRY_CURRY_THM, CURRY_UNCURRY_THM]);
183
184(* EXT is a functor from the Kleisli category into the (I,o) category *)
185val EXT_UNIT = store_thm
186  ("EXT_UNIT",
187  ``EXT UNIT = I``,
188  REWRITE_TAC [FUN_EQ_THM, EXT_DEF, UNIT_CURRY,
189    UNCURRY_CURRY_THM, o_THM, I_THM]);
190
191val EXT_MCOMP = store_thm
192  ("EXT_MCOMP",
193  ``EXT (MCOMP g f) = EXT g o EXT f``,
194  REWRITE_TAC [FUN_EQ_THM, EXT_DEF, UNCURRY_CURRY_THM, o_THM, MCOMP_ALT]);
195
196val EXT_o_UNIT = store_thm
197  ("EXT_o_UNIT",
198  ``EXT f o UNIT = f``,
199  REWRITE_TAC [GSYM MCOMP_DEF, MCOMP_ID]);
200
201(* UNIT o _ is the functor in the opposite direction *)
202val UNIT_o_MCOMP = store_thm
203  ("UNIT_o_MCOMP",
204  ``MCOMP (UNIT o g) (UNIT o f) = UNIT o g o f``,
205  REWRITE_TAC [MCOMP_DEF, o_ASSOC, EXT_o_UNIT]) ;
206
207val BIND_EXT = store_thm
208  ("BIND_EXT",
209  ``BIND m f = EXT f m``,
210  REWRITE_TAC [BIND_DEF, EXT_DEF]) ;
211
212val MMAP_EXT = store_thm
213  ("MMAP_EXT",
214  ``MMAP f = EXT (UNIT o f)``,
215  REWRITE_TAC [FUN_EQ_THM, MMAP_DEF, BIND_EXT]) ;
216
217val JOIN_EXT = store_thm
218  ("JOIN_EXT",
219  ``JOIN = EXT I``,
220  REWRITE_TAC [FUN_EQ_THM, JOIN_DEF, BIND_EXT]) ;
221
222val EXT_JM = store_thm
223  ("EXT_JM",
224  ``EXT f = JOIN o MMAP f``,
225  REWRITE_TAC [JOIN_EXT, BIND_EXT, MMAP_EXT, GSYM EXT_MCOMP,
226    MCOMP_DEF, o_ASSOC, EXT_o_UNIT, I_o_ID]) ;
227
228val BIND_LEFT_UNIT = store_thm
229  ("BIND_LEFT_UNIT",
230   ``!(k:'a->'b->'c#'b) x. BIND (UNIT x) k = k x``,
231   REPEAT STRIP_TAC
232   >> MATCH_MP_TAC EQ_EXT
233   >> REWRITE_TAC [BIND_DEF, UNIT_DEF, o_DEF]
234   >> CONV_TAC (DEPTH_CONV BETA_CONV)
235   >> REWRITE_TAC [UNCURRY_DEF]);
236
237val UNIT_UNCURRY = store_thm
238  ("UNIT_UNCURRY",
239   ``!(s:'a#'b). UNCURRY UNIT s = s``,
240   REWRITE_TAC [UNCURRY_VAR, UNIT_DEF]
241   >> CONV_TAC (DEPTH_CONV BETA_CONV)
242   >> REWRITE_TAC [PAIR]);
243
244val BIND_RIGHT_UNIT = store_thm
245  ("BIND_RIGHT_UNIT",
246   ``!(k:'a->'b#'a). BIND k UNIT = k``,
247   REPEAT STRIP_TAC
248   >> MATCH_MP_TAC EQ_EXT
249   >> REWRITE_TAC [BIND_DEF, UNIT_UNCURRY, o_DEF]
250   >> CONV_TAC (DEPTH_CONV BETA_CONV)
251   >> REWRITE_TAC []);
252
253val BIND_ASSOC = store_thm
254  ("BIND_ASSOC",
255   ``!(k:'a->'b#'a) (m:'b->'a->'c#'a) (n:'c->'a->'d#'a).
256       BIND k (\a. BIND (m a) n) = BIND (BIND k m) n``,
257   REWRITE_TAC [BIND_DEF, UNCURRY_VAR, o_DEF]
258   >> CONV_TAC (DEPTH_CONV BETA_CONV)
259   >> REWRITE_TAC []);
260
261val MMAP_ID = store_thm
262  ("MMAP_ID",
263   ``MMAP I = (I:('a->'b#'a)->('a->'b#'a))``,
264   REWRITE_TAC [MMAP_EXT, I_o_ID, EXT_UNIT]) ;
265
266val MMAP_COMP = store_thm
267  ("MMAP_COMP",
268   ``!f g. (MMAP (f o g):('a->'b#'a)->('a->'d#'a))
269           = (MMAP f:('a->'c#'a)->('a->'d#'a)) o MMAP g``,
270   REWRITE_TAC [MMAP_EXT, o_THM, GSYM EXT_MCOMP, UNIT_o_MCOMP]) ;
271
272val MMAP_UNIT = store_thm
273  ("MMAP_UNIT",
274   ``!(f:'b->'c). MMAP f o UNIT = (UNIT:'c->'a->'c#'a) o f``,
275   REWRITE_TAC [MMAP_EXT, EXT_o_UNIT]) ;
276
277val EXT_o_JOIN = store_thm
278  ("EXT_o_JOIN",
279   ``!f. EXT f o JOIN = EXT (EXT f:('a->'b#'a)->('a->'c#'a))``,
280   REWRITE_TAC [JOIN_EXT, GSYM EXT_MCOMP, MCOMP_DEF, I_o_ID]) ;
281
282val MMAP_JOIN = store_thm
283  ("MMAP_JOIN",
284   ``!f. MMAP f o JOIN = JOIN o MMAP (MMAP f:('a->'b#'a)->('a->'c#'a))``,
285   REWRITE_TAC [GSYM EXT_JM] >> REWRITE_TAC [MMAP_EXT, EXT_o_JOIN]) ;
286
287val JOIN_UNIT = store_thm
288  ("JOIN_UNIT",
289   ``JOIN o UNIT = (I:('a->'b#'a)->('a->'b#'a))``,
290   REWRITE_TAC [JOIN_EXT, EXT_o_UNIT]) ;
291
292val JOIN_MMAP_UNIT = store_thm
293  ("JOIN_MMAP_UNIT",
294   ``JOIN o MMAP UNIT = (I:('a->'b#'a)->('a->'b#'a))``,
295   REWRITE_TAC [GSYM EXT_JM, EXT_UNIT]) ;
296
297val JOIN_MAP_JOIN = store_thm
298  ("JOIN_MAP_JOIN",
299   ``JOIN o MMAP JOIN = ((JOIN o JOIN)
300       :('a -> ('a -> ('a -> 'b # 'a) # 'a) # 'a) -> 'a -> 'b # 'a)``,
301   REWRITE_TAC [GSYM EXT_JM] >> REWRITE_TAC [JOIN_EXT, GSYM EXT_o_JOIN]) ;
302
303val JOIN_MAP = store_thm
304  ("JOIN_MAP",
305   ``!k (m:'b->'a->'c#'a). BIND k m = JOIN (MMAP m k)``,
306   REWRITE_TAC [BIND_EXT, EXT_JM, o_THM]) ;
307
308val FST_o_UNIT = store_thm
309  ("FST_o_UNIT",
310   ``!x. FST o UNIT x = K x``,
311   FUN_EQ_TAC
312   >> REWRITE_TAC [o_THM, UNIT_DEF, K_THM]
313   >> BETA_TAC
314   >> REWRITE_TAC [FST]);
315
316val SND_o_UNIT = store_thm
317  ("SND_o_UNIT",
318   ``!x. SND o UNIT x = I``,
319   FUN_EQ_TAC
320   >> REWRITE_TAC [o_THM, UNIT_DEF, I_THM]
321   >> BETA_TAC
322   >> REWRITE_TAC [SND]);
323
324val FST_o_MMAP = store_thm
325  ("FST_o_MMAP",
326   ``!f g. FST o MMAP f g = f o FST o g``,
327   FUN_EQ_TAC
328   >> REWRITE_TAC [MMAP_DEF, BIND_DEF, UNCURRY, o_THM, UNIT_DEF]
329   >> BETA_TAC
330   >> REWRITE_TAC [FST]);
331
332val sequence_nil = store_thm("sequence_nil",
333  ``sequence [] = UNIT []``,
334  BasicProvers.SRW_TAC[][sequence_def])
335val _ = BasicProvers.export_rewrites["sequence_nil"]
336
337val mapM_nil = store_thm("mapM_nil",
338  ``mapM f [] = UNIT []``,
339  BasicProvers.SRW_TAC[][mapM_def])
340val _ = BasicProvers.export_rewrites["mapM_nil"]
341
342val mapM_cons = store_thm("mapM_cons",
343  ``mapM f (x::xs) = BIND (f x) (\y. BIND (mapM f xs) (\ys. UNIT (y::ys)))``,
344  BasicProvers.SRW_TAC[][mapM_def,sequence_def])
345
346(*---------------------------------------------------------------------------*)
347(* Support for termination condition extraction for recursive monadic defns. *)
348(*---------------------------------------------------------------------------*)
349(*
350val BIND_CONG = Q.store_thm
351("BIND_CONG",
352 `!a b c d.
353   (a = c) /\
354   (!x y s. (c s = (x,y)) ==> (b x y = d x y))
355    ==>
356   (BIND a b = BIND c d)`,
357 SRW_TAC [] [BIND_DEF,pairTheory.UNCURRY_VAR,combinTheory.o_DEF,FUN_EQ_THM]
358  THEN FIRST_ASSUM MATCH_MP_TAC
359  THEN METIS_TAC [pairTheory.PAIR]);
360
361val _ = adjoin_to_theory
362{sig_ps = NONE,
363 struct_ps = SOME
364 (fn ppstrm => let
365   val S = (fn s => (PP.add_string ppstrm s; PP.add_newline ppstrm))
366 in
367   S "val _ = DefnBase.add_cong BIND_CONG;";
368   S "val _ = TotalDefn.termination_simps := (!TotalDefn.termination_simps @ [UNIT_DEF]);"
369 end)};
370*)
371
372(* ------------------------------------------------------------------------- *)
373
374val _ = export_theory ();
375