1(*  Title:      HOL/Probability/PMF_Impl.thy
2    Author:     Manuel Eberl, TU M��nchen
3    
4    An implementation of PMFs using Mappings, which are implemented with association lists
5    by default. Also includes Quickcheck setup for PMFs.
6*)
7
8section \<open>Code generation for PMFs\<close>
9
10theory PMF_Impl
11imports Probability_Mass_Function "HOL-Library.AList_Mapping"
12begin
13
14subsection \<open>General code generation setup\<close>
15
16definition pmf_of_mapping :: "('a, real) mapping \<Rightarrow> 'a pmf" where
17  "pmf_of_mapping m = embed_pmf (Mapping.lookup_default 0 m)" 
18
19lemma nn_integral_lookup_default:
20  fixes m :: "('a, real) mapping"
21  assumes "finite (Mapping.keys m)" "All_mapping m (\<lambda>_ x. x \<ge> 0)"
22  shows   "nn_integral (count_space UNIV) (\<lambda>k. ennreal (Mapping.lookup_default 0 m k)) = 
23             ennreal (\<Sum>k\<in>Mapping.keys m. Mapping.lookup_default 0 m k)"
24proof -
25  have "nn_integral (count_space UNIV) (\<lambda>k. ennreal (Mapping.lookup_default 0 m k)) = 
26          (\<Sum>x\<in>Mapping.keys m. ennreal (Mapping.lookup_default 0 m x))" using assms
27    by (subst nn_integral_count_space'[of "Mapping.keys m"])
28       (auto simp: Mapping.lookup_default_def keys_is_none_rep Option.is_none_def)
29  also from assms have "\<dots> = ennreal (\<Sum>k\<in>Mapping.keys m. Mapping.lookup_default 0 m k)"
30    by (intro sum_ennreal) 
31       (auto simp: Mapping.lookup_default_def All_mapping_def split: option.splits)
32  finally show ?thesis .
33qed
34
35lemma pmf_of_mapping: 
36  assumes "finite (Mapping.keys m)" "All_mapping m (\<lambda>_ p. p \<ge> 0)"
37  assumes "(\<Sum>x\<in>Mapping.keys m. Mapping.lookup_default 0 m x) = 1"
38  shows   "pmf (pmf_of_mapping m) x = Mapping.lookup_default 0 m x"
39  unfolding pmf_of_mapping_def
40proof (intro pmf_embed_pmf)
41  from assms show "(\<integral>\<^sup>+x. ennreal (Mapping.lookup_default 0 m x) \<partial>count_space UNIV) = 1"
42    by (subst nn_integral_lookup_default) (simp_all)
43qed (insert assms, simp add: All_mapping_def Mapping.lookup_default_def split: option.splits)
44
45lemma pmf_of_set_pmf_of_mapping:
46  assumes "A \<noteq> {}" "set xs = A" "distinct xs"
47  shows   "pmf_of_set A = pmf_of_mapping (Mapping.tabulate xs (\<lambda>_. 1 / real (length xs)))" 
48           (is "?lhs = ?rhs")
49  by (rule pmf_eqI, subst pmf_of_mapping)
50     (insert assms, auto intro!: All_mapping_tabulate 
51                         simp: Mapping.lookup_default_def lookup_tabulate distinct_card)
52
53lift_definition mapping_of_pmf :: "'a pmf \<Rightarrow> ('a, real) mapping" is
54  "\<lambda>p x. if pmf p x = 0 then None else Some (pmf p x)" .
55  
56lemma lookup_default_mapping_of_pmf: 
57  "Mapping.lookup_default 0 (mapping_of_pmf p) x = pmf p x"
58  by (simp add: mapping_of_pmf.abs_eq lookup_default_def Mapping.lookup.abs_eq)
59
60context
61begin
62
63interpretation pmf_as_function .
64
65lemma nn_integral_pmf_eq_1: "(\<integral>\<^sup>+ x. ennreal (pmf p x) \<partial>count_space UNIV) = 1"
66  by transfer simp_all
67end
68  
69lemma pmf_of_mapping_mapping_of_pmf [code abstype]: 
70    "pmf_of_mapping (mapping_of_pmf p) = p"
71  unfolding pmf_of_mapping_def
72  by (rule pmf_eqI, subst pmf_embed_pmf)
73     (insert nn_integral_pmf_eq_1[of p], 
74      auto simp: lookup_default_mapping_of_pmf split: option.splits)
75
76lemma mapping_of_pmfI:
77  assumes "\<And>x. x \<in> Mapping.keys m \<Longrightarrow> Mapping.lookup m x = Some (pmf p x)" 
78  assumes "Mapping.keys m = set_pmf p"
79  shows   "mapping_of_pmf p = m"
80  using assms by transfer (rule ext, auto simp: set_pmf_eq)
81  
82lemma mapping_of_pmfI':
83  assumes "\<And>x. x \<in> Mapping.keys m \<Longrightarrow> Mapping.lookup_default 0 m x = pmf p x" 
84  assumes "Mapping.keys m = set_pmf p"
85  shows   "mapping_of_pmf p = m"
86  using assms unfolding Mapping.lookup_default_def 
87  by transfer (rule ext, force simp: set_pmf_eq)
88
89lemma return_pmf_code [code abstract]:
90  "mapping_of_pmf (return_pmf x) = Mapping.update x 1 Mapping.empty"
91  by (intro mapping_of_pmfI) (auto simp: lookup_update')
92
93lemma pmf_of_set_code_aux:
94  assumes "A \<noteq> {}" "set xs = A" "distinct xs"
95  shows   "mapping_of_pmf (pmf_of_set A) = Mapping.tabulate xs (\<lambda>_. 1 / real (length xs))"
96  using assms
97  by (intro mapping_of_pmfI, subst pmf_of_set)
98     (auto simp: lookup_tabulate distinct_card)
99
100definition pmf_of_set_impl where
101  "pmf_of_set_impl A = mapping_of_pmf (pmf_of_set A)"
102
103(* This equation can be used to easily implement pmf_of_set for other set implementations *)
104lemma pmf_of_set_impl_code_alt:
105  assumes "A \<noteq> {}" "finite A"
106  shows   "pmf_of_set_impl A = 
107             (let p = 1 / real (card A) 
108              in  Finite_Set.fold (\<lambda>x. Mapping.update x p) Mapping.empty A)"
109proof -
110  define p where "p = 1 / real (card A)"
111  let ?m = "Finite_Set.fold (\<lambda>x. Mapping.update x p) Mapping.empty A"
112  interpret comp_fun_idem "\<lambda>x. Mapping.update x p"
113    by standard (transfer, force simp: fun_eq_iff)+
114  have keys: "Mapping.keys ?m = A"
115    using assms(2) by (induction A rule: finite_induct) simp_all
116  have lookup: "Mapping.lookup ?m x = Some p" if "x \<in> A" for x
117    using assms(2) that by (induction A rule: finite_induct) (auto simp: lookup_update')
118  from keys lookup assms show ?thesis unfolding pmf_of_set_impl_def
119    by (intro mapping_of_pmfI) (simp_all add: Let_def p_def)
120qed
121
122lemma pmf_of_set_impl_code [code]:
123  "pmf_of_set_impl (set xs) = 
124    (if xs = [] then
125         Code.abort (STR ''pmf_of_set of empty set'') (\<lambda>_. mapping_of_pmf (pmf_of_set (set xs)))
126      else let xs' = remdups xs; p = 1 / real (length xs') in
127         Mapping.tabulate xs' (\<lambda>_. p))"
128  unfolding pmf_of_set_impl_def
129  using pmf_of_set_code_aux[of "set xs" "remdups xs"] by (simp add: Let_def)
130
131lemma pmf_of_set_code [code abstract]:
132  "mapping_of_pmf (pmf_of_set A) = pmf_of_set_impl A"
133  by (simp add: pmf_of_set_impl_def)
134
135
136lemma pmf_of_multiset_pmf_of_mapping:
137  assumes "A \<noteq> {#}" "set xs = set_mset A" "distinct xs"
138  shows   "mapping_of_pmf (pmf_of_multiset A) = Mapping.tabulate xs (\<lambda>x. count A x / real (size A))" 
139  using assms by (intro mapping_of_pmfI) (auto simp: lookup_tabulate)
140
141definition pmf_of_multiset_impl where
142  "pmf_of_multiset_impl A = mapping_of_pmf (pmf_of_multiset A)"  
143
144lemma pmf_of_multiset_impl_code_alt:
145  assumes "A \<noteq> {#}"
146  shows   "pmf_of_multiset_impl A =
147             (let p = 1 / real (size A)
148              in  fold_mset (\<lambda>x. Mapping.map_default x 0 ((+) p)) Mapping.empty A)"
149proof -
150  define p where "p = 1 / real (size A)"
151  interpret comp_fun_commute "\<lambda>x. Mapping.map_default x 0 ((+) p)"
152    unfolding Mapping.map_default_def [abs_def]
153    by (standard, intro mapping_eqI ext) 
154       (simp_all add: o_def lookup_map_entry' lookup_default' lookup_default_def)
155  let ?m = "fold_mset (\<lambda>x. Mapping.map_default x 0 ((+) p)) Mapping.empty A"
156  have keys: "Mapping.keys ?m = set_mset A" by (induction A) simp_all
157  have lookup: "Mapping.lookup_default 0 ?m x = real (count A x) * p" for x
158    by (induction A)
159       (simp_all add: lookup_map_default' lookup_default_def lookup_empty ring_distribs)
160  from keys lookup assms show ?thesis unfolding pmf_of_multiset_impl_def
161    by (intro mapping_of_pmfI') (simp_all add: Let_def p_def)
162qed
163
164lemma pmf_of_multiset_impl_code [code]:
165  "pmf_of_multiset_impl (mset xs) =
166     (if xs = [] then 
167        Code.abort (STR ''pmf_of_multiset of empty multiset'') 
168          (\<lambda>_. mapping_of_pmf (pmf_of_multiset (mset xs)))
169      else let xs' = remdups xs; p = 1 / real (length xs) in
170         Mapping.tabulate xs' (\<lambda>x. real (count (mset xs) x) * p))"
171  using pmf_of_multiset_pmf_of_mapping[of "mset xs" "remdups xs"]
172  by (simp add: pmf_of_multiset_impl_def)        
173
174lemma pmf_of_multiset_code [code abstract]:
175  "mapping_of_pmf (pmf_of_multiset A) = pmf_of_multiset_impl A"
176  by (simp add: pmf_of_multiset_impl_def)
177
178  
179lemma bernoulli_pmf_code [code abstract]:
180  "mapping_of_pmf (bernoulli_pmf p) = 
181     (if p \<le> 0 then Mapping.update False 1 Mapping.empty 
182      else if p \<ge> 1 then Mapping.update True 1 Mapping.empty
183      else Mapping.update False (1 - p) (Mapping.update True p Mapping.empty))"
184  by (intro mapping_of_pmfI) (auto simp: bernoulli_pmf.rep_eq lookup_update' set_pmf_eq)
185
186
187lemma pmf_code [code]: "pmf p x = Mapping.lookup_default 0 (mapping_of_pmf p) x"
188  unfolding mapping_of_pmf_def Mapping.lookup_default_def
189  by (auto split: option.splits simp: id_def Mapping.lookup.abs_eq)
190
191lemma set_pmf_code [code]: "set_pmf p = Mapping.keys (mapping_of_pmf p)"
192  by transfer (auto simp: dom_def set_pmf_eq)
193
194lemma keys_mapping_of_pmf [simp]: "Mapping.keys (mapping_of_pmf p) = set_pmf p"
195  by transfer (auto simp: dom_def set_pmf_eq)
196  
197
198
199definition fold_combine_plus where
200  "fold_combine_plus = comm_monoid_set.F (Mapping.combine ((+) :: real \<Rightarrow> _)) Mapping.empty"
201
202context
203begin
204
205interpretation fold_combine_plus: combine_mapping_abel_semigroup "(+) :: real \<Rightarrow> _"
206  by unfold_locales (simp_all add: add_ac)
207  
208qualified lemma lookup_default_fold_combine_plus: 
209  fixes A :: "'b set" and f :: "'b \<Rightarrow> ('a, real) mapping"
210  assumes "finite A"
211  shows   "Mapping.lookup_default 0 (fold_combine_plus f A) x = 
212             (\<Sum>y\<in>A. Mapping.lookup_default 0 (f y) x)"
213  unfolding fold_combine_plus_def using assms 
214    by (induction A rule: finite_induct) 
215       (simp_all add: lookup_default_empty lookup_default_neutral_combine)
216
217qualified lemma keys_fold_combine_plus: 
218  "finite A \<Longrightarrow> Mapping.keys (fold_combine_plus f A) = (\<Union>x\<in>A. Mapping.keys (f x))"
219  by (simp add: fold_combine_plus_def fold_combine_plus.keys_fold_combine)
220
221qualified lemma fold_combine_plus_code [code]:
222  "fold_combine_plus g (set xs) = foldr (\<lambda>x. Mapping.combine (+) (g x)) (remdups xs) Mapping.empty"
223  by (simp add: fold_combine_plus_def fold_combine_plus.fold_combine_code)
224
225private lemma lookup_default_0_map_values:
226  assumes "f x 0 = 0"
227  shows   "Mapping.lookup_default 0 (Mapping.map_values f m) x = f x (Mapping.lookup_default 0 m x)"
228  unfolding Mapping.lookup_default_def
229  using assms by transfer (auto split: option.splits)
230
231qualified lemma mapping_of_bind_pmf:
232  assumes "finite (set_pmf p)"
233  shows   "mapping_of_pmf (bind_pmf p f) = 
234             fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. (*) (pmf p x)) 
235               (mapping_of_pmf (f x))) (set_pmf p)"
236  using assms
237  by (intro mapping_of_pmfI')
238     (auto simp: keys_fold_combine_plus lookup_default_fold_combine_plus 
239                 pmf_bind integral_measure_pmf lookup_default_0_map_values 
240                 lookup_default_mapping_of_pmf mult_ac)
241
242lift_definition bind_pmf_aux :: "'a pmf \<Rightarrow> ('a \<Rightarrow> 'b pmf) \<Rightarrow> 'a set \<Rightarrow> ('b, real) mapping" is
243  "\<lambda>(p :: 'a pmf) (f :: 'a \<Rightarrow> 'b pmf) (A::'a set) (x::'b). 
244     if x \<in> (\<Union>y\<in>A. set_pmf (f y)) then 
245       Some (measure_pmf.expectation p (\<lambda>y. indicator A y * pmf (f y) x)) 
246     else None" .
247
248lemma keys_bind_pmf_aux [simp]:
249  "Mapping.keys (bind_pmf_aux p f A) = (\<Union>x\<in>A. set_pmf (f x))"
250  by transfer (auto split: if_splits)
251
252lemma lookup_default_bind_pmf_aux:
253  "Mapping.lookup_default 0 (bind_pmf_aux p f A) x = 
254     (if x \<in> (\<Union>y\<in>A. set_pmf (f y)) then 
255        measure_pmf.expectation p (\<lambda>y. indicator A y * pmf (f y) x) else 0)"
256  unfolding lookup_default_def by transfer' simp_all
257
258lemma lookup_default_bind_pmf_aux' [simp]:
259  "Mapping.lookup_default 0 (bind_pmf_aux p f (set_pmf p)) x = pmf (bind_pmf p f) x"
260  unfolding lookup_default_def
261  by transfer (auto simp: pmf_bind AE_measure_pmf_iff set_pmf_eq
262                    intro!: integral_cong_AE integral_eq_zero_AE)
263  
264lemma bind_pmf_aux_correct:
265  "mapping_of_pmf (bind_pmf p f) = bind_pmf_aux p f (set_pmf p)"
266  by (intro mapping_of_pmfI') simp_all
267
268lemma bind_pmf_aux_code_aux:
269  assumes "finite A"
270  shows   "bind_pmf_aux p f A = 
271             fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. (*) (pmf p x))
272               (mapping_of_pmf (f x))) A" (is "?lhs = ?rhs")
273proof (intro mapping_eqI'[where d = 0])
274  fix x assume "x \<in> Mapping.keys ?lhs"
275  then obtain y where y: "y \<in> A" "x \<in> set_pmf (f y)" by auto
276  hence "Mapping.lookup_default 0 ?lhs x = 
277           measure_pmf.expectation p (\<lambda>y. indicator A y * pmf (f y) x)"
278    by (auto simp: lookup_default_bind_pmf_aux)
279  also from assms have "\<dots> = (\<Sum>y\<in>A. pmf p y * pmf (f y) x)"
280    by (subst integral_measure_pmf [of A])
281       (auto simp: set_pmf_eq indicator_def mult_ac split: if_splits)
282  also from assms have "\<dots> = Mapping.lookup_default 0 ?rhs x"
283    by (simp add: lookup_default_fold_combine_plus lookup_default_0_map_values
284          lookup_default_mapping_of_pmf)
285  finally show "Mapping.lookup_default 0 ?lhs x = Mapping.lookup_default 0 ?rhs x" .
286qed (insert assms, simp_all add: keys_fold_combine_plus)
287
288lemma bind_pmf_aux_code [code]:
289  "bind_pmf_aux p f (set xs) = 
290     fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. (*) (pmf p x))
291               (mapping_of_pmf (f x))) (set xs)"
292  by (rule bind_pmf_aux_code_aux) simp_all
293
294lemmas bind_pmf_code [code abstract] = bind_pmf_aux_correct
295
296end
297
298hide_const (open) fold_combine_plus
299
300
301lift_definition cond_pmf_impl :: "'a pmf \<Rightarrow> 'a set \<Rightarrow> ('a, real) mapping option" is
302  "\<lambda>p A. if A \<inter> set_pmf p = {} then None else 
303     Some (\<lambda>x. if x \<in> A \<inter> set_pmf p then Some (pmf p x / measure_pmf.prob p A) else None)" .
304
305lemma cond_pmf_impl_code_alt:
306  assumes "finite A"
307  shows   "cond_pmf_impl p A = (
308             let C = A \<inter> set_pmf p;
309                 prob = (\<Sum>x\<in>C. pmf p x)
310             in  if prob = 0 then 
311                   None
312                 else
313                   Some (Mapping.map_values (\<lambda>_ y. y / prob) 
314                     (Mapping.filter (\<lambda>k _. k \<in> C) (mapping_of_pmf p))))"
315proof -
316  define C where "C = A \<inter> set_pmf p"
317  define prob where "prob = (\<Sum>x\<in>C. pmf p x)"
318  also note C_def
319  also from assms have "(\<Sum>x\<in>A \<inter> set_pmf p. pmf p x) = (\<Sum>x\<in>A. pmf p x)"
320    by (intro sum.mono_neutral_left) (auto simp: set_pmf_eq)
321  finally have prob1: "prob = (\<Sum>x\<in>A. pmf p x)" .
322  hence prob2: "prob = measure_pmf.prob p A"
323    using assms by (subst measure_measure_pmf_finite) simp_all
324  have prob3: "prob = 0 \<longleftrightarrow> A \<inter> set_pmf p = {}"
325    by (subst prob1, subst sum_nonneg_eq_0_iff) (auto simp: set_pmf_eq assms)
326  from assms have prob4: "prob = measure_pmf.prob p C"
327    unfolding prob_def by (intro measure_measure_pmf_finite [symmetric]) (simp_all add: C_def)
328  
329  show ?thesis
330  proof (cases "prob = 0")
331    case True
332    hence "A \<inter> set_pmf p = {}" by (subst (asm) prob3)
333    with True show ?thesis by (simp add: Let_def prob_def C_def cond_pmf_impl.abs_eq)
334  next
335    case False
336    hence A: "C \<noteq> {}" unfolding C_def by (subst (asm) prob3) auto
337    with prob3 have prob_nz: "prob \<noteq> 0" by (auto simp: C_def)
338    fix x
339    have "cond_pmf_impl p A = 
340            Some (mapping.Mapping (\<lambda>x. if x \<in> C then 
341              Some (pmf p x / measure_pmf.prob p C) else None))" 
342         (is "_ = Some ?m")
343      using A prob2 prob4 unfolding C_def by transfer (auto simp: fun_eq_iff)
344    also have "?m = Mapping.map_values (\<lambda>_ y. y / prob) 
345                 (Mapping.filter (\<lambda>k _. k \<in> C) (mapping_of_pmf p))"
346      using prob_nz prob4 assms unfolding C_def
347      by transfer (auto simp: fun_eq_iff set_pmf_eq)
348    finally show ?thesis using False by (simp add: Let_def prob_def C_def)
349  qed
350qed
351
352lemma cond_pmf_impl_code [code]:
353  "cond_pmf_impl p (set xs) = (
354     let C = set xs \<inter> set_pmf p;
355         prob = (\<Sum>x\<in>C. pmf p x)
356     in  if prob = 0 then 
357           None
358         else
359           Some (Mapping.map_values (\<lambda>_ y. y / prob) 
360             (Mapping.filter (\<lambda>k _. k \<in> C) (mapping_of_pmf p))))"
361  by (rule cond_pmf_impl_code_alt) simp_all
362
363lemma cond_pmf_code [code abstract]:
364  "mapping_of_pmf (cond_pmf p A) = 
365     (case cond_pmf_impl p A of
366        None \<Rightarrow> Code.abort (STR ''cond_pmf with set of probability 0'')
367                  (\<lambda>_. mapping_of_pmf (cond_pmf p A))
368      | Some m \<Rightarrow> m)"
369proof (cases "cond_pmf_impl p A")
370  case (Some m)
371  hence A: "set_pmf p \<inter> A \<noteq> {}" by transfer (auto split: if_splits)
372  from Some have B: "Mapping.keys m = set_pmf (cond_pmf p A)"
373    by (subst set_cond_pmf[OF A], transfer) (auto split: if_splits)
374  with Some A have "mapping_of_pmf (cond_pmf p A) = m"
375    by (intro mapping_of_pmfI[OF _ B], transfer) (auto split: if_splits simp: pmf_cond)
376  with Some show ?thesis by simp
377qed simp_all
378
379
380lemma binomial_pmf_code [code abstract]:
381  "mapping_of_pmf (binomial_pmf n p) = (
382     if p < 0 \<or> p > 1 then 
383       Code.abort (STR ''binomial_pmf with invalid probability'') 
384         (\<lambda>_. mapping_of_pmf (binomial_pmf n p))
385     else if p = 0 then Mapping.update 0 1 Mapping.empty
386     else if p = 1 then Mapping.update n 1 Mapping.empty
387     else Mapping.tabulate [0..<Suc n] (\<lambda>k. real (n choose k) * p ^ k * (1 - p) ^ (n - k)))"
388  by (cases "p < 0 \<or> p > 1")
389     (simp, intro mapping_of_pmfI, 
390      auto simp: lookup_update' lookup_empty set_pmf_binomial_eq lookup_tabulate split: if_splits)
391
392
393lemma pred_pmf_code [code]:
394  "pred_pmf P p = (\<forall>x\<in>set_pmf p. P x)"
395  by (auto simp: pred_pmf_def)
396
397
398lemma mapping_of_pmf_pmf_of_list:
399  assumes "\<And>x. x \<in> snd ` set xs \<Longrightarrow> x > 0" "sum_list (map snd xs) = 1"
400  shows   "mapping_of_pmf (pmf_of_list xs) = 
401             Mapping.tabulate (remdups (map fst xs)) 
402               (\<lambda>x. sum_list (map snd (filter (\<lambda>z. fst z = x) xs)))"
403proof -
404  from assms have wf: "pmf_of_list_wf xs" by (intro pmf_of_list_wfI) force
405  with assms have "set_pmf (pmf_of_list xs) = fst ` set xs"
406    by (intro set_pmf_of_list_eq) auto
407  with wf show ?thesis
408    by (intro mapping_of_pmfI) (auto simp: lookup_tabulate pmf_pmf_of_list)
409qed
410
411lemma mapping_of_pmf_pmf_of_list':
412  assumes "pmf_of_list_wf xs"
413  defines "xs' \<equiv> filter (\<lambda>z. snd z \<noteq> 0) xs"
414  shows   "mapping_of_pmf (pmf_of_list xs) = 
415             Mapping.tabulate (remdups (map fst xs')) 
416               (\<lambda>x. sum_list (map snd (filter (\<lambda>z. fst z = x) xs')))" (is "_ = ?rhs") 
417proof -
418  have wf: "pmf_of_list_wf xs'" unfolding xs'_def by (rule pmf_of_list_remove_zeros) fact
419  have pos: "\<forall>x\<in>snd`set xs'. x > 0" using assms(1) unfolding xs'_def
420    by (force simp: pmf_of_list_wf_def)
421  from assms have "pmf_of_list xs = pmf_of_list xs'" 
422    unfolding xs'_def by (subst pmf_of_list_remove_zeros) simp_all
423  also from wf pos have "mapping_of_pmf \<dots> = ?rhs"
424    by (intro mapping_of_pmf_pmf_of_list) (auto simp: pmf_of_list_wf_def)
425  finally show ?thesis .
426qed
427
428lemma pmf_of_list_wf_code [code]:
429  "pmf_of_list_wf xs \<longleftrightarrow> list_all (\<lambda>z. snd z \<ge> 0) xs \<and> sum_list (map snd xs) = 1"
430  by (auto simp add: pmf_of_list_wf_def list_all_def)
431
432lemma pmf_of_list_code [code abstract]:
433  "mapping_of_pmf (pmf_of_list xs) = (
434     if pmf_of_list_wf xs then
435       let xs' = filter (\<lambda>z. snd z \<noteq> 0) xs
436       in  Mapping.tabulate (remdups (map fst xs')) 
437             (\<lambda>x. sum_list (map snd (filter (\<lambda>z. fst z = x) xs')))
438     else
439       Code.abort (STR ''Invalid list for pmf_of_list'') (\<lambda>_. mapping_of_pmf (pmf_of_list xs)))"
440  using mapping_of_pmf_pmf_of_list'[of xs] by (simp add: Let_def)
441
442lemma mapping_of_pmf_eq_iff [simp]:
443  "mapping_of_pmf p = mapping_of_pmf q \<longleftrightarrow> p = (q :: 'a pmf)"
444proof (transfer, intro iffI pmf_eqI)
445  fix p q :: "'a pmf" and x :: 'a
446  assume "(\<lambda>x. if pmf p x = 0 then None else Some (pmf p x)) =
447            (\<lambda>x. if pmf q x = 0 then None else Some (pmf q x))"
448  hence "(if pmf p x = 0 then None else Some (pmf p x)) =
449           (if pmf q x = 0 then None else Some (pmf q x))" for x
450    by (simp add: fun_eq_iff)
451  from this[of x] show "pmf p x = pmf q x" by (auto split: if_splits)
452qed (simp_all cong: if_cong)
453
454
455subsection \<open>Code abbreviations for integrals and probabilities\<close>
456
457text \<open>
458  Integrals and probabilities are defined for general measures, so we cannot give any
459  code equations directly. We can, however, specialise these constants them to PMFs, 
460  give code equations for these specialised constants, and tell the code generator 
461  to unfold the original constants to the specialised ones whenever possible.
462\<close>
463
464definition pmf_integral where
465  "pmf_integral p f = lebesgue_integral (measure_pmf p) (f :: _ \<Rightarrow> real)"
466
467definition pmf_set_integral where
468  "pmf_set_integral p f A = lebesgue_integral (measure_pmf p) (\<lambda>x. indicator A x * f x :: real)"
469
470definition pmf_prob where
471  "pmf_prob p A = measure_pmf.prob p A"
472
473lemma pmf_prob_compl: "pmf_prob p (-A) = 1 - pmf_prob p A"
474  using measure_pmf.prob_compl[of A p] by (simp add: pmf_prob_def Compl_eq_Diff_UNIV)
475
476lemma pmf_integral_pmf_set_integral [code]:
477  "pmf_integral p f = pmf_set_integral p f (set_pmf p)"
478  unfolding pmf_integral_def pmf_set_integral_def
479  by (intro integral_cong_AE) (simp_all add: AE_measure_pmf_iff)
480
481lemma pmf_prob_pmf_set_integral:
482  "pmf_prob p A = pmf_set_integral p (\<lambda>_. 1) A"
483  by (simp add: pmf_prob_def pmf_set_integral_def)
484  
485lemma pmf_set_integral_code_alt_finite:
486  "finite A \<Longrightarrow> pmf_set_integral p f A = (\<Sum>x\<in>A. pmf p x * f x)"
487  unfolding pmf_set_integral_def
488  by (subst integral_measure_pmf[of A]) (auto simp: indicator_def mult_ac split: if_splits)
489  
490lemma pmf_set_integral_code [code]:
491  "pmf_set_integral p f (set xs) = (\<Sum>x\<in>set xs. pmf p x * f x)"
492  by (rule pmf_set_integral_code_alt_finite) simp_all
493
494
495lemma pmf_prob_code_alt_finite:
496  "finite A \<Longrightarrow> pmf_prob p A = (\<Sum>x\<in>A. pmf p x)"
497  by (simp add: pmf_prob_pmf_set_integral pmf_set_integral_code_alt_finite)
498
499lemma pmf_prob_code [code]:
500  "pmf_prob p (set xs) = (\<Sum>x\<in>set xs. pmf p x)"
501  "pmf_prob p (List.coset xs) = 1 - (\<Sum>x\<in>set xs. pmf p x)"
502  by (simp_all add: pmf_prob_code_alt_finite pmf_prob_compl)
503  
504
505lemma pmf_prob_code_unfold [code_abbrev]: "pmf_prob p = measure_pmf.prob p"
506  by (intro ext) (simp add: pmf_prob_def)
507
508(* FIXME: Why does this not work without parameters? *)
509lemma pmf_integral_code_unfold [code_abbrev]: "pmf_integral p = measure_pmf.expectation p"
510  by (intro ext) (simp add: pmf_integral_def)
511
512
513
514definition "pmf_of_alist xs = embed_pmf (\<lambda>x. case map_of xs x of Some p \<Rightarrow> p | None \<Rightarrow> 0)"
515
516lemma pmf_of_mapping_Mapping [code_post]:
517    "pmf_of_mapping (Mapping xs) = pmf_of_alist xs"
518  unfolding pmf_of_mapping_def Mapping.lookup_default_def [abs_def] pmf_of_alist_def
519  by transfer simp_all
520
521
522instantiation pmf :: (equal) equal
523begin
524
525definition "equal_pmf p q = (mapping_of_pmf p = mapping_of_pmf (q :: 'a pmf))"
526
527instance by standard (simp add: equal_pmf_def)
528end
529
530definition single :: "'a \<Rightarrow> 'a multiset" where
531"single s = {#s#}"
532
533definition (in term_syntax)
534  pmfify :: "('a::typerep multiset \<times> (unit \<Rightarrow> Code_Evaluation.term)) \<Rightarrow>
535             'a \<times> (unit \<Rightarrow> Code_Evaluation.term) \<Rightarrow>
536             'a pmf \<times> (unit \<Rightarrow> Code_Evaluation.term)" where
537  [code_unfold]: "pmfify A x =  
538    Code_Evaluation.valtermify pmf_of_multiset {\<cdot>} 
539      (Code_Evaluation.valtermify (+) {\<cdot>} A {\<cdot>} 
540       (Code_Evaluation.valtermify single {\<cdot>} x))"
541
542
543notation fcomp (infixl "\<circ>>" 60)
544notation scomp (infixl "\<circ>\<rightarrow>" 60)
545
546instantiation pmf :: (random) random
547begin
548
549definition
550  "Quickcheck_Random.random i = 
551     Quickcheck_Random.random i \<circ>\<rightarrow> (\<lambda>A. 
552       Quickcheck_Random.random i \<circ>\<rightarrow> (\<lambda>x. Pair (pmfify A x)))"
553
554instance ..
555
556end
557
558no_notation fcomp (infixl "\<circ>>" 60)
559no_notation scomp (infixl "\<circ>\<rightarrow>" 60)
560
561instantiation pmf :: (full_exhaustive) full_exhaustive
562begin
563
564definition full_exhaustive_pmf :: "('a pmf \<times> (unit \<Rightarrow> term) \<Rightarrow> (bool \<times> term list) option) \<Rightarrow> natural \<Rightarrow> (bool \<times> term list) option"
565where
566  "full_exhaustive_pmf f i =
567     Quickcheck_Exhaustive.full_exhaustive (\<lambda>A. 
568       Quickcheck_Exhaustive.full_exhaustive (\<lambda>x. f (pmfify A x)) i) i"
569
570instance ..
571
572end
573
574end
575