(* * Copyright 2014, NICTA * * This software may be distributed and modified according to the terms of * the BSD 2-Clause license. Note that NO WARRANTY is provided. * See "LICENSE_BSD2.txt" for details. * * @TAG(NICTA_BSD) *) (* Author: Thomas Sewell Library routines etc expected by Haskell code. *) theory HaskellLib_H imports Lib "More_Numeral_Type" "Monad_WP/NonDetMonadVCG" begin abbreviation (input) "flip \ swp" abbreviation(input) bind_drop :: "('a, 'c) nondet_monad \ ('a, 'b) nondet_monad \ ('a, 'b) nondet_monad" (infixl ">>'_" 60) where "bind_drop \ (\x y. bind x (K_bind y))" lemma bind_drop_test: "foldr bind_drop x (return ()) = sequence_x x" by (rule ext, simp add: sequence_x_def) (* If the given monad is deterministic, this function converts the nondet_monad type into a normal deterministic state monad *) definition runState :: "('s, 'a) nondet_monad \ 's \ ('a \ 's)" where "runState f s \ THE x. x \ fst (f s)" definition sassert :: "bool \ 'a \ 'a" where "sassert P \ if P then id else (\x. undefined)" lemma sassert_cong[fundef_cong]: "\ P = P'; P' \ s = s' \ \ sassert P s = sassert P' s'" apply (simp add: sassert_def) done definition haskell_assert :: "bool \ unit list \ ('a, unit) nondet_monad" where "haskell_assert P L \ assert P" definition haskell_assertE :: "bool \ unit list \ ('a, 'e + unit) nondet_monad" where "haskell_assertE P L \ assertE P" declare haskell_assert_def [simp] haskell_assertE_def [simp] definition stateAssert :: "('a \ bool) \ unit list \ ('a, unit) nondet_monad" where "stateAssert P L \ get >>= (\s. assert (P s))" definition haskell_fail :: "unit list \ ('a, 'b) nondet_monad" where haskell_fail_def[simp]: "haskell_fail L \ fail" definition catchError_def[simp]: "catchError \ handleE" definition "curry1 \ id" definition "curry2 \ curry" definition "curry3 f a b c \ f (a, b, c)" definition "curry4 f a b c d \ f (a, b, c, d)" definition "curry5 f a b c d e \ f (a, b, c, d, e)" declare curry1_def[simp] curry2_def[simp] curry3_def[simp] curry4_def[simp] curry5_def[simp] definition "split1 \ id" definition "split2 \ case_prod" definition "split3 f \ \(a, b, c). f a b c" definition "split4 f \ \(a, b, c, d). f a b c d" definition "split5 f \ \(a, b, c, d, e). f a b c d e" declare split1_def[simp] split2_def[simp] lemma split3_simp[simp]: "split3 f (a, b, c) = f a b c" by (simp add: split3_def) lemma split4_simp[simp]: "split4 f (a, b, c, d) = f a b c d" by (simp add: split4_def) lemma split5_simp[simp]: "split5 f (a, b, c, d, e) = f a b c d e" by (simp add: split5_def) definition "Just \ Some" definition "Nothing \ None" definition "fromJust \ the" definition "isJust x \ x \ None" definition "tail \ tl" definition "head \ hd" definition error :: "unit list \ 'a" where "error \ \x. undefined" definition "reverse \ rev" definition "isNothing x \ x = None" definition "maybeApply \ option_map" definition "maybe \ case_option" definition "foldR f init L \ foldr f L init" definition "elem x L \ x \ set L" definition "notElem x L \ x \ set L" type_synonym ordering = bool definition compare :: "('a :: ord) \ 'a \ ordering" where "compare \ (<)" primrec insertBy :: "('a \ 'a \ bool) \ 'a \ 'a list \ 'a list" where "insertBy f a [] = [a]" | "insertBy f a (b # bs) = (if (f a b) then (a # b # bs) else (b # (insertBy f a bs)))" lemma insertBy_length [simp]: "length (insertBy f a as) = (1 + length as)" by (induct as) simp_all primrec sortBy :: "('a \ 'a \ ordering) \ 'a list \ 'a list" where "sortBy f [] = []" | "sortBy f (a # as) = insertBy f a (sortBy f as)" lemma sortBy_length: "length (sortBy f as) = length as" by (induct as) simp_all definition "sortH \ sortBy compare" definition "catMaybes \ (map the) \ (filter isJust)" definition "runExceptT \ id" declare Just_def[simp] Nothing_def[simp] fromJust_def[simp] isJust_def[simp] tail_def[simp] head_def[simp] error_def[simp] reverse_def[simp] isNothing_def[simp] maybeApply_def[simp] maybe_def[simp] foldR_def[simp] elem_def[simp] notElem_def[simp] catMaybes_def[simp] runExceptT_def[simp] definition "headM L \ (case L of (h # t) \ return h | _ \ fail)" definition "tailM L \ (case L of (h # t) \ return t | _ \ fail)" axiomatization typeOf :: "'a \ unit list" definition "either f1 f2 c \ case c of Inl r1 \ f1 r1 | Inr r2 \ f2 r2" lemma either_simp[simp]: "either = case_sum" apply (rule ext)+ apply (simp add: either_def) done instantiation nat :: bit begin definition "bitNOT = nat o bitNOT o int" definition "bitAND x y = nat (bitAND (int x) (int y))" definition "bitOR x y = nat (bitOR (int x) (int y))" definition "bitXOR x y = nat (bitXOR (int x) (int y))" instance .. end class HS_bit = bit + fixes shiftL :: "'a \ nat \ 'a" fixes shiftR :: "'a \ nat \ 'a" fixes bitSize :: "'a \ nat" instantiation word :: (len0) HS_bit begin definition shiftL_word[simp]: "(shiftL :: 'a::len0 word \ nat \ 'a word) \ shiftl" definition shiftR_word[simp]: "(shiftR :: 'a::len0 word \ nat \ 'a word) \ shiftr" definition bitSize_word[simp]: "(bitSize :: 'a::len0 word \ nat) \ size" instance .. end instantiation nat :: HS_bit begin definition shiftL_nat: "shiftL (x :: nat) n \ (2 ^ n) * x" definition shiftR_nat: "shiftR (x :: nat) n \ x div (2 ^ n)" text {* bitSize not defined for nat *} instance .. end class finiteBit = bit + fixes finiteBitSize :: "'a \ nat" instantiation word :: (len0) finiteBit begin definition finiteBitSize_word[simp]: "(finiteBitSize :: 'a::len0 word \ nat) \ size" instance .. end definition bit_def[simp]: "bit x \ shiftL 1 x" definition "isAligned x n \ x && mask n = 0" class integral = ord + fixes fromInteger :: "nat \ 'a" fixes toInteger :: "'a \ nat" assumes integral_inv: "fromInteger \ toInteger = id" instantiation nat :: integral begin definition fromInteger_nat: "fromInteger \ id" definition toInteger_nat: "toInteger \ id" instance apply (intro_classes) apply (simp add: toInteger_nat fromInteger_nat) done end instantiation word :: (len) integral begin definition fromInteger_word: "fromInteger \ of_nat :: nat \ 'a::len word" definition toInteger_word: "toInteger \ unat" instance apply (intro_classes) apply (rule ext) apply (simp add: toInteger_word fromInteger_word) done end definition fromIntegral :: "('a :: integral) \ ('b :: integral)" where "fromIntegral \ fromInteger \ toInteger" lemma fromIntegral_simp1[simp]: "(fromIntegral :: nat \ ('a :: len) word) = of_nat" by (simp add: fromIntegral_def fromInteger_word toInteger_nat) lemma fromIntegral_simp2[simp]: "fromIntegral = unat" by (simp add: fromIntegral_def fromInteger_nat toInteger_word) lemma fromIntegral_simp3[simp]: "fromIntegral = ucast" apply (simp add: fromIntegral_def fromInteger_word toInteger_word) apply (rule ext) apply (simp add: ucast_def) apply (subst word_of_nat) apply (simp add: unat_def) done lemma fromIntegral_simp_nat[simp]: "(fromIntegral :: nat \ nat) = id" by (simp add: fromIntegral_def fromInteger_nat toInteger_nat) definition infix_apply :: "'a \ ('a \ 'b \ 'c) \ 'b \ 'c" ("_ `~_~` _" [81, 100, 80] 80) where infix_apply_def[simp]: "infix_apply a f b \ f a b" term "return $ a `~b~` c d" definition zip3 :: "'a list \ 'b list \ 'c list \ ('a \ 'b \ 'c) list" where "zip3 a b c \ zip a (zip b c)" (* avoid even attempting haskell's show class *) definition "show" :: "'a \ unit list" where "show x \ []" lemma show_simp_away[simp]: "S @ show t = S" by (simp add: show_def) definition "andList \ foldl (\) True" definition "orList \ foldl (\) False" primrec mapAccumL :: "('a \ 'b \ 'a \ 'c) \ 'a \ 'b list \ 'a \ ('c list)" where "mapAccumL f s [] = (s, [])" | "mapAccumL f s (x#xs) = ( let (s', r) = f s x; (s'', rs) = mapAccumL f s' xs in (s'', r#rs) )" primrec untilM :: "('b \ ('s, 'a option) nondet_monad) \ 'b list \ ('s, 'a option) nondet_monad" where "untilM f [] = return None" | "untilM f (x#xs) = do r \ f x; case r of None \ untilM f xs | Some res \ return (Some res) od" primrec untilME :: "('c \ ('s, ('a + 'b option)) nondet_monad) \ 'c list \ ('s, 'a + 'b option) nondet_monad" where "untilME f [] = returnOk None" | "untilME f (x#xs) = doE r \ f x; case r of None \ untilME f xs | Some res \ returnOk (Some res) odE" primrec findM :: "('a \ ('s, bool) nondet_monad) \ 'a list \ ('s, 'a option) nondet_monad" where "findM f [] = return None" | "findM f (x#xs) = do r \ f x; if r then return (Some x) else findM f xs od" primrec findME :: "('a \ ('s, ('e + bool)) nondet_monad) \ 'a list \ ('s, ('e + 'a option)) nondet_monad" where "findME f [] = returnOk None" | "findME f (x#xs) = doE r \ f x; if r then returnOk (Some x) else findME f xs odE" primrec tails :: "'a list \ 'a list list" where "tails [] = [[]]" | "tails (x#xs) = (x#xs)#(tails xs)" lemma finite_surj_type: "\ (\x. \y. (x :: 'b) = f (y :: 'a)); finite (UNIV :: 'a set) \ \ finite (UNIV :: 'b set)" apply (erule finite_surj) apply safe apply (erule allE) apply safe apply (erule image_eqI) apply simp done lemma finite_finite[simp]: "finite (s :: ('a :: finite) set)" by simp lemma finite_inv_card_less': "U = (UNIV :: ('a :: finite) set) \ (card (U - insert a s) < card (U - s)) = (a \ s)" apply (case_tac "a \ s") apply (simp_all add: insert_absorb) apply (subgoal_tac "card s < card U") apply (simp add: card_Diff_subset) apply (rule psubset_card_mono) apply safe apply simp_all done lemma finite_inv_card_less: "(card (UNIV - insert (a :: ('a :: finite)) s) < card (UNIV - s)) = (a \ s)" by (simp add: finite_inv_card_less') text {* Support for defining enumerations on datatypes derived from enumerations *} lemma distinct_map_enum: "\ (\ x y. (F x = F y \ x = y )) \ \ distinct (map F (enum :: 'a :: enum list))" apply (simp add: distinct_map) apply (rule inj_onI) apply simp done definition "minimum ls \ Min (set ls)" definition "maximum ls \ Max (set ls)" primrec (nonexhaustive) hdCons :: "'a \ 'a list list \ 'a list list" where "hdCons x (ys # zs) = (x # ys) # zs" primrec rangesBy :: "('a \ 'a \ bool) \ 'a list \ 'a list list" where "rangesBy f [] = []" | "rangesBy f (x # xs) = (case xs of [] \ [[x]] | (y # ys) \ if (f x y) then hdCons x (rangesBy f xs) else [x] # (rangesBy f xs))" definition partition :: "('a \ bool) \ 'a list \ 'a list \ 'a list" where "partition f xs \ (filter f xs, filter (\x. \ f x) xs)" definition listSubtract :: "'a list \ 'a list \ 'a list" where "listSubtract xs ys \ filter (\x. x \ set ys) xs" definition init :: "'a list \ 'a list" where "init xs \ case (length xs) of Suc n \ take n xs | _ \ undefined" primrec break :: "('a \ bool) \ 'a list \ ('a list \ 'a list)" where "break f [] = ([], [])" | "break f (x # xs) = (if f x then ([], x # xs) else (\(ys, zs). (x # ys, zs)) (break f xs))" definition "uncurry \ case_prod" definition sum :: "'a list \ 'a::{plus,zero}" where "sum \ foldl (+) 0" definition "replicateM n m \ sequence (replicate n m)" definition maybeToMonad_def[simp]: "maybeToMonad \ assert_opt" definition funArray :: "('a \ 'b) \ ('a \ 'b)" where funArray_def[simp]: "funArray \ id" definition funPartialArray :: "('a \ 'b) \ ('a :: enumeration_alt \ 'a) \ ('a \ 'b)" where "funPartialArray f xrange \ \x. (if x \ set [fst xrange .e. snd xrange] then f x else undefined)" definition forM_def[simp]: "forM xs f \ mapM f xs" definition forM_x_def[simp]: "forM_x xs f \ mapM_x f xs" definition forME_x_def[simp]: "forME_x xs f \ mapME_x f xs" definition arrayListUpdate :: "('a \ 'b) \ ('a \ 'b) list \ ('a \ 'b)" (infixl "aLU" 90) where arrayListUpdate_def[simp]: "arrayListUpdate f l \ foldl (\f p. f(fst p := snd p)) f l" definition "genericTake \ take \ fromIntegral" definition "genericLength \ fromIntegral \ length" abbreviation "null == List.null" syntax (input) "_listcompr" :: "'a \ lc_qual \ lc_quals \ 'a list" ("[_ | __") lemma "[(x,1) . x \ [0..10]] = [(x,1) | x \ [0..10]]" by (rule refl) end