1(*
2 * Copyright 2014, NICTA
3 *
4 * This software may be distributed and modified according to the terms of
5 * the BSD 2-Clause license. Note that NO WARRANTY is provided.
6 * See "LICENSE_BSD2.txt" for details.
7 *
8 * @TAG(NICTA_BSD)
9 *)
10
11(* Author: Thomas Sewell
12
13   Library routines etc expected by Haskell code.
14*)
15
16theory HaskellLib_H
17imports
18  Lib
19  "More_Numeral_Type"
20  "Monad_WP/NonDetMonadVCG"
21begin
22
23abbreviation (input) "flip \<equiv> swp"
24
25abbreviation(input) bind_drop :: "('a, 'c) nondet_monad \<Rightarrow> ('a, 'b) nondet_monad
26                      \<Rightarrow> ('a, 'b) nondet_monad" (infixl ">>'_" 60)
27  where "bind_drop \<equiv> (\<lambda>x y. bind x (K_bind y))"
28
29lemma bind_drop_test:
30  "foldr bind_drop x (return ()) = sequence_x x"
31  by (rule ext, simp add: sequence_x_def)
32
33(* If the given monad is deterministic, this function converts
34    the nondet_monad type into a normal deterministic state monad *)
35definition
36  runState :: "('s, 'a) nondet_monad \<Rightarrow> 's \<Rightarrow> ('a \<times> 's)" where
37 "runState f s \<equiv> THE x. x \<in> fst (f s)"
38
39definition
40  sassert :: "bool \<Rightarrow> 'a \<Rightarrow> 'a" where
41 "sassert P \<equiv> if P then id else (\<lambda>x. undefined)"
42
43lemma sassert_cong[fundef_cong]:
44 "\<lbrakk> P = P'; P' \<Longrightarrow> s = s' \<rbrakk> \<Longrightarrow> sassert P s = sassert P' s'"
45  apply (simp add: sassert_def)
46  done
47
48definition
49  haskell_assert :: "bool \<Rightarrow> unit list \<Rightarrow> ('a, unit) nondet_monad" where
50 "haskell_assert P L \<equiv> assert P"
51
52definition
53  haskell_assertE :: "bool \<Rightarrow> unit list \<Rightarrow> ('a, 'e + unit) nondet_monad" where
54 "haskell_assertE P L \<equiv> assertE P"
55
56declare haskell_assert_def [simp] haskell_assertE_def [simp]
57
58definition
59  stateAssert :: "('a \<Rightarrow> bool) \<Rightarrow> unit list \<Rightarrow> ('a, unit) nondet_monad" where
60 "stateAssert P L \<equiv> get >>= (\<lambda>s. assert (P s))"
61
62definition
63  haskell_fail :: "unit list \<Rightarrow> ('a, 'b) nondet_monad" where
64  haskell_fail_def[simp]:
65 "haskell_fail L \<equiv> fail"
66
67definition
68  catchError_def[simp]:
69 "catchError \<equiv> handleE"
70
71definition
72  "curry1 \<equiv> id"
73definition
74  "curry2 \<equiv> curry"
75definition
76  "curry3 f a b c \<equiv> f (a, b, c)"
77definition
78  "curry4 f a b c d \<equiv> f (a, b, c, d)"
79definition
80  "curry5 f a b c d e \<equiv> f (a, b, c, d, e)"
81
82declare curry1_def[simp] curry2_def[simp]
83        curry3_def[simp] curry4_def[simp] curry5_def[simp]
84
85definition
86  "split1 \<equiv> id"
87definition
88  "split2 \<equiv> case_prod"
89definition
90  "split3 f \<equiv> \<lambda>(a, b, c). f a b c"
91definition
92  "split4 f \<equiv> \<lambda>(a, b, c, d). f a b c d"
93definition
94  "split5 f \<equiv> \<lambda>(a, b, c, d, e). f a b c d e"
95
96declare split1_def[simp] split2_def[simp]
97
98lemma split3_simp[simp]: "split3 f (a, b, c) = f a b c"
99  by (simp add: split3_def)
100
101lemma split4_simp[simp]: "split4 f (a, b, c, d) = f a b c d"
102  by (simp add: split4_def)
103
104lemma split5_simp[simp]: "split5 f (a, b, c, d, e) = f a b c d e"
105  by (simp add: split5_def)
106
107definition
108 "Just \<equiv> Some"
109definition
110 "Nothing \<equiv> None"
111
112definition
113 "fromJust \<equiv> the"
114definition
115 "isJust x \<equiv> x \<noteq> None"
116
117definition
118 "tail \<equiv> tl"
119definition
120 "head \<equiv> hd"
121
122definition
123  error :: "unit list \<Rightarrow> 'a" where
124 "error \<equiv> \<lambda>x. undefined"
125
126definition
127 "reverse \<equiv> rev"
128
129definition
130 "isNothing x \<equiv> x = None"
131
132definition
133 "maybeApply \<equiv> option_map"
134
135definition
136 "maybe \<equiv> case_option"
137
138definition
139 "foldR f init L \<equiv> foldr f L init"
140
141definition
142 "elem x L \<equiv> x \<in> set L"
143
144definition
145 "notElem x L \<equiv> x \<notin> set L"
146
147type_synonym ordering = bool
148
149definition
150  compare :: "('a :: ord) \<Rightarrow> 'a \<Rightarrow> ordering" where
151  "compare \<equiv> (<)"
152
153primrec
154  insertBy :: "('a \<Rightarrow> 'a \<Rightarrow> bool) \<Rightarrow> 'a \<Rightarrow> 'a list \<Rightarrow> 'a list"
155where
156  "insertBy f a [] = [a]"
157| "insertBy f a (b # bs) = (if (f a b) then (a # b # bs) else (b # (insertBy f a bs)))"
158
159lemma insertBy_length [simp]:
160  "length (insertBy f a as) = (1 + length as)"
161  by (induct as) simp_all
162
163primrec
164  sortBy :: "('a \<Rightarrow> 'a \<Rightarrow> ordering) \<Rightarrow> 'a list \<Rightarrow> 'a list"
165where
166  "sortBy f [] = []"
167| "sortBy f (a # as) = insertBy f a (sortBy f as)"
168
169lemma sortBy_length:
170  "length (sortBy f as) = length as"
171  by (induct as) simp_all
172
173definition
174 "sortH \<equiv> sortBy compare"
175
176definition
177 "catMaybes \<equiv> (map the) \<circ> (filter isJust)"
178
179definition
180 "runExceptT \<equiv> id"
181
182declare Just_def[simp] Nothing_def[simp] fromJust_def[simp]
183      isJust_def[simp] tail_def[simp] head_def[simp]
184      error_def[simp] reverse_def[simp] isNothing_def[simp]
185      maybeApply_def[simp] maybe_def[simp]
186      foldR_def[simp] elem_def[simp] notElem_def[simp]
187      catMaybes_def[simp] runExceptT_def[simp]
188
189definition
190 "headM L \<equiv> (case L of (h # t) \<Rightarrow> return h | _ \<Rightarrow> fail)"
191
192definition
193 "tailM L \<equiv> (case L of (h # t) \<Rightarrow> return t | _ \<Rightarrow> fail)"
194
195axiomatization
196  typeOf :: "'a \<Rightarrow> unit list"
197
198definition
199 "either f1 f2 c \<equiv> case c of Inl r1 \<Rightarrow> f1 r1 | Inr r2 \<Rightarrow> f2 r2"
200
201lemma either_simp[simp]: "either = case_sum"
202  apply (rule ext)+
203  apply (simp add: either_def)
204  done
205
206
207instantiation nat :: bit
208begin
209
210definition
211  "bitNOT = nat o bitNOT o int"
212
213definition
214  "bitAND x y = nat (bitAND (int x) (int y))"
215
216definition
217  "bitOR x y = nat (bitOR (int x) (int y))"
218
219definition
220  "bitXOR x y = nat (bitXOR (int x) (int y))"
221
222instance ..
223
224end
225
226class HS_bit = bit +
227  fixes shiftL :: "'a \<Rightarrow> nat \<Rightarrow> 'a"
228  fixes shiftR :: "'a \<Rightarrow> nat \<Rightarrow> 'a"
229  fixes bitSize :: "'a \<Rightarrow> nat"
230
231instantiation word :: (len0) HS_bit
232begin
233
234definition
235  shiftL_word[simp]: "(shiftL :: 'a::len0 word \<Rightarrow> nat \<Rightarrow> 'a word) \<equiv> shiftl"
236
237definition
238  shiftR_word[simp]: "(shiftR :: 'a::len0 word \<Rightarrow> nat \<Rightarrow> 'a word) \<equiv> shiftr"
239
240definition
241  bitSize_word[simp]: "(bitSize :: 'a::len0 word \<Rightarrow> nat) \<equiv> size"
242
243instance ..
244
245end
246
247instantiation nat ::  HS_bit
248begin
249
250definition
251  shiftL_nat: "shiftL (x :: nat) n \<equiv> (2 ^ n) * x"
252
253definition
254  shiftR_nat: "shiftR (x :: nat) n \<equiv> x div (2 ^ n)"
255
256text {* bitSize not defined for nat *}
257
258instance ..
259
260end
261
262class finiteBit = bit +
263  fixes finiteBitSize :: "'a \<Rightarrow> nat"
264
265instantiation word :: (len0) finiteBit
266begin
267
268definition
269  finiteBitSize_word[simp]: "(finiteBitSize :: 'a::len0 word \<Rightarrow> nat) \<equiv> size"
270
271instance ..
272
273end
274
275definition
276  bit_def[simp]:
277 "bit x \<equiv> shiftL 1 x"
278
279definition
280"isAligned x n \<equiv> x && mask n = 0"
281
282class integral = ord +
283  fixes fromInteger :: "nat \<Rightarrow> 'a"
284  fixes toInteger :: "'a \<Rightarrow> nat"
285  assumes integral_inv: "fromInteger \<circ> toInteger = id"
286
287instantiation nat :: integral
288begin
289
290definition
291 fromInteger_nat: "fromInteger \<equiv> id"
292
293definition
294 toInteger_nat: "toInteger \<equiv> id"
295
296instance
297  apply (intro_classes)
298  apply (simp add: toInteger_nat fromInteger_nat)
299  done
300
301end
302
303
304instantiation word :: (len) integral
305begin
306
307definition
308 fromInteger_word: "fromInteger \<equiv> of_nat :: nat \<Rightarrow> 'a::len word"
309
310definition
311 toInteger_word: "toInteger \<equiv> unat"
312
313instance
314  apply (intro_classes)
315  apply (rule ext)
316  apply (simp add: toInteger_word fromInteger_word)
317  done
318
319end
320
321definition
322  fromIntegral :: "('a :: integral) \<Rightarrow> ('b :: integral)" where
323 "fromIntegral \<equiv> fromInteger \<circ> toInteger"
324
325lemma fromIntegral_simp1[simp]: "(fromIntegral :: nat \<Rightarrow> ('a :: len) word) = of_nat"
326  by (simp add: fromIntegral_def fromInteger_word toInteger_nat)
327
328lemma fromIntegral_simp2[simp]: "fromIntegral = unat"
329  by (simp add: fromIntegral_def fromInteger_nat toInteger_word)
330
331lemma fromIntegral_simp3[simp]: "fromIntegral = ucast"
332  apply (simp add: fromIntegral_def fromInteger_word toInteger_word)
333  apply (rule ext)
334  apply (simp add: ucast_def)
335  apply (subst word_of_nat)
336  apply (simp add: unat_def)
337  done
338
339lemma fromIntegral_simp_nat[simp]: "(fromIntegral :: nat \<Rightarrow> nat) = id"
340  by (simp add: fromIntegral_def fromInteger_nat toInteger_nat)
341
342definition
343  infix_apply :: "'a \<Rightarrow> ('a \<Rightarrow> 'b \<Rightarrow> 'c) \<Rightarrow> 'b \<Rightarrow> 'c" ("_ `~_~` _" [81, 100, 80] 80) where
344  infix_apply_def[simp]:
345 "infix_apply a f b \<equiv> f a b"
346
347term "return $ a `~b~` c d"
348
349definition
350  zip3 :: "'a list \<Rightarrow> 'b list \<Rightarrow> 'c list \<Rightarrow> ('a \<times> 'b \<times> 'c) list" where
351 "zip3 a b c \<equiv> zip a (zip b c)"
352
353(* avoid even attempting haskell's show class *)
354definition
355 "show" :: "'a \<Rightarrow> unit list" where
356 "show x \<equiv> []"
357
358lemma show_simp_away[simp]: "S @ show t = S"
359  by (simp add: show_def)
360
361definition
362 "andList \<equiv> foldl (\<and>) True"
363
364definition
365 "orList \<equiv> foldl (\<or>) False"
366
367primrec
368  mapAccumL :: "('a \<Rightarrow> 'b \<Rightarrow> 'a \<times> 'c) \<Rightarrow> 'a \<Rightarrow> 'b list \<Rightarrow> 'a \<times> ('c list)"
369where
370  "mapAccumL f s [] = (s, [])"
371| "mapAccumL f s (x#xs) = (
372     let (s', r) = f s x;
373         (s'', rs) = mapAccumL f s' xs
374     in (s'', r#rs)
375  )"
376
377primrec
378  untilM :: "('b \<Rightarrow> ('s, 'a option) nondet_monad) \<Rightarrow> 'b list \<Rightarrow> ('s, 'a option) nondet_monad"
379where
380  "untilM f [] = return None"
381| "untilM f (x#xs) = do
382    r \<leftarrow> f x;
383    case r of
384      None     \<Rightarrow> untilM f xs
385    | Some res \<Rightarrow> return (Some res)
386  od"
387
388primrec
389  untilME :: "('c \<Rightarrow> ('s, ('a + 'b option)) nondet_monad) \<Rightarrow> 'c list \<Rightarrow> ('s, 'a + 'b option) nondet_monad"
390where
391  "untilME f [] = returnOk None"
392| "untilME f (x#xs) = doE
393    r \<leftarrow> f x;
394    case r of
395      None     \<Rightarrow> untilME f xs
396    | Some res \<Rightarrow> returnOk (Some res)
397  odE"
398
399primrec
400  findM :: "('a \<Rightarrow> ('s, bool) nondet_monad) \<Rightarrow> 'a list \<Rightarrow> ('s, 'a option) nondet_monad"
401where
402  "findM f [] = return None"
403| "findM f (x#xs) = do
404    r \<leftarrow> f x;
405    if r
406      then return (Some x)
407      else findM f xs
408  od"
409
410primrec
411  findME :: "('a \<Rightarrow> ('s, ('e + bool)) nondet_monad) \<Rightarrow> 'a list \<Rightarrow> ('s, ('e + 'a option)) nondet_monad"
412where
413  "findME f [] = returnOk None"
414| "findME f (x#xs) = doE
415    r \<leftarrow> f x;
416    if r
417      then returnOk (Some x)
418      else findME f xs
419  odE"
420
421primrec
422  tails :: "'a list \<Rightarrow> 'a list list"
423where
424  "tails [] = [[]]"
425| "tails (x#xs) = (x#xs)#(tails xs)"
426
427lemma finite_surj_type:
428  "\<lbrakk> (\<forall>x. \<exists>y. (x :: 'b) = f (y :: 'a)); finite (UNIV :: 'a set) \<rbrakk> \<Longrightarrow> finite (UNIV :: 'b set)"
429  apply (erule finite_surj)
430  apply safe
431  apply (erule allE)
432  apply safe
433  apply (erule image_eqI)
434  apply simp
435  done
436
437lemma finite_finite[simp]: "finite (s :: ('a :: finite) set)"
438  by simp
439
440lemma finite_inv_card_less':
441  "U = (UNIV :: ('a :: finite) set) \<Longrightarrow> (card (U - insert a s) < card (U - s)) = (a \<notin> s)"
442  apply (case_tac "a \<in> s")
443   apply (simp_all add: insert_absorb)
444  apply (subgoal_tac "card s < card U")
445   apply (simp add: card_Diff_subset)
446  apply (rule psubset_card_mono)
447   apply safe
448    apply simp_all
449  done
450
451lemma finite_inv_card_less:
452   "(card (UNIV - insert (a :: ('a :: finite)) s) < card (UNIV - s)) = (a \<notin> s)"
453  by (simp add: finite_inv_card_less')
454
455text {* Support for defining enumerations on datatypes derived from enumerations *}
456lemma distinct_map_enum: "\<lbrakk> (\<forall> x y. (F x = F y \<longrightarrow> x = y )) \<rbrakk> \<Longrightarrow> distinct (map F (enum :: 'a :: enum list))"
457  apply (simp add: distinct_map)
458  apply (rule inj_onI)
459  apply simp
460  done
461
462definition
463  "minimum ls \<equiv> Min (set ls)"
464definition
465  "maximum ls \<equiv> Max (set ls)"
466
467primrec (nonexhaustive)
468  hdCons :: "'a \<Rightarrow> 'a list list \<Rightarrow> 'a list list"
469where
470 "hdCons x (ys # zs) = (x # ys) # zs"
471
472primrec
473  rangesBy :: "('a \<Rightarrow> 'a \<Rightarrow> bool) \<Rightarrow> 'a list \<Rightarrow> 'a list list"
474where
475  "rangesBy f [] = []"
476| "rangesBy f (x # xs) =
477    (case xs of [] \<Rightarrow> [[x]]
478             | (y # ys) \<Rightarrow> if (f x y) then hdCons x (rangesBy f xs)
479                                     else [x] # (rangesBy f xs))"
480
481definition
482  partition :: "('a \<Rightarrow> bool) \<Rightarrow> 'a list \<Rightarrow> 'a list \<times> 'a list" where
483 "partition f xs \<equiv> (filter f xs, filter (\<lambda>x. \<not> f x) xs)"
484
485definition
486  listSubtract :: "'a list \<Rightarrow> 'a list \<Rightarrow> 'a list" where
487 "listSubtract xs ys \<equiv> filter (\<lambda>x. x \<in> set ys) xs"
488
489definition
490  init :: "'a list \<Rightarrow> 'a list" where
491 "init xs \<equiv> case (length xs) of Suc n \<Rightarrow> take n xs | _ \<Rightarrow> undefined"
492
493primrec
494  break :: "('a \<Rightarrow> bool) \<Rightarrow> 'a list \<Rightarrow> ('a list \<times> 'a list)"
495where
496  "break f []       = ([], [])"
497| "break f (x # xs) =
498   (if f x
499    then ([], x # xs)
500    else (\<lambda>(ys, zs). (x # ys, zs)) (break f xs))"
501
502definition
503 "uncurry \<equiv> case_prod"
504
505definition
506  sum :: "'a list \<Rightarrow> 'a::{plus,zero}" where
507 "sum \<equiv> foldl (+) 0"
508
509definition
510 "replicateM n m \<equiv> sequence (replicate n m)"
511
512definition
513  maybeToMonad_def[simp]:
514 "maybeToMonad \<equiv> assert_opt"
515
516definition
517  funArray :: "('a \<Rightarrow> 'b) \<Rightarrow> ('a \<Rightarrow> 'b)"  where
518  funArray_def[simp]:
519 "funArray \<equiv> id"
520
521definition
522  funPartialArray :: "('a \<Rightarrow> 'b) \<Rightarrow> ('a :: enumeration_alt \<times> 'a) \<Rightarrow> ('a \<Rightarrow> 'b)"  where
523 "funPartialArray f xrange \<equiv> \<lambda>x. (if x \<in> set [fst xrange .e. snd xrange] then f x else undefined)"
524
525definition
526  forM_def[simp]:
527 "forM xs f \<equiv> mapM f xs"
528
529definition
530  forM_x_def[simp]:
531 "forM_x xs f \<equiv> mapM_x f xs"
532
533definition
534  forME_x_def[simp]:
535 "forME_x xs f \<equiv> mapME_x f xs"
536
537definition
538  arrayListUpdate :: "('a \<Rightarrow> 'b) \<Rightarrow> ('a \<times> 'b) list \<Rightarrow> ('a \<Rightarrow> 'b)" (infixl "aLU" 90)
539where
540  arrayListUpdate_def[simp]:
541  "arrayListUpdate f l  \<equiv> foldl (\<lambda>f p. f(fst p := snd p)) f l"
542
543definition
544 "genericTake \<equiv> take \<circ> fromIntegral"
545
546definition
547 "genericLength \<equiv> fromIntegral \<circ> length"
548
549abbreviation
550  "null == List.null"
551
552syntax (input)
553  "_listcompr" :: "'a \<Rightarrow> lc_qual \<Rightarrow> lc_quals \<Rightarrow> 'a list"  ("[_ | __")
554
555lemma "[(x,1) . x \<leftarrow> [0..10]] = [(x,1) | x \<leftarrow> [0..10]]" by (rule refl)
556
557end
558