1(* Title: HOL/Probability/PMF_Impl.thy 2 Author: Manuel Eberl, TU M��nchen 3 4 An implementation of PMFs using Mappings, which are implemented with association lists 5 by default. Also includes Quickcheck setup for PMFs. 6*) 7 8section \<open>Code generation for PMFs\<close> 9 10theory PMF_Impl 11imports Probability_Mass_Function "HOL-Library.AList_Mapping" 12begin 13 14subsection \<open>General code generation setup\<close> 15 16definition pmf_of_mapping :: "('a, real) mapping \<Rightarrow> 'a pmf" where 17 "pmf_of_mapping m = embed_pmf (Mapping.lookup_default 0 m)" 18 19lemma nn_integral_lookup_default: 20 fixes m :: "('a, real) mapping" 21 assumes "finite (Mapping.keys m)" "All_mapping m (\<lambda>_ x. x \<ge> 0)" 22 shows "nn_integral (count_space UNIV) (\<lambda>k. ennreal (Mapping.lookup_default 0 m k)) = 23 ennreal (\<Sum>k\<in>Mapping.keys m. Mapping.lookup_default 0 m k)" 24proof - 25 have "nn_integral (count_space UNIV) (\<lambda>k. ennreal (Mapping.lookup_default 0 m k)) = 26 (\<Sum>x\<in>Mapping.keys m. ennreal (Mapping.lookup_default 0 m x))" using assms 27 by (subst nn_integral_count_space'[of "Mapping.keys m"]) 28 (auto simp: Mapping.lookup_default_def keys_is_none_rep Option.is_none_def) 29 also from assms have "\<dots> = ennreal (\<Sum>k\<in>Mapping.keys m. Mapping.lookup_default 0 m k)" 30 by (intro sum_ennreal) 31 (auto simp: Mapping.lookup_default_def All_mapping_def split: option.splits) 32 finally show ?thesis . 33qed 34 35lemma pmf_of_mapping: 36 assumes "finite (Mapping.keys m)" "All_mapping m (\<lambda>_ p. p \<ge> 0)" 37 assumes "(\<Sum>x\<in>Mapping.keys m. Mapping.lookup_default 0 m x) = 1" 38 shows "pmf (pmf_of_mapping m) x = Mapping.lookup_default 0 m x" 39 unfolding pmf_of_mapping_def 40proof (intro pmf_embed_pmf) 41 from assms show "(\<integral>\<^sup>+x. ennreal (Mapping.lookup_default 0 m x) \<partial>count_space UNIV) = 1" 42 by (subst nn_integral_lookup_default) (simp_all) 43qed (insert assms, simp add: All_mapping_def Mapping.lookup_default_def split: option.splits) 44 45lemma pmf_of_set_pmf_of_mapping: 46 assumes "A \<noteq> {}" "set xs = A" "distinct xs" 47 shows "pmf_of_set A = pmf_of_mapping (Mapping.tabulate xs (\<lambda>_. 1 / real (length xs)))" 48 (is "?lhs = ?rhs") 49 by (rule pmf_eqI, subst pmf_of_mapping) 50 (insert assms, auto intro!: All_mapping_tabulate 51 simp: Mapping.lookup_default_def lookup_tabulate distinct_card) 52 53lift_definition mapping_of_pmf :: "'a pmf \<Rightarrow> ('a, real) mapping" is 54 "\<lambda>p x. if pmf p x = 0 then None else Some (pmf p x)" . 55 56lemma lookup_default_mapping_of_pmf: 57 "Mapping.lookup_default 0 (mapping_of_pmf p) x = pmf p x" 58 by (simp add: mapping_of_pmf.abs_eq lookup_default_def Mapping.lookup.abs_eq) 59 60context 61begin 62 63interpretation pmf_as_function . 64 65lemma nn_integral_pmf_eq_1: "(\<integral>\<^sup>+ x. ennreal (pmf p x) \<partial>count_space UNIV) = 1" 66 by transfer simp_all 67end 68 69lemma pmf_of_mapping_mapping_of_pmf [code abstype]: 70 "pmf_of_mapping (mapping_of_pmf p) = p" 71 unfolding pmf_of_mapping_def 72 by (rule pmf_eqI, subst pmf_embed_pmf) 73 (insert nn_integral_pmf_eq_1[of p], 74 auto simp: lookup_default_mapping_of_pmf split: option.splits) 75 76lemma mapping_of_pmfI: 77 assumes "\<And>x. x \<in> Mapping.keys m \<Longrightarrow> Mapping.lookup m x = Some (pmf p x)" 78 assumes "Mapping.keys m = set_pmf p" 79 shows "mapping_of_pmf p = m" 80 using assms by transfer (rule ext, auto simp: set_pmf_eq) 81 82lemma mapping_of_pmfI': 83 assumes "\<And>x. x \<in> Mapping.keys m \<Longrightarrow> Mapping.lookup_default 0 m x = pmf p x" 84 assumes "Mapping.keys m = set_pmf p" 85 shows "mapping_of_pmf p = m" 86 using assms unfolding Mapping.lookup_default_def 87 by transfer (rule ext, force simp: set_pmf_eq) 88 89lemma return_pmf_code [code abstract]: 90 "mapping_of_pmf (return_pmf x) = Mapping.update x 1 Mapping.empty" 91 by (intro mapping_of_pmfI) (auto simp: lookup_update') 92 93lemma pmf_of_set_code_aux: 94 assumes "A \<noteq> {}" "set xs = A" "distinct xs" 95 shows "mapping_of_pmf (pmf_of_set A) = Mapping.tabulate xs (\<lambda>_. 1 / real (length xs))" 96 using assms 97 by (intro mapping_of_pmfI, subst pmf_of_set) 98 (auto simp: lookup_tabulate distinct_card) 99 100definition pmf_of_set_impl where 101 "pmf_of_set_impl A = mapping_of_pmf (pmf_of_set A)" 102 103(* This equation can be used to easily implement pmf_of_set for other set implementations *) 104lemma pmf_of_set_impl_code_alt: 105 assumes "A \<noteq> {}" "finite A" 106 shows "pmf_of_set_impl A = 107 (let p = 1 / real (card A) 108 in Finite_Set.fold (\<lambda>x. Mapping.update x p) Mapping.empty A)" 109proof - 110 define p where "p = 1 / real (card A)" 111 let ?m = "Finite_Set.fold (\<lambda>x. Mapping.update x p) Mapping.empty A" 112 interpret comp_fun_idem "\<lambda>x. Mapping.update x p" 113 by standard (transfer, force simp: fun_eq_iff)+ 114 have keys: "Mapping.keys ?m = A" 115 using assms(2) by (induction A rule: finite_induct) simp_all 116 have lookup: "Mapping.lookup ?m x = Some p" if "x \<in> A" for x 117 using assms(2) that by (induction A rule: finite_induct) (auto simp: lookup_update') 118 from keys lookup assms show ?thesis unfolding pmf_of_set_impl_def 119 by (intro mapping_of_pmfI) (simp_all add: Let_def p_def) 120qed 121 122lemma pmf_of_set_impl_code [code]: 123 "pmf_of_set_impl (set xs) = 124 (if xs = [] then 125 Code.abort (STR ''pmf_of_set of empty set'') (\<lambda>_. mapping_of_pmf (pmf_of_set (set xs))) 126 else let xs' = remdups xs; p = 1 / real (length xs') in 127 Mapping.tabulate xs' (\<lambda>_. p))" 128 unfolding pmf_of_set_impl_def 129 using pmf_of_set_code_aux[of "set xs" "remdups xs"] by (simp add: Let_def) 130 131lemma pmf_of_set_code [code abstract]: 132 "mapping_of_pmf (pmf_of_set A) = pmf_of_set_impl A" 133 by (simp add: pmf_of_set_impl_def) 134 135 136lemma pmf_of_multiset_pmf_of_mapping: 137 assumes "A \<noteq> {#}" "set xs = set_mset A" "distinct xs" 138 shows "mapping_of_pmf (pmf_of_multiset A) = Mapping.tabulate xs (\<lambda>x. count A x / real (size A))" 139 using assms by (intro mapping_of_pmfI) (auto simp: lookup_tabulate) 140 141definition pmf_of_multiset_impl where 142 "pmf_of_multiset_impl A = mapping_of_pmf (pmf_of_multiset A)" 143 144lemma pmf_of_multiset_impl_code_alt: 145 assumes "A \<noteq> {#}" 146 shows "pmf_of_multiset_impl A = 147 (let p = 1 / real (size A) 148 in fold_mset (\<lambda>x. Mapping.map_default x 0 ((+) p)) Mapping.empty A)" 149proof - 150 define p where "p = 1 / real (size A)" 151 interpret comp_fun_commute "\<lambda>x. Mapping.map_default x 0 ((+) p)" 152 unfolding Mapping.map_default_def [abs_def] 153 by (standard, intro mapping_eqI ext) 154 (simp_all add: o_def lookup_map_entry' lookup_default' lookup_default_def) 155 let ?m = "fold_mset (\<lambda>x. Mapping.map_default x 0 ((+) p)) Mapping.empty A" 156 have keys: "Mapping.keys ?m = set_mset A" by (induction A) simp_all 157 have lookup: "Mapping.lookup_default 0 ?m x = real (count A x) * p" for x 158 by (induction A) 159 (simp_all add: lookup_map_default' lookup_default_def lookup_empty ring_distribs) 160 from keys lookup assms show ?thesis unfolding pmf_of_multiset_impl_def 161 by (intro mapping_of_pmfI') (simp_all add: Let_def p_def) 162qed 163 164lemma pmf_of_multiset_impl_code [code]: 165 "pmf_of_multiset_impl (mset xs) = 166 (if xs = [] then 167 Code.abort (STR ''pmf_of_multiset of empty multiset'') 168 (\<lambda>_. mapping_of_pmf (pmf_of_multiset (mset xs))) 169 else let xs' = remdups xs; p = 1 / real (length xs) in 170 Mapping.tabulate xs' (\<lambda>x. real (count (mset xs) x) * p))" 171 using pmf_of_multiset_pmf_of_mapping[of "mset xs" "remdups xs"] 172 by (simp add: pmf_of_multiset_impl_def) 173 174lemma pmf_of_multiset_code [code abstract]: 175 "mapping_of_pmf (pmf_of_multiset A) = pmf_of_multiset_impl A" 176 by (simp add: pmf_of_multiset_impl_def) 177 178 179lemma bernoulli_pmf_code [code abstract]: 180 "mapping_of_pmf (bernoulli_pmf p) = 181 (if p \<le> 0 then Mapping.update False 1 Mapping.empty 182 else if p \<ge> 1 then Mapping.update True 1 Mapping.empty 183 else Mapping.update False (1 - p) (Mapping.update True p Mapping.empty))" 184 by (intro mapping_of_pmfI) (auto simp: bernoulli_pmf.rep_eq lookup_update' set_pmf_eq) 185 186 187lemma pmf_code [code]: "pmf p x = Mapping.lookup_default 0 (mapping_of_pmf p) x" 188 unfolding mapping_of_pmf_def Mapping.lookup_default_def 189 by (auto split: option.splits simp: id_def Mapping.lookup.abs_eq) 190 191lemma set_pmf_code [code]: "set_pmf p = Mapping.keys (mapping_of_pmf p)" 192 by transfer (auto simp: dom_def set_pmf_eq) 193 194lemma keys_mapping_of_pmf [simp]: "Mapping.keys (mapping_of_pmf p) = set_pmf p" 195 by transfer (auto simp: dom_def set_pmf_eq) 196 197 198 199definition fold_combine_plus where 200 "fold_combine_plus = comm_monoid_set.F (Mapping.combine ((+) :: real \<Rightarrow> _)) Mapping.empty" 201 202context 203begin 204 205interpretation fold_combine_plus: combine_mapping_abel_semigroup "(+) :: real \<Rightarrow> _" 206 by unfold_locales (simp_all add: add_ac) 207 208qualified lemma lookup_default_fold_combine_plus: 209 fixes A :: "'b set" and f :: "'b \<Rightarrow> ('a, real) mapping" 210 assumes "finite A" 211 shows "Mapping.lookup_default 0 (fold_combine_plus f A) x = 212 (\<Sum>y\<in>A. Mapping.lookup_default 0 (f y) x)" 213 unfolding fold_combine_plus_def using assms 214 by (induction A rule: finite_induct) 215 (simp_all add: lookup_default_empty lookup_default_neutral_combine) 216 217qualified lemma keys_fold_combine_plus: 218 "finite A \<Longrightarrow> Mapping.keys (fold_combine_plus f A) = (\<Union>x\<in>A. Mapping.keys (f x))" 219 by (simp add: fold_combine_plus_def fold_combine_plus.keys_fold_combine) 220 221qualified lemma fold_combine_plus_code [code]: 222 "fold_combine_plus g (set xs) = foldr (\<lambda>x. Mapping.combine (+) (g x)) (remdups xs) Mapping.empty" 223 by (simp add: fold_combine_plus_def fold_combine_plus.fold_combine_code) 224 225private lemma lookup_default_0_map_values: 226 assumes "f x 0 = 0" 227 shows "Mapping.lookup_default 0 (Mapping.map_values f m) x = f x (Mapping.lookup_default 0 m x)" 228 unfolding Mapping.lookup_default_def 229 using assms by transfer (auto split: option.splits) 230 231qualified lemma mapping_of_bind_pmf: 232 assumes "finite (set_pmf p)" 233 shows "mapping_of_pmf (bind_pmf p f) = 234 fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. (*) (pmf p x)) 235 (mapping_of_pmf (f x))) (set_pmf p)" 236 using assms 237 by (intro mapping_of_pmfI') 238 (auto simp: keys_fold_combine_plus lookup_default_fold_combine_plus 239 pmf_bind integral_measure_pmf lookup_default_0_map_values 240 lookup_default_mapping_of_pmf mult_ac) 241 242lift_definition bind_pmf_aux :: "'a pmf \<Rightarrow> ('a \<Rightarrow> 'b pmf) \<Rightarrow> 'a set \<Rightarrow> ('b, real) mapping" is 243 "\<lambda>(p :: 'a pmf) (f :: 'a \<Rightarrow> 'b pmf) (A::'a set) (x::'b). 244 if x \<in> (\<Union>y\<in>A. set_pmf (f y)) then 245 Some (measure_pmf.expectation p (\<lambda>y. indicator A y * pmf (f y) x)) 246 else None" . 247 248lemma keys_bind_pmf_aux [simp]: 249 "Mapping.keys (bind_pmf_aux p f A) = (\<Union>x\<in>A. set_pmf (f x))" 250 by transfer (auto split: if_splits) 251 252lemma lookup_default_bind_pmf_aux: 253 "Mapping.lookup_default 0 (bind_pmf_aux p f A) x = 254 (if x \<in> (\<Union>y\<in>A. set_pmf (f y)) then 255 measure_pmf.expectation p (\<lambda>y. indicator A y * pmf (f y) x) else 0)" 256 unfolding lookup_default_def by transfer' simp_all 257 258lemma lookup_default_bind_pmf_aux' [simp]: 259 "Mapping.lookup_default 0 (bind_pmf_aux p f (set_pmf p)) x = pmf (bind_pmf p f) x" 260 unfolding lookup_default_def 261 by transfer (auto simp: pmf_bind AE_measure_pmf_iff set_pmf_eq 262 intro!: integral_cong_AE integral_eq_zero_AE) 263 264lemma bind_pmf_aux_correct: 265 "mapping_of_pmf (bind_pmf p f) = bind_pmf_aux p f (set_pmf p)" 266 by (intro mapping_of_pmfI') simp_all 267 268lemma bind_pmf_aux_code_aux: 269 assumes "finite A" 270 shows "bind_pmf_aux p f A = 271 fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. (*) (pmf p x)) 272 (mapping_of_pmf (f x))) A" (is "?lhs = ?rhs") 273proof (intro mapping_eqI'[where d = 0]) 274 fix x assume "x \<in> Mapping.keys ?lhs" 275 then obtain y where y: "y \<in> A" "x \<in> set_pmf (f y)" by auto 276 hence "Mapping.lookup_default 0 ?lhs x = 277 measure_pmf.expectation p (\<lambda>y. indicator A y * pmf (f y) x)" 278 by (auto simp: lookup_default_bind_pmf_aux) 279 also from assms have "\<dots> = (\<Sum>y\<in>A. pmf p y * pmf (f y) x)" 280 by (subst integral_measure_pmf [of A]) 281 (auto simp: set_pmf_eq indicator_def mult_ac split: if_splits) 282 also from assms have "\<dots> = Mapping.lookup_default 0 ?rhs x" 283 by (simp add: lookup_default_fold_combine_plus lookup_default_0_map_values 284 lookup_default_mapping_of_pmf) 285 finally show "Mapping.lookup_default 0 ?lhs x = Mapping.lookup_default 0 ?rhs x" . 286qed (insert assms, simp_all add: keys_fold_combine_plus) 287 288lemma bind_pmf_aux_code [code]: 289 "bind_pmf_aux p f (set xs) = 290 fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. (*) (pmf p x)) 291 (mapping_of_pmf (f x))) (set xs)" 292 by (rule bind_pmf_aux_code_aux) simp_all 293 294lemmas bind_pmf_code [code abstract] = bind_pmf_aux_correct 295 296end 297 298hide_const (open) fold_combine_plus 299 300 301lift_definition cond_pmf_impl :: "'a pmf \<Rightarrow> 'a set \<Rightarrow> ('a, real) mapping option" is 302 "\<lambda>p A. if A \<inter> set_pmf p = {} then None else 303 Some (\<lambda>x. if x \<in> A \<inter> set_pmf p then Some (pmf p x / measure_pmf.prob p A) else None)" . 304 305lemma cond_pmf_impl_code_alt: 306 assumes "finite A" 307 shows "cond_pmf_impl p A = ( 308 let C = A \<inter> set_pmf p; 309 prob = (\<Sum>x\<in>C. pmf p x) 310 in if prob = 0 then 311 None 312 else 313 Some (Mapping.map_values (\<lambda>_ y. y / prob) 314 (Mapping.filter (\<lambda>k _. k \<in> C) (mapping_of_pmf p))))" 315proof - 316 define C where "C = A \<inter> set_pmf p" 317 define prob where "prob = (\<Sum>x\<in>C. pmf p x)" 318 also note C_def 319 also from assms have "(\<Sum>x\<in>A \<inter> set_pmf p. pmf p x) = (\<Sum>x\<in>A. pmf p x)" 320 by (intro sum.mono_neutral_left) (auto simp: set_pmf_eq) 321 finally have prob1: "prob = (\<Sum>x\<in>A. pmf p x)" . 322 hence prob2: "prob = measure_pmf.prob p A" 323 using assms by (subst measure_measure_pmf_finite) simp_all 324 have prob3: "prob = 0 \<longleftrightarrow> A \<inter> set_pmf p = {}" 325 by (subst prob1, subst sum_nonneg_eq_0_iff) (auto simp: set_pmf_eq assms) 326 from assms have prob4: "prob = measure_pmf.prob p C" 327 unfolding prob_def by (intro measure_measure_pmf_finite [symmetric]) (simp_all add: C_def) 328 329 show ?thesis 330 proof (cases "prob = 0") 331 case True 332 hence "A \<inter> set_pmf p = {}" by (subst (asm) prob3) 333 with True show ?thesis by (simp add: Let_def prob_def C_def cond_pmf_impl.abs_eq) 334 next 335 case False 336 hence A: "C \<noteq> {}" unfolding C_def by (subst (asm) prob3) auto 337 with prob3 have prob_nz: "prob \<noteq> 0" by (auto simp: C_def) 338 fix x 339 have "cond_pmf_impl p A = 340 Some (mapping.Mapping (\<lambda>x. if x \<in> C then 341 Some (pmf p x / measure_pmf.prob p C) else None))" 342 (is "_ = Some ?m") 343 using A prob2 prob4 unfolding C_def by transfer (auto simp: fun_eq_iff) 344 also have "?m = Mapping.map_values (\<lambda>_ y. y / prob) 345 (Mapping.filter (\<lambda>k _. k \<in> C) (mapping_of_pmf p))" 346 using prob_nz prob4 assms unfolding C_def 347 by transfer (auto simp: fun_eq_iff set_pmf_eq) 348 finally show ?thesis using False by (simp add: Let_def prob_def C_def) 349 qed 350qed 351 352lemma cond_pmf_impl_code [code]: 353 "cond_pmf_impl p (set xs) = ( 354 let C = set xs \<inter> set_pmf p; 355 prob = (\<Sum>x\<in>C. pmf p x) 356 in if prob = 0 then 357 None 358 else 359 Some (Mapping.map_values (\<lambda>_ y. y / prob) 360 (Mapping.filter (\<lambda>k _. k \<in> C) (mapping_of_pmf p))))" 361 by (rule cond_pmf_impl_code_alt) simp_all 362 363lemma cond_pmf_code [code abstract]: 364 "mapping_of_pmf (cond_pmf p A) = 365 (case cond_pmf_impl p A of 366 None \<Rightarrow> Code.abort (STR ''cond_pmf with set of probability 0'') 367 (\<lambda>_. mapping_of_pmf (cond_pmf p A)) 368 | Some m \<Rightarrow> m)" 369proof (cases "cond_pmf_impl p A") 370 case (Some m) 371 hence A: "set_pmf p \<inter> A \<noteq> {}" by transfer (auto split: if_splits) 372 from Some have B: "Mapping.keys m = set_pmf (cond_pmf p A)" 373 by (subst set_cond_pmf[OF A], transfer) (auto split: if_splits) 374 with Some A have "mapping_of_pmf (cond_pmf p A) = m" 375 by (intro mapping_of_pmfI[OF _ B], transfer) (auto split: if_splits simp: pmf_cond) 376 with Some show ?thesis by simp 377qed simp_all 378 379 380lemma binomial_pmf_code [code abstract]: 381 "mapping_of_pmf (binomial_pmf n p) = ( 382 if p < 0 \<or> p > 1 then 383 Code.abort (STR ''binomial_pmf with invalid probability'') 384 (\<lambda>_. mapping_of_pmf (binomial_pmf n p)) 385 else if p = 0 then Mapping.update 0 1 Mapping.empty 386 else if p = 1 then Mapping.update n 1 Mapping.empty 387 else Mapping.tabulate [0..<Suc n] (\<lambda>k. real (n choose k) * p ^ k * (1 - p) ^ (n - k)))" 388 by (cases "p < 0 \<or> p > 1") 389 (simp, intro mapping_of_pmfI, 390 auto simp: lookup_update' lookup_empty set_pmf_binomial_eq lookup_tabulate split: if_splits) 391 392 393lemma pred_pmf_code [code]: 394 "pred_pmf P p = (\<forall>x\<in>set_pmf p. P x)" 395 by (auto simp: pred_pmf_def) 396 397 398lemma mapping_of_pmf_pmf_of_list: 399 assumes "\<And>x. x \<in> snd ` set xs \<Longrightarrow> x > 0" "sum_list (map snd xs) = 1" 400 shows "mapping_of_pmf (pmf_of_list xs) = 401 Mapping.tabulate (remdups (map fst xs)) 402 (\<lambda>x. sum_list (map snd (filter (\<lambda>z. fst z = x) xs)))" 403proof - 404 from assms have wf: "pmf_of_list_wf xs" by (intro pmf_of_list_wfI) force 405 with assms have "set_pmf (pmf_of_list xs) = fst ` set xs" 406 by (intro set_pmf_of_list_eq) auto 407 with wf show ?thesis 408 by (intro mapping_of_pmfI) (auto simp: lookup_tabulate pmf_pmf_of_list) 409qed 410 411lemma mapping_of_pmf_pmf_of_list': 412 assumes "pmf_of_list_wf xs" 413 defines "xs' \<equiv> filter (\<lambda>z. snd z \<noteq> 0) xs" 414 shows "mapping_of_pmf (pmf_of_list xs) = 415 Mapping.tabulate (remdups (map fst xs')) 416 (\<lambda>x. sum_list (map snd (filter (\<lambda>z. fst z = x) xs')))" (is "_ = ?rhs") 417proof - 418 have wf: "pmf_of_list_wf xs'" unfolding xs'_def by (rule pmf_of_list_remove_zeros) fact 419 have pos: "\<forall>x\<in>snd`set xs'. x > 0" using assms(1) unfolding xs'_def 420 by (force simp: pmf_of_list_wf_def) 421 from assms have "pmf_of_list xs = pmf_of_list xs'" 422 unfolding xs'_def by (subst pmf_of_list_remove_zeros) simp_all 423 also from wf pos have "mapping_of_pmf \<dots> = ?rhs" 424 by (intro mapping_of_pmf_pmf_of_list) (auto simp: pmf_of_list_wf_def) 425 finally show ?thesis . 426qed 427 428lemma pmf_of_list_wf_code [code]: 429 "pmf_of_list_wf xs \<longleftrightarrow> list_all (\<lambda>z. snd z \<ge> 0) xs \<and> sum_list (map snd xs) = 1" 430 by (auto simp add: pmf_of_list_wf_def list_all_def) 431 432lemma pmf_of_list_code [code abstract]: 433 "mapping_of_pmf (pmf_of_list xs) = ( 434 if pmf_of_list_wf xs then 435 let xs' = filter (\<lambda>z. snd z \<noteq> 0) xs 436 in Mapping.tabulate (remdups (map fst xs')) 437 (\<lambda>x. sum_list (map snd (filter (\<lambda>z. fst z = x) xs'))) 438 else 439 Code.abort (STR ''Invalid list for pmf_of_list'') (\<lambda>_. mapping_of_pmf (pmf_of_list xs)))" 440 using mapping_of_pmf_pmf_of_list'[of xs] by (simp add: Let_def) 441 442lemma mapping_of_pmf_eq_iff [simp]: 443 "mapping_of_pmf p = mapping_of_pmf q \<longleftrightarrow> p = (q :: 'a pmf)" 444proof (transfer, intro iffI pmf_eqI) 445 fix p q :: "'a pmf" and x :: 'a 446 assume "(\<lambda>x. if pmf p x = 0 then None else Some (pmf p x)) = 447 (\<lambda>x. if pmf q x = 0 then None else Some (pmf q x))" 448 hence "(if pmf p x = 0 then None else Some (pmf p x)) = 449 (if pmf q x = 0 then None else Some (pmf q x))" for x 450 by (simp add: fun_eq_iff) 451 from this[of x] show "pmf p x = pmf q x" by (auto split: if_splits) 452qed (simp_all cong: if_cong) 453 454 455subsection \<open>Code abbreviations for integrals and probabilities\<close> 456 457text \<open> 458 Integrals and probabilities are defined for general measures, so we cannot give any 459 code equations directly. We can, however, specialise these constants them to PMFs, 460 give code equations for these specialised constants, and tell the code generator 461 to unfold the original constants to the specialised ones whenever possible. 462\<close> 463 464definition pmf_integral where 465 "pmf_integral p f = lebesgue_integral (measure_pmf p) (f :: _ \<Rightarrow> real)" 466 467definition pmf_set_integral where 468 "pmf_set_integral p f A = lebesgue_integral (measure_pmf p) (\<lambda>x. indicator A x * f x :: real)" 469 470definition pmf_prob where 471 "pmf_prob p A = measure_pmf.prob p A" 472 473lemma pmf_prob_compl: "pmf_prob p (-A) = 1 - pmf_prob p A" 474 using measure_pmf.prob_compl[of A p] by (simp add: pmf_prob_def Compl_eq_Diff_UNIV) 475 476lemma pmf_integral_pmf_set_integral [code]: 477 "pmf_integral p f = pmf_set_integral p f (set_pmf p)" 478 unfolding pmf_integral_def pmf_set_integral_def 479 by (intro integral_cong_AE) (simp_all add: AE_measure_pmf_iff) 480 481lemma pmf_prob_pmf_set_integral: 482 "pmf_prob p A = pmf_set_integral p (\<lambda>_. 1) A" 483 by (simp add: pmf_prob_def pmf_set_integral_def) 484 485lemma pmf_set_integral_code_alt_finite: 486 "finite A \<Longrightarrow> pmf_set_integral p f A = (\<Sum>x\<in>A. pmf p x * f x)" 487 unfolding pmf_set_integral_def 488 by (subst integral_measure_pmf[of A]) (auto simp: indicator_def mult_ac split: if_splits) 489 490lemma pmf_set_integral_code [code]: 491 "pmf_set_integral p f (set xs) = (\<Sum>x\<in>set xs. pmf p x * f x)" 492 by (rule pmf_set_integral_code_alt_finite) simp_all 493 494 495lemma pmf_prob_code_alt_finite: 496 "finite A \<Longrightarrow> pmf_prob p A = (\<Sum>x\<in>A. pmf p x)" 497 by (simp add: pmf_prob_pmf_set_integral pmf_set_integral_code_alt_finite) 498 499lemma pmf_prob_code [code]: 500 "pmf_prob p (set xs) = (\<Sum>x\<in>set xs. pmf p x)" 501 "pmf_prob p (List.coset xs) = 1 - (\<Sum>x\<in>set xs. pmf p x)" 502 by (simp_all add: pmf_prob_code_alt_finite pmf_prob_compl) 503 504 505lemma pmf_prob_code_unfold [code_abbrev]: "pmf_prob p = measure_pmf.prob p" 506 by (intro ext) (simp add: pmf_prob_def) 507 508(* FIXME: Why does this not work without parameters? *) 509lemma pmf_integral_code_unfold [code_abbrev]: "pmf_integral p = measure_pmf.expectation p" 510 by (intro ext) (simp add: pmf_integral_def) 511 512 513 514definition "pmf_of_alist xs = embed_pmf (\<lambda>x. case map_of xs x of Some p \<Rightarrow> p | None \<Rightarrow> 0)" 515 516lemma pmf_of_mapping_Mapping [code_post]: 517 "pmf_of_mapping (Mapping xs) = pmf_of_alist xs" 518 unfolding pmf_of_mapping_def Mapping.lookup_default_def [abs_def] pmf_of_alist_def 519 by transfer simp_all 520 521 522instantiation pmf :: (equal) equal 523begin 524 525definition "equal_pmf p q = (mapping_of_pmf p = mapping_of_pmf (q :: 'a pmf))" 526 527instance by standard (simp add: equal_pmf_def) 528end 529 530definition single :: "'a \<Rightarrow> 'a multiset" where 531"single s = {#s#}" 532 533definition (in term_syntax) 534 pmfify :: "('a::typerep multiset \<times> (unit \<Rightarrow> Code_Evaluation.term)) \<Rightarrow> 535 'a \<times> (unit \<Rightarrow> Code_Evaluation.term) \<Rightarrow> 536 'a pmf \<times> (unit \<Rightarrow> Code_Evaluation.term)" where 537 [code_unfold]: "pmfify A x = 538 Code_Evaluation.valtermify pmf_of_multiset {\<cdot>} 539 (Code_Evaluation.valtermify (+) {\<cdot>} A {\<cdot>} 540 (Code_Evaluation.valtermify single {\<cdot>} x))" 541 542 543notation fcomp (infixl "\<circ>>" 60) 544notation scomp (infixl "\<circ>\<rightarrow>" 60) 545 546instantiation pmf :: (random) random 547begin 548 549definition 550 "Quickcheck_Random.random i = 551 Quickcheck_Random.random i \<circ>\<rightarrow> (\<lambda>A. 552 Quickcheck_Random.random i \<circ>\<rightarrow> (\<lambda>x. Pair (pmfify A x)))" 553 554instance .. 555 556end 557 558no_notation fcomp (infixl "\<circ>>" 60) 559no_notation scomp (infixl "\<circ>\<rightarrow>" 60) 560 561instantiation pmf :: (full_exhaustive) full_exhaustive 562begin 563 564definition full_exhaustive_pmf :: "('a pmf \<times> (unit \<Rightarrow> term) \<Rightarrow> (bool \<times> term list) option) \<Rightarrow> natural \<Rightarrow> (bool \<times> term list) option" 565where 566 "full_exhaustive_pmf f i = 567 Quickcheck_Exhaustive.full_exhaustive (\<lambda>A. 568 Quickcheck_Exhaustive.full_exhaustive (\<lambda>x. f (pmfify A x)) i) i" 569 570instance .. 571 572end 573 574end 575