1(*
2 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3 *
4 * SPDX-License-Identifier: BSD-2-Clause
5 *)
6
7(* Author: Thomas Sewell
8
9   Library routines etc expected by Haskell code.
10*)
11
12theory HaskellLib_H
13imports
14  Lib
15  NatBitwise
16  More_Numeral_Type
17  NonDetMonadVCG
18begin
19
20abbreviation (input) "flip \<equiv> swp"
21
22abbreviation(input) bind_drop :: "('a, 'c) nondet_monad \<Rightarrow> ('a, 'b) nondet_monad
23                      \<Rightarrow> ('a, 'b) nondet_monad" (infixl ">>'_" 60)
24  where "bind_drop \<equiv> (\<lambda>x y. bind x (K_bind y))"
25
26lemma bind_drop_test:
27  "foldr bind_drop x (return ()) = sequence_x x"
28  by (rule ext, simp add: sequence_x_def)
29
30(* If the given monad is deterministic, this function converts
31    the nondet_monad type into a normal deterministic state monad *)
32definition
33  runState :: "('s, 'a) nondet_monad \<Rightarrow> 's \<Rightarrow> ('a \<times> 's)" where
34 "runState f s \<equiv> THE x. x \<in> fst (f s)"
35
36definition
37  sassert :: "bool \<Rightarrow> 'a \<Rightarrow> 'a" where
38 "sassert P \<equiv> if P then id else (\<lambda>x. undefined)"
39
40lemma sassert_cong[fundef_cong]:
41 "\<lbrakk> P = P'; P' \<Longrightarrow> s = s' \<rbrakk> \<Longrightarrow> sassert P s = sassert P' s'"
42  apply (simp add: sassert_def)
43  done
44
45definition
46  haskell_assert :: "bool \<Rightarrow> unit list \<Rightarrow> ('a, unit) nondet_monad" where
47 "haskell_assert P L \<equiv> assert P"
48
49definition
50  haskell_assertE :: "bool \<Rightarrow> unit list \<Rightarrow> ('a, 'e + unit) nondet_monad" where
51 "haskell_assertE P L \<equiv> assertE P"
52
53declare haskell_assert_def [simp] haskell_assertE_def [simp]
54
55definition
56  stateAssert :: "('a \<Rightarrow> bool) \<Rightarrow> unit list \<Rightarrow> ('a, unit) nondet_monad" where
57 "stateAssert P L \<equiv> get >>= (\<lambda>s. assert (P s))"
58
59definition
60  haskell_fail :: "unit list \<Rightarrow> ('a, 'b) nondet_monad" where
61  haskell_fail_def[simp]:
62 "haskell_fail L \<equiv> fail"
63
64definition
65  catchError_def[simp]:
66 "catchError \<equiv> handleE"
67
68definition
69  "curry1 \<equiv> id"
70definition
71  "curry2 \<equiv> curry"
72definition
73  "curry3 f a b c \<equiv> f (a, b, c)"
74definition
75  "curry4 f a b c d \<equiv> f (a, b, c, d)"
76definition
77  "curry5 f a b c d e \<equiv> f (a, b, c, d, e)"
78
79declare curry1_def[simp] curry2_def[simp]
80        curry3_def[simp] curry4_def[simp] curry5_def[simp]
81
82definition
83  "split1 \<equiv> id"
84definition
85  "split2 \<equiv> case_prod"
86definition
87  "split3 f \<equiv> \<lambda>(a, b, c). f a b c"
88definition
89  "split4 f \<equiv> \<lambda>(a, b, c, d). f a b c d"
90definition
91  "split5 f \<equiv> \<lambda>(a, b, c, d, e). f a b c d e"
92
93declare split1_def[simp] split2_def[simp]
94
95lemma split3_simp[simp]: "split3 f (a, b, c) = f a b c"
96  by (simp add: split3_def)
97
98lemma split4_simp[simp]: "split4 f (a, b, c, d) = f a b c d"
99  by (simp add: split4_def)
100
101lemma split5_simp[simp]: "split5 f (a, b, c, d, e) = f a b c d e"
102  by (simp add: split5_def)
103
104definition
105 "Just \<equiv> Some"
106definition
107 "Nothing \<equiv> None"
108
109definition
110 "fromJust \<equiv> the"
111definition
112 "isJust x \<equiv> x \<noteq> None"
113
114definition
115 "tail \<equiv> tl"
116definition
117 "head \<equiv> hd"
118
119definition
120  error :: "unit list \<Rightarrow> 'a" where
121 "error \<equiv> \<lambda>x. undefined"
122
123definition
124 "reverse \<equiv> rev"
125
126definition
127 "isNothing x \<equiv> x = None"
128
129definition
130 "maybeApply \<equiv> option_map"
131
132definition
133 "maybe \<equiv> case_option"
134
135definition
136 "foldR f init L \<equiv> foldr f L init"
137
138definition
139 "elem x L \<equiv> x \<in> set L"
140
141definition
142 "notElem x L \<equiv> x \<notin> set L"
143
144type_synonym ordering = bool
145
146definition
147  compare :: "('a :: ord) \<Rightarrow> 'a \<Rightarrow> ordering" where
148  "compare \<equiv> (<)"
149
150primrec
151  insertBy :: "('a \<Rightarrow> 'a \<Rightarrow> bool) \<Rightarrow> 'a \<Rightarrow> 'a list \<Rightarrow> 'a list"
152where
153  "insertBy f a [] = [a]"
154| "insertBy f a (b # bs) = (if (f a b) then (a # b # bs) else (b # (insertBy f a bs)))"
155
156lemma insertBy_length [simp]:
157  "length (insertBy f a as) = (1 + length as)"
158  by (induct as) simp_all
159
160primrec
161  sortBy :: "('a \<Rightarrow> 'a \<Rightarrow> ordering) \<Rightarrow> 'a list \<Rightarrow> 'a list"
162where
163  "sortBy f [] = []"
164| "sortBy f (a # as) = insertBy f a (sortBy f as)"
165
166lemma sortBy_length:
167  "length (sortBy f as) = length as"
168  by (induct as) simp_all
169
170definition
171 "sortH \<equiv> sortBy compare"
172
173definition
174 "catMaybes \<equiv> (map the) \<circ> (filter isJust)"
175
176definition
177 "runExceptT \<equiv> id"
178
179declare Just_def[simp] Nothing_def[simp] fromJust_def[simp]
180      isJust_def[simp] tail_def[simp] head_def[simp]
181      error_def[simp] reverse_def[simp] isNothing_def[simp]
182      maybeApply_def[simp] maybe_def[simp]
183      foldR_def[simp] elem_def[simp] notElem_def[simp]
184      catMaybes_def[simp] runExceptT_def[simp]
185
186definition
187 "headM L \<equiv> (case L of (h # t) \<Rightarrow> return h | _ \<Rightarrow> fail)"
188
189definition
190 "tailM L \<equiv> (case L of (h # t) \<Rightarrow> return t | _ \<Rightarrow> fail)"
191
192axiomatization
193  typeOf :: "'a \<Rightarrow> unit list"
194
195definition
196 "either f1 f2 c \<equiv> case c of Inl r1 \<Rightarrow> f1 r1 | Inr r2 \<Rightarrow> f2 r2"
197
198lemma either_simp[simp]: "either = case_sum"
199  apply (rule ext)+
200  apply (simp add: either_def)
201  done
202
203class HS_bit = bit_operations +
204  fixes shiftL :: "'a \<Rightarrow> nat \<Rightarrow> 'a"
205  fixes shiftR :: "'a \<Rightarrow> nat \<Rightarrow> 'a"
206  fixes bitSize :: "'a \<Rightarrow> nat"
207
208instantiation word :: (len0) HS_bit
209begin
210
211definition
212  shiftL_word[simp]: "(shiftL :: 'a::len0 word \<Rightarrow> nat \<Rightarrow> 'a word) \<equiv> shiftl"
213
214definition
215  shiftR_word[simp]: "(shiftR :: 'a::len0 word \<Rightarrow> nat \<Rightarrow> 'a word) \<equiv> shiftr"
216
217definition
218  bitSize_word[simp]: "(bitSize :: 'a::len0 word \<Rightarrow> nat) \<equiv> size"
219
220instance ..
221
222end
223
224instantiation nat ::  HS_bit
225begin
226
227definition
228  shiftL_nat: "shiftL (x :: nat) n \<equiv> (2 ^ n) * x"
229
230definition
231  shiftR_nat: "shiftR (x :: nat) n \<equiv> x div (2 ^ n)"
232
233text \<open>bitSize not defined for nat\<close>
234
235instance ..
236
237end
238
239class finiteBit = bit_operations +
240  fixes finiteBitSize :: "'a \<Rightarrow> nat"
241
242instantiation word :: (len0) finiteBit
243begin
244
245definition
246  finiteBitSize_word[simp]: "(finiteBitSize :: 'a::len0 word \<Rightarrow> nat) \<equiv> size"
247
248instance ..
249
250end
251
252definition bit :: "nat \<Rightarrow> 'a::{one,HS_bit}" where
253  bit_def[simp]: "bit x \<equiv> shiftL 1 x"
254
255definition
256"isAligned x n \<equiv> x && mask n = 0"
257
258class integral = ord +
259  fixes fromInteger :: "nat \<Rightarrow> 'a"
260  fixes toInteger :: "'a \<Rightarrow> nat"
261  assumes integral_inv: "fromInteger \<circ> toInteger = id"
262
263instantiation nat :: integral
264begin
265
266definition
267 fromInteger_nat: "fromInteger \<equiv> id"
268
269definition
270 toInteger_nat: "toInteger \<equiv> id"
271
272instance
273  apply (intro_classes)
274  apply (simp add: toInteger_nat fromInteger_nat)
275  done
276
277end
278
279
280instantiation word :: (len) integral
281begin
282
283definition
284 fromInteger_word: "fromInteger \<equiv> of_nat :: nat \<Rightarrow> 'a::len word"
285
286definition
287 toInteger_word: "toInteger \<equiv> unat"
288
289instance
290  apply (intro_classes)
291  apply (rule ext)
292  apply (simp add: toInteger_word fromInteger_word)
293  done
294
295end
296
297definition
298  fromIntegral :: "('a :: integral) \<Rightarrow> ('b :: integral)" where
299 "fromIntegral \<equiv> fromInteger \<circ> toInteger"
300
301lemma fromIntegral_simp1[simp]: "(fromIntegral :: nat \<Rightarrow> ('a :: len) word) = of_nat"
302  by (simp add: fromIntegral_def fromInteger_word toInteger_nat)
303
304lemma fromIntegral_simp2[simp]: "fromIntegral = unat"
305  by (simp add: fromIntegral_def fromInteger_nat toInteger_word)
306
307lemma fromIntegral_simp3[simp]: "fromIntegral = ucast"
308  apply (simp add: fromIntegral_def fromInteger_word toInteger_word)
309  apply (rule ext)
310  apply (simp add: ucast_def)
311  apply (subst word_of_nat)
312  apply (simp add: unat_def)
313  done
314
315lemma fromIntegral_simp_nat[simp]: "(fromIntegral :: nat \<Rightarrow> nat) = id"
316  by (simp add: fromIntegral_def fromInteger_nat toInteger_nat)
317
318definition
319  infix_apply :: "'a \<Rightarrow> ('a \<Rightarrow> 'b \<Rightarrow> 'c) \<Rightarrow> 'b \<Rightarrow> 'c" ("_ `~_~` _" [81, 100, 80] 80) where
320  infix_apply_def[simp]:
321 "infix_apply a f b \<equiv> f a b"
322
323term "return $ a `~b~` c d"
324
325definition
326  zip3 :: "'a list \<Rightarrow> 'b list \<Rightarrow> 'c list \<Rightarrow> ('a \<times> 'b \<times> 'c) list" where
327 "zip3 a b c \<equiv> zip a (zip b c)"
328
329(* avoid even attempting haskell's show class *)
330definition
331 "show" :: "'a \<Rightarrow> unit list" where
332 "show x \<equiv> []"
333
334lemma show_simp_away[simp]: "S @ show t = S"
335  by (simp add: show_def)
336
337definition
338 "andList \<equiv> foldl (\<and>) True"
339
340definition
341 "orList \<equiv> foldl (\<or>) False"
342
343primrec
344  mapAccumL :: "('a \<Rightarrow> 'b \<Rightarrow> 'a \<times> 'c) \<Rightarrow> 'a \<Rightarrow> 'b list \<Rightarrow> 'a \<times> ('c list)"
345where
346  "mapAccumL f s [] = (s, [])"
347| "mapAccumL f s (x#xs) = (
348     let (s', r) = f s x;
349         (s'', rs) = mapAccumL f s' xs
350     in (s'', r#rs)
351  )"
352
353primrec
354  untilM :: "('b \<Rightarrow> ('s, 'a option) nondet_monad) \<Rightarrow> 'b list \<Rightarrow> ('s, 'a option) nondet_monad"
355where
356  "untilM f [] = return None"
357| "untilM f (x#xs) = do
358    r \<leftarrow> f x;
359    case r of
360      None     \<Rightarrow> untilM f xs
361    | Some res \<Rightarrow> return (Some res)
362  od"
363
364primrec
365  untilME :: "('c \<Rightarrow> ('s, ('a + 'b option)) nondet_monad) \<Rightarrow> 'c list \<Rightarrow> ('s, 'a + 'b option) nondet_monad"
366where
367  "untilME f [] = returnOk None"
368| "untilME f (x#xs) = doE
369    r \<leftarrow> f x;
370    case r of
371      None     \<Rightarrow> untilME f xs
372    | Some res \<Rightarrow> returnOk (Some res)
373  odE"
374
375primrec
376  findM :: "('a \<Rightarrow> ('s, bool) nondet_monad) \<Rightarrow> 'a list \<Rightarrow> ('s, 'a option) nondet_monad"
377where
378  "findM f [] = return None"
379| "findM f (x#xs) = do
380    r \<leftarrow> f x;
381    if r
382      then return (Some x)
383      else findM f xs
384  od"
385
386primrec
387  findME :: "('a \<Rightarrow> ('s, ('e + bool)) nondet_monad) \<Rightarrow> 'a list \<Rightarrow> ('s, ('e + 'a option)) nondet_monad"
388where
389  "findME f [] = returnOk None"
390| "findME f (x#xs) = doE
391    r \<leftarrow> f x;
392    if r
393      then returnOk (Some x)
394      else findME f xs
395  odE"
396
397primrec
398  tails :: "'a list \<Rightarrow> 'a list list"
399where
400  "tails [] = [[]]"
401| "tails (x#xs) = (x#xs)#(tails xs)"
402
403lemma finite_surj_type:
404  "\<lbrakk> (\<forall>x. \<exists>y. (x :: 'b) = f (y :: 'a)); finite (UNIV :: 'a set) \<rbrakk> \<Longrightarrow> finite (UNIV :: 'b set)"
405  apply (erule finite_surj)
406  apply safe
407  apply (erule allE)
408  apply safe
409  apply (erule image_eqI)
410  apply simp
411  done
412
413lemma finite_finite[simp]: "finite (s :: ('a :: finite) set)"
414  by simp
415
416lemma finite_inv_card_less':
417  "U = (UNIV :: ('a :: finite) set) \<Longrightarrow> (card (U - insert a s) < card (U - s)) = (a \<notin> s)"
418  apply (case_tac "a \<in> s")
419   apply (simp_all add: insert_absorb)
420  apply (subgoal_tac "card s < card U")
421   apply (simp add: card_Diff_subset)
422  apply (rule psubset_card_mono)
423   apply safe
424    apply simp_all
425  done
426
427lemma finite_inv_card_less:
428   "(card (UNIV - insert (a :: ('a :: finite)) s) < card (UNIV - s)) = (a \<notin> s)"
429  by (simp add: finite_inv_card_less')
430
431definition
432  "minimum ls \<equiv> Min (set ls)"
433definition
434  "maximum ls \<equiv> Max (set ls)"
435
436primrec (nonexhaustive)
437  hdCons :: "'a \<Rightarrow> 'a list list \<Rightarrow> 'a list list"
438where
439 "hdCons x (ys # zs) = (x # ys) # zs"
440
441primrec
442  rangesBy :: "('a \<Rightarrow> 'a \<Rightarrow> bool) \<Rightarrow> 'a list \<Rightarrow> 'a list list"
443where
444  "rangesBy f [] = []"
445| "rangesBy f (x # xs) =
446    (case xs of [] \<Rightarrow> [[x]]
447             | (y # ys) \<Rightarrow> if (f x y) then hdCons x (rangesBy f xs)
448                                     else [x] # (rangesBy f xs))"
449
450definition
451  partition :: "('a \<Rightarrow> bool) \<Rightarrow> 'a list \<Rightarrow> 'a list \<times> 'a list" where
452 "partition f xs \<equiv> (filter f xs, filter (\<lambda>x. \<not> f x) xs)"
453
454definition
455  listSubtract :: "'a list \<Rightarrow> 'a list \<Rightarrow> 'a list" where
456 "listSubtract xs ys \<equiv> filter (\<lambda>x. x \<in> set ys) xs"
457
458definition
459  init :: "'a list \<Rightarrow> 'a list" where
460 "init xs \<equiv> case (length xs) of Suc n \<Rightarrow> take n xs | _ \<Rightarrow> undefined"
461
462primrec
463  break :: "('a \<Rightarrow> bool) \<Rightarrow> 'a list \<Rightarrow> ('a list \<times> 'a list)"
464where
465  "break f []       = ([], [])"
466| "break f (x # xs) =
467   (if f x
468    then ([], x # xs)
469    else (\<lambda>(ys, zs). (x # ys, zs)) (break f xs))"
470
471definition
472 "uncurry \<equiv> case_prod"
473
474definition
475  sum :: "'a list \<Rightarrow> 'a::{plus,zero}" where
476 "sum \<equiv> foldl (+) 0"
477
478definition
479 "replicateM n m \<equiv> sequence (replicate n m)"
480
481definition
482  maybeToMonad_def[simp]:
483 "maybeToMonad \<equiv> assert_opt"
484
485definition
486  funArray :: "('a \<Rightarrow> 'b) \<Rightarrow> ('a \<Rightarrow> 'b)"  where
487  funArray_def[simp]:
488 "funArray \<equiv> id"
489
490definition
491  funPartialArray :: "('a \<Rightarrow> 'b) \<Rightarrow> ('a :: enumeration_alt \<times> 'a) \<Rightarrow> ('a \<Rightarrow> 'b)"  where
492 "funPartialArray f xrange \<equiv> \<lambda>x. (if x \<in> set [fst xrange .e. snd xrange] then f x else undefined)"
493
494definition
495  forM_def[simp]:
496 "forM xs f \<equiv> mapM f xs"
497
498definition
499  forM_x_def[simp]:
500 "forM_x xs f \<equiv> mapM_x f xs"
501
502definition
503  forME_x_def[simp]:
504 "forME_x xs f \<equiv> mapME_x f xs"
505
506definition
507  arrayListUpdate :: "('a \<Rightarrow> 'b) \<Rightarrow> ('a \<times> 'b) list \<Rightarrow> ('a \<Rightarrow> 'b)" (infixl "aLU" 90)
508where
509  arrayListUpdate_def[simp]:
510  "arrayListUpdate f l  \<equiv> foldl (\<lambda>f p. f(fst p := snd p)) f l"
511
512definition
513 "genericTake \<equiv> take \<circ> fromIntegral"
514
515definition
516 "genericLength \<equiv> fromIntegral \<circ> length"
517
518abbreviation
519  "null == List.null"
520
521syntax (input)
522  "_listcompr" :: "'a \<Rightarrow> lc_qual \<Rightarrow> lc_quals \<Rightarrow> 'a list"  ("[_ | __")
523
524lemma "[(x,1) . x \<leftarrow> [0..10]] = [(x,1) | x \<leftarrow> [0..10]]" by (rule refl)
525
526end
527