1(* 2 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230) 3 * 4 * SPDX-License-Identifier: BSD-2-Clause 5 *) 6 7(* Author: Thomas Sewell 8 9 Library routines etc expected by Haskell code. 10*) 11 12theory HaskellLib_H 13imports 14 Lib 15 NatBitwise 16 More_Numeral_Type 17 NonDetMonadVCG 18begin 19 20abbreviation (input) "flip \<equiv> swp" 21 22abbreviation(input) bind_drop :: "('a, 'c) nondet_monad \<Rightarrow> ('a, 'b) nondet_monad 23 \<Rightarrow> ('a, 'b) nondet_monad" (infixl ">>'_" 60) 24 where "bind_drop \<equiv> (\<lambda>x y. bind x (K_bind y))" 25 26lemma bind_drop_test: 27 "foldr bind_drop x (return ()) = sequence_x x" 28 by (rule ext, simp add: sequence_x_def) 29 30(* If the given monad is deterministic, this function converts 31 the nondet_monad type into a normal deterministic state monad *) 32definition 33 runState :: "('s, 'a) nondet_monad \<Rightarrow> 's \<Rightarrow> ('a \<times> 's)" where 34 "runState f s \<equiv> THE x. x \<in> fst (f s)" 35 36definition 37 sassert :: "bool \<Rightarrow> 'a \<Rightarrow> 'a" where 38 "sassert P \<equiv> if P then id else (\<lambda>x. undefined)" 39 40lemma sassert_cong[fundef_cong]: 41 "\<lbrakk> P = P'; P' \<Longrightarrow> s = s' \<rbrakk> \<Longrightarrow> sassert P s = sassert P' s'" 42 apply (simp add: sassert_def) 43 done 44 45definition 46 haskell_assert :: "bool \<Rightarrow> unit list \<Rightarrow> ('a, unit) nondet_monad" where 47 "haskell_assert P L \<equiv> assert P" 48 49definition 50 haskell_assertE :: "bool \<Rightarrow> unit list \<Rightarrow> ('a, 'e + unit) nondet_monad" where 51 "haskell_assertE P L \<equiv> assertE P" 52 53declare haskell_assert_def [simp] haskell_assertE_def [simp] 54 55definition 56 stateAssert :: "('a \<Rightarrow> bool) \<Rightarrow> unit list \<Rightarrow> ('a, unit) nondet_monad" where 57 "stateAssert P L \<equiv> get >>= (\<lambda>s. assert (P s))" 58 59definition 60 haskell_fail :: "unit list \<Rightarrow> ('a, 'b) nondet_monad" where 61 haskell_fail_def[simp]: 62 "haskell_fail L \<equiv> fail" 63 64definition 65 catchError_def[simp]: 66 "catchError \<equiv> handleE" 67 68definition 69 "curry1 \<equiv> id" 70definition 71 "curry2 \<equiv> curry" 72definition 73 "curry3 f a b c \<equiv> f (a, b, c)" 74definition 75 "curry4 f a b c d \<equiv> f (a, b, c, d)" 76definition 77 "curry5 f a b c d e \<equiv> f (a, b, c, d, e)" 78 79declare curry1_def[simp] curry2_def[simp] 80 curry3_def[simp] curry4_def[simp] curry5_def[simp] 81 82definition 83 "split1 \<equiv> id" 84definition 85 "split2 \<equiv> case_prod" 86definition 87 "split3 f \<equiv> \<lambda>(a, b, c). f a b c" 88definition 89 "split4 f \<equiv> \<lambda>(a, b, c, d). f a b c d" 90definition 91 "split5 f \<equiv> \<lambda>(a, b, c, d, e). f a b c d e" 92 93declare split1_def[simp] split2_def[simp] 94 95lemma split3_simp[simp]: "split3 f (a, b, c) = f a b c" 96 by (simp add: split3_def) 97 98lemma split4_simp[simp]: "split4 f (a, b, c, d) = f a b c d" 99 by (simp add: split4_def) 100 101lemma split5_simp[simp]: "split5 f (a, b, c, d, e) = f a b c d e" 102 by (simp add: split5_def) 103 104definition 105 "Just \<equiv> Some" 106definition 107 "Nothing \<equiv> None" 108 109definition 110 "fromJust \<equiv> the" 111definition 112 "isJust x \<equiv> x \<noteq> None" 113 114definition 115 "tail \<equiv> tl" 116definition 117 "head \<equiv> hd" 118 119definition 120 error :: "unit list \<Rightarrow> 'a" where 121 "error \<equiv> \<lambda>x. undefined" 122 123definition 124 "reverse \<equiv> rev" 125 126definition 127 "isNothing x \<equiv> x = None" 128 129definition 130 "maybeApply \<equiv> option_map" 131 132definition 133 "maybe \<equiv> case_option" 134 135definition 136 "foldR f init L \<equiv> foldr f L init" 137 138definition 139 "elem x L \<equiv> x \<in> set L" 140 141definition 142 "notElem x L \<equiv> x \<notin> set L" 143 144type_synonym ordering = bool 145 146definition 147 compare :: "('a :: ord) \<Rightarrow> 'a \<Rightarrow> ordering" where 148 "compare \<equiv> (<)" 149 150primrec 151 insertBy :: "('a \<Rightarrow> 'a \<Rightarrow> bool) \<Rightarrow> 'a \<Rightarrow> 'a list \<Rightarrow> 'a list" 152where 153 "insertBy f a [] = [a]" 154| "insertBy f a (b # bs) = (if (f a b) then (a # b # bs) else (b # (insertBy f a bs)))" 155 156lemma insertBy_length [simp]: 157 "length (insertBy f a as) = (1 + length as)" 158 by (induct as) simp_all 159 160primrec 161 sortBy :: "('a \<Rightarrow> 'a \<Rightarrow> ordering) \<Rightarrow> 'a list \<Rightarrow> 'a list" 162where 163 "sortBy f [] = []" 164| "sortBy f (a # as) = insertBy f a (sortBy f as)" 165 166lemma sortBy_length: 167 "length (sortBy f as) = length as" 168 by (induct as) simp_all 169 170definition 171 "sortH \<equiv> sortBy compare" 172 173definition 174 "catMaybes \<equiv> (map the) \<circ> (filter isJust)" 175 176definition 177 "runExceptT \<equiv> id" 178 179declare Just_def[simp] Nothing_def[simp] fromJust_def[simp] 180 isJust_def[simp] tail_def[simp] head_def[simp] 181 error_def[simp] reverse_def[simp] isNothing_def[simp] 182 maybeApply_def[simp] maybe_def[simp] 183 foldR_def[simp] elem_def[simp] notElem_def[simp] 184 catMaybes_def[simp] runExceptT_def[simp] 185 186definition 187 "headM L \<equiv> (case L of (h # t) \<Rightarrow> return h | _ \<Rightarrow> fail)" 188 189definition 190 "tailM L \<equiv> (case L of (h # t) \<Rightarrow> return t | _ \<Rightarrow> fail)" 191 192axiomatization 193 typeOf :: "'a \<Rightarrow> unit list" 194 195definition 196 "either f1 f2 c \<equiv> case c of Inl r1 \<Rightarrow> f1 r1 | Inr r2 \<Rightarrow> f2 r2" 197 198lemma either_simp[simp]: "either = case_sum" 199 apply (rule ext)+ 200 apply (simp add: either_def) 201 done 202 203class HS_bit = bit_operations + 204 fixes shiftL :: "'a \<Rightarrow> nat \<Rightarrow> 'a" 205 fixes shiftR :: "'a \<Rightarrow> nat \<Rightarrow> 'a" 206 fixes bitSize :: "'a \<Rightarrow> nat" 207 208instantiation word :: (len0) HS_bit 209begin 210 211definition 212 shiftL_word[simp]: "(shiftL :: 'a::len0 word \<Rightarrow> nat \<Rightarrow> 'a word) \<equiv> shiftl" 213 214definition 215 shiftR_word[simp]: "(shiftR :: 'a::len0 word \<Rightarrow> nat \<Rightarrow> 'a word) \<equiv> shiftr" 216 217definition 218 bitSize_word[simp]: "(bitSize :: 'a::len0 word \<Rightarrow> nat) \<equiv> size" 219 220instance .. 221 222end 223 224instantiation nat :: HS_bit 225begin 226 227definition 228 shiftL_nat: "shiftL (x :: nat) n \<equiv> (2 ^ n) * x" 229 230definition 231 shiftR_nat: "shiftR (x :: nat) n \<equiv> x div (2 ^ n)" 232 233text \<open>bitSize not defined for nat\<close> 234 235instance .. 236 237end 238 239class finiteBit = bit_operations + 240 fixes finiteBitSize :: "'a \<Rightarrow> nat" 241 242instantiation word :: (len0) finiteBit 243begin 244 245definition 246 finiteBitSize_word[simp]: "(finiteBitSize :: 'a::len0 word \<Rightarrow> nat) \<equiv> size" 247 248instance .. 249 250end 251 252definition bit :: "nat \<Rightarrow> 'a::{one,HS_bit}" where 253 bit_def[simp]: "bit x \<equiv> shiftL 1 x" 254 255definition 256"isAligned x n \<equiv> x && mask n = 0" 257 258class integral = ord + 259 fixes fromInteger :: "nat \<Rightarrow> 'a" 260 fixes toInteger :: "'a \<Rightarrow> nat" 261 assumes integral_inv: "fromInteger \<circ> toInteger = id" 262 263instantiation nat :: integral 264begin 265 266definition 267 fromInteger_nat: "fromInteger \<equiv> id" 268 269definition 270 toInteger_nat: "toInteger \<equiv> id" 271 272instance 273 apply (intro_classes) 274 apply (simp add: toInteger_nat fromInteger_nat) 275 done 276 277end 278 279 280instantiation word :: (len) integral 281begin 282 283definition 284 fromInteger_word: "fromInteger \<equiv> of_nat :: nat \<Rightarrow> 'a::len word" 285 286definition 287 toInteger_word: "toInteger \<equiv> unat" 288 289instance 290 apply (intro_classes) 291 apply (rule ext) 292 apply (simp add: toInteger_word fromInteger_word) 293 done 294 295end 296 297definition 298 fromIntegral :: "('a :: integral) \<Rightarrow> ('b :: integral)" where 299 "fromIntegral \<equiv> fromInteger \<circ> toInteger" 300 301lemma fromIntegral_simp1[simp]: "(fromIntegral :: nat \<Rightarrow> ('a :: len) word) = of_nat" 302 by (simp add: fromIntegral_def fromInteger_word toInteger_nat) 303 304lemma fromIntegral_simp2[simp]: "fromIntegral = unat" 305 by (simp add: fromIntegral_def fromInteger_nat toInteger_word) 306 307lemma fromIntegral_simp3[simp]: "fromIntegral = ucast" 308 apply (simp add: fromIntegral_def fromInteger_word toInteger_word) 309 apply (rule ext) 310 apply (simp add: ucast_def) 311 apply (subst word_of_nat) 312 apply (simp add: unat_def) 313 done 314 315lemma fromIntegral_simp_nat[simp]: "(fromIntegral :: nat \<Rightarrow> nat) = id" 316 by (simp add: fromIntegral_def fromInteger_nat toInteger_nat) 317 318definition 319 infix_apply :: "'a \<Rightarrow> ('a \<Rightarrow> 'b \<Rightarrow> 'c) \<Rightarrow> 'b \<Rightarrow> 'c" ("_ `~_~` _" [81, 100, 80] 80) where 320 infix_apply_def[simp]: 321 "infix_apply a f b \<equiv> f a b" 322 323term "return $ a `~b~` c d" 324 325definition 326 zip3 :: "'a list \<Rightarrow> 'b list \<Rightarrow> 'c list \<Rightarrow> ('a \<times> 'b \<times> 'c) list" where 327 "zip3 a b c \<equiv> zip a (zip b c)" 328 329(* avoid even attempting haskell's show class *) 330definition 331 "show" :: "'a \<Rightarrow> unit list" where 332 "show x \<equiv> []" 333 334lemma show_simp_away[simp]: "S @ show t = S" 335 by (simp add: show_def) 336 337definition 338 "andList \<equiv> foldl (\<and>) True" 339 340definition 341 "orList \<equiv> foldl (\<or>) False" 342 343primrec 344 mapAccumL :: "('a \<Rightarrow> 'b \<Rightarrow> 'a \<times> 'c) \<Rightarrow> 'a \<Rightarrow> 'b list \<Rightarrow> 'a \<times> ('c list)" 345where 346 "mapAccumL f s [] = (s, [])" 347| "mapAccumL f s (x#xs) = ( 348 let (s', r) = f s x; 349 (s'', rs) = mapAccumL f s' xs 350 in (s'', r#rs) 351 )" 352 353primrec 354 untilM :: "('b \<Rightarrow> ('s, 'a option) nondet_monad) \<Rightarrow> 'b list \<Rightarrow> ('s, 'a option) nondet_monad" 355where 356 "untilM f [] = return None" 357| "untilM f (x#xs) = do 358 r \<leftarrow> f x; 359 case r of 360 None \<Rightarrow> untilM f xs 361 | Some res \<Rightarrow> return (Some res) 362 od" 363 364primrec 365 untilME :: "('c \<Rightarrow> ('s, ('a + 'b option)) nondet_monad) \<Rightarrow> 'c list \<Rightarrow> ('s, 'a + 'b option) nondet_monad" 366where 367 "untilME f [] = returnOk None" 368| "untilME f (x#xs) = doE 369 r \<leftarrow> f x; 370 case r of 371 None \<Rightarrow> untilME f xs 372 | Some res \<Rightarrow> returnOk (Some res) 373 odE" 374 375primrec 376 findM :: "('a \<Rightarrow> ('s, bool) nondet_monad) \<Rightarrow> 'a list \<Rightarrow> ('s, 'a option) nondet_monad" 377where 378 "findM f [] = return None" 379| "findM f (x#xs) = do 380 r \<leftarrow> f x; 381 if r 382 then return (Some x) 383 else findM f xs 384 od" 385 386primrec 387 findME :: "('a \<Rightarrow> ('s, ('e + bool)) nondet_monad) \<Rightarrow> 'a list \<Rightarrow> ('s, ('e + 'a option)) nondet_monad" 388where 389 "findME f [] = returnOk None" 390| "findME f (x#xs) = doE 391 r \<leftarrow> f x; 392 if r 393 then returnOk (Some x) 394 else findME f xs 395 odE" 396 397primrec 398 tails :: "'a list \<Rightarrow> 'a list list" 399where 400 "tails [] = [[]]" 401| "tails (x#xs) = (x#xs)#(tails xs)" 402 403lemma finite_surj_type: 404 "\<lbrakk> (\<forall>x. \<exists>y. (x :: 'b) = f (y :: 'a)); finite (UNIV :: 'a set) \<rbrakk> \<Longrightarrow> finite (UNIV :: 'b set)" 405 apply (erule finite_surj) 406 apply safe 407 apply (erule allE) 408 apply safe 409 apply (erule image_eqI) 410 apply simp 411 done 412 413lemma finite_finite[simp]: "finite (s :: ('a :: finite) set)" 414 by simp 415 416lemma finite_inv_card_less': 417 "U = (UNIV :: ('a :: finite) set) \<Longrightarrow> (card (U - insert a s) < card (U - s)) = (a \<notin> s)" 418 apply (case_tac "a \<in> s") 419 apply (simp_all add: insert_absorb) 420 apply (subgoal_tac "card s < card U") 421 apply (simp add: card_Diff_subset) 422 apply (rule psubset_card_mono) 423 apply safe 424 apply simp_all 425 done 426 427lemma finite_inv_card_less: 428 "(card (UNIV - insert (a :: ('a :: finite)) s) < card (UNIV - s)) = (a \<notin> s)" 429 by (simp add: finite_inv_card_less') 430 431definition 432 "minimum ls \<equiv> Min (set ls)" 433definition 434 "maximum ls \<equiv> Max (set ls)" 435 436primrec (nonexhaustive) 437 hdCons :: "'a \<Rightarrow> 'a list list \<Rightarrow> 'a list list" 438where 439 "hdCons x (ys # zs) = (x # ys) # zs" 440 441primrec 442 rangesBy :: "('a \<Rightarrow> 'a \<Rightarrow> bool) \<Rightarrow> 'a list \<Rightarrow> 'a list list" 443where 444 "rangesBy f [] = []" 445| "rangesBy f (x # xs) = 446 (case xs of [] \<Rightarrow> [[x]] 447 | (y # ys) \<Rightarrow> if (f x y) then hdCons x (rangesBy f xs) 448 else [x] # (rangesBy f xs))" 449 450definition 451 partition :: "('a \<Rightarrow> bool) \<Rightarrow> 'a list \<Rightarrow> 'a list \<times> 'a list" where 452 "partition f xs \<equiv> (filter f xs, filter (\<lambda>x. \<not> f x) xs)" 453 454definition 455 listSubtract :: "'a list \<Rightarrow> 'a list \<Rightarrow> 'a list" where 456 "listSubtract xs ys \<equiv> filter (\<lambda>x. x \<in> set ys) xs" 457 458definition 459 init :: "'a list \<Rightarrow> 'a list" where 460 "init xs \<equiv> case (length xs) of Suc n \<Rightarrow> take n xs | _ \<Rightarrow> undefined" 461 462primrec 463 break :: "('a \<Rightarrow> bool) \<Rightarrow> 'a list \<Rightarrow> ('a list \<times> 'a list)" 464where 465 "break f [] = ([], [])" 466| "break f (x # xs) = 467 (if f x 468 then ([], x # xs) 469 else (\<lambda>(ys, zs). (x # ys, zs)) (break f xs))" 470 471definition 472 "uncurry \<equiv> case_prod" 473 474definition 475 sum :: "'a list \<Rightarrow> 'a::{plus,zero}" where 476 "sum \<equiv> foldl (+) 0" 477 478definition 479 "replicateM n m \<equiv> sequence (replicate n m)" 480 481definition 482 maybeToMonad_def[simp]: 483 "maybeToMonad \<equiv> assert_opt" 484 485definition 486 funArray :: "('a \<Rightarrow> 'b) \<Rightarrow> ('a \<Rightarrow> 'b)" where 487 funArray_def[simp]: 488 "funArray \<equiv> id" 489 490definition 491 funPartialArray :: "('a \<Rightarrow> 'b) \<Rightarrow> ('a :: enumeration_alt \<times> 'a) \<Rightarrow> ('a \<Rightarrow> 'b)" where 492 "funPartialArray f xrange \<equiv> \<lambda>x. (if x \<in> set [fst xrange .e. snd xrange] then f x else undefined)" 493 494definition 495 forM_def[simp]: 496 "forM xs f \<equiv> mapM f xs" 497 498definition 499 forM_x_def[simp]: 500 "forM_x xs f \<equiv> mapM_x f xs" 501 502definition 503 forME_x_def[simp]: 504 "forME_x xs f \<equiv> mapME_x f xs" 505 506definition 507 arrayListUpdate :: "('a \<Rightarrow> 'b) \<Rightarrow> ('a \<times> 'b) list \<Rightarrow> ('a \<Rightarrow> 'b)" (infixl "aLU" 90) 508where 509 arrayListUpdate_def[simp]: 510 "arrayListUpdate f l \<equiv> foldl (\<lambda>f p. f(fst p := snd p)) f l" 511 512definition 513 "genericTake \<equiv> take \<circ> fromIntegral" 514 515definition 516 "genericLength \<equiv> fromIntegral \<circ> length" 517 518abbreviation 519 "null == List.null" 520 521syntax (input) 522 "_listcompr" :: "'a \<Rightarrow> lc_qual \<Rightarrow> lc_quals \<Rightarrow> 'a list" ("[_ | __") 523 524lemma "[(x,1) . x \<leftarrow> [0..10]] = [(x,1) | x \<leftarrow> [0..10]]" by (rule refl) 525 526end 527