1(* 2 * Copyright 2014, NICTA 3 * 4 * This software may be distributed and modified according to the terms of 5 * the BSD 2-Clause license. Note that NO WARRANTY is provided. 6 * See "LICENSE_BSD2.txt" for details. 7 * 8 * @TAG(NICTA_BSD) 9 *) 10 11(* Author: Thomas Sewell 12 13 Library routines etc expected by Haskell code. 14*) 15 16theory HaskellLib_H 17imports 18 Lib 19 "More_Numeral_Type" 20 "Monad_WP/NonDetMonadVCG" 21begin 22 23abbreviation (input) "flip \<equiv> swp" 24 25abbreviation(input) bind_drop :: "('a, 'c) nondet_monad \<Rightarrow> ('a, 'b) nondet_monad 26 \<Rightarrow> ('a, 'b) nondet_monad" (infixl ">>'_" 60) 27 where "bind_drop \<equiv> (\<lambda>x y. bind x (K_bind y))" 28 29lemma bind_drop_test: 30 "foldr bind_drop x (return ()) = sequence_x x" 31 by (rule ext, simp add: sequence_x_def) 32 33(* If the given monad is deterministic, this function converts 34 the nondet_monad type into a normal deterministic state monad *) 35definition 36 runState :: "('s, 'a) nondet_monad \<Rightarrow> 's \<Rightarrow> ('a \<times> 's)" where 37 "runState f s \<equiv> THE x. x \<in> fst (f s)" 38 39definition 40 sassert :: "bool \<Rightarrow> 'a \<Rightarrow> 'a" where 41 "sassert P \<equiv> if P then id else (\<lambda>x. undefined)" 42 43lemma sassert_cong[fundef_cong]: 44 "\<lbrakk> P = P'; P' \<Longrightarrow> s = s' \<rbrakk> \<Longrightarrow> sassert P s = sassert P' s'" 45 apply (simp add: sassert_def) 46 done 47 48definition 49 haskell_assert :: "bool \<Rightarrow> unit list \<Rightarrow> ('a, unit) nondet_monad" where 50 "haskell_assert P L \<equiv> assert P" 51 52definition 53 haskell_assertE :: "bool \<Rightarrow> unit list \<Rightarrow> ('a, 'e + unit) nondet_monad" where 54 "haskell_assertE P L \<equiv> assertE P" 55 56declare haskell_assert_def [simp] haskell_assertE_def [simp] 57 58definition 59 stateAssert :: "('a \<Rightarrow> bool) \<Rightarrow> unit list \<Rightarrow> ('a, unit) nondet_monad" where 60 "stateAssert P L \<equiv> get >>= (\<lambda>s. assert (P s))" 61 62definition 63 haskell_fail :: "unit list \<Rightarrow> ('a, 'b) nondet_monad" where 64 haskell_fail_def[simp]: 65 "haskell_fail L \<equiv> fail" 66 67definition 68 catchError_def[simp]: 69 "catchError \<equiv> handleE" 70 71definition 72 "curry1 \<equiv> id" 73definition 74 "curry2 \<equiv> curry" 75definition 76 "curry3 f a b c \<equiv> f (a, b, c)" 77definition 78 "curry4 f a b c d \<equiv> f (a, b, c, d)" 79definition 80 "curry5 f a b c d e \<equiv> f (a, b, c, d, e)" 81 82declare curry1_def[simp] curry2_def[simp] 83 curry3_def[simp] curry4_def[simp] curry5_def[simp] 84 85definition 86 "split1 \<equiv> id" 87definition 88 "split2 \<equiv> case_prod" 89definition 90 "split3 f \<equiv> \<lambda>(a, b, c). f a b c" 91definition 92 "split4 f \<equiv> \<lambda>(a, b, c, d). f a b c d" 93definition 94 "split5 f \<equiv> \<lambda>(a, b, c, d, e). f a b c d e" 95 96declare split1_def[simp] split2_def[simp] 97 98lemma split3_simp[simp]: "split3 f (a, b, c) = f a b c" 99 by (simp add: split3_def) 100 101lemma split4_simp[simp]: "split4 f (a, b, c, d) = f a b c d" 102 by (simp add: split4_def) 103 104lemma split5_simp[simp]: "split5 f (a, b, c, d, e) = f a b c d e" 105 by (simp add: split5_def) 106 107definition 108 "Just \<equiv> Some" 109definition 110 "Nothing \<equiv> None" 111 112definition 113 "fromJust \<equiv> the" 114definition 115 "isJust x \<equiv> x \<noteq> None" 116 117definition 118 "tail \<equiv> tl" 119definition 120 "head \<equiv> hd" 121 122definition 123 error :: "unit list \<Rightarrow> 'a" where 124 "error \<equiv> \<lambda>x. undefined" 125 126definition 127 "reverse \<equiv> rev" 128 129definition 130 "isNothing x \<equiv> x = None" 131 132definition 133 "maybeApply \<equiv> option_map" 134 135definition 136 "maybe \<equiv> case_option" 137 138definition 139 "foldR f init L \<equiv> foldr f L init" 140 141definition 142 "elem x L \<equiv> x \<in> set L" 143 144definition 145 "notElem x L \<equiv> x \<notin> set L" 146 147type_synonym ordering = bool 148 149definition 150 compare :: "('a :: ord) \<Rightarrow> 'a \<Rightarrow> ordering" where 151 "compare \<equiv> (<)" 152 153primrec 154 insertBy :: "('a \<Rightarrow> 'a \<Rightarrow> bool) \<Rightarrow> 'a \<Rightarrow> 'a list \<Rightarrow> 'a list" 155where 156 "insertBy f a [] = [a]" 157| "insertBy f a (b # bs) = (if (f a b) then (a # b # bs) else (b # (insertBy f a bs)))" 158 159lemma insertBy_length [simp]: 160 "length (insertBy f a as) = (1 + length as)" 161 by (induct as) simp_all 162 163primrec 164 sortBy :: "('a \<Rightarrow> 'a \<Rightarrow> ordering) \<Rightarrow> 'a list \<Rightarrow> 'a list" 165where 166 "sortBy f [] = []" 167| "sortBy f (a # as) = insertBy f a (sortBy f as)" 168 169lemma sortBy_length: 170 "length (sortBy f as) = length as" 171 by (induct as) simp_all 172 173definition 174 "sortH \<equiv> sortBy compare" 175 176definition 177 "catMaybes \<equiv> (map the) \<circ> (filter isJust)" 178 179definition 180 "runExceptT \<equiv> id" 181 182declare Just_def[simp] Nothing_def[simp] fromJust_def[simp] 183 isJust_def[simp] tail_def[simp] head_def[simp] 184 error_def[simp] reverse_def[simp] isNothing_def[simp] 185 maybeApply_def[simp] maybe_def[simp] 186 foldR_def[simp] elem_def[simp] notElem_def[simp] 187 catMaybes_def[simp] runExceptT_def[simp] 188 189definition 190 "headM L \<equiv> (case L of (h # t) \<Rightarrow> return h | _ \<Rightarrow> fail)" 191 192definition 193 "tailM L \<equiv> (case L of (h # t) \<Rightarrow> return t | _ \<Rightarrow> fail)" 194 195axiomatization 196 typeOf :: "'a \<Rightarrow> unit list" 197 198definition 199 "either f1 f2 c \<equiv> case c of Inl r1 \<Rightarrow> f1 r1 | Inr r2 \<Rightarrow> f2 r2" 200 201lemma either_simp[simp]: "either = case_sum" 202 apply (rule ext)+ 203 apply (simp add: either_def) 204 done 205 206 207instantiation nat :: bit 208begin 209 210definition 211 "bitNOT = nat o bitNOT o int" 212 213definition 214 "bitAND x y = nat (bitAND (int x) (int y))" 215 216definition 217 "bitOR x y = nat (bitOR (int x) (int y))" 218 219definition 220 "bitXOR x y = nat (bitXOR (int x) (int y))" 221 222instance .. 223 224end 225 226class HS_bit = bit + 227 fixes shiftL :: "'a \<Rightarrow> nat \<Rightarrow> 'a" 228 fixes shiftR :: "'a \<Rightarrow> nat \<Rightarrow> 'a" 229 fixes bitSize :: "'a \<Rightarrow> nat" 230 231instantiation word :: (len0) HS_bit 232begin 233 234definition 235 shiftL_word[simp]: "(shiftL :: 'a::len0 word \<Rightarrow> nat \<Rightarrow> 'a word) \<equiv> shiftl" 236 237definition 238 shiftR_word[simp]: "(shiftR :: 'a::len0 word \<Rightarrow> nat \<Rightarrow> 'a word) \<equiv> shiftr" 239 240definition 241 bitSize_word[simp]: "(bitSize :: 'a::len0 word \<Rightarrow> nat) \<equiv> size" 242 243instance .. 244 245end 246 247instantiation nat :: HS_bit 248begin 249 250definition 251 shiftL_nat: "shiftL (x :: nat) n \<equiv> (2 ^ n) * x" 252 253definition 254 shiftR_nat: "shiftR (x :: nat) n \<equiv> x div (2 ^ n)" 255 256text {* bitSize not defined for nat *} 257 258instance .. 259 260end 261 262class finiteBit = bit + 263 fixes finiteBitSize :: "'a \<Rightarrow> nat" 264 265instantiation word :: (len0) finiteBit 266begin 267 268definition 269 finiteBitSize_word[simp]: "(finiteBitSize :: 'a::len0 word \<Rightarrow> nat) \<equiv> size" 270 271instance .. 272 273end 274 275definition 276 bit_def[simp]: 277 "bit x \<equiv> shiftL 1 x" 278 279definition 280"isAligned x n \<equiv> x && mask n = 0" 281 282class integral = ord + 283 fixes fromInteger :: "nat \<Rightarrow> 'a" 284 fixes toInteger :: "'a \<Rightarrow> nat" 285 assumes integral_inv: "fromInteger \<circ> toInteger = id" 286 287instantiation nat :: integral 288begin 289 290definition 291 fromInteger_nat: "fromInteger \<equiv> id" 292 293definition 294 toInteger_nat: "toInteger \<equiv> id" 295 296instance 297 apply (intro_classes) 298 apply (simp add: toInteger_nat fromInteger_nat) 299 done 300 301end 302 303 304instantiation word :: (len) integral 305begin 306 307definition 308 fromInteger_word: "fromInteger \<equiv> of_nat :: nat \<Rightarrow> 'a::len word" 309 310definition 311 toInteger_word: "toInteger \<equiv> unat" 312 313instance 314 apply (intro_classes) 315 apply (rule ext) 316 apply (simp add: toInteger_word fromInteger_word) 317 done 318 319end 320 321definition 322 fromIntegral :: "('a :: integral) \<Rightarrow> ('b :: integral)" where 323 "fromIntegral \<equiv> fromInteger \<circ> toInteger" 324 325lemma fromIntegral_simp1[simp]: "(fromIntegral :: nat \<Rightarrow> ('a :: len) word) = of_nat" 326 by (simp add: fromIntegral_def fromInteger_word toInteger_nat) 327 328lemma fromIntegral_simp2[simp]: "fromIntegral = unat" 329 by (simp add: fromIntegral_def fromInteger_nat toInteger_word) 330 331lemma fromIntegral_simp3[simp]: "fromIntegral = ucast" 332 apply (simp add: fromIntegral_def fromInteger_word toInteger_word) 333 apply (rule ext) 334 apply (simp add: ucast_def) 335 apply (subst word_of_nat) 336 apply (simp add: unat_def) 337 done 338 339lemma fromIntegral_simp_nat[simp]: "(fromIntegral :: nat \<Rightarrow> nat) = id" 340 by (simp add: fromIntegral_def fromInteger_nat toInteger_nat) 341 342definition 343 infix_apply :: "'a \<Rightarrow> ('a \<Rightarrow> 'b \<Rightarrow> 'c) \<Rightarrow> 'b \<Rightarrow> 'c" ("_ `~_~` _" [81, 100, 80] 80) where 344 infix_apply_def[simp]: 345 "infix_apply a f b \<equiv> f a b" 346 347term "return $ a `~b~` c d" 348 349definition 350 zip3 :: "'a list \<Rightarrow> 'b list \<Rightarrow> 'c list \<Rightarrow> ('a \<times> 'b \<times> 'c) list" where 351 "zip3 a b c \<equiv> zip a (zip b c)" 352 353(* avoid even attempting haskell's show class *) 354definition 355 "show" :: "'a \<Rightarrow> unit list" where 356 "show x \<equiv> []" 357 358lemma show_simp_away[simp]: "S @ show t = S" 359 by (simp add: show_def) 360 361definition 362 "andList \<equiv> foldl (\<and>) True" 363 364definition 365 "orList \<equiv> foldl (\<or>) False" 366 367primrec 368 mapAccumL :: "('a \<Rightarrow> 'b \<Rightarrow> 'a \<times> 'c) \<Rightarrow> 'a \<Rightarrow> 'b list \<Rightarrow> 'a \<times> ('c list)" 369where 370 "mapAccumL f s [] = (s, [])" 371| "mapAccumL f s (x#xs) = ( 372 let (s', r) = f s x; 373 (s'', rs) = mapAccumL f s' xs 374 in (s'', r#rs) 375 )" 376 377primrec 378 untilM :: "('b \<Rightarrow> ('s, 'a option) nondet_monad) \<Rightarrow> 'b list \<Rightarrow> ('s, 'a option) nondet_monad" 379where 380 "untilM f [] = return None" 381| "untilM f (x#xs) = do 382 r \<leftarrow> f x; 383 case r of 384 None \<Rightarrow> untilM f xs 385 | Some res \<Rightarrow> return (Some res) 386 od" 387 388primrec 389 untilME :: "('c \<Rightarrow> ('s, ('a + 'b option)) nondet_monad) \<Rightarrow> 'c list \<Rightarrow> ('s, 'a + 'b option) nondet_monad" 390where 391 "untilME f [] = returnOk None" 392| "untilME f (x#xs) = doE 393 r \<leftarrow> f x; 394 case r of 395 None \<Rightarrow> untilME f xs 396 | Some res \<Rightarrow> returnOk (Some res) 397 odE" 398 399primrec 400 findM :: "('a \<Rightarrow> ('s, bool) nondet_monad) \<Rightarrow> 'a list \<Rightarrow> ('s, 'a option) nondet_monad" 401where 402 "findM f [] = return None" 403| "findM f (x#xs) = do 404 r \<leftarrow> f x; 405 if r 406 then return (Some x) 407 else findM f xs 408 od" 409 410primrec 411 findME :: "('a \<Rightarrow> ('s, ('e + bool)) nondet_monad) \<Rightarrow> 'a list \<Rightarrow> ('s, ('e + 'a option)) nondet_monad" 412where 413 "findME f [] = returnOk None" 414| "findME f (x#xs) = doE 415 r \<leftarrow> f x; 416 if r 417 then returnOk (Some x) 418 else findME f xs 419 odE" 420 421primrec 422 tails :: "'a list \<Rightarrow> 'a list list" 423where 424 "tails [] = [[]]" 425| "tails (x#xs) = (x#xs)#(tails xs)" 426 427lemma finite_surj_type: 428 "\<lbrakk> (\<forall>x. \<exists>y. (x :: 'b) = f (y :: 'a)); finite (UNIV :: 'a set) \<rbrakk> \<Longrightarrow> finite (UNIV :: 'b set)" 429 apply (erule finite_surj) 430 apply safe 431 apply (erule allE) 432 apply safe 433 apply (erule image_eqI) 434 apply simp 435 done 436 437lemma finite_finite[simp]: "finite (s :: ('a :: finite) set)" 438 by simp 439 440lemma finite_inv_card_less': 441 "U = (UNIV :: ('a :: finite) set) \<Longrightarrow> (card (U - insert a s) < card (U - s)) = (a \<notin> s)" 442 apply (case_tac "a \<in> s") 443 apply (simp_all add: insert_absorb) 444 apply (subgoal_tac "card s < card U") 445 apply (simp add: card_Diff_subset) 446 apply (rule psubset_card_mono) 447 apply safe 448 apply simp_all 449 done 450 451lemma finite_inv_card_less: 452 "(card (UNIV - insert (a :: ('a :: finite)) s) < card (UNIV - s)) = (a \<notin> s)" 453 by (simp add: finite_inv_card_less') 454 455text {* Support for defining enumerations on datatypes derived from enumerations *} 456lemma distinct_map_enum: "\<lbrakk> (\<forall> x y. (F x = F y \<longrightarrow> x = y )) \<rbrakk> \<Longrightarrow> distinct (map F (enum :: 'a :: enum list))" 457 apply (simp add: distinct_map) 458 apply (rule inj_onI) 459 apply simp 460 done 461 462definition 463 "minimum ls \<equiv> Min (set ls)" 464definition 465 "maximum ls \<equiv> Max (set ls)" 466 467primrec (nonexhaustive) 468 hdCons :: "'a \<Rightarrow> 'a list list \<Rightarrow> 'a list list" 469where 470 "hdCons x (ys # zs) = (x # ys) # zs" 471 472primrec 473 rangesBy :: "('a \<Rightarrow> 'a \<Rightarrow> bool) \<Rightarrow> 'a list \<Rightarrow> 'a list list" 474where 475 "rangesBy f [] = []" 476| "rangesBy f (x # xs) = 477 (case xs of [] \<Rightarrow> [[x]] 478 | (y # ys) \<Rightarrow> if (f x y) then hdCons x (rangesBy f xs) 479 else [x] # (rangesBy f xs))" 480 481definition 482 partition :: "('a \<Rightarrow> bool) \<Rightarrow> 'a list \<Rightarrow> 'a list \<times> 'a list" where 483 "partition f xs \<equiv> (filter f xs, filter (\<lambda>x. \<not> f x) xs)" 484 485definition 486 listSubtract :: "'a list \<Rightarrow> 'a list \<Rightarrow> 'a list" where 487 "listSubtract xs ys \<equiv> filter (\<lambda>x. x \<in> set ys) xs" 488 489definition 490 init :: "'a list \<Rightarrow> 'a list" where 491 "init xs \<equiv> case (length xs) of Suc n \<Rightarrow> take n xs | _ \<Rightarrow> undefined" 492 493primrec 494 break :: "('a \<Rightarrow> bool) \<Rightarrow> 'a list \<Rightarrow> ('a list \<times> 'a list)" 495where 496 "break f [] = ([], [])" 497| "break f (x # xs) = 498 (if f x 499 then ([], x # xs) 500 else (\<lambda>(ys, zs). (x # ys, zs)) (break f xs))" 501 502definition 503 "uncurry \<equiv> case_prod" 504 505definition 506 sum :: "'a list \<Rightarrow> 'a::{plus,zero}" where 507 "sum \<equiv> foldl (+) 0" 508 509definition 510 "replicateM n m \<equiv> sequence (replicate n m)" 511 512definition 513 maybeToMonad_def[simp]: 514 "maybeToMonad \<equiv> assert_opt" 515 516definition 517 funArray :: "('a \<Rightarrow> 'b) \<Rightarrow> ('a \<Rightarrow> 'b)" where 518 funArray_def[simp]: 519 "funArray \<equiv> id" 520 521definition 522 funPartialArray :: "('a \<Rightarrow> 'b) \<Rightarrow> ('a :: enumeration_alt \<times> 'a) \<Rightarrow> ('a \<Rightarrow> 'b)" where 523 "funPartialArray f xrange \<equiv> \<lambda>x. (if x \<in> set [fst xrange .e. snd xrange] then f x else undefined)" 524 525definition 526 forM_def[simp]: 527 "forM xs f \<equiv> mapM f xs" 528 529definition 530 forM_x_def[simp]: 531 "forM_x xs f \<equiv> mapM_x f xs" 532 533definition 534 forME_x_def[simp]: 535 "forME_x xs f \<equiv> mapME_x f xs" 536 537definition 538 arrayListUpdate :: "('a \<Rightarrow> 'b) \<Rightarrow> ('a \<times> 'b) list \<Rightarrow> ('a \<Rightarrow> 'b)" (infixl "aLU" 90) 539where 540 arrayListUpdate_def[simp]: 541 "arrayListUpdate f l \<equiv> foldl (\<lambda>f p. f(fst p := snd p)) f l" 542 543definition 544 "genericTake \<equiv> take \<circ> fromIntegral" 545 546definition 547 "genericLength \<equiv> fromIntegral \<circ> length" 548 549abbreviation 550 "null == List.null" 551 552syntax (input) 553 "_listcompr" :: "'a \<Rightarrow> lc_qual \<Rightarrow> lc_quals \<Rightarrow> 'a list" ("[_ | __") 554 555lemma "[(x,1) . x \<leftarrow> [0..10]] = [(x,1) | x \<leftarrow> [0..10]]" by (rule refl) 556 557end 558