1(* Author: Andreas Lochbihler, ETH Zurich *)
2
3section \<open>Discrete subprobability distribution\<close>
4
5theory SPMF imports
6  Probability_Mass_Function
7  "HOL-Library.Complete_Partial_Order2"
8  "HOL-Library.Rewrite"
9begin
10
11subsection \<open>Auxiliary material\<close>
12
13lemma cSUP_singleton [simp]: "(SUP x\<in>{x}. f x :: _ :: conditionally_complete_lattice) = f x"
14by (metis cSup_singleton image_empty image_insert)
15
16subsubsection \<open>More about extended reals\<close>
17
18lemma [simp]:
19  shows ennreal_max_0: "ennreal (max 0 x) = ennreal x"
20  and ennreal_max_0': "ennreal (max x 0) = ennreal x"
21by(simp_all add: max_def ennreal_eq_0_iff)
22
23lemma e2ennreal_0 [simp]: "e2ennreal 0 = 0"
24by(simp add: zero_ennreal_def)
25
26lemma enn2real_bot [simp]: "enn2real \<bottom> = 0"
27by(simp add: bot_ennreal_def)
28
29lemma continuous_at_ennreal[continuous_intros]: "continuous F f \<Longrightarrow> continuous F (\<lambda>x. ennreal (f x))"
30  unfolding continuous_def by auto
31
32lemma ennreal_Sup:
33  assumes *: "(SUP a\<in>A. ennreal a) \<noteq> \<top>"
34  and "A \<noteq> {}"
35  shows "ennreal (Sup A) = (SUP a\<in>A. ennreal a)"
36proof (rule continuous_at_Sup_mono)
37  obtain r where r: "ennreal r = (SUP a\<in>A. ennreal a)" "r \<ge> 0"
38    using * by(cases "(SUP a\<in>A. ennreal a)") simp_all
39  then show "bdd_above A"
40    by(auto intro!: SUP_upper bdd_aboveI[of _ r] simp add: ennreal_le_iff[symmetric])
41qed (auto simp: mono_def continuous_at_imp_continuous_at_within continuous_at_ennreal ennreal_leI assms)
42
43lemma ennreal_SUP:
44  "\<lbrakk> (SUP a\<in>A. ennreal (f a)) \<noteq> \<top>; A \<noteq> {} \<rbrakk> \<Longrightarrow> ennreal (SUP a\<in>A. f a) = (SUP a\<in>A. ennreal (f a))"
45using ennreal_Sup[of "f ` A"] by (auto simp add: image_comp)
46
47lemma ennreal_lt_0: "x < 0 \<Longrightarrow> ennreal x = 0"
48by(simp add: ennreal_eq_0_iff)
49
50subsubsection \<open>More about \<^typ>\<open>'a option\<close>\<close>
51
52lemma None_in_map_option_image [simp]: "None \<in> map_option f ` A \<longleftrightarrow> None \<in> A"
53by auto
54
55lemma Some_in_map_option_image [simp]: "Some x \<in> map_option f ` A \<longleftrightarrow> (\<exists>y. x = f y \<and> Some y \<in> A)"
56by(auto intro: rev_image_eqI dest: sym)
57
58lemma case_option_collapse: "case_option x (\<lambda>_. x) = (\<lambda>_. x)"
59by(simp add: fun_eq_iff split: option.split)
60
61lemma case_option_id: "case_option None Some = id"
62by(rule ext)(simp split: option.split)
63
64inductive ord_option :: "('a \<Rightarrow> 'b \<Rightarrow> bool) \<Rightarrow> 'a option \<Rightarrow> 'b option \<Rightarrow> bool"
65  for ord :: "'a \<Rightarrow> 'b \<Rightarrow> bool"
66where
67  None: "ord_option ord None x"
68| Some: "ord x y \<Longrightarrow> ord_option ord (Some x) (Some y)"
69
70inductive_simps ord_option_simps [simp]:
71  "ord_option ord None x"
72  "ord_option ord x None"
73  "ord_option ord (Some x) (Some y)"
74  "ord_option ord (Some x) None"
75
76inductive_simps ord_option_eq_simps [simp]:
77  "ord_option (=) None y"
78  "ord_option (=) (Some x) y"
79
80lemma ord_option_reflI: "(\<And>y. y \<in> set_option x \<Longrightarrow> ord y y) \<Longrightarrow> ord_option ord x x"
81by(cases x) simp_all
82
83lemma reflp_ord_option: "reflp ord \<Longrightarrow> reflp (ord_option ord)"
84by(simp add: reflp_def ord_option_reflI)
85
86lemma ord_option_trans:
87  "\<lbrakk> ord_option ord x y; ord_option ord y z;
88    \<And>a b c. \<lbrakk> a \<in> set_option x; b \<in> set_option y; c \<in> set_option z; ord a b; ord b c \<rbrakk> \<Longrightarrow> ord a c \<rbrakk>
89  \<Longrightarrow> ord_option ord x z"
90by(auto elim!: ord_option.cases)
91
92lemma transp_ord_option: "transp ord \<Longrightarrow> transp (ord_option ord)"
93unfolding transp_def by(blast intro: ord_option_trans)
94
95lemma antisymp_ord_option: "antisymp ord \<Longrightarrow> antisymp (ord_option ord)"
96by(auto intro!: antisympI elim!: ord_option.cases dest: antisympD)
97
98lemma ord_option_chainD:
99  "Complete_Partial_Order.chain (ord_option ord) Y
100  \<Longrightarrow> Complete_Partial_Order.chain ord {x. Some x \<in> Y}"
101by(rule chainI)(auto dest: chainD)
102
103definition lub_option :: "('a set \<Rightarrow> 'b) \<Rightarrow> 'a option set \<Rightarrow> 'b option"
104where "lub_option lub Y = (if Y \<subseteq> {None} then None else Some (lub {x. Some x \<in> Y}))"
105
106lemma map_lub_option: "map_option f (lub_option lub Y) = lub_option (f \<circ> lub) Y"
107by(simp add: lub_option_def)
108
109lemma lub_option_upper:
110  assumes "Complete_Partial_Order.chain (ord_option ord) Y" "x \<in> Y"
111  and lub_upper: "\<And>Y x. \<lbrakk> Complete_Partial_Order.chain ord Y; x \<in> Y \<rbrakk> \<Longrightarrow> ord x (lub Y)"
112  shows "ord_option ord x (lub_option lub Y)"
113using assms(1-2)
114by(cases x)(auto simp add: lub_option_def intro: lub_upper[OF ord_option_chainD])
115
116lemma lub_option_least:
117  assumes Y: "Complete_Partial_Order.chain (ord_option ord) Y"
118  and upper: "\<And>x. x \<in> Y \<Longrightarrow> ord_option ord x y"
119  assumes lub_least: "\<And>Y y. \<lbrakk> Complete_Partial_Order.chain ord Y; \<And>x. x \<in> Y \<Longrightarrow> ord x y \<rbrakk> \<Longrightarrow> ord (lub Y) y"
120  shows "ord_option ord (lub_option lub Y) y"
121using Y
122by(cases y)(auto 4 3 simp add: lub_option_def intro: lub_least[OF ord_option_chainD] dest: upper)
123
124lemma lub_map_option: "lub_option lub (map_option f ` Y) = lub_option (lub \<circ> (`) f) Y"
125apply(auto simp add: lub_option_def)
126apply(erule notE)
127apply(rule arg_cong[where f=lub])
128apply(auto intro: rev_image_eqI dest: sym)
129done
130
131lemma ord_option_mono: "\<lbrakk> ord_option A x y; \<And>x y. A x y \<Longrightarrow> B x y \<rbrakk> \<Longrightarrow> ord_option B x y"
132by(auto elim: ord_option.cases)
133
134lemma ord_option_mono' [mono]:
135  "(\<And>x y. A x y \<longrightarrow> B x y) \<Longrightarrow> ord_option A x y \<longrightarrow> ord_option B x y"
136by(blast intro: ord_option_mono)
137
138lemma ord_option_compp: "ord_option (A OO B) = ord_option A OO ord_option B"
139by(auto simp add: fun_eq_iff elim!: ord_option.cases intro: ord_option.intros)
140
141lemma ord_option_inf: "inf (ord_option A) (ord_option B) = ord_option (inf A B)" (is "?lhs = ?rhs")
142proof(rule antisym)
143  show "?lhs \<le> ?rhs" by(auto elim!: ord_option.cases)
144qed(auto elim: ord_option_mono)
145
146lemma ord_option_map2: "ord_option ord x (map_option f y) = ord_option (\<lambda>x y. ord x (f y)) x y"
147by(auto elim: ord_option.cases)
148
149lemma ord_option_map1: "ord_option ord (map_option f x) y = ord_option (\<lambda>x y. ord (f x) y) x y"
150by(auto elim: ord_option.cases)
151
152lemma option_ord_Some1_iff: "option_ord (Some x) y \<longleftrightarrow> y = Some x"
153by(auto simp add: flat_ord_def)
154
155subsubsection \<open>A relator for sets that treats sets like predicates\<close>
156
157context includes lifting_syntax
158begin
159
160definition rel_pred :: "('a \<Rightarrow> 'b \<Rightarrow> bool) \<Rightarrow> 'a set \<Rightarrow> 'b set \<Rightarrow> bool"
161where "rel_pred R A B = (R ===> (=)) (\<lambda>x. x \<in> A) (\<lambda>y. y \<in> B)"
162
163lemma rel_predI: "(R ===> (=)) (\<lambda>x. x \<in> A) (\<lambda>y. y \<in> B) \<Longrightarrow> rel_pred R A B"
164by(simp add: rel_pred_def)
165
166lemma rel_predD: "\<lbrakk> rel_pred R A B; R x y \<rbrakk> \<Longrightarrow> x \<in> A \<longleftrightarrow> y \<in> B"
167by(simp add: rel_pred_def rel_fun_def)
168
169lemma Collect_parametric: "((A ===> (=)) ===> rel_pred A) Collect Collect"
170  \<comment> \<open>Declare this rule as @{attribute transfer_rule} only locally
171      because it blows up the search space for @{method transfer}
172      (in combination with @{thm [source] Collect_transfer})\<close>
173by(simp add: rel_funI rel_predI)
174
175end
176
177subsubsection \<open>Monotonicity rules\<close>
178
179lemma monotone_gfp_eadd1: "monotone (\<ge>) (\<ge>) (\<lambda>x. x + y :: enat)"
180by(auto intro!: monotoneI)
181
182lemma monotone_gfp_eadd2: "monotone (\<ge>) (\<ge>) (\<lambda>y. x + y :: enat)"
183by(auto intro!: monotoneI)
184
185lemma mono2mono_gfp_eadd[THEN gfp.mono2mono2, cont_intro, simp]:
186  shows monotone_eadd: "monotone (rel_prod (\<ge>) (\<ge>)) (\<ge>) (\<lambda>(x, y). x + y :: enat)"
187by(simp add: monotone_gfp_eadd1 monotone_gfp_eadd2)
188
189lemma eadd_gfp_partial_function_mono [partial_function_mono]:
190  "\<lbrakk> monotone (fun_ord (\<ge>)) (\<ge>) f; monotone (fun_ord (\<ge>)) (\<ge>) g \<rbrakk>
191  \<Longrightarrow> monotone (fun_ord (\<ge>)) (\<ge>) (\<lambda>x. f x + g x :: enat)"
192by(rule mono2mono_gfp_eadd)
193
194lemma mono2mono_ereal[THEN lfp.mono2mono]:
195  shows monotone_ereal: "monotone (\<le>) (\<le>) ereal"
196by(rule monotoneI) simp
197
198lemma mono2mono_ennreal[THEN lfp.mono2mono]:
199  shows monotone_ennreal: "monotone (\<le>) (\<le>) ennreal"
200by(rule monotoneI)(simp add: ennreal_leI)
201
202subsubsection \<open>Bijections\<close>
203
204lemma bi_unique_rel_set_bij_betw:
205  assumes unique: "bi_unique R"
206  and rel: "rel_set R A B"
207  shows "\<exists>f. bij_betw f A B \<and> (\<forall>x\<in>A. R x (f x))"
208proof -
209  from assms obtain f where f: "\<And>x. x \<in> A \<Longrightarrow> R x (f x)" and B: "\<And>x. x \<in> A \<Longrightarrow> f x \<in> B"
210    apply(atomize_elim)
211    apply(fold all_conj_distrib)
212    apply(subst choice_iff[symmetric])
213    apply(auto dest: rel_setD1)
214    done
215  have "inj_on f A" by(rule inj_onI)(auto dest!: f dest: bi_uniqueDl[OF unique])
216  moreover have "f ` A = B" using rel
217    by(auto 4 3 intro: B dest: rel_setD2 f bi_uniqueDr[OF unique])
218  ultimately have "bij_betw f A B" unfolding bij_betw_def ..
219  thus ?thesis using f by blast
220qed
221
222lemma bij_betw_rel_setD: "bij_betw f A B \<Longrightarrow> rel_set (\<lambda>x y. y = f x) A B"
223by(rule rel_setI)(auto dest: bij_betwE bij_betw_imp_surj_on[symmetric])
224
225subsection \<open>Subprobability mass function\<close>
226
227type_synonym 'a spmf = "'a option pmf"
228translations (type) "'a spmf" \<leftharpoondown> (type) "'a option pmf"
229
230definition measure_spmf :: "'a spmf \<Rightarrow> 'a measure"
231where "measure_spmf p = distr (restrict_space (measure_pmf p) (range Some)) (count_space UNIV) the"
232
233abbreviation spmf :: "'a spmf \<Rightarrow> 'a \<Rightarrow> real"
234where "spmf p x \<equiv> pmf p (Some x)"
235
236lemma space_measure_spmf: "space (measure_spmf p) = UNIV"
237by(simp add: measure_spmf_def)
238
239lemma sets_measure_spmf [simp, measurable_cong]: "sets (measure_spmf p) = sets (count_space UNIV)"
240by(simp add: measure_spmf_def)
241
242lemma measure_spmf_not_bot [simp]: "measure_spmf p \<noteq> \<bottom>"
243proof
244  assume "measure_spmf p = \<bottom>"
245  hence "space (measure_spmf p) = space \<bottom>" by simp
246  thus False by(simp add: space_measure_spmf)
247qed
248
249lemma measurable_the_measure_pmf_Some [measurable, simp]:
250  "the \<in> measurable (restrict_space (measure_pmf p) (range Some)) (count_space UNIV)"
251by(auto simp add: measurable_def sets_restrict_space space_restrict_space integral_restrict_space)
252
253lemma measurable_spmf_measure1[simp]: "measurable (measure_spmf M) N = UNIV \<rightarrow> space N"
254by(auto simp: measurable_def space_measure_spmf)
255
256lemma measurable_spmf_measure2[simp]: "measurable N (measure_spmf M) = measurable N (count_space UNIV)"
257by(intro measurable_cong_sets) simp_all
258
259lemma subprob_space_measure_spmf [simp, intro!]: "subprob_space (measure_spmf p)"
260proof
261  show "emeasure (measure_spmf p) (space (measure_spmf p)) \<le> 1"
262    by(simp add: measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space measure_pmf.measure_le_1)
263qed(simp add: space_measure_spmf)
264
265interpretation measure_spmf: subprob_space "measure_spmf p" for p
266by(rule subprob_space_measure_spmf)
267
268lemma finite_measure_spmf [simp]: "finite_measure (measure_spmf p)"
269by unfold_locales
270
271lemma spmf_conv_measure_spmf: "spmf p x = measure (measure_spmf p) {x}"
272by(auto simp add: measure_spmf_def measure_distr measure_restrict_space pmf.rep_eq space_restrict_space intro: arg_cong2[where f=measure])
273
274lemma emeasure_measure_spmf_conv_measure_pmf:
275  "emeasure (measure_spmf p) A = emeasure (measure_pmf p) (Some ` A)"
276by(auto simp add: measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space intro: arg_cong2[where f=emeasure])
277
278lemma measure_measure_spmf_conv_measure_pmf:
279  "measure (measure_spmf p) A = measure (measure_pmf p) (Some ` A)"
280using emeasure_measure_spmf_conv_measure_pmf[of p A]
281by(simp add: measure_spmf.emeasure_eq_measure measure_pmf.emeasure_eq_measure)
282
283lemma emeasure_spmf_map_pmf_Some [simp]:
284  "emeasure (measure_spmf (map_pmf Some p)) A = emeasure (measure_pmf p) A"
285by(auto simp add: measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space intro: arg_cong2[where f=emeasure])
286
287lemma measure_spmf_map_pmf_Some [simp]:
288  "measure (measure_spmf (map_pmf Some p)) A = measure (measure_pmf p) A"
289using emeasure_spmf_map_pmf_Some[of p A] by(simp add: measure_spmf.emeasure_eq_measure measure_pmf.emeasure_eq_measure)
290
291lemma nn_integral_measure_spmf: "(\<integral>\<^sup>+ x. f x \<partial>measure_spmf p) = \<integral>\<^sup>+ x. ennreal (spmf p x) * f x \<partial>count_space UNIV"
292  (is "?lhs = ?rhs")
293proof -
294  have "?lhs = \<integral>\<^sup>+ x. pmf p x * f (the x) \<partial>count_space (range Some)"
295    by(simp add: measure_spmf_def nn_integral_distr nn_integral_restrict_space nn_integral_measure_pmf nn_integral_count_space_indicator ac_simps times_ereal.simps(1)[symmetric] del: times_ereal.simps(1))
296  also have "\<dots> = \<integral>\<^sup>+ x. ennreal (spmf p (the x)) * f (the x) \<partial>count_space (range Some)"
297    by(rule nn_integral_cong) auto
298  also have "\<dots> = \<integral>\<^sup>+ x. spmf p (the (Some x)) * f (the (Some x)) \<partial>count_space UNIV"
299    by(rule nn_integral_bij_count_space[symmetric])(simp add: bij_betw_def)
300  also have "\<dots> = ?rhs" by simp
301  finally show ?thesis .
302qed
303
304lemma integral_measure_spmf:
305  assumes "integrable (measure_spmf p) f"
306  shows "(\<integral> x. f x \<partial>measure_spmf p) = \<integral> x. spmf p x * f x \<partial>count_space UNIV"
307proof -
308  have "integrable (count_space UNIV) (\<lambda>x. spmf p x * f x)"
309    using assms by(simp add: integrable_iff_bounded nn_integral_measure_spmf abs_mult ennreal_mult'')
310  then show ?thesis using assms
311    by(simp add: real_lebesgue_integral_def nn_integral_measure_spmf ennreal_mult'[symmetric])
312qed
313
314lemma emeasure_spmf_single: "emeasure (measure_spmf p) {x} = spmf p x"
315by(simp add: measure_spmf.emeasure_eq_measure spmf_conv_measure_spmf)
316
317lemma measurable_measure_spmf[measurable]:
318  "(\<lambda>x. measure_spmf (M x)) \<in> measurable (count_space UNIV) (subprob_algebra (count_space UNIV))"
319by (auto simp: space_subprob_algebra)
320
321lemma nn_integral_measure_spmf_conv_measure_pmf:
322  assumes [measurable]: "f \<in> borel_measurable (count_space UNIV)"
323  shows "nn_integral (measure_spmf p) f = nn_integral (restrict_space (measure_pmf p) (range Some)) (f \<circ> the)"
324by(simp add: measure_spmf_def nn_integral_distr o_def)
325
326lemma measure_spmf_in_space_subprob_algebra [simp]:
327  "measure_spmf p \<in> space (subprob_algebra (count_space UNIV))"
328by(simp add: space_subprob_algebra)
329
330lemma nn_integral_spmf_neq_top: "(\<integral>\<^sup>+ x. spmf p x \<partial>count_space UNIV) \<noteq> \<top>"
331using nn_integral_measure_spmf[where f="\<lambda>_. 1", of p, symmetric] by simp
332
333lemma SUP_spmf_neq_top': "(SUP p\<in>Y. ennreal (spmf p x)) \<noteq> \<top>"
334proof(rule neq_top_trans)
335  show "(SUP p\<in>Y. ennreal (spmf p x)) \<le> 1" by(rule SUP_least)(simp add: pmf_le_1)
336qed simp
337
338lemma SUP_spmf_neq_top: "(SUP i. ennreal (spmf (Y i) x)) \<noteq> \<top>"
339proof(rule neq_top_trans)
340  show "(SUP i. ennreal (spmf (Y i) x)) \<le> 1" by(rule SUP_least)(simp add: pmf_le_1)
341qed simp
342
343lemma SUP_emeasure_spmf_neq_top: "(SUP p\<in>Y. emeasure (measure_spmf p) A) \<noteq> \<top>"
344proof(rule neq_top_trans)
345  show "(SUP p\<in>Y. emeasure (measure_spmf p) A) \<le> 1"
346    by(rule SUP_least)(simp add: measure_spmf.subprob_emeasure_le_1)
347qed simp
348
349subsection \<open>Support\<close>
350
351definition set_spmf :: "'a spmf \<Rightarrow> 'a set"
352where "set_spmf p = set_pmf p \<bind> set_option"
353
354lemma set_spmf_rep_eq: "set_spmf p = {x. measure (measure_spmf p) {x} \<noteq> 0}"
355proof -
356  have "\<And>x :: 'a. the -` {x} \<inter> range Some = {Some x}" by auto
357  then show ?thesis
358    by(auto simp add: set_spmf_def set_pmf.rep_eq measure_spmf_def measure_distr measure_restrict_space space_restrict_space intro: rev_image_eqI)
359qed
360
361lemma in_set_spmf: "x \<in> set_spmf p \<longleftrightarrow> Some x \<in> set_pmf p"
362by(simp add: set_spmf_def)
363
364lemma AE_measure_spmf_iff [simp]: "(AE x in measure_spmf p. P x) \<longleftrightarrow> (\<forall>x\<in>set_spmf p. P x)"
365by(auto 4 3 simp add: measure_spmf_def AE_distr_iff AE_restrict_space_iff AE_measure_pmf_iff set_spmf_def cong del: AE_cong)
366
367lemma spmf_eq_0_set_spmf: "spmf p x = 0 \<longleftrightarrow> x \<notin> set_spmf p"
368by(auto simp add: pmf_eq_0_set_pmf set_spmf_def intro: rev_image_eqI)
369
370lemma in_set_spmf_iff_spmf: "x \<in> set_spmf p \<longleftrightarrow> spmf p x \<noteq> 0"
371by(auto simp add: set_spmf_def set_pmf_iff intro: rev_image_eqI)
372
373lemma set_spmf_return_pmf_None [simp]: "set_spmf (return_pmf None) = {}"
374by(auto simp add: set_spmf_def)
375
376lemma countable_set_spmf [simp]: "countable (set_spmf p)"
377by(simp add: set_spmf_def bind_UNION)
378
379lemma spmf_eqI:
380  assumes "\<And>i. spmf p i = spmf q i"
381  shows "p = q"
382proof(rule pmf_eqI)
383  fix i
384  show "pmf p i = pmf q i"
385  proof(cases i)
386    case (Some i')
387    thus ?thesis by(simp add: assms)
388  next
389    case None
390    have "ennreal (pmf p i) = measure (measure_pmf p) {i}" by(simp add: pmf_def)
391    also have "{i} = space (measure_pmf p) - range Some"
392      by(auto simp add: None intro: ccontr)
393    also have "measure (measure_pmf p) \<dots> = ennreal 1 - measure (measure_pmf p) (range Some)"
394      by(simp add: measure_pmf.prob_compl ennreal_minus[symmetric] del: space_measure_pmf)
395    also have "range Some = (\<Union>x\<in>set_spmf p. {Some x}) \<union> Some ` (- set_spmf p)"
396      by auto
397    also have "measure (measure_pmf p) \<dots> = measure (measure_pmf p) (\<Union>x\<in>set_spmf p. {Some x})"
398      by(rule measure_pmf.measure_zero_union)(auto simp add: measure_pmf.prob_eq_0 AE_measure_pmf_iff in_set_spmf_iff_spmf set_pmf_iff)
399    also have "ennreal \<dots> = \<integral>\<^sup>+ x. measure (measure_pmf p) {Some x} \<partial>count_space (set_spmf p)"
400      unfolding measure_pmf.emeasure_eq_measure[symmetric]
401      by(simp_all add: emeasure_UN_countable disjoint_family_on_def)
402    also have "\<dots> = \<integral>\<^sup>+ x. spmf p x \<partial>count_space (set_spmf p)" by(simp add: pmf_def)
403    also have "\<dots> = \<integral>\<^sup>+ x. spmf q x \<partial>count_space (set_spmf p)" by(simp add: assms)
404    also have "set_spmf p = set_spmf q" by(auto simp add: in_set_spmf_iff_spmf assms)
405    also have "(\<integral>\<^sup>+ x. spmf q x \<partial>count_space (set_spmf q)) = \<integral>\<^sup>+ x. measure (measure_pmf q) {Some x} \<partial>count_space (set_spmf q)"
406      by(simp add: pmf_def)
407    also have "\<dots> = measure (measure_pmf q) (\<Union>x\<in>set_spmf q. {Some x})"
408      unfolding measure_pmf.emeasure_eq_measure[symmetric]
409      by(simp_all add: emeasure_UN_countable disjoint_family_on_def)
410    also have "\<dots> = measure (measure_pmf q) ((\<Union>x\<in>set_spmf q. {Some x}) \<union> Some ` (- set_spmf q))"
411      by(rule ennreal_cong measure_pmf.measure_zero_union[symmetric])+(auto simp add: measure_pmf.prob_eq_0 AE_measure_pmf_iff in_set_spmf_iff_spmf set_pmf_iff)
412    also have "((\<Union>x\<in>set_spmf q. {Some x}) \<union> Some ` (- set_spmf q)) = range Some" by auto
413    also have "ennreal 1 - measure (measure_pmf q) \<dots> = measure (measure_pmf q) (space (measure_pmf q) - range Some)"
414      by(simp add: one_ereal_def measure_pmf.prob_compl ennreal_minus[symmetric] del: space_measure_pmf)
415    also have "space (measure_pmf q) - range Some = {i}"
416      by(auto simp add: None intro: ccontr)
417    also have "measure (measure_pmf q) \<dots> = pmf q i" by(simp add: pmf_def)
418    finally show ?thesis by simp
419  qed
420qed
421
422lemma integral_measure_spmf_restrict:
423  fixes f ::  "'a \<Rightarrow> 'b :: {banach, second_countable_topology}" shows
424  "(\<integral> x. f x \<partial>measure_spmf M) = (\<integral> x. f x \<partial>restrict_space (measure_spmf M) (set_spmf M))"
425by(auto intro!: integral_cong_AE simp add: integral_restrict_space)
426
427lemma nn_integral_measure_spmf':
428  "(\<integral>\<^sup>+ x. f x \<partial>measure_spmf p) = \<integral>\<^sup>+ x. ennreal (spmf p x) * f x \<partial>count_space (set_spmf p)"
429by(auto simp add: nn_integral_measure_spmf nn_integral_count_space_indicator in_set_spmf_iff_spmf intro!: nn_integral_cong split: split_indicator)
430
431subsection \<open>Functorial structure\<close>
432
433abbreviation map_spmf :: "('a \<Rightarrow> 'b) \<Rightarrow> 'a spmf \<Rightarrow> 'b spmf"
434where "map_spmf f \<equiv> map_pmf (map_option f)"
435
436context begin
437local_setup \<open>Local_Theory.map_background_naming (Name_Space.mandatory_path "spmf")\<close>
438
439lemma map_comp: "map_spmf f (map_spmf g p) = map_spmf (f \<circ> g) p"
440by(simp add: pmf.map_comp o_def option.map_comp)
441
442lemma map_id0: "map_spmf id = id"
443by(simp add: pmf.map_id option.map_id0)
444
445lemma map_id [simp]: "map_spmf id p = p"
446by(simp add: map_id0)
447
448lemma map_ident [simp]: "map_spmf (\<lambda>x. x) p = p"
449by(simp add: id_def[symmetric])
450
451end
452
453lemma set_map_spmf [simp]: "set_spmf (map_spmf f p) = f ` set_spmf p"
454by(simp add: set_spmf_def image_bind bind_image o_def Option.option.set_map)
455
456lemma map_spmf_cong:
457  "\<lbrakk> p = q; \<And>x. x \<in> set_spmf q \<Longrightarrow> f x = g x \<rbrakk>
458  \<Longrightarrow> map_spmf f p = map_spmf g q"
459by(auto intro: pmf.map_cong option.map_cong simp add: in_set_spmf)
460
461lemma map_spmf_cong_simp:
462  "\<lbrakk> p = q; \<And>x. x \<in> set_spmf q =simp=> f x = g x \<rbrakk>
463  \<Longrightarrow> map_spmf f p = map_spmf g q"
464unfolding simp_implies_def by(rule map_spmf_cong)
465
466lemma map_spmf_idI: "(\<And>x. x \<in> set_spmf p \<Longrightarrow> f x = x) \<Longrightarrow> map_spmf f p = p"
467by(rule map_pmf_idI map_option_idI)+(simp add: in_set_spmf)
468
469lemma emeasure_map_spmf:
470  "emeasure (measure_spmf (map_spmf f p)) A = emeasure (measure_spmf p) (f -` A)"
471by(auto simp add: measure_spmf_def emeasure_distr measurable_restrict_space1 space_restrict_space emeasure_restrict_space intro: arg_cong2[where f=emeasure])
472
473lemma measure_map_spmf: "measure (measure_spmf (map_spmf f p)) A = measure (measure_spmf p) (f -` A)"
474using emeasure_map_spmf[of f p A] by(simp add: measure_spmf.emeasure_eq_measure)
475
476lemma measure_map_spmf_conv_distr:
477  "measure_spmf (map_spmf f p) = distr (measure_spmf p) (count_space UNIV) f"
478by(rule measure_eqI)(simp_all add: emeasure_map_spmf emeasure_distr)
479
480lemma spmf_map_pmf_Some [simp]: "spmf (map_pmf Some p) i = pmf p i"
481by(simp add: pmf_map_inj')
482
483lemma spmf_map_inj: "\<lbrakk> inj_on f (set_spmf M); x \<in> set_spmf M \<rbrakk> \<Longrightarrow> spmf (map_spmf f M) (f x) = spmf M x"
484by(subst option.map(2)[symmetric, where f=f])(rule pmf_map_inj, auto simp add: in_set_spmf inj_on_def elim!: option.inj_map_strong[rotated])
485
486lemma spmf_map_inj': "inj f \<Longrightarrow> spmf (map_spmf f M) (f x) = spmf M x"
487by(subst option.map(2)[symmetric, where f=f])(rule pmf_map_inj'[OF option.inj_map])
488
489lemma spmf_map_outside: "x \<notin> f ` set_spmf M \<Longrightarrow> spmf (map_spmf f M) x = 0"
490unfolding spmf_eq_0_set_spmf by simp
491
492lemma ennreal_spmf_map: "ennreal (spmf (map_spmf f p) x) = emeasure (measure_spmf p) (f -` {x})"
493by(auto simp add: ennreal_pmf_map measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space intro: arg_cong2[where f=emeasure])
494
495lemma spmf_map: "spmf (map_spmf f p) x = measure (measure_spmf p) (f -` {x})"
496using ennreal_spmf_map[of f p x] by(simp add: measure_spmf.emeasure_eq_measure)
497
498lemma ennreal_spmf_map_conv_nn_integral:
499  "ennreal (spmf (map_spmf f p) x) = integral\<^sup>N (measure_spmf p) (indicator (f -` {x}))"
500by(auto simp add: ennreal_pmf_map measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space intro: arg_cong2[where f=emeasure])
501
502subsection \<open>Monad operations\<close>
503
504subsubsection \<open>Return\<close>
505
506abbreviation return_spmf :: "'a \<Rightarrow> 'a spmf"
507where "return_spmf x \<equiv> return_pmf (Some x)"
508
509lemma pmf_return_spmf: "pmf (return_spmf x) y = indicator {y} (Some x)"
510by(fact pmf_return)
511
512lemma measure_spmf_return_spmf: "measure_spmf (return_spmf x) = Giry_Monad.return (count_space UNIV) x"
513by(rule measure_eqI)(simp_all add: measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space indicator_def)
514
515lemma measure_spmf_return_pmf_None [simp]: "measure_spmf (return_pmf None) = null_measure (count_space UNIV)"
516by(rule measure_eqI)(auto simp add: measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space indicator_eq_0_iff)
517
518lemma set_return_spmf [simp]: "set_spmf (return_spmf x) = {x}"
519by(auto simp add: set_spmf_def)
520
521subsubsection \<open>Bind\<close>
522
523definition bind_spmf :: "'a spmf \<Rightarrow> ('a \<Rightarrow> 'b spmf) \<Rightarrow> 'b spmf"
524where "bind_spmf x f = bind_pmf x (\<lambda>a. case a of None \<Rightarrow> return_pmf None | Some a' \<Rightarrow> f a')"
525
526adhoc_overloading Monad_Syntax.bind bind_spmf
527
528lemma return_None_bind_spmf [simp]: "return_pmf None \<bind> (f :: 'a \<Rightarrow> _) = return_pmf None"
529by(simp add: bind_spmf_def bind_return_pmf)
530
531lemma return_bind_spmf [simp]: "return_spmf x \<bind> f = f x"
532by(simp add: bind_spmf_def bind_return_pmf)
533
534lemma bind_return_spmf [simp]: "x \<bind> return_spmf = x"
535proof -
536  have "\<And>a :: 'a option. (case a of None \<Rightarrow> return_pmf None | Some a' \<Rightarrow> return_spmf a') = return_pmf a"
537    by(simp split: option.split)
538  then show ?thesis
539    by(simp add: bind_spmf_def bind_return_pmf')
540qed
541
542lemma bind_spmf_assoc [simp]:
543  fixes x :: "'a spmf" and f :: "'a \<Rightarrow> 'b spmf" and g :: "'b \<Rightarrow> 'c spmf"
544  shows "(x \<bind> f) \<bind> g = x \<bind> (\<lambda>y. f y \<bind> g)"
545by(auto simp add: bind_spmf_def bind_assoc_pmf fun_eq_iff bind_return_pmf split: option.split intro: arg_cong[where f="bind_pmf x"])
546
547lemma pmf_bind_spmf_None: "pmf (p \<bind> f) None = pmf p None + \<integral> x. pmf (f x) None \<partial>measure_spmf p"
548  (is "?lhs = ?rhs")
549proof -
550  let ?f = "\<lambda>x. pmf (case x of None \<Rightarrow> return_pmf None | Some x \<Rightarrow> f x) None"
551  have "?lhs = \<integral> x. ?f x \<partial>measure_pmf p"
552    by(simp add: bind_spmf_def pmf_bind)
553  also have "\<dots> = \<integral> x. ?f None * indicator {None} x + ?f x * indicator (range Some) x \<partial>measure_pmf p"
554    by(rule Bochner_Integration.integral_cong)(auto simp add: indicator_def)
555  also have "\<dots> = (\<integral> x. ?f None * indicator {None} x \<partial>measure_pmf p) + (\<integral> x. ?f x * indicator (range Some) x \<partial>measure_pmf p)"
556    by(rule Bochner_Integration.integral_add)(auto 4 3 intro: integrable_real_mult_indicator measure_pmf.integrable_const_bound[where B=1] simp add: AE_measure_pmf_iff pmf_le_1)
557  also have "\<dots> = pmf p None + \<integral> x. indicator (range Some) x * pmf (f (the x)) None \<partial>measure_pmf p"
558    by(auto simp add: measure_measure_pmf_finite indicator_eq_0_iff intro!: Bochner_Integration.integral_cong)
559  also have "\<dots> = ?rhs" unfolding measure_spmf_def
560    by(subst integral_distr)(auto simp add: integral_restrict_space)
561  finally show ?thesis .
562qed
563
564lemma spmf_bind: "spmf (p \<bind> f) y = \<integral> x. spmf (f x) y \<partial>measure_spmf p"
565unfolding measure_spmf_def
566by(subst integral_distr)(auto simp add: bind_spmf_def pmf_bind integral_restrict_space indicator_eq_0_iff intro!: Bochner_Integration.integral_cong split: option.split)
567
568lemma ennreal_spmf_bind: "ennreal (spmf (p \<bind> f) x) = \<integral>\<^sup>+ y. spmf (f y) x \<partial>measure_spmf p"
569by(auto simp add: bind_spmf_def ennreal_pmf_bind nn_integral_measure_spmf_conv_measure_pmf nn_integral_restrict_space intro: nn_integral_cong split: split_indicator option.split)
570
571lemma measure_spmf_bind_pmf: "measure_spmf (p \<bind> f) = measure_pmf p \<bind> measure_spmf \<circ> f"
572  (is "?lhs = ?rhs")
573proof(rule measure_eqI)
574  show "sets ?lhs = sets ?rhs"
575    by(simp add: sets_bind[where N="count_space UNIV"] space_measure_spmf)
576next
577  fix A :: "'a set"
578  have "emeasure ?lhs A = \<integral>\<^sup>+ x. emeasure (measure_spmf (f x)) A \<partial>measure_pmf p"
579    by(simp add: measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space bind_spmf_def)
580  also have "\<dots> = emeasure ?rhs A"
581    by(simp add: emeasure_bind[where N="count_space UNIV"] space_measure_spmf space_subprob_algebra)
582  finally show "emeasure ?lhs A = emeasure ?rhs A" .
583qed
584
585lemma measure_spmf_bind: "measure_spmf (p \<bind> f) = measure_spmf p \<bind> measure_spmf \<circ> f"
586  (is "?lhs = ?rhs")
587proof(rule measure_eqI)
588  show "sets ?lhs = sets ?rhs"
589    by(simp add: sets_bind[where N="count_space UNIV"] space_measure_spmf)
590next
591  fix A :: "'a set"
592  let ?A = "the -` A \<inter> range Some"
593  have "emeasure ?lhs A = \<integral>\<^sup>+ x. emeasure (measure_pmf (case x of None \<Rightarrow> return_pmf None | Some x \<Rightarrow> f x)) ?A \<partial>measure_pmf p"
594    by(simp add: measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space bind_spmf_def)
595  also have "\<dots> =  \<integral>\<^sup>+ x. emeasure (measure_pmf (f (the x))) ?A * indicator (range Some) x \<partial>measure_pmf p"
596    by(rule nn_integral_cong)(auto split: option.split simp add: indicator_def)
597  also have "\<dots> = \<integral>\<^sup>+ x. emeasure (measure_spmf (f x)) A \<partial>measure_spmf p"
598    by(simp add: measure_spmf_def nn_integral_distr nn_integral_restrict_space emeasure_distr space_restrict_space emeasure_restrict_space)
599  also have "\<dots> = emeasure ?rhs A"
600    by(simp add: emeasure_bind[where N="count_space UNIV"] space_measure_spmf space_subprob_algebra)
601  finally show "emeasure ?lhs A = emeasure ?rhs A" .
602qed
603
604lemma map_spmf_bind_spmf: "map_spmf f (bind_spmf p g) = bind_spmf p (map_spmf f \<circ> g)"
605by(auto simp add: bind_spmf_def map_bind_pmf fun_eq_iff split: option.split intro: arg_cong2[where f=bind_pmf])
606
607lemma bind_map_spmf: "map_spmf f p \<bind> g = p \<bind> g \<circ> f"
608by(simp add: bind_spmf_def bind_map_pmf o_def cong del: option.case_cong_weak)
609
610lemma spmf_bind_leI:
611  assumes "\<And>y. y \<in> set_spmf p \<Longrightarrow> spmf (f y) x \<le> r"
612  and "0 \<le> r"
613  shows "spmf (bind_spmf p f) x \<le> r"
614proof -
615  have "ennreal (spmf (bind_spmf p f) x) = \<integral>\<^sup>+ y. spmf (f y) x \<partial>measure_spmf p" by(rule ennreal_spmf_bind)
616  also have "\<dots> \<le> \<integral>\<^sup>+ y. r \<partial>measure_spmf p" by(rule nn_integral_mono_AE)(simp add: assms)
617  also have "\<dots> \<le> r" using assms measure_spmf.emeasure_space_le_1
618    by(auto simp add: measure_spmf.emeasure_eq_measure intro!: mult_left_le)
619  finally show ?thesis using assms(2) by(simp)
620qed
621
622lemma map_spmf_conv_bind_spmf: "map_spmf f p = (p \<bind> (\<lambda>x. return_spmf (f x)))"
623by(simp add: map_pmf_def bind_spmf_def)(rule bind_pmf_cong, simp_all split: option.split)
624
625lemma bind_spmf_cong:
626  "\<lbrakk> p = q; \<And>x. x \<in> set_spmf q \<Longrightarrow> f x = g x \<rbrakk>
627  \<Longrightarrow> bind_spmf p f = bind_spmf q g"
628by(auto simp add: bind_spmf_def in_set_spmf intro: bind_pmf_cong option.case_cong)
629
630lemma bind_spmf_cong_simp:
631  "\<lbrakk> p = q; \<And>x. x \<in> set_spmf q =simp=> f x = g x \<rbrakk>
632  \<Longrightarrow> bind_spmf p f = bind_spmf q g"
633by(simp add: simp_implies_def cong: bind_spmf_cong)
634
635lemma set_bind_spmf: "set_spmf (M \<bind> f) = set_spmf M \<bind> (set_spmf \<circ> f)"
636by(auto simp add: set_spmf_def bind_spmf_def bind_UNION split: option.splits)
637
638lemma bind_spmf_const_return_None [simp]: "bind_spmf p (\<lambda>_. return_pmf None) = return_pmf None"
639by(simp add: bind_spmf_def case_option_collapse)
640
641lemma bind_commute_spmf:
642  "bind_spmf p (\<lambda>x. bind_spmf q (f x)) = bind_spmf q (\<lambda>y. bind_spmf p (\<lambda>x. f x y))"
643  (is "?lhs = ?rhs")
644proof -
645  let ?f = "\<lambda>x y. case x of None \<Rightarrow> return_pmf None | Some a \<Rightarrow> (case y of None \<Rightarrow> return_pmf None | Some b \<Rightarrow> f a b)"
646  have "?lhs = p \<bind> (\<lambda>x. q \<bind> ?f x)"
647    unfolding bind_spmf_def by(rule bind_pmf_cong[OF refl])(simp split: option.split)
648  also have "\<dots> = q \<bind> (\<lambda>y. p \<bind> (\<lambda>x. ?f x y))" by(rule bind_commute_pmf)
649  also have "\<dots> = ?rhs" unfolding bind_spmf_def
650    by(rule bind_pmf_cong[OF refl])(auto split: option.split, metis bind_spmf_const_return_None bind_spmf_def)
651  finally show ?thesis .
652qed
653
654subsection \<open>Relator\<close>
655
656abbreviation rel_spmf :: "('a \<Rightarrow> 'b \<Rightarrow> bool) \<Rightarrow> 'a spmf \<Rightarrow> 'b spmf \<Rightarrow> bool"
657where "rel_spmf R \<equiv> rel_pmf (rel_option R)"
658
659lemma rel_pmf_mono:
660  "\<lbrakk>rel_pmf A f g; \<And>x y. A x y \<Longrightarrow> B x y \<rbrakk> \<Longrightarrow> rel_pmf B f g"
661using pmf.rel_mono[of A B] by(simp add: le_fun_def)
662
663lemma rel_spmf_mono:
664  "\<lbrakk>rel_spmf A f g; \<And>x y. A x y \<Longrightarrow> B x y \<rbrakk> \<Longrightarrow> rel_spmf B f g"
665apply(erule rel_pmf_mono)
666using option.rel_mono[of A B] by(simp add: le_fun_def)
667
668lemma rel_spmf_mono_strong:
669  "\<lbrakk> rel_spmf A f g; \<And>x y. \<lbrakk> A x y; x \<in> set_spmf f; y \<in> set_spmf g \<rbrakk> \<Longrightarrow> B x y \<rbrakk> \<Longrightarrow> rel_spmf B f g"
670apply(erule pmf.rel_mono_strong)
671apply(erule option.rel_mono_strong)
672apply(auto simp add: in_set_spmf)
673done
674
675lemma rel_spmf_reflI: "(\<And>x. x \<in> set_spmf p \<Longrightarrow> P x x) \<Longrightarrow> rel_spmf P p p"
676by(rule rel_pmf_reflI)(auto simp add: set_spmf_def intro: rel_option_reflI)
677
678lemma rel_spmfI [intro?]:
679  "\<lbrakk> \<And>x y. (x, y) \<in> set_spmf pq \<Longrightarrow> P x y; map_spmf fst pq = p; map_spmf snd pq = q \<rbrakk>
680  \<Longrightarrow> rel_spmf P p q"
681by(rule rel_pmf.intros[where pq="map_pmf (\<lambda>x. case x of None \<Rightarrow> (None, None) | Some (a, b) \<Rightarrow> (Some a, Some b)) pq"])
682  (auto simp add: pmf.map_comp o_def in_set_spmf split: option.splits intro: pmf.map_cong)
683
684lemma rel_spmfE [elim?, consumes 1, case_names rel_spmf]:
685  assumes "rel_spmf P p q"
686  obtains pq where
687    "\<And>x y. (x, y) \<in> set_spmf pq \<Longrightarrow> P x y"
688    "p = map_spmf fst pq"
689    "q = map_spmf snd pq"
690using assms
691proof(cases rule: rel_pmf.cases[consumes 1, case_names rel_pmf])
692  case (rel_pmf pq)
693  let ?pq = "map_pmf (\<lambda>(a, b). case (a, b) of (Some x, Some y) \<Rightarrow> Some (x, y) | _ \<Rightarrow> None) pq"
694  have "\<And>x y. (x, y) \<in> set_spmf ?pq \<Longrightarrow> P x y"
695    by(auto simp add: in_set_spmf split: option.split_asm dest: rel_pmf(1))
696  moreover
697  have "\<And>x. (x, None) \<in> set_pmf pq \<Longrightarrow> x = None" by(auto dest!: rel_pmf(1))
698  then have "p = map_spmf fst ?pq" using rel_pmf(2)
699    by(auto simp add: pmf.map_comp split_beta intro!: pmf.map_cong split: option.split)
700  moreover
701  have "\<And>y. (None, y) \<in> set_pmf pq \<Longrightarrow> y = None" by(auto dest!: rel_pmf(1))
702  then have "q = map_spmf snd ?pq" using rel_pmf(3)
703    by(auto simp add: pmf.map_comp split_beta intro!: pmf.map_cong split: option.split)
704  ultimately show thesis ..
705qed
706
707lemma rel_spmf_simps:
708  "rel_spmf R p q \<longleftrightarrow> (\<exists>pq. (\<forall>(x, y)\<in>set_spmf pq. R x y) \<and> map_spmf fst pq = p \<and> map_spmf snd pq = q)"
709by(auto intro: rel_spmfI elim!: rel_spmfE)
710
711lemma spmf_rel_map:
712  shows spmf_rel_map1: "\<And>R f x. rel_spmf R (map_spmf f x) = rel_spmf (\<lambda>x. R (f x)) x"
713  and spmf_rel_map2: "\<And>R x g y. rel_spmf R x (map_spmf g y) = rel_spmf (\<lambda>x y. R x (g y)) x y"
714by(simp_all add: fun_eq_iff pmf.rel_map option.rel_map[abs_def])
715
716lemma spmf_rel_conversep: "rel_spmf R\<inverse>\<inverse> = (rel_spmf R)\<inverse>\<inverse>"
717by(simp add: option.rel_conversep pmf.rel_conversep)
718
719lemma spmf_rel_eq: "rel_spmf (=) = (=)"
720by(simp add: pmf.rel_eq option.rel_eq)
721
722context includes lifting_syntax
723begin
724
725lemma bind_spmf_parametric [transfer_rule]:
726  "(rel_spmf A ===> (A ===> rel_spmf B) ===> rel_spmf B) bind_spmf bind_spmf"
727unfolding bind_spmf_def[abs_def] by transfer_prover
728
729lemma return_spmf_parametric: "(A ===> rel_spmf A) return_spmf return_spmf"
730by transfer_prover
731
732lemma map_spmf_parametric: "((A ===> B) ===> rel_spmf A ===> rel_spmf B) map_spmf map_spmf"
733by transfer_prover
734
735lemma rel_spmf_parametric:
736  "((A ===> B ===> (=)) ===> rel_spmf A ===> rel_spmf B ===> (=)) rel_spmf rel_spmf"
737by transfer_prover
738
739lemma set_spmf_parametric [transfer_rule]:
740  "(rel_spmf A ===> rel_set A) set_spmf set_spmf"
741unfolding set_spmf_def[abs_def] by transfer_prover
742
743lemma return_spmf_None_parametric:
744  "(rel_spmf A) (return_pmf None) (return_pmf None)"
745by simp
746
747end
748
749lemma rel_spmf_bindI:
750  "\<lbrakk> rel_spmf R p q; \<And>x y. R x y \<Longrightarrow> rel_spmf P (f x) (g y) \<rbrakk>
751  \<Longrightarrow> rel_spmf P (p \<bind> f) (q \<bind> g)"
752by(fact bind_spmf_parametric[THEN rel_funD, THEN rel_funD, OF _ rel_funI])
753
754lemma rel_spmf_bind_reflI:
755  "(\<And>x. x \<in> set_spmf p \<Longrightarrow> rel_spmf P (f x) (g x)) \<Longrightarrow> rel_spmf P (p \<bind> f) (p \<bind> g)"
756by(rule rel_spmf_bindI[where R="\<lambda>x y. x = y \<and> x \<in> set_spmf p"])(auto intro: rel_spmf_reflI)
757
758lemma rel_pmf_return_pmfI: "P x y \<Longrightarrow> rel_pmf P (return_pmf x) (return_pmf y)"
759by(rule rel_pmf.intros[where pq="return_pmf (x, y)"])(simp_all)
760
761context includes lifting_syntax
762begin
763
764text \<open>We do not yet have a relator for \<^typ>\<open>'a measure\<close>, so we combine \<^const>\<open>measure\<close> and \<^const>\<open>measure_pmf\<close>\<close>
765lemma measure_pmf_parametric:
766  "(rel_pmf A ===> rel_pred A ===> (=)) (\<lambda>p. measure (measure_pmf p)) (\<lambda>q. measure (measure_pmf q))"
767proof(rule rel_funI)+
768  fix p q X Y
769  assume "rel_pmf A p q" and "rel_pred A X Y"
770  from this(1) obtain pq where A: "\<And>x y. (x, y) \<in> set_pmf pq \<Longrightarrow> A x y"
771    and p: "p = map_pmf fst pq" and q: "q = map_pmf snd pq" by cases auto
772  show "measure p X = measure q Y" unfolding p q measure_map_pmf
773    by(rule measure_pmf.finite_measure_eq_AE)(auto simp add: AE_measure_pmf_iff dest!: A rel_predD[OF \<open>rel_pred _ _ _\<close>])
774qed
775
776lemma measure_spmf_parametric:
777  "(rel_spmf A ===> rel_pred A ===> (=)) (\<lambda>p. measure (measure_spmf p)) (\<lambda>q. measure (measure_spmf q))"
778unfolding measure_measure_spmf_conv_measure_pmf[abs_def]
779apply(rule rel_funI)+
780apply(erule measure_pmf_parametric[THEN rel_funD, THEN rel_funD])
781apply(auto simp add: rel_pred_def rel_fun_def elim: option.rel_cases)
782done
783
784end
785
786subsection \<open>From \<^typ>\<open>'a pmf\<close> to \<^typ>\<open>'a spmf\<close>\<close>
787
788definition spmf_of_pmf :: "'a pmf \<Rightarrow> 'a spmf"
789where "spmf_of_pmf = map_pmf Some"
790
791lemma set_spmf_spmf_of_pmf [simp]: "set_spmf (spmf_of_pmf p) = set_pmf p"
792by(auto simp add: spmf_of_pmf_def set_spmf_def bind_image o_def)
793
794lemma spmf_spmf_of_pmf [simp]: "spmf (spmf_of_pmf p) x = pmf p x"
795by(simp add: spmf_of_pmf_def)
796
797lemma pmf_spmf_of_pmf_None [simp]: "pmf (spmf_of_pmf p) None = 0"
798using ennreal_pmf_map[of Some p None] by(simp add: spmf_of_pmf_def)
799
800lemma emeasure_spmf_of_pmf [simp]: "emeasure (measure_spmf (spmf_of_pmf p)) A = emeasure (measure_pmf p) A"
801by(simp add: emeasure_measure_spmf_conv_measure_pmf spmf_of_pmf_def inj_vimage_image_eq)
802
803lemma measure_spmf_spmf_of_pmf [simp]: "measure_spmf (spmf_of_pmf p) = measure_pmf p"
804by(rule measure_eqI) simp_all
805
806lemma map_spmf_of_pmf [simp]: "map_spmf f (spmf_of_pmf p) = spmf_of_pmf (map_pmf f p)"
807by(simp add: spmf_of_pmf_def pmf.map_comp o_def)
808
809lemma rel_spmf_spmf_of_pmf [simp]: "rel_spmf R (spmf_of_pmf p) (spmf_of_pmf q) = rel_pmf R p q"
810by(simp add: spmf_of_pmf_def pmf.rel_map)
811
812lemma spmf_of_pmf_return_pmf [simp]: "spmf_of_pmf (return_pmf x) = return_spmf x"
813by(simp add: spmf_of_pmf_def)
814
815lemma bind_spmf_of_pmf [simp]: "bind_spmf (spmf_of_pmf p) f = bind_pmf p f"
816by(simp add: spmf_of_pmf_def bind_spmf_def bind_map_pmf)
817
818lemma set_spmf_bind_pmf: "set_spmf (bind_pmf p f) = Set.bind (set_pmf p) (set_spmf \<circ> f)"
819unfolding bind_spmf_of_pmf[symmetric] by(subst set_bind_spmf) simp
820
821lemma spmf_of_pmf_bind: "spmf_of_pmf (bind_pmf p f) = bind_pmf p (\<lambda>x. spmf_of_pmf (f x))"
822by(simp add: spmf_of_pmf_def map_bind_pmf)
823
824lemma bind_pmf_return_spmf: "p \<bind> (\<lambda>x. return_spmf (f x)) = spmf_of_pmf (map_pmf f p)"
825by(simp add: map_pmf_def spmf_of_pmf_bind)
826
827subsection \<open>Weight of a subprobability\<close>
828
829abbreviation weight_spmf :: "'a spmf \<Rightarrow> real"
830where "weight_spmf p \<equiv> measure (measure_spmf p) (space (measure_spmf p))"
831
832lemma weight_spmf_def: "weight_spmf p = measure (measure_spmf p) UNIV"
833by(simp add: space_measure_spmf)
834
835lemma weight_spmf_le_1: "weight_spmf p \<le> 1"
836by(simp add: measure_spmf.subprob_measure_le_1)
837
838lemma weight_return_spmf [simp]: "weight_spmf (return_spmf x) = 1"
839by(simp add: measure_spmf_return_spmf measure_return)
840
841lemma weight_return_pmf_None [simp]: "weight_spmf (return_pmf None) = 0"
842by(simp)
843
844lemma weight_map_spmf [simp]: "weight_spmf (map_spmf f p) = weight_spmf p"
845by(simp add: weight_spmf_def measure_map_spmf)
846
847lemma weight_spmf_of_pmf [simp]: "weight_spmf (spmf_of_pmf p) = 1"
848using measure_pmf.prob_space[of p] by(simp add: spmf_of_pmf_def weight_spmf_def)
849
850lemma weight_spmf_nonneg: "weight_spmf p \<ge> 0"
851by(fact measure_nonneg)
852
853lemma (in finite_measure) integrable_weight_spmf [simp]:
854  "(\<lambda>x. weight_spmf (f x)) \<in> borel_measurable M \<Longrightarrow> integrable M (\<lambda>x. weight_spmf (f x))"
855by(rule integrable_const_bound[where B=1])(simp_all add: weight_spmf_nonneg weight_spmf_le_1)
856
857lemma weight_spmf_eq_nn_integral_spmf: "weight_spmf p = \<integral>\<^sup>+ x. spmf p x \<partial>count_space UNIV"
858by(simp add: measure_measure_spmf_conv_measure_pmf space_measure_spmf measure_pmf.emeasure_eq_measure[symmetric] nn_integral_pmf[symmetric] embed_measure_count_space[symmetric] inj_on_def nn_integral_embed_measure measurable_embed_measure1)
859
860lemma weight_spmf_eq_nn_integral_support:
861  "weight_spmf p = \<integral>\<^sup>+ x. spmf p x \<partial>count_space (set_spmf p)"
862unfolding weight_spmf_eq_nn_integral_spmf
863by(auto simp add: nn_integral_count_space_indicator in_set_spmf_iff_spmf intro!: nn_integral_cong split: split_indicator)
864
865lemma pmf_None_eq_weight_spmf: "pmf p None = 1 - weight_spmf p"
866proof -
867  have "weight_spmf p = \<integral>\<^sup>+ x. spmf p x \<partial>count_space UNIV" by(rule weight_spmf_eq_nn_integral_spmf)
868  also have "\<dots> = \<integral>\<^sup>+ x. ennreal (pmf p x) * indicator (range Some) x \<partial>count_space UNIV"
869    by(simp add: nn_integral_count_space_indicator[symmetric] embed_measure_count_space[symmetric] nn_integral_embed_measure measurable_embed_measure1)
870  also have "\<dots> + pmf p None = \<integral>\<^sup>+ x. ennreal (pmf p x) * indicator (range Some) x + ennreal (pmf p None) * indicator {None} x \<partial>count_space UNIV"
871    by(subst nn_integral_add)(simp_all add: max_def)
872  also have "\<dots> = \<integral>\<^sup>+ x. pmf p x \<partial>count_space UNIV"
873    by(rule nn_integral_cong)(auto split: split_indicator)
874  also have "\<dots> = 1" by (simp add: nn_integral_pmf)
875  finally show ?thesis by(simp add: ennreal_plus[symmetric] del: ennreal_plus)
876qed
877
878lemma weight_spmf_conv_pmf_None: "weight_spmf p = 1 - pmf p None"
879by(simp add: pmf_None_eq_weight_spmf)
880
881lemma weight_spmf_le_0: "weight_spmf p \<le> 0 \<longleftrightarrow> weight_spmf p = 0"
882by(rule measure_le_0_iff)
883
884lemma weight_spmf_lt_0: "\<not> weight_spmf p < 0"
885by(simp add: not_less weight_spmf_nonneg)
886
887lemma spmf_le_weight: "spmf p x \<le> weight_spmf p"
888proof -
889  have "ennreal (spmf p x) \<le> weight_spmf p"
890    unfolding weight_spmf_eq_nn_integral_spmf by(rule nn_integral_ge_point) simp
891  then show ?thesis by simp
892qed
893
894lemma weight_spmf_eq_0: "weight_spmf p = 0 \<longleftrightarrow> p = return_pmf None"
895by(auto intro!: pmf_eqI simp add: pmf_None_eq_weight_spmf split: split_indicator)(metis not_Some_eq pmf_le_0_iff spmf_le_weight)
896
897lemma weight_bind_spmf: "weight_spmf (x \<bind> f) = lebesgue_integral (measure_spmf x) (weight_spmf \<circ> f)"
898unfolding weight_spmf_def
899by(simp add: measure_spmf_bind o_def measure_spmf.measure_bind[where N="count_space UNIV"])
900
901lemma rel_spmf_weightD: "rel_spmf A p q \<Longrightarrow> weight_spmf p = weight_spmf q"
902by(erule rel_spmfE) simp
903
904lemma rel_spmf_bij_betw:
905  assumes f: "bij_betw f (set_spmf p) (set_spmf q)"
906  and eq: "\<And>x. x \<in> set_spmf p \<Longrightarrow> spmf p x = spmf q (f x)"
907  shows "rel_spmf (\<lambda>x y. f x = y) p q"
908proof -
909  let ?f = "map_option f"
910
911  have weq: "ennreal (weight_spmf p) = ennreal (weight_spmf q)"
912    unfolding weight_spmf_eq_nn_integral_support
913    by(subst nn_integral_bij_count_space[OF f, symmetric])(rule nn_integral_cong_AE, simp add: eq AE_count_space)
914  then have "None \<in> set_pmf p \<longleftrightarrow> None \<in> set_pmf q"
915    by(simp add: pmf_None_eq_weight_spmf set_pmf_iff)
916  with f have "bij_betw (map_option f) (set_pmf p) (set_pmf q)"
917    apply(auto simp add: bij_betw_def in_set_spmf inj_on_def intro: option.expand)
918    apply(rename_tac [!] x)
919    apply(case_tac [!] x)
920    apply(auto iff: in_set_spmf)
921    done
922  then have "rel_pmf (\<lambda>x y. ?f x = y) p q"
923    by(rule rel_pmf_bij_betw)(case_tac x, simp_all add: weq[simplified] eq in_set_spmf pmf_None_eq_weight_spmf)
924  thus ?thesis by(rule pmf.rel_mono_strong)(auto intro!: rel_optionI simp add: Option.is_none_def)
925qed
926
927subsection \<open>From density to spmfs\<close>
928
929context fixes f :: "'a \<Rightarrow> real" begin
930
931definition embed_spmf :: "'a spmf"
932where "embed_spmf = embed_pmf (\<lambda>x. case x of None \<Rightarrow> 1 - enn2real (\<integral>\<^sup>+ x. ennreal (f x) \<partial>count_space UNIV) | Some x' \<Rightarrow> max 0 (f x'))"
933
934context
935  assumes prob: "(\<integral>\<^sup>+ x. ennreal (f x) \<partial>count_space UNIV) \<le> 1"
936begin
937
938lemma nn_integral_embed_spmf_eq_1:
939  "(\<integral>\<^sup>+ x. ennreal (case x of None \<Rightarrow> 1 - enn2real (\<integral>\<^sup>+ x. ennreal (f x) \<partial>count_space UNIV) | Some x' \<Rightarrow> max 0 (f x')) \<partial>count_space UNIV) = 1"
940  (is "?lhs = _" is "(\<integral>\<^sup>+ x. ?f x \<partial>?M) = _")
941proof -
942  have "?lhs = \<integral>\<^sup>+ x. ?f x * indicator {None} x + ?f x * indicator (range Some) x \<partial>?M"
943    by(rule nn_integral_cong)(auto split: split_indicator)
944  also have "\<dots> = (1 - enn2real (\<integral>\<^sup>+ x. ennreal (f x) \<partial>count_space UNIV)) + \<integral>\<^sup>+ x. ?f x * indicator (range Some) x \<partial>?M"
945    (is "_ = ?None + ?Some")
946    by(subst nn_integral_add)(simp_all add: AE_count_space max_def le_diff_eq real_le_ereal_iff one_ereal_def[symmetric] prob split: option.split)
947  also have "?Some = \<integral>\<^sup>+ x. ?f x \<partial>count_space (range Some)"
948    by(simp add: nn_integral_count_space_indicator)
949  also have "count_space (range Some) = embed_measure (count_space UNIV) Some"
950    by(simp add: embed_measure_count_space)
951  also have "(\<integral>\<^sup>+ x. ?f x \<partial>\<dots>) = \<integral>\<^sup>+ x. ennreal (f x) \<partial>count_space UNIV"
952    by(subst nn_integral_embed_measure)(simp_all add: measurable_embed_measure1)
953  also have "?None + \<dots> = 1" using prob
954    by(auto simp add: ennreal_minus[symmetric] ennreal_1[symmetric] ennreal_enn2real_if top_unique simp del: ennreal_1)(simp add: diff_add_self_ennreal)
955  finally show ?thesis .
956qed
957
958lemma pmf_embed_spmf_None: "pmf embed_spmf None = 1 - enn2real (\<integral>\<^sup>+ x. ennreal (f x) \<partial>count_space UNIV)"
959unfolding embed_spmf_def
960apply(subst pmf_embed_pmf)
961  subgoal using prob by(simp add: field_simps enn2real_leI split: option.split)
962 apply(rule nn_integral_embed_spmf_eq_1)
963apply simp
964done
965
966lemma spmf_embed_spmf [simp]: "spmf embed_spmf x = max 0 (f x)"
967unfolding embed_spmf_def
968apply(subst pmf_embed_pmf)
969  subgoal using prob by(simp add: field_simps enn2real_leI split: option.split)
970 apply(rule nn_integral_embed_spmf_eq_1)
971apply simp
972done
973
974end
975
976end
977
978lemma embed_spmf_K_0[simp]: "embed_spmf (\<lambda>_. 0) = return_pmf None" (is "?lhs = ?rhs")
979by(rule spmf_eqI)(simp add: zero_ereal_def[symmetric])
980
981subsection \<open>Ordering on spmfs\<close>
982
983text \<open>
984  \<^const>\<open>rel_pmf\<close> does not preserve a ccpo structure. Counterexample by Saheb-Djahromi:
985  Take prefix order over \<open>bool llist\<close> and
986  the set \<open>range (\<lambda>n :: nat. uniform (llist_n n))\<close> where \<open>llist_n\<close> is the set
987  of all \<open>llist\<close>s of length \<open>n\<close> and \<open>uniform\<close> returns a uniform distribution over
988  the given set. The set forms a chain in \<open>ord_pmf lprefix\<close>, but it has not an upper bound.
989  Any upper bound may contain only infinite lists in its support because otherwise it is not greater
990  than the \<open>n+1\<close>-st element in the chain where \<open>n\<close> is the length of the finite list.
991  Moreover its support must contain all infinite lists, because otherwise there is a finite list
992  all of whose finite extensions are not in the support - a contradiction to the upper bound property.
993  Hence, the support is uncountable, but pmf's only have countable support.
994
995  However, if all chains in the ccpo are finite, then it should preserve the ccpo structure.
996\<close>
997
998abbreviation ord_spmf :: "('a \<Rightarrow> 'a \<Rightarrow> bool) \<Rightarrow> 'a spmf \<Rightarrow> 'a spmf \<Rightarrow> bool"
999where "ord_spmf ord \<equiv> rel_pmf (ord_option ord)"
1000
1001locale ord_spmf_syntax begin
1002notation ord_spmf (infix "\<sqsubseteq>\<index>" 60)
1003end
1004
1005lemma ord_spmf_map_spmf1: "ord_spmf R (map_spmf f p) = ord_spmf (\<lambda>x. R (f x)) p"
1006by(simp add: pmf.rel_map[abs_def] ord_option_map1[abs_def])
1007
1008lemma ord_spmf_map_spmf2: "ord_spmf R p (map_spmf f q) = ord_spmf (\<lambda>x y. R x (f y)) p q"
1009by(simp add: pmf.rel_map ord_option_map2)
1010
1011lemma ord_spmf_map_spmf12: "ord_spmf R (map_spmf f p) (map_spmf f q) = ord_spmf (\<lambda>x y. R (f x) (f y)) p q"
1012by(simp add: pmf.rel_map ord_option_map1[abs_def] ord_option_map2)
1013
1014lemmas ord_spmf_map_spmf = ord_spmf_map_spmf1 ord_spmf_map_spmf2 ord_spmf_map_spmf12
1015
1016context fixes ord :: "'a \<Rightarrow> 'a \<Rightarrow> bool" (structure) begin
1017interpretation ord_spmf_syntax .
1018
1019lemma ord_spmfI:
1020  "\<lbrakk> \<And>x y. (x, y) \<in> set_spmf pq \<Longrightarrow> ord x y; map_spmf fst pq = p; map_spmf snd pq = q \<rbrakk>
1021  \<Longrightarrow> p \<sqsubseteq> q"
1022by(rule rel_pmf.intros[where pq="map_pmf (\<lambda>x. case x of None \<Rightarrow> (None, None) | Some (a, b) \<Rightarrow> (Some a, Some b)) pq"])
1023  (auto simp add: pmf.map_comp o_def in_set_spmf split: option.splits intro: pmf.map_cong)
1024
1025lemma ord_spmf_None [simp]: "return_pmf None \<sqsubseteq> x"
1026by(rule rel_pmf.intros[where pq="map_pmf (Pair None) x"])(auto simp add: pmf.map_comp o_def)
1027
1028lemma ord_spmf_reflI: "(\<And>x. x \<in> set_spmf p \<Longrightarrow> ord x x) \<Longrightarrow> p \<sqsubseteq> p"
1029by(rule rel_pmf_reflI ord_option_reflI)+(auto simp add: in_set_spmf)
1030
1031lemma rel_spmf_inf:
1032  assumes "p \<sqsubseteq> q"
1033  and "q \<sqsubseteq> p"
1034  and refl: "reflp ord"
1035  and trans: "transp ord"
1036  shows "rel_spmf (inf ord ord\<inverse>\<inverse>) p q"
1037proof -
1038  from \<open>p \<sqsubseteq> q\<close> \<open>q \<sqsubseteq> p\<close>
1039  have "rel_pmf (inf (ord_option ord) (ord_option ord)\<inverse>\<inverse>) p q"
1040    by(rule rel_pmf_inf)(blast intro: reflp_ord_option transp_ord_option refl trans)+
1041  also have "inf (ord_option ord) (ord_option ord)\<inverse>\<inverse> = rel_option (inf ord ord\<inverse>\<inverse>)"
1042    by(auto simp add: fun_eq_iff elim: ord_option.cases option.rel_cases)
1043  finally show ?thesis .
1044qed
1045
1046end
1047
1048lemma ord_spmf_return_spmf2: "ord_spmf R p (return_spmf y) \<longleftrightarrow> (\<forall>x\<in>set_spmf p. R x y)"
1049by(auto simp add: rel_pmf_return_pmf2 in_set_spmf ord_option.simps intro: ccontr)
1050
1051lemma ord_spmf_mono: "\<lbrakk> ord_spmf A p q; \<And>x y. A x y \<Longrightarrow> B x y \<rbrakk> \<Longrightarrow> ord_spmf B p q"
1052by(erule rel_pmf_mono)(erule ord_option_mono)
1053
1054lemma ord_spmf_compp: "ord_spmf (A OO B) = ord_spmf A OO ord_spmf B"
1055by(simp add: ord_option_compp pmf.rel_compp)
1056
1057lemma ord_spmf_bindI:
1058  assumes pq: "ord_spmf R p q"
1059  and fg: "\<And>x y. R x y \<Longrightarrow> ord_spmf P (f x) (g y)"
1060  shows "ord_spmf P (p \<bind> f) (q \<bind> g)"
1061unfolding bind_spmf_def using pq
1062by(rule rel_pmf_bindI)(auto split: option.split intro: fg)
1063
1064lemma ord_spmf_bind_reflI:
1065  "(\<And>x. x \<in> set_spmf p \<Longrightarrow> ord_spmf R (f x) (g x))
1066  \<Longrightarrow> ord_spmf R (p \<bind> f) (p \<bind> g)"
1067by(rule ord_spmf_bindI[where R="\<lambda>x y. x = y \<and> x \<in> set_spmf p"])(auto intro: ord_spmf_reflI)
1068
1069lemma ord_pmf_increaseI:
1070  assumes le: "\<And>x. spmf p x \<le> spmf q x"
1071  and refl: "\<And>x. x \<in> set_spmf p \<Longrightarrow> R x x"
1072  shows "ord_spmf R p q"
1073proof(rule rel_pmf.intros)
1074  define pq where "pq = embed_pmf
1075    (\<lambda>(x, y). case x of Some x' \<Rightarrow> (case y of Some y' \<Rightarrow> if x' = y' then spmf p x' else 0 | None \<Rightarrow> 0)
1076      | None \<Rightarrow> (case y of None \<Rightarrow> pmf q None | Some y' \<Rightarrow> spmf q y' - spmf p y'))"
1077     (is "_ = embed_pmf ?f")
1078  have nonneg: "\<And>xy. ?f xy \<ge> 0"
1079    by(clarsimp simp add: le field_simps split: option.split)
1080  have integral: "(\<integral>\<^sup>+ xy. ?f xy \<partial>count_space UNIV) = 1" (is "nn_integral ?M _ = _")
1081  proof -
1082    have "(\<integral>\<^sup>+ xy. ?f xy \<partial>count_space UNIV) =
1083      \<integral>\<^sup>+ xy. ennreal (?f xy) * indicator {(None, None)} xy +
1084             ennreal (?f xy) * indicator (range (\<lambda>x. (None, Some x))) xy +
1085             ennreal (?f xy) * indicator (range (\<lambda>x. (Some x, Some x))) xy \<partial>?M"
1086      by(rule nn_integral_cong)(auto split: split_indicator option.splits if_split_asm)
1087    also have "\<dots> = (\<integral>\<^sup>+ xy. ?f xy * indicator {(None, None)} xy \<partial>?M) +
1088        (\<integral>\<^sup>+ xy. ennreal (?f xy) * indicator (range (\<lambda>x. (None, Some x))) xy \<partial>?M) +
1089        (\<integral>\<^sup>+ xy. ennreal (?f xy) * indicator (range (\<lambda>x. (Some x, Some x))) xy \<partial>?M)"
1090      (is "_ = ?None + ?Some2 + ?Some")
1091      by(subst nn_integral_add)(simp_all add: nn_integral_add AE_count_space le_diff_eq le split: option.split)
1092    also have "?None = pmf q None" by simp
1093    also have "?Some2 = \<integral>\<^sup>+ x. ennreal (spmf q x) - spmf p x \<partial>count_space UNIV"
1094      by(simp add: nn_integral_count_space_indicator[symmetric] embed_measure_count_space[symmetric] inj_on_def nn_integral_embed_measure measurable_embed_measure1 ennreal_minus)
1095    also have "\<dots> = (\<integral>\<^sup>+ x. spmf q x \<partial>count_space UNIV) - (\<integral>\<^sup>+ x. spmf p x \<partial>count_space UNIV)"
1096      (is "_ = ?Some2' - ?Some2''")
1097      by(subst nn_integral_diff)(simp_all add: le nn_integral_spmf_neq_top)
1098    also have "?Some = \<integral>\<^sup>+ x. spmf p x \<partial>count_space UNIV"
1099      by(simp add: nn_integral_count_space_indicator[symmetric] embed_measure_count_space[symmetric] inj_on_def nn_integral_embed_measure measurable_embed_measure1)
1100    also have "pmf q None + (?Some2' - ?Some2'') + \<dots> = pmf q None + ?Some2'"
1101      by(auto simp add: diff_add_self_ennreal le intro!: nn_integral_mono)
1102    also have "\<dots> = \<integral>\<^sup>+ x. ennreal (pmf q x) * indicator {None} x + ennreal (pmf q x) * indicator (range Some) x \<partial>count_space UNIV"
1103      by(subst nn_integral_add)(simp_all add: nn_integral_count_space_indicator[symmetric] embed_measure_count_space[symmetric] nn_integral_embed_measure measurable_embed_measure1)
1104    also have "\<dots> = \<integral>\<^sup>+ x. pmf q x \<partial>count_space UNIV"
1105      by(rule nn_integral_cong)(auto split: split_indicator)
1106    also have "\<dots> = 1" by(simp add: nn_integral_pmf)
1107    finally show ?thesis .
1108  qed
1109  note f = nonneg integral
1110
1111  { fix x y
1112    assume "(x, y) \<in> set_pmf pq"
1113    hence "?f (x, y) \<noteq> 0" unfolding pq_def by(simp add: set_embed_pmf[OF f])
1114    then show "ord_option R x y"
1115      by(simp add: spmf_eq_0_set_spmf refl split: option.split_asm if_split_asm) }
1116
1117  have weight_le: "weight_spmf p \<le> weight_spmf q"
1118    by(subst ennreal_le_iff[symmetric])(auto simp add: weight_spmf_eq_nn_integral_spmf intro!: nn_integral_mono le)
1119
1120  show "map_pmf fst pq = p"
1121  proof(rule pmf_eqI)
1122    fix i
1123    have "ennreal (pmf (map_pmf fst pq) i) = (\<integral>\<^sup>+ y. pmf pq (i, y) \<partial>count_space UNIV)"
1124      unfolding pq_def ennreal_pmf_map
1125      apply(simp add: embed_pmf.rep_eq[OF f] o_def emeasure_density nn_integral_count_space_indicator[symmetric])
1126      apply(subst pmf_embed_pmf[OF f])
1127      apply(rule nn_integral_bij_count_space[symmetric])
1128      apply(auto simp add: bij_betw_def inj_on_def)
1129      done
1130    also have "\<dots> = pmf p i"
1131    proof(cases i)
1132      case (Some x)
1133      have "(\<integral>\<^sup>+ y. pmf pq (Some x, y) \<partial>count_space UNIV) = \<integral>\<^sup>+ y. pmf p (Some x) * indicator {Some x} y \<partial>count_space UNIV"
1134        by(rule nn_integral_cong)(simp add: pq_def pmf_embed_pmf[OF f] split: option.split)
1135      then show ?thesis using Some by simp
1136    next
1137      case None
1138      have "(\<integral>\<^sup>+ y. pmf pq (None, y) \<partial>count_space UNIV) =
1139            (\<integral>\<^sup>+ y. ennreal (pmf pq (None, Some (the y))) * indicator (range Some) y +
1140                   ennreal (pmf pq (None, None)) * indicator {None} y \<partial>count_space UNIV)"
1141        by(rule nn_integral_cong)(auto split: split_indicator)
1142      also have "\<dots> = (\<integral>\<^sup>+ y. ennreal (pmf pq (None, Some (the y))) \<partial>count_space (range Some)) + pmf pq (None, None)"
1143        by(subst nn_integral_add)(simp_all add: nn_integral_count_space_indicator)
1144      also have "\<dots> = (\<integral>\<^sup>+ y. ennreal (spmf q y) - ennreal (spmf p y) \<partial>count_space UNIV) + pmf q None"
1145        by(simp add: pq_def pmf_embed_pmf[OF f] embed_measure_count_space[symmetric] nn_integral_embed_measure measurable_embed_measure1 ennreal_minus)
1146      also have "(\<integral>\<^sup>+ y. ennreal (spmf q y) - ennreal (spmf p y) \<partial>count_space UNIV) =
1147                 (\<integral>\<^sup>+ y. spmf q y \<partial>count_space UNIV) - (\<integral>\<^sup>+ y. spmf p y \<partial>count_space UNIV)"
1148        by(subst nn_integral_diff)(simp_all add: AE_count_space le nn_integral_spmf_neq_top split: split_indicator)
1149      also have "\<dots> = pmf p None - pmf q None"
1150        by(simp add: pmf_None_eq_weight_spmf weight_spmf_eq_nn_integral_spmf[symmetric] ennreal_minus)
1151      also have "\<dots> = ennreal (pmf p None) - ennreal (pmf q None)" by(simp add: ennreal_minus)
1152      finally show ?thesis using None weight_le
1153        by(auto simp add: diff_add_self_ennreal pmf_None_eq_weight_spmf intro: ennreal_leI)
1154    qed
1155    finally show "pmf (map_pmf fst pq) i = pmf p i" by simp
1156  qed
1157
1158  show "map_pmf snd pq = q"
1159  proof(rule pmf_eqI)
1160    fix i
1161    have "ennreal (pmf (map_pmf snd pq) i) = (\<integral>\<^sup>+ x. pmf pq (x, i) \<partial>count_space UNIV)"
1162      unfolding pq_def ennreal_pmf_map
1163      apply(simp add: embed_pmf.rep_eq[OF f] o_def emeasure_density nn_integral_count_space_indicator[symmetric])
1164      apply(subst pmf_embed_pmf[OF f])
1165      apply(rule nn_integral_bij_count_space[symmetric])
1166      apply(auto simp add: bij_betw_def inj_on_def)
1167      done
1168    also have "\<dots> = ennreal (pmf q i)"
1169    proof(cases i)
1170      case None
1171      have "(\<integral>\<^sup>+ x. pmf pq (x, None) \<partial>count_space UNIV) = \<integral>\<^sup>+ x. pmf q None * indicator {None :: 'a option} x \<partial>count_space UNIV"
1172        by(rule nn_integral_cong)(simp add: pq_def pmf_embed_pmf[OF f] split: option.split)
1173      then show ?thesis using None by simp
1174    next
1175      case (Some y)
1176      have "(\<integral>\<^sup>+ x. pmf pq (x, Some y) \<partial>count_space UNIV) =
1177        (\<integral>\<^sup>+ x. ennreal (pmf pq (x, Some y)) * indicator (range Some) x +
1178               ennreal (pmf pq (None, Some y)) * indicator {None} x \<partial>count_space UNIV)"
1179        by(rule nn_integral_cong)(auto split: split_indicator)
1180      also have "\<dots> = (\<integral>\<^sup>+ x. ennreal (pmf pq (x, Some y)) * indicator (range Some) x \<partial>count_space UNIV) + pmf pq (None, Some y)"
1181        by(subst nn_integral_add)(simp_all)
1182      also have "\<dots> = (\<integral>\<^sup>+ x. ennreal (spmf p y) * indicator {Some y} x \<partial>count_space UNIV) + (spmf q y - spmf p y)"
1183        by(auto simp add: pq_def pmf_embed_pmf[OF f] one_ereal_def[symmetric] simp del: nn_integral_indicator_singleton intro!: arg_cong2[where f="(+)"] nn_integral_cong split: option.split)
1184      also have "\<dots> = spmf q y" by(simp add: ennreal_minus[symmetric] le)
1185      finally show ?thesis using Some by simp
1186    qed
1187    finally show "pmf (map_pmf snd pq) i = pmf q i" by simp
1188  qed
1189qed
1190
1191lemma ord_spmf_eq_leD:
1192  assumes "ord_spmf (=) p q"
1193  shows "spmf p x \<le> spmf q x"
1194proof(cases "x \<in> set_spmf p")
1195  case False
1196  thus ?thesis by(simp add: in_set_spmf_iff_spmf)
1197next
1198  case True
1199  from assms obtain pq
1200    where pq: "\<And>x y. (x, y) \<in> set_pmf pq \<Longrightarrow> ord_option (=) x y"
1201    and p: "p = map_pmf fst pq"
1202    and q: "q = map_pmf snd pq" by cases auto
1203  have "ennreal (spmf p x) = integral\<^sup>N pq (indicator (fst -` {Some x}))"
1204    using p by(simp add: ennreal_pmf_map)
1205  also have "\<dots> = integral\<^sup>N pq (indicator {(Some x, Some x)})"
1206    by(rule nn_integral_cong_AE)(auto simp add: AE_measure_pmf_iff split: split_indicator dest: pq)
1207  also have "\<dots> \<le> integral\<^sup>N pq (indicator (snd -` {Some x}))"
1208    by(rule nn_integral_mono) simp
1209  also have "\<dots> = ennreal (spmf q x)" using q by(simp add: ennreal_pmf_map)
1210  finally show ?thesis by simp
1211qed
1212
1213lemma ord_spmf_eqD_set_spmf: "ord_spmf (=) p q \<Longrightarrow> set_spmf p \<subseteq> set_spmf q"
1214by(rule subsetI)(drule_tac x=x in ord_spmf_eq_leD, auto simp add: in_set_spmf_iff_spmf)
1215
1216lemma ord_spmf_eqD_emeasure:
1217  "ord_spmf (=) p q \<Longrightarrow> emeasure (measure_spmf p) A \<le> emeasure (measure_spmf q) A"
1218by(auto intro!: nn_integral_mono split: split_indicator dest: ord_spmf_eq_leD simp add: nn_integral_measure_spmf nn_integral_indicator[symmetric])
1219
1220lemma ord_spmf_eqD_measure_spmf: "ord_spmf (=) p q \<Longrightarrow> measure_spmf p \<le> measure_spmf q"
1221  by (subst le_measure) (auto simp: ord_spmf_eqD_emeasure)
1222
1223subsection \<open>CCPO structure for the flat ccpo \<^term>\<open>ord_option (=)\<close>\<close>
1224
1225context fixes Y :: "'a spmf set" begin
1226
1227definition lub_spmf :: "'a spmf"
1228where "lub_spmf = embed_spmf (\<lambda>x. enn2real (SUP p \<in> Y. ennreal (spmf p x)))"
1229  \<comment> \<open>We go through \<^typ>\<open>ennreal\<close> to have a sensible definition even if \<^term>\<open>Y\<close> is empty.\<close>
1230
1231lemma lub_spmf_empty [simp]: "SPMF.lub_spmf {} = return_pmf None"
1232by(simp add: SPMF.lub_spmf_def bot_ereal_def)
1233
1234context assumes chain: "Complete_Partial_Order.chain (ord_spmf (=)) Y" begin
1235
1236lemma chain_ord_spmf_eqD: "Complete_Partial_Order.chain (\<le>) ((\<lambda>p x. ennreal (spmf p x)) ` Y)"
1237  (is "Complete_Partial_Order.chain _ (?f ` _)")
1238proof(rule chainI)
1239  fix f g
1240  assume "f \<in> ?f ` Y" "g \<in> ?f ` Y"
1241  then obtain p q where f: "f = ?f p" "p \<in> Y" and g: "g = ?f q" "q \<in> Y" by blast
1242  from chain \<open>p \<in> Y\<close> \<open>q \<in> Y\<close> have "ord_spmf (=) p q \<or> ord_spmf (=) q p" by(rule chainD)
1243  thus "f \<le> g \<or> g \<le> f"
1244  proof
1245    assume "ord_spmf (=) p q"
1246    hence "\<And>x. spmf p x \<le> spmf q x" by(rule ord_spmf_eq_leD)
1247    hence "f \<le> g" unfolding f g by(auto intro: le_funI)
1248    thus ?thesis ..
1249  next
1250    assume "ord_spmf (=) q p"
1251    hence "\<And>x. spmf q x \<le> spmf p x" by(rule ord_spmf_eq_leD)
1252    hence "g \<le> f" unfolding f g by(auto intro: le_funI)
1253    thus ?thesis ..
1254  qed
1255qed
1256
1257lemma ord_spmf_eq_pmf_None_eq:
1258  assumes le: "ord_spmf (=) p q"
1259  and None: "pmf p None = pmf q None"
1260  shows "p = q"
1261proof(rule spmf_eqI)
1262  fix i
1263  from le have le': "\<And>x. spmf p x \<le> spmf q x" by(rule ord_spmf_eq_leD)
1264  have "(\<integral>\<^sup>+ x. ennreal (spmf q x) - spmf p x \<partial>count_space UNIV) =
1265     (\<integral>\<^sup>+ x. spmf q x \<partial>count_space UNIV) - (\<integral>\<^sup>+ x. spmf p x \<partial>count_space UNIV)"
1266    by(subst nn_integral_diff)(simp_all add: AE_count_space le' nn_integral_spmf_neq_top)
1267  also have "\<dots> = (1 - pmf q None) - (1 - pmf p None)" unfolding pmf_None_eq_weight_spmf
1268    by(simp add: weight_spmf_eq_nn_integral_spmf[symmetric] ennreal_minus)
1269  also have "\<dots> = 0" using None by simp
1270  finally have "\<And>x. spmf q x \<le> spmf p x"
1271    by(simp add: nn_integral_0_iff_AE AE_count_space ennreal_minus ennreal_eq_0_iff)
1272  with le' show "spmf p i = spmf q i" by(rule antisym)
1273qed
1274
1275lemma ord_spmf_eqD_pmf_None:
1276  assumes "ord_spmf (=) x y"
1277  shows "pmf x None \<ge> pmf y None"
1278using assms
1279apply cases
1280apply(clarsimp simp only: ennreal_le_iff[symmetric, OF pmf_nonneg] ennreal_pmf_map)
1281apply(fastforce simp add: AE_measure_pmf_iff intro!: nn_integral_mono_AE)
1282done
1283
1284text \<open>
1285  Chains on \<^typ>\<open>'a spmf\<close> maintain countable support.
1286  Thanks to Johannes H��lzl for the proof idea.
1287\<close>
1288lemma spmf_chain_countable: "countable (\<Union>p\<in>Y. set_spmf p)"
1289proof(cases "Y = {}")
1290  case Y: False
1291  show ?thesis
1292  proof(cases "\<exists>x\<in>Y. \<forall>y\<in>Y. ord_spmf (=) y x")
1293    case True
1294    then obtain x where x: "x \<in> Y" and upper: "\<And>y. y \<in> Y \<Longrightarrow> ord_spmf (=) y x" by blast
1295    hence "(\<Union>x\<in>Y. set_spmf x) \<subseteq> set_spmf x" by(auto dest: ord_spmf_eqD_set_spmf)
1296    thus ?thesis by(rule countable_subset) simp
1297  next
1298    case False
1299    define N :: "'a option pmf \<Rightarrow> real" where "N p = pmf p None" for p
1300
1301    have N_less_imp_le_spmf: "\<lbrakk> x \<in> Y; y \<in> Y; N y < N x \<rbrakk> \<Longrightarrow> ord_spmf (=) x y" for x y
1302      using chainD[OF chain, of x y] ord_spmf_eqD_pmf_None[of x y] ord_spmf_eqD_pmf_None[of y x]
1303      by (auto simp: N_def)
1304    have N_eq_imp_eq: "\<lbrakk> x \<in> Y; y \<in> Y; N y = N x \<rbrakk> \<Longrightarrow> x = y" for x y
1305      using chainD[OF chain, of x y] by(auto simp add: N_def dest: ord_spmf_eq_pmf_None_eq)
1306
1307    have NC: "N ` Y \<noteq> {}" "bdd_below (N ` Y)"
1308      using \<open>Y \<noteq> {}\<close> by(auto intro!: bdd_belowI[of _ 0] simp: N_def)
1309    have NC_less: "Inf (N ` Y) < N x" if "x \<in> Y" for x unfolding cInf_less_iff[OF NC]
1310    proof(rule ccontr)
1311      assume **: "\<not> (\<exists>y\<in>N ` Y. y < N x)"
1312      { fix y
1313        assume "y \<in> Y"
1314        with ** consider "N x < N y" | "N x = N y" by(auto simp add: not_less le_less)
1315        hence "ord_spmf (=) y x" using \<open>y \<in> Y\<close> \<open>x \<in> Y\<close>
1316          by cases(auto dest: N_less_imp_le_spmf N_eq_imp_eq intro: ord_spmf_reflI) }
1317      with False \<open>x \<in> Y\<close> show False by blast
1318    qed
1319
1320    from NC have "Inf (N ` Y) \<in> closure (N ` Y)" by (intro closure_contains_Inf)
1321    then obtain X' where "\<And>n. X' n \<in> N ` Y" and X': "X' \<longlonglongrightarrow> Inf (N ` Y)"
1322      unfolding closure_sequential by auto
1323    then obtain X where X: "\<And>n. X n \<in> Y" and "X' = (\<lambda>n. N (X n))" unfolding image_iff Bex_def by metis
1324
1325    with X' have seq: "(\<lambda>n. N (X n)) \<longlonglongrightarrow> Inf (N ` Y)" by simp
1326    have "(\<Union>x \<in> Y. set_spmf x) \<subseteq> (\<Union>n. set_spmf (X n))"
1327    proof(rule UN_least)
1328      fix x
1329      assume "x \<in> Y"
1330      from order_tendstoD(2)[OF seq NC_less[OF \<open>x \<in> Y\<close>]]
1331      obtain i where "N (X i) < N x" by (auto simp: eventually_sequentially)
1332      thus "set_spmf x \<subseteq> (\<Union>n. set_spmf (X n))" using X \<open>x \<in> Y\<close>
1333        by(blast dest: N_less_imp_le_spmf ord_spmf_eqD_set_spmf)
1334    qed
1335    thus ?thesis by(rule countable_subset) simp
1336  qed
1337qed simp
1338
1339lemma lub_spmf_subprob: "(\<integral>\<^sup>+ x. (SUP p \<in> Y. ennreal (spmf p x)) \<partial>count_space UNIV) \<le> 1"
1340proof(cases "Y = {}")
1341  case True
1342  thus ?thesis by(simp add: bot_ennreal)
1343next
1344  case False
1345  let ?B = "\<Union>p\<in>Y. set_spmf p"
1346  have countable: "countable ?B" by(rule spmf_chain_countable)
1347
1348  have "(\<integral>\<^sup>+ x. (SUP p\<in>Y. ennreal (spmf p x)) \<partial>count_space UNIV) =
1349        (\<integral>\<^sup>+ x. (SUP p\<in>Y. ennreal (spmf p x) * indicator ?B x) \<partial>count_space UNIV)"
1350    by (intro nn_integral_cong arg_cong [of _ _ Sup]) (auto split: split_indicator simp add: spmf_eq_0_set_spmf)
1351  also have "\<dots> = (\<integral>\<^sup>+ x. (SUP p\<in>Y. ennreal (spmf p x)) \<partial>count_space ?B)"
1352    unfolding ennreal_indicator[symmetric] using False
1353    by(subst SUP_mult_right_ennreal[symmetric])(simp add: ennreal_indicator nn_integral_count_space_indicator)
1354  also have "\<dots> = (SUP p\<in>Y. \<integral>\<^sup>+ x. spmf p x \<partial>count_space ?B)" using False _ countable
1355    by(rule nn_integral_monotone_convergence_SUP_countable)(rule chain_ord_spmf_eqD)
1356  also have "\<dots> \<le> 1"
1357  proof(rule SUP_least)
1358    fix p
1359    assume "p \<in> Y"
1360    have "(\<integral>\<^sup>+ x. spmf p x \<partial>count_space ?B) = \<integral>\<^sup>+ x. ennreal (spmf p x) * indicator ?B x \<partial>count_space UNIV"
1361      by(simp add: nn_integral_count_space_indicator)
1362    also have "\<dots> = \<integral>\<^sup>+ x. spmf p x \<partial>count_space UNIV"
1363      by(rule nn_integral_cong)(auto split: split_indicator simp add: spmf_eq_0_set_spmf \<open>p \<in> Y\<close>)
1364    also have "\<dots> \<le> 1"
1365      by(simp add: weight_spmf_eq_nn_integral_spmf[symmetric] weight_spmf_le_1)
1366    finally show "(\<integral>\<^sup>+ x. spmf p x \<partial>count_space ?B) \<le> 1" .
1367  qed
1368  finally show ?thesis .
1369qed
1370
1371lemma spmf_lub_spmf:
1372  assumes "Y \<noteq> {}"
1373  shows "spmf lub_spmf x = (SUP p \<in> Y. spmf p x)"
1374proof -
1375  from assms obtain p where "p \<in> Y" by auto
1376  have "spmf lub_spmf x = max 0 (enn2real (SUP p\<in>Y. ennreal (spmf p x)))" unfolding lub_spmf_def
1377    by(rule spmf_embed_spmf)(simp del: SUP_eq_top_iff Sup_eq_top_iff add: ennreal_enn2real_if SUP_spmf_neq_top' lub_spmf_subprob)
1378  also have "\<dots> = enn2real (SUP p\<in>Y. ennreal (spmf p x))"
1379    by(rule max_absorb2)(simp)
1380  also have "\<dots> = enn2real (ennreal (SUP p \<in> Y. spmf p x))" using assms
1381    by(subst ennreal_SUP[symmetric])(simp_all add: SUP_spmf_neq_top' del: SUP_eq_top_iff Sup_eq_top_iff)
1382  also have "0 \<le> (\<Squnion>p\<in>Y. spmf p x)" using assms
1383    by(auto intro!: cSUP_upper2 bdd_aboveI[where M=1] simp add: pmf_le_1)
1384  then have "enn2real (ennreal (SUP p \<in> Y. spmf p x)) = (SUP p \<in> Y. spmf p x)"
1385    by(rule enn2real_ennreal)
1386  finally show ?thesis .
1387qed
1388
1389lemma ennreal_spmf_lub_spmf: "Y \<noteq> {} \<Longrightarrow> ennreal (spmf lub_spmf x) = (SUP p\<in>Y. ennreal (spmf p x))"
1390unfolding spmf_lub_spmf by(subst ennreal_SUP)(simp_all add: SUP_spmf_neq_top' del: SUP_eq_top_iff Sup_eq_top_iff)
1391
1392lemma lub_spmf_upper:
1393  assumes p: "p \<in> Y"
1394  shows "ord_spmf (=) p lub_spmf"
1395proof(rule ord_pmf_increaseI)
1396  fix x
1397  from p have [simp]: "Y \<noteq> {}" by auto
1398  from p have "ennreal (spmf p x) \<le> (SUP p\<in>Y. ennreal (spmf p x))" by(rule SUP_upper)
1399  also have "\<dots> = ennreal (spmf lub_spmf x)" using p
1400    by(subst spmf_lub_spmf)(auto simp add: ennreal_SUP SUP_spmf_neq_top' simp del: SUP_eq_top_iff Sup_eq_top_iff)
1401  finally show "spmf p x \<le> spmf lub_spmf x" by simp
1402qed simp
1403
1404lemma lub_spmf_least:
1405  assumes z: "\<And>x. x \<in> Y \<Longrightarrow> ord_spmf (=) x z"
1406  shows "ord_spmf (=) lub_spmf z"
1407proof(cases "Y = {}")
1408  case nonempty: False
1409  show ?thesis
1410  proof(rule ord_pmf_increaseI)
1411    fix x
1412    from nonempty obtain p where p: "p \<in> Y" by auto
1413    have "ennreal (spmf lub_spmf x) = (SUP p\<in>Y. ennreal (spmf p x))"
1414      by(subst spmf_lub_spmf)(auto simp add: ennreal_SUP SUP_spmf_neq_top' nonempty simp del: SUP_eq_top_iff Sup_eq_top_iff)
1415    also have "\<dots> \<le> ennreal (spmf z x)" by(rule SUP_least)(simp add: ord_spmf_eq_leD z)
1416    finally show "spmf lub_spmf x \<le> spmf z x" by simp
1417  qed simp
1418qed simp
1419
1420lemma set_lub_spmf: "set_spmf lub_spmf = (\<Union>p\<in>Y. set_spmf p)" (is "?lhs = ?rhs")
1421proof(cases "Y = {}")
1422  case [simp]: False
1423  show ?thesis
1424  proof(rule set_eqI)
1425    fix x
1426    have "x \<in> ?lhs \<longleftrightarrow> ennreal (spmf lub_spmf x) > 0"
1427      by(simp_all add: in_set_spmf_iff_spmf less_le)
1428    also have "\<dots> \<longleftrightarrow> (\<exists>p\<in>Y. ennreal (spmf p x) > 0)"
1429      by(simp add: ennreal_spmf_lub_spmf less_SUP_iff)
1430    also have "\<dots> \<longleftrightarrow> x \<in> ?rhs"
1431      by(auto simp add: in_set_spmf_iff_spmf less_le)
1432    finally show "x \<in> ?lhs \<longleftrightarrow> x \<in> ?rhs" .
1433  qed
1434qed simp
1435
1436lemma emeasure_lub_spmf:
1437  assumes Y: "Y \<noteq> {}"
1438  shows "emeasure (measure_spmf lub_spmf) A = (SUP y\<in>Y. emeasure (measure_spmf y) A)"
1439  (is "?lhs = ?rhs")
1440proof -
1441  let ?M = "count_space (set_spmf lub_spmf)"
1442  have "?lhs = \<integral>\<^sup>+ x. ennreal (spmf lub_spmf x) * indicator A x \<partial>?M"
1443    by(auto simp add: nn_integral_indicator[symmetric] nn_integral_measure_spmf')
1444  also have "\<dots> = \<integral>\<^sup>+ x. (SUP y\<in>Y. ennreal (spmf y x) * indicator A x) \<partial>?M"
1445    unfolding ennreal_indicator[symmetric]
1446    by(simp add: spmf_lub_spmf assms ennreal_SUP[OF SUP_spmf_neq_top'] SUP_mult_right_ennreal)
1447  also from assms have "\<dots> = (SUP y\<in>Y. \<integral>\<^sup>+ x. ennreal (spmf y x) * indicator A x \<partial>?M)"
1448  proof(rule nn_integral_monotone_convergence_SUP_countable)
1449    have "(\<lambda>i x. ennreal (spmf i x) * indicator A x) ` Y = (\<lambda>f x. f x * indicator A x) ` (\<lambda>p x. ennreal (spmf p x)) ` Y"
1450      by(simp add: image_image)
1451    also have "Complete_Partial_Order.chain (\<le>) \<dots>" using chain_ord_spmf_eqD
1452      by(rule chain_imageI)(auto simp add: le_fun_def split: split_indicator)
1453    finally show "Complete_Partial_Order.chain (\<le>) ((\<lambda>i x. ennreal (spmf i x) * indicator A x) ` Y)" .
1454  qed simp
1455  also have "\<dots> = (SUP y\<in>Y. \<integral>\<^sup>+ x. ennreal (spmf y x) * indicator A x \<partial>count_space UNIV)"
1456    by(auto simp add: nn_integral_count_space_indicator set_lub_spmf spmf_eq_0_set_spmf split: split_indicator intro!: arg_cong [of _ _ Sup] image_cong nn_integral_cong)
1457  also have "\<dots> = ?rhs"
1458    by(auto simp add: nn_integral_indicator[symmetric] nn_integral_measure_spmf)
1459  finally show ?thesis .
1460qed
1461
1462lemma measure_lub_spmf:
1463  assumes Y: "Y \<noteq> {}"
1464  shows "measure (measure_spmf lub_spmf) A = (SUP y\<in>Y. measure (measure_spmf y) A)" (is "?lhs = ?rhs")
1465proof -
1466  have "ennreal ?lhs = ennreal ?rhs"
1467    using emeasure_lub_spmf[OF assms] SUP_emeasure_spmf_neq_top[of A Y] Y
1468    unfolding measure_spmf.emeasure_eq_measure by(subst ennreal_SUP)
1469  moreover have "0 \<le> ?rhs" using Y
1470    by(auto intro!: cSUP_upper2 bdd_aboveI[where M=1] measure_spmf.subprob_measure_le_1)
1471  ultimately show ?thesis by(simp)
1472qed
1473
1474lemma weight_lub_spmf:
1475  assumes Y: "Y \<noteq> {}"
1476  shows "weight_spmf lub_spmf = (SUP y\<in>Y. weight_spmf y)"
1477unfolding weight_spmf_def by(rule measure_lub_spmf) fact
1478
1479lemma measure_spmf_lub_spmf:
1480  assumes Y: "Y \<noteq> {}"
1481  shows "measure_spmf lub_spmf = (SUP p\<in>Y. measure_spmf p)" (is "?lhs = ?rhs")
1482proof(rule measure_eqI)
1483  from assms obtain p where p: "p \<in> Y" by auto
1484  from chain have chain': "Complete_Partial_Order.chain (\<le>) (measure_spmf ` Y)"
1485    by(rule chain_imageI)(rule ord_spmf_eqD_measure_spmf)
1486  show "sets ?lhs = sets ?rhs"
1487    using Y by (subst sets_SUP) auto
1488  show "emeasure ?lhs A = emeasure ?rhs A" for A
1489    using chain' Y p by (subst emeasure_SUP_chain) (auto simp:  emeasure_lub_spmf)
1490qed
1491
1492end
1493
1494end
1495
1496lemma partial_function_definitions_spmf: "partial_function_definitions (ord_spmf (=)) lub_spmf"
1497  (is "partial_function_definitions ?R _")
1498proof
1499  fix x show "?R x x" by(simp add: ord_spmf_reflI)
1500next
1501  fix x y z
1502  assume "?R x y" "?R y z"
1503  with transp_ord_option[OF transp_equality] show "?R x z" by(rule transp_rel_pmf[THEN transpD])
1504next
1505  fix x y
1506  assume "?R x y" "?R y x"
1507  thus "x = y"
1508    by(rule rel_pmf_antisym)(simp_all add: reflp_ord_option transp_ord_option antisymp_ord_option)
1509next
1510  fix Y x
1511  assume "Complete_Partial_Order.chain ?R Y" "x \<in> Y"
1512  then show "?R x (lub_spmf Y)"
1513    by(rule lub_spmf_upper)
1514next
1515  fix Y z
1516  assume "Complete_Partial_Order.chain ?R Y" "\<And>x. x \<in> Y \<Longrightarrow> ?R x z"
1517  then show "?R (lub_spmf Y) z"
1518    by(cases "Y = {}")(simp_all add: lub_spmf_least)
1519qed
1520
1521lemma ccpo_spmf: "class.ccpo lub_spmf (ord_spmf (=)) (mk_less (ord_spmf (=)))"
1522by(rule ccpo partial_function_definitions_spmf)+
1523
1524interpretation spmf: partial_function_definitions "ord_spmf (=)" "lub_spmf"
1525  rewrites "lub_spmf {} \<equiv> return_pmf None"
1526by(rule partial_function_definitions_spmf) simp
1527
1528declaration \<open>Partial_Function.init "spmf" \<^term>\<open>spmf.fixp_fun\<close>
1529  \<^term>\<open>spmf.mono_body\<close> @{thm spmf.fixp_rule_uc} @{thm spmf.fixp_induct_uc}
1530  NONE\<close>
1531
1532declare spmf.leq_refl[simp]
1533declare admissible_leI[OF ccpo_spmf, cont_intro]
1534
1535abbreviation "mono_spmf \<equiv> monotone (fun_ord (ord_spmf (=))) (ord_spmf (=))"
1536
1537lemma lub_spmf_const [simp]: "lub_spmf {p} = p"
1538by(rule spmf_eqI)(simp add: spmf_lub_spmf[OF ccpo.chain_singleton[OF ccpo_spmf]])
1539
1540lemma bind_spmf_mono':
1541  assumes fg: "ord_spmf (=) f g"
1542  and hk: "\<And>x :: 'a. ord_spmf (=) (h x) (k x)"
1543  shows "ord_spmf (=) (f \<bind> h) (g \<bind> k)"
1544unfolding bind_spmf_def using assms(1)
1545by(rule rel_pmf_bindI)(auto split: option.split simp add: hk)
1546
1547lemma bind_spmf_mono [partial_function_mono]:
1548  assumes mf: "mono_spmf B" and mg: "\<And>y. mono_spmf (\<lambda>f. C y f)"
1549  shows "mono_spmf (\<lambda>f. bind_spmf (B f) (\<lambda>y. C y f))"
1550proof (rule monotoneI)
1551  fix f g :: "'a \<Rightarrow> 'b spmf"
1552  assume fg: "fun_ord (ord_spmf (=)) f g"
1553  with mf have "ord_spmf (=) (B f) (B g)" by (rule monotoneD[of _ _ _ f g])
1554  moreover from mg have "\<And>y'. ord_spmf (=) (C y' f) (C y' g)"
1555    by (rule monotoneD) (rule fg)
1556  ultimately show "ord_spmf (=) (bind_spmf (B f) (\<lambda>y. C y f)) (bind_spmf (B g) (\<lambda>y'. C y' g))"
1557    by(rule bind_spmf_mono')
1558qed
1559
1560lemma monotone_bind_spmf1: "monotone (ord_spmf (=)) (ord_spmf (=)) (\<lambda>y. bind_spmf y g)"
1561by(rule monotoneI)(simp add: bind_spmf_mono' ord_spmf_reflI)
1562
1563lemma monotone_bind_spmf2:
1564  assumes g: "\<And>x. monotone ord (ord_spmf (=)) (\<lambda>y. g y x)"
1565  shows "monotone ord (ord_spmf (=)) (\<lambda>y. bind_spmf p (g y))"
1566by(rule monotoneI)(auto intro: bind_spmf_mono' monotoneD[OF g] ord_spmf_reflI)
1567
1568lemma bind_lub_spmf:
1569  assumes chain: "Complete_Partial_Order.chain (ord_spmf (=)) Y"
1570  shows "bind_spmf (lub_spmf Y) f = lub_spmf ((\<lambda>p. bind_spmf p f) ` Y)" (is "?lhs = ?rhs")
1571proof(cases "Y = {}")
1572  case Y: False
1573  show ?thesis
1574  proof(rule spmf_eqI)
1575    fix i
1576    have chain': "Complete_Partial_Order.chain (\<le>) ((\<lambda>p x. ennreal (spmf p x * spmf (f x) i)) ` Y)"
1577      using chain by(rule chain_imageI)(auto simp add: le_fun_def dest: ord_spmf_eq_leD intro: mult_right_mono)
1578    have chain'': "Complete_Partial_Order.chain (ord_spmf (=)) ((\<lambda>p. p \<bind> f) ` Y)"
1579      using chain by(rule chain_imageI)(auto intro!: monotoneI bind_spmf_mono' ord_spmf_reflI)
1580    let ?M = "count_space (set_spmf (lub_spmf Y))"
1581    have "ennreal (spmf ?lhs i) = \<integral>\<^sup>+ x. ennreal (spmf (lub_spmf Y) x) * ennreal (spmf (f x) i) \<partial>?M"
1582      by(auto simp add: ennreal_spmf_lub_spmf ennreal_spmf_bind nn_integral_measure_spmf')
1583    also have "\<dots> = \<integral>\<^sup>+ x. (SUP p\<in>Y. ennreal (spmf p x * spmf (f x) i)) \<partial>?M"
1584      by(subst ennreal_spmf_lub_spmf[OF chain Y])(subst SUP_mult_right_ennreal, simp_all add: ennreal_mult Y)
1585    also have "\<dots> = (SUP p\<in>Y. \<integral>\<^sup>+ x. ennreal (spmf p x * spmf (f x) i) \<partial>?M)"
1586      using Y chain' by(rule nn_integral_monotone_convergence_SUP_countable) simp
1587    also have "\<dots> = (SUP p\<in>Y. ennreal (spmf (bind_spmf p f) i))"
1588      by(auto simp add: ennreal_spmf_bind nn_integral_measure_spmf nn_integral_count_space_indicator set_lub_spmf[OF chain] in_set_spmf_iff_spmf ennreal_mult intro!: arg_cong [of _ _ Sup] image_cong nn_integral_cong split: split_indicator)
1589    also have "\<dots> = ennreal (spmf ?rhs i)" using chain'' by(simp add: ennreal_spmf_lub_spmf Y image_comp)
1590    finally show "spmf ?lhs i = spmf ?rhs i" by simp
1591  qed
1592qed simp
1593
1594lemma map_lub_spmf:
1595  "Complete_Partial_Order.chain (ord_spmf (=)) Y
1596  \<Longrightarrow> map_spmf f (lub_spmf Y) = lub_spmf (map_spmf f ` Y)"
1597unfolding map_spmf_conv_bind_spmf[abs_def] by(simp add: bind_lub_spmf o_def)
1598
1599lemma mcont_bind_spmf1: "mcont lub_spmf (ord_spmf (=)) lub_spmf (ord_spmf (=)) (\<lambda>y. bind_spmf y f)"
1600using monotone_bind_spmf1 by(rule mcontI)(rule contI, simp add: bind_lub_spmf)
1601
1602lemma bind_lub_spmf2:
1603  assumes chain: "Complete_Partial_Order.chain ord Y"
1604  and g: "\<And>y. monotone ord (ord_spmf (=)) (g y)"
1605  shows "bind_spmf x (\<lambda>y. lub_spmf (g y ` Y)) = lub_spmf ((\<lambda>p. bind_spmf x (\<lambda>y. g y p)) ` Y)"
1606  (is "?lhs = ?rhs")
1607proof(cases "Y = {}")
1608  case Y: False
1609  show ?thesis
1610  proof(rule spmf_eqI)
1611    fix i
1612    have chain': "\<And>y. Complete_Partial_Order.chain (ord_spmf (=)) (g y ` Y)"
1613      using chain g[THEN monotoneD] by(rule chain_imageI)
1614    have chain'': "Complete_Partial_Order.chain (\<le>) ((\<lambda>p y. ennreal (spmf x y * spmf (g y p) i)) ` Y)"
1615      using chain by(rule chain_imageI)(auto simp add: le_fun_def dest: ord_spmf_eq_leD monotoneD[OF g] intro!: mult_left_mono)
1616    have chain''': "Complete_Partial_Order.chain (ord_spmf (=)) ((\<lambda>p. bind_spmf x (\<lambda>y. g y p)) ` Y)"
1617      using chain by(rule chain_imageI)(rule monotone_bind_spmf2[OF g, THEN monotoneD])
1618
1619    have "ennreal (spmf ?lhs i) = \<integral>\<^sup>+ y. (SUP p\<in>Y. ennreal (spmf x y * spmf (g y p) i)) \<partial>count_space (set_spmf x)"
1620      by(simp add: ennreal_spmf_bind ennreal_spmf_lub_spmf[OF chain'] Y nn_integral_measure_spmf' SUP_mult_left_ennreal ennreal_mult image_comp)
1621    also have "\<dots> = (SUP p\<in>Y. \<integral>\<^sup>+ y. ennreal (spmf x y * spmf (g y p) i) \<partial>count_space (set_spmf x))"
1622      unfolding nn_integral_measure_spmf' using Y chain''
1623      by(rule nn_integral_monotone_convergence_SUP_countable) simp
1624    also have "\<dots> = (SUP p\<in>Y. ennreal (spmf (bind_spmf x (\<lambda>y. g y p)) i))"
1625      by(simp add: ennreal_spmf_bind nn_integral_measure_spmf' ennreal_mult)
1626    also have "\<dots> = ennreal (spmf ?rhs i)" using chain'''
1627      by(auto simp add: ennreal_spmf_lub_spmf Y image_comp)
1628    finally show "spmf ?lhs i = spmf ?rhs i" by simp
1629  qed
1630qed simp
1631
1632lemma mcont_bind_spmf [cont_intro]:
1633  assumes f: "mcont luba orda lub_spmf (ord_spmf (=)) f"
1634  and g: "\<And>y. mcont luba orda lub_spmf (ord_spmf (=)) (g y)"
1635  shows "mcont luba orda lub_spmf (ord_spmf (=)) (\<lambda>x. bind_spmf (f x) (\<lambda>y. g y x))"
1636proof(rule spmf.mcont2mcont'[OF _ _ f])
1637  fix z
1638  show "mcont lub_spmf (ord_spmf (=)) lub_spmf (ord_spmf (=)) (\<lambda>x. bind_spmf x (\<lambda>y. g y z))"
1639    by(rule mcont_bind_spmf1)
1640next
1641  fix x
1642  let ?f = "\<lambda>z. bind_spmf x (\<lambda>y. g y z)"
1643  have "monotone orda (ord_spmf (=)) ?f" using mcont_mono[OF g] by(rule monotone_bind_spmf2)
1644  moreover have "cont luba orda lub_spmf (ord_spmf (=)) ?f"
1645  proof(rule contI)
1646    fix Y
1647    assume chain: "Complete_Partial_Order.chain orda Y" and Y: "Y \<noteq> {}"
1648    have "bind_spmf x (\<lambda>y. g y (luba Y)) = bind_spmf x (\<lambda>y. lub_spmf (g y ` Y))"
1649      by(rule bind_spmf_cong)(simp_all add: mcont_contD[OF g chain Y])
1650    also have "\<dots> = lub_spmf ((\<lambda>p. x \<bind> (\<lambda>y. g y p)) ` Y)" using chain
1651      by(rule bind_lub_spmf2)(rule mcont_mono[OF g])
1652    finally show "bind_spmf x (\<lambda>y. g y (luba Y)) = \<dots>" .
1653  qed
1654  ultimately show "mcont luba orda lub_spmf (ord_spmf (=)) ?f" by(rule mcontI)
1655qed
1656
1657lemma bind_pmf_mono [partial_function_mono]:
1658  "(\<And>y. mono_spmf (\<lambda>f. C y f)) \<Longrightarrow> mono_spmf (\<lambda>f. bind_pmf p (\<lambda>x. C x f))"
1659using bind_spmf_mono[of "\<lambda>_. spmf_of_pmf p" C] by simp
1660
1661lemma map_spmf_mono [partial_function_mono]: "mono_spmf B \<Longrightarrow> mono_spmf (\<lambda>g. map_spmf f (B g))"
1662unfolding map_spmf_conv_bind_spmf by(rule bind_spmf_mono) simp_all
1663
1664lemma mcont_map_spmf [cont_intro]:
1665  "mcont luba orda lub_spmf (ord_spmf (=)) g
1666  \<Longrightarrow> mcont luba orda lub_spmf (ord_spmf (=)) (\<lambda>x. map_spmf f (g x))"
1667unfolding map_spmf_conv_bind_spmf by(rule mcont_bind_spmf) simp_all
1668
1669lemma monotone_set_spmf: "monotone (ord_spmf (=)) (\<subseteq>) set_spmf"
1670by(rule monotoneI)(rule ord_spmf_eqD_set_spmf)
1671
1672lemma cont_set_spmf: "cont lub_spmf (ord_spmf (=)) Union (\<subseteq>) set_spmf"
1673by(rule contI)(subst set_lub_spmf; simp)
1674
1675lemma mcont2mcont_set_spmf[THEN mcont2mcont, cont_intro]:
1676  shows mcont_set_spmf: "mcont lub_spmf (ord_spmf (=)) Union (\<subseteq>) set_spmf"
1677by(rule mcontI monotone_set_spmf cont_set_spmf)+
1678
1679lemma monotone_spmf: "monotone (ord_spmf (=)) (\<le>) (\<lambda>p. spmf p x)"
1680by(rule monotoneI)(simp add: ord_spmf_eq_leD)
1681
1682lemma cont_spmf: "cont lub_spmf (ord_spmf (=)) Sup (\<le>) (\<lambda>p. spmf p x)"
1683by(rule contI)(simp add: spmf_lub_spmf)
1684
1685lemma mcont_spmf: "mcont lub_spmf (ord_spmf (=)) Sup (\<le>) (\<lambda>p. spmf p x)"
1686by(rule mcontI monotone_spmf cont_spmf)+
1687
1688lemma cont_ennreal_spmf: "cont lub_spmf (ord_spmf (=)) Sup (\<le>) (\<lambda>p. ennreal (spmf p x))"
1689by(rule contI)(simp add: ennreal_spmf_lub_spmf)
1690
1691lemma mcont2mcont_ennreal_spmf [THEN mcont2mcont, cont_intro]:
1692  shows mcont_ennreal_spmf: "mcont lub_spmf (ord_spmf (=)) Sup (\<le>) (\<lambda>p. ennreal (spmf p x))"
1693by(rule mcontI mono2mono_ennreal monotone_spmf cont_ennreal_spmf)+
1694
1695lemma nn_integral_map_spmf [simp]: "nn_integral (measure_spmf (map_spmf f p)) g = nn_integral (measure_spmf p) (g \<circ> f)"
1696by(auto 4 3 simp add: measure_spmf_def nn_integral_distr nn_integral_restrict_space intro: nn_integral_cong split: split_indicator)
1697
1698subsubsection \<open>Admissibility of \<^term>\<open>rel_spmf\<close>\<close>
1699
1700lemma rel_spmf_measureD:
1701  assumes "rel_spmf R p q"
1702  shows "measure (measure_spmf p) A \<le> measure (measure_spmf q) {y. \<exists>x\<in>A. R x y}" (is "?lhs \<le> ?rhs")
1703proof -
1704  have "?lhs = measure (measure_pmf p) (Some ` A)" by(simp add: measure_measure_spmf_conv_measure_pmf)
1705  also have "\<dots> \<le> measure (measure_pmf q) {y. \<exists>x\<in>Some ` A. rel_option R x y}"
1706    using assms by(rule rel_pmf_measureD)
1707  also have "\<dots> = ?rhs" unfolding measure_measure_spmf_conv_measure_pmf
1708    by(rule arg_cong2[where f=measure])(auto simp add: option_rel_Some1)
1709  finally show ?thesis .
1710qed
1711
1712locale rel_spmf_characterisation =
1713  assumes rel_pmf_measureI:
1714    "\<And>(R :: 'a option \<Rightarrow> 'b option \<Rightarrow> bool) p q.
1715    (\<And>A. measure (measure_pmf p) A \<le> measure (measure_pmf q) {y. \<exists>x\<in>A. R x y})
1716    \<Longrightarrow> rel_pmf R p q"
1717  \<comment> \<open>This assumption is shown to hold in general in the AFP entry \<open>MFMC_Countable\<close>.\<close>
1718begin
1719
1720context fixes R :: "'a \<Rightarrow> 'b \<Rightarrow> bool" begin
1721
1722lemma rel_spmf_measureI:
1723  assumes eq1: "\<And>A. measure (measure_spmf p) A \<le> measure (measure_spmf q) {y. \<exists>x\<in>A. R x y}"
1724  assumes eq2: "weight_spmf q \<le> weight_spmf p"
1725  shows "rel_spmf R p q"
1726proof(rule rel_pmf_measureI)
1727  fix A :: "'a option set"
1728  define A' where "A' = the ` (A \<inter> range Some)"
1729  define A'' where "A'' = A \<inter> {None}"
1730  have A: "A = Some ` A' \<union> A''" "Some ` A' \<inter> A'' = {}"
1731    unfolding A'_def A''_def by(auto 4 3 intro: rev_image_eqI)
1732  have "measure (measure_pmf p) A = measure (measure_pmf p) (Some ` A') + measure (measure_pmf p) A''"
1733    by(simp add: A measure_pmf.finite_measure_Union)
1734  also have "measure (measure_pmf p) (Some ` A') = measure (measure_spmf p) A'"
1735    by(simp add: measure_measure_spmf_conv_measure_pmf)
1736  also have "\<dots> \<le> measure (measure_spmf q) {y. \<exists>x\<in>A'. R x y}" by(rule eq1)
1737  also (ord_eq_le_trans[OF _ add_right_mono])
1738  have "\<dots> = measure (measure_pmf q) {y. \<exists>x\<in>A'. rel_option R (Some x) y}"
1739    unfolding measure_measure_spmf_conv_measure_pmf
1740    by(rule arg_cong2[where f=measure])(auto simp add: A'_def option_rel_Some1)
1741  also
1742  { have "weight_spmf p \<le> measure (measure_spmf q) {y. \<exists>x. R x y}"
1743      using eq1[of UNIV] unfolding weight_spmf_def by simp
1744    also have "\<dots> \<le> weight_spmf q" unfolding weight_spmf_def
1745      by(rule measure_spmf.finite_measure_mono) simp_all
1746    finally have "weight_spmf p = weight_spmf q" using eq2 by simp }
1747  then have "measure (measure_pmf p) A'' = measure (measure_pmf q) (if None \<in> A then {None} else {})"
1748    unfolding A''_def by(simp add: pmf_None_eq_weight_spmf measure_pmf_single)
1749  also have "measure (measure_pmf q) {y. \<exists>x\<in>A'. rel_option R (Some x) y} + \<dots> = measure (measure_pmf q) {y. \<exists>x\<in>A. rel_option R x y}"
1750    by(subst measure_pmf.finite_measure_Union[symmetric])
1751      (auto 4 3 intro!: arg_cong2[where f=measure] simp add: option_rel_Some1 option_rel_Some2 A'_def intro: rev_bexI elim: option.rel_cases)
1752  finally show "measure (measure_pmf p) A \<le> \<dots>" .
1753qed
1754
1755lemma admissible_rel_spmf:
1756  "ccpo.admissible (prod_lub lub_spmf lub_spmf) (rel_prod (ord_spmf (=)) (ord_spmf (=))) (case_prod (rel_spmf R))"
1757  (is "ccpo.admissible ?lub ?ord ?P")
1758proof(rule ccpo.admissibleI)
1759  fix Y
1760  assume chain: "Complete_Partial_Order.chain ?ord Y"
1761    and Y: "Y \<noteq> {}"
1762    and R: "\<forall>(p, q) \<in> Y. rel_spmf R p q"
1763  from R have R: "\<And>p q. (p, q) \<in> Y \<Longrightarrow> rel_spmf R p q" by auto
1764  have chain1: "Complete_Partial_Order.chain (ord_spmf (=)) (fst ` Y)"
1765    and chain2: "Complete_Partial_Order.chain (ord_spmf (=)) (snd ` Y)"
1766    using chain by(rule chain_imageI; clarsimp)+
1767  from Y have Y1: "fst ` Y \<noteq> {}" and Y2: "snd ` Y \<noteq> {}" by auto
1768
1769  have "rel_spmf R (lub_spmf (fst ` Y)) (lub_spmf (snd ` Y))"
1770  proof(rule rel_spmf_measureI)
1771    show "weight_spmf (lub_spmf (snd ` Y)) \<le> weight_spmf (lub_spmf (fst ` Y))"
1772      by(auto simp add: weight_lub_spmf chain1 chain2 Y rel_spmf_weightD[OF R, symmetric] intro!: cSUP_least intro: cSUP_upper2[OF bdd_aboveI2[OF weight_spmf_le_1]])
1773
1774    fix A
1775    have "measure (measure_spmf (lub_spmf (fst ` Y))) A = (SUP y\<in>fst ` Y. measure (measure_spmf y) A)"
1776      using chain1 Y1 by(rule measure_lub_spmf)
1777    also have "\<dots> \<le> (SUP y\<in>snd ` Y. measure (measure_spmf y) {y. \<exists>x\<in>A. R x y})" using Y1
1778      by(rule cSUP_least)(auto intro!: cSUP_upper2[OF bdd_aboveI2[OF measure_spmf.subprob_measure_le_1]] rel_spmf_measureD R)
1779    also have "\<dots> = measure (measure_spmf (lub_spmf (snd ` Y))) {y. \<exists>x\<in>A. R x y}"
1780      using chain2 Y2 by(rule measure_lub_spmf[symmetric])
1781    finally show "measure (measure_spmf (lub_spmf (fst ` Y))) A \<le> \<dots>" .
1782  qed
1783  then show "?P (?lub Y)" by(simp add: prod_lub_def)
1784qed
1785
1786lemma admissible_rel_spmf_mcont [cont_intro]:
1787  "\<lbrakk> mcont lub ord lub_spmf (ord_spmf (=)) f; mcont lub ord lub_spmf (ord_spmf (=)) g \<rbrakk>
1788  \<Longrightarrow> ccpo.admissible lub ord (\<lambda>x. rel_spmf R (f x) (g x))"
1789by(rule admissible_subst[OF admissible_rel_spmf, where f="\<lambda>x. (f x, g x)", simplified])(rule mcont_Pair)
1790
1791context includes lifting_syntax
1792begin
1793
1794lemma fixp_spmf_parametric':
1795  assumes f: "\<And>x. monotone (ord_spmf (=)) (ord_spmf (=)) F"
1796  and g: "\<And>x. monotone (ord_spmf (=)) (ord_spmf (=)) G"
1797  and param: "(rel_spmf R ===> rel_spmf R) F G"
1798  shows "(rel_spmf R) (ccpo.fixp lub_spmf (ord_spmf (=)) F) (ccpo.fixp lub_spmf (ord_spmf (=)) G)"
1799by(rule parallel_fixp_induct[OF ccpo_spmf ccpo_spmf _ f g])(auto intro: param[THEN rel_funD])
1800
1801lemma fixp_spmf_parametric:
1802  assumes f: "\<And>x. mono_spmf (\<lambda>f. F f x)"
1803  and g: "\<And>x. mono_spmf (\<lambda>f. G f x)"
1804  and param: "((A ===> rel_spmf R) ===> A ===> rel_spmf R) F G"
1805  shows "(A ===> rel_spmf R) (spmf.fixp_fun F) (spmf.fixp_fun G)"
1806using f g
1807proof(rule parallel_fixp_induct_1_1[OF partial_function_definitions_spmf partial_function_definitions_spmf _ _ reflexive reflexive, where P="(A ===> rel_spmf R)"])
1808  show "ccpo.admissible (prod_lub (fun_lub lub_spmf) (fun_lub lub_spmf)) (rel_prod (fun_ord (ord_spmf (=))) (fun_ord (ord_spmf (=)))) (\<lambda>x. (A ===> rel_spmf R) (fst x) (snd x))"
1809    unfolding rel_fun_def
1810    apply(rule admissible_all admissible_imp admissible_rel_spmf_mcont)+
1811    apply(rule spmf.mcont2mcont[OF mcont_call])
1812     apply(rule mcont_fst)
1813    apply(rule spmf.mcont2mcont[OF mcont_call])
1814     apply(rule mcont_snd)
1815    done
1816  show "(A ===> rel_spmf R) (\<lambda>_. lub_spmf {}) (\<lambda>_. lub_spmf {})" by auto
1817  show "(A ===> rel_spmf R) (F f) (G g)" if "(A ===> rel_spmf R) f g" for f g
1818    using that by(rule rel_funD[OF param])
1819qed
1820
1821end
1822
1823end
1824
1825end
1826
1827subsection \<open>Restrictions on spmfs\<close>
1828
1829definition restrict_spmf :: "'a spmf \<Rightarrow> 'a set \<Rightarrow> 'a spmf" (infixl "\<upharpoonleft>" 110)
1830where "p \<upharpoonleft> A = map_pmf (\<lambda>x. x \<bind> (\<lambda>y. if y \<in> A then Some y else None)) p"
1831
1832lemma set_restrict_spmf [simp]: "set_spmf (p \<upharpoonleft> A) = set_spmf p \<inter> A"
1833by(fastforce simp add: restrict_spmf_def set_spmf_def split: bind_splits if_split_asm)
1834
1835lemma restrict_map_spmf: "map_spmf f p \<upharpoonleft> A = map_spmf f (p \<upharpoonleft> (f -` A))"
1836by(simp add: restrict_spmf_def pmf.map_comp o_def map_option_bind bind_map_option if_distrib cong del: if_weak_cong)
1837
1838lemma restrict_restrict_spmf [simp]: "p \<upharpoonleft> A \<upharpoonleft> B = p \<upharpoonleft> (A \<inter> B)"
1839by(auto simp add: restrict_spmf_def pmf.map_comp o_def intro!: pmf.map_cong bind_option_cong)
1840
1841lemma restrict_spmf_empty [simp]: "p \<upharpoonleft> {} = return_pmf None"
1842by(simp add: restrict_spmf_def)
1843
1844lemma restrict_spmf_UNIV [simp]: "p \<upharpoonleft> UNIV = p"
1845by(simp add: restrict_spmf_def)
1846
1847lemma spmf_restrict_spmf_outside [simp]: "x \<notin> A \<Longrightarrow> spmf (p \<upharpoonleft> A) x = 0"
1848by(simp add: spmf_eq_0_set_spmf)
1849
1850lemma emeasure_restrict_spmf [simp]:
1851  "emeasure (measure_spmf (p \<upharpoonleft> A)) X = emeasure (measure_spmf p) (X \<inter> A)"
1852by(auto simp add: restrict_spmf_def measure_spmf_def emeasure_distr measurable_restrict_space1 emeasure_restrict_space space_restrict_space intro: arg_cong2[where f=emeasure] split: bind_splits if_split_asm)
1853
1854lemma measure_restrict_spmf [simp]:
1855  "measure (measure_spmf (p \<upharpoonleft> A)) X = measure (measure_spmf p) (X \<inter> A)"
1856using emeasure_restrict_spmf[of p A X]
1857by(simp only: measure_spmf.emeasure_eq_measure ennreal_inj measure_nonneg)
1858
1859lemma spmf_restrict_spmf: "spmf (p \<upharpoonleft> A) x = (if x \<in> A then spmf p x else 0)"
1860by(simp add: spmf_conv_measure_spmf)
1861
1862lemma spmf_restrict_spmf_inside [simp]: "x \<in> A \<Longrightarrow> spmf (p \<upharpoonleft> A) x = spmf p x"
1863by(simp add: spmf_restrict_spmf)
1864
1865lemma pmf_restrict_spmf_None: "pmf (p \<upharpoonleft> A) None = pmf p None + measure (measure_spmf p) (- A)"
1866proof -
1867  have [simp]: "None \<notin> Some ` (- A)" by auto
1868  have "(\<lambda>x. x \<bind> (\<lambda>y. if y \<in> A then Some y else None)) -` {None} = {None} \<union> (Some ` (- A))"
1869    by(auto split: bind_splits if_split_asm)
1870  then show ?thesis unfolding ereal.inject[symmetric]
1871    by(simp add: restrict_spmf_def ennreal_pmf_map emeasure_pmf_single del: ereal.inject)
1872      (simp add: pmf.rep_eq measure_pmf.finite_measure_Union[symmetric] measure_measure_spmf_conv_measure_pmf measure_pmf.emeasure_eq_measure)
1873qed
1874
1875lemma restrict_spmf_trivial: "(\<And>x. x \<in> set_spmf p \<Longrightarrow> x \<in> A) \<Longrightarrow> p \<upharpoonleft> A = p"
1876by(rule spmf_eqI)(auto simp add: spmf_restrict_spmf spmf_eq_0_set_spmf)
1877
1878lemma restrict_spmf_trivial': "set_spmf p \<subseteq> A \<Longrightarrow> p \<upharpoonleft> A = p"
1879by(rule restrict_spmf_trivial) blast
1880
1881lemma restrict_return_spmf: "return_spmf x \<upharpoonleft> A = (if x \<in> A then return_spmf x else return_pmf None)"
1882by(simp add: restrict_spmf_def)
1883
1884lemma restrict_return_spmf_inside [simp]: "x \<in> A \<Longrightarrow> return_spmf x \<upharpoonleft> A = return_spmf x"
1885by(simp add: restrict_return_spmf)
1886
1887lemma restrict_return_spmf_outside [simp]: "x \<notin> A \<Longrightarrow> return_spmf x \<upharpoonleft> A = return_pmf None"
1888by(simp add: restrict_return_spmf)
1889
1890lemma restrict_spmf_return_pmf_None [simp]: "return_pmf None \<upharpoonleft> A = return_pmf None"
1891by(simp add: restrict_spmf_def)
1892
1893lemma restrict_bind_pmf: "bind_pmf p g \<upharpoonleft> A = p \<bind> (\<lambda>x. g x \<upharpoonleft> A)"
1894by(simp add: restrict_spmf_def map_bind_pmf o_def)
1895
1896lemma restrict_bind_spmf: "bind_spmf p g \<upharpoonleft> A = p \<bind> (\<lambda>x. g x \<upharpoonleft> A)"
1897by(auto simp add: bind_spmf_def restrict_bind_pmf cong del: option.case_cong_weak cong: option.case_cong intro!: bind_pmf_cong split: option.split)
1898
1899lemma bind_restrict_pmf: "bind_pmf (p \<upharpoonleft> A) g = p \<bind> (\<lambda>x. if x \<in> Some ` A then g x else g None)"
1900by(auto simp add: restrict_spmf_def bind_map_pmf fun_eq_iff split: bind_split intro: arg_cong2[where f=bind_pmf])
1901
1902lemma bind_restrict_spmf: "bind_spmf (p \<upharpoonleft> A) g = p \<bind> (\<lambda>x. if x \<in> A then g x else return_pmf None)"
1903by(auto simp add: bind_spmf_def bind_restrict_pmf fun_eq_iff intro: arg_cong2[where f=bind_pmf] split: option.split)
1904
1905lemma spmf_map_restrict: "spmf (map_spmf fst (p \<upharpoonleft> (snd -` {y}))) x = spmf p (x, y)"
1906by(subst spmf_map)(auto intro: arg_cong2[where f=measure] simp add: spmf_conv_measure_spmf)
1907
1908lemma measure_eqI_restrict_spmf:
1909  assumes "rel_spmf R (restrict_spmf p A) (restrict_spmf q B)"
1910  shows "measure (measure_spmf p) A = measure (measure_spmf q) B"
1911proof -
1912  from assms have "weight_spmf (restrict_spmf p A) = weight_spmf (restrict_spmf q B)" by(rule rel_spmf_weightD)
1913  thus ?thesis by(simp add: weight_spmf_def)
1914qed
1915
1916subsection \<open>Subprobability distributions of sets\<close>
1917
1918definition spmf_of_set :: "'a set \<Rightarrow> 'a spmf"
1919where
1920  "spmf_of_set A = (if finite A \<and> A \<noteq> {} then spmf_of_pmf (pmf_of_set A) else return_pmf None)"
1921
1922lemma spmf_of_set: "spmf (spmf_of_set A) x = indicator A x / card A"
1923by(auto simp add: spmf_of_set_def)
1924
1925lemma pmf_spmf_of_set_None [simp]: "pmf (spmf_of_set A) None = indicator {A. infinite A \<or> A = {}} A"
1926by(simp add: spmf_of_set_def)
1927
1928lemma set_spmf_of_set: "set_spmf (spmf_of_set A) = (if finite A then A else {})"
1929by(simp add: spmf_of_set_def)
1930
1931lemma set_spmf_of_set_finite [simp]: "finite A \<Longrightarrow> set_spmf (spmf_of_set A) = A"
1932by(simp add: set_spmf_of_set)
1933
1934lemma spmf_of_set_singleton: "spmf_of_set {x} = return_spmf x"
1935by(simp add: spmf_of_set_def pmf_of_set_singleton)
1936
1937lemma map_spmf_of_set_inj_on [simp]:
1938  "inj_on f A \<Longrightarrow> map_spmf f (spmf_of_set A) = spmf_of_set (f ` A)"
1939by(auto simp add: spmf_of_set_def map_pmf_of_set_inj dest: finite_imageD)
1940
1941lemma spmf_of_pmf_pmf_of_set [simp]:
1942  "\<lbrakk> finite A; A \<noteq> {} \<rbrakk> \<Longrightarrow> spmf_of_pmf (pmf_of_set A) = spmf_of_set A"
1943by(simp add: spmf_of_set_def)
1944
1945lemma weight_spmf_of_set:
1946  "weight_spmf (spmf_of_set A) = (if finite A \<and> A \<noteq> {} then 1 else 0)"
1947by(auto simp only: spmf_of_set_def weight_spmf_of_pmf weight_return_pmf_None split: if_split)
1948
1949lemma weight_spmf_of_set_finite [simp]: "\<lbrakk> finite A; A \<noteq> {} \<rbrakk> \<Longrightarrow> weight_spmf (spmf_of_set A) = 1"
1950by(simp add: weight_spmf_of_set)
1951
1952lemma weight_spmf_of_set_infinite [simp]: "infinite A \<Longrightarrow> weight_spmf (spmf_of_set A) = 0"
1953by(simp add: weight_spmf_of_set)
1954
1955lemma measure_spmf_spmf_of_set:
1956  "measure_spmf (spmf_of_set A) = (if finite A \<and> A \<noteq> {} then measure_pmf (pmf_of_set A) else null_measure (count_space UNIV))"
1957by(simp add: spmf_of_set_def del: spmf_of_pmf_pmf_of_set)
1958
1959lemma emeasure_spmf_of_set:
1960  "emeasure (measure_spmf (spmf_of_set S)) A = card (S \<inter> A) / card S"
1961by(auto simp add: measure_spmf_spmf_of_set emeasure_pmf_of_set)
1962
1963lemma measure_spmf_of_set:
1964  "measure (measure_spmf (spmf_of_set S)) A = card (S \<inter> A) / card S"
1965by(auto simp add: measure_spmf_spmf_of_set measure_pmf_of_set)
1966
1967lemma nn_integral_spmf_of_set: "nn_integral (measure_spmf (spmf_of_set A)) f = sum f A / card A"
1968by(cases "finite A")(auto simp add: spmf_of_set_def nn_integral_pmf_of_set card_gt_0_iff simp del: spmf_of_pmf_pmf_of_set)
1969
1970lemma integral_spmf_of_set: "integral\<^sup>L (measure_spmf (spmf_of_set A)) f = sum f A / card A"
1971by(clarsimp simp add: spmf_of_set_def integral_pmf_of_set card_gt_0_iff simp del: spmf_of_pmf_pmf_of_set)
1972
1973notepad begin \<comment> \<open>\<^const>\<open>pmf_of_set\<close> is not fully parametric.\<close>
1974  define R :: "nat \<Rightarrow> nat \<Rightarrow> bool" where "R x y \<longleftrightarrow> (x \<noteq> 0 \<longrightarrow> y = 0)" for x y
1975  define A :: "nat set" where "A = {0, 1}"
1976  define B :: "nat set" where "B = {0, 1, 2}"
1977  have "rel_set R A B" unfolding R_def[abs_def] A_def B_def rel_set_def by auto
1978  have "\<not> rel_pmf R (pmf_of_set A) (pmf_of_set B)"
1979  proof
1980    assume "rel_pmf R (pmf_of_set A) (pmf_of_set B)"
1981    then obtain pq where pq: "\<And>x y. (x, y) \<in> set_pmf pq \<Longrightarrow> R x y"
1982      and 1: "map_pmf fst pq = pmf_of_set A"
1983      and 2: "map_pmf snd pq = pmf_of_set B"
1984      by cases auto
1985    have "pmf (pmf_of_set B) 1 = 1 / 3" by(simp add: B_def)
1986    have "pmf (pmf_of_set B) 2 = 1 / 3" by(simp add: B_def)
1987
1988    have "2 / 3 = pmf (pmf_of_set B) 1 + pmf (pmf_of_set B) 2" by(simp add: B_def)
1989    also have "\<dots> = measure (measure_pmf (pmf_of_set B)) ({1} \<union> {2})"
1990      by(subst measure_pmf.finite_measure_Union)(simp_all add: measure_pmf_single)
1991    also have "\<dots> = emeasure (measure_pmf pq) (snd -` {2, 1})"
1992      unfolding 2[symmetric] measure_pmf.emeasure_eq_measure[symmetric] by(simp)
1993    also have "\<dots> = emeasure (measure_pmf pq) {(0, 2), (0, 1)}"
1994      by(rule emeasure_eq_AE)(auto simp add: AE_measure_pmf_iff R_def dest!: pq)
1995    also have "\<dots> \<le> emeasure (measure_pmf pq) (fst -` {0})"
1996      by(rule emeasure_mono) auto
1997    also have "\<dots> = emeasure (measure_pmf (pmf_of_set A)) {0}"
1998      unfolding 1[symmetric] by simp
1999    also have "\<dots> = pmf (pmf_of_set A) 0"
2000      by(simp add: measure_pmf_single measure_pmf.emeasure_eq_measure)
2001    also have "pmf (pmf_of_set A) 0 = 1 / 2" by(simp add: A_def)
2002    finally show False by(subst (asm) ennreal_le_iff; simp)
2003  qed
2004end
2005
2006lemma rel_pmf_of_set_bij:
2007  assumes f: "bij_betw f A B"
2008  and A: "A \<noteq> {}" "finite A"
2009  and B: "B \<noteq> {}" "finite B"
2010  and R: "\<And>x. x \<in> A \<Longrightarrow> R x (f x)"
2011  shows "rel_pmf R (pmf_of_set A) (pmf_of_set B)"
2012proof(rule pmf.rel_mono_strong)
2013  define AB where "AB = (\<lambda>x. (x, f x)) ` A"
2014  define R' where "R' x y \<longleftrightarrow> (x, y) \<in> AB" for x y
2015  have "(x, y) \<in> AB" if "(x, y) \<in> set_pmf (pmf_of_set AB)" for x y
2016    using that by(auto simp add: AB_def A)
2017  moreover have "map_pmf fst (pmf_of_set AB) = pmf_of_set A"
2018    by(simp add: AB_def map_pmf_of_set_inj[symmetric] inj_on_def A pmf.map_comp o_def)
2019  moreover
2020  from f have [simp]: "inj_on f A" by(rule bij_betw_imp_inj_on)
2021  from f have [simp]: "f ` A = B" by(rule bij_betw_imp_surj_on)
2022  have "map_pmf snd (pmf_of_set AB) = pmf_of_set B"
2023    by(simp add: AB_def map_pmf_of_set_inj[symmetric] inj_on_def A pmf.map_comp o_def)
2024      (simp add: map_pmf_of_set_inj A)
2025  ultimately show "rel_pmf (\<lambda>x y. (x, y) \<in> AB) (pmf_of_set A) (pmf_of_set B)" ..
2026qed(auto intro: R)
2027
2028lemma rel_spmf_of_set_bij:
2029  assumes f: "bij_betw f A B"
2030  and R: "\<And>x. x \<in> A \<Longrightarrow> R x (f x)"
2031  shows "rel_spmf R (spmf_of_set A) (spmf_of_set B)"
2032proof -
2033  have "finite A \<longleftrightarrow> finite B" using f by(rule bij_betw_finite)
2034  moreover have "A = {} \<longleftrightarrow> B = {}" using f by(auto dest: bij_betw_empty2 bij_betw_empty1)
2035  ultimately show ?thesis using assms
2036    by(auto simp add: spmf_of_set_def simp del: spmf_of_pmf_pmf_of_set intro: rel_pmf_of_set_bij)
2037qed
2038
2039context includes lifting_syntax
2040begin
2041
2042lemma rel_spmf_of_set:
2043  assumes "bi_unique R"
2044  shows "(rel_set R ===> rel_spmf R) spmf_of_set spmf_of_set"
2045proof
2046  fix A B
2047  assume R: "rel_set R A B"
2048  with assms obtain f where "bij_betw f A B" and f: "\<And>x. x \<in> A \<Longrightarrow> R x (f x)"
2049    by(auto dest: bi_unique_rel_set_bij_betw)
2050  then show "rel_spmf R (spmf_of_set A) (spmf_of_set B)" by(rule rel_spmf_of_set_bij)
2051qed
2052
2053end
2054
2055lemma map_mem_spmf_of_set:
2056  assumes "finite B" "B \<noteq> {}"
2057  shows "map_spmf (\<lambda>x. x \<in> A) (spmf_of_set B) = spmf_of_pmf (bernoulli_pmf (card (A \<inter> B) / card B))"
2058  (is "?lhs = ?rhs")
2059proof(rule spmf_eqI)
2060  fix i
2061  have "ennreal (spmf ?lhs i) = card (B \<inter> (\<lambda>x. x \<in> A) -` {i}) / (card B)"
2062    by(subst ennreal_spmf_map)(simp add: measure_spmf_spmf_of_set assms emeasure_pmf_of_set)
2063  also have "\<dots> = (if i then card (B \<inter> A) / card B else card (B - A) / card B)"
2064    by(auto intro: arg_cong[where f=card])
2065  also have "\<dots> = (if i then card (B \<inter> A) / card B else (card B - card (B \<inter> A)) / card B)"
2066    by(auto simp add: card_Diff_subset_Int assms)
2067  also have "\<dots> = ennreal (spmf ?rhs i)"
2068    by(simp add: assms card_gt_0_iff field_simps card_mono Int_commute of_nat_diff)
2069  finally show "spmf ?lhs i = spmf ?rhs i" by simp
2070qed
2071
2072abbreviation coin_spmf :: "bool spmf"
2073where "coin_spmf \<equiv> spmf_of_set UNIV"
2074
2075lemma map_eq_const_coin_spmf: "map_spmf ((=) c) coin_spmf = coin_spmf"
2076proof -
2077  have "inj ((\<longleftrightarrow>) c)" "range ((\<longleftrightarrow>) c) = UNIV" by(auto intro: inj_onI)
2078  then show ?thesis by simp
2079qed
2080
2081lemma bind_coin_spmf_eq_const: "coin_spmf \<bind> (\<lambda>x :: bool. return_spmf (b = x)) = coin_spmf"
2082using map_eq_const_coin_spmf unfolding map_spmf_conv_bind_spmf by simp
2083
2084lemma bind_coin_spmf_eq_const': "coin_spmf \<bind> (\<lambda>x :: bool. return_spmf (x = b)) = coin_spmf"
2085by(rewrite in "_ = \<hole>" bind_coin_spmf_eq_const[symmetric, of b])(auto intro: bind_spmf_cong)
2086
2087subsection \<open>Losslessness\<close>
2088
2089definition lossless_spmf :: "'a spmf \<Rightarrow> bool"
2090where "lossless_spmf p \<longleftrightarrow> weight_spmf p = 1"
2091
2092lemma lossless_iff_pmf_None: "lossless_spmf p \<longleftrightarrow> pmf p None = 0"
2093by(simp add: lossless_spmf_def pmf_None_eq_weight_spmf)
2094
2095lemma lossless_return_spmf [iff]: "lossless_spmf (return_spmf x)"
2096by(simp add: lossless_iff_pmf_None)
2097
2098lemma lossless_return_pmf_None [iff]: "\<not> lossless_spmf (return_pmf None)"
2099by(simp add: lossless_iff_pmf_None)
2100
2101lemma lossless_map_spmf [simp]: "lossless_spmf (map_spmf f p) \<longleftrightarrow> lossless_spmf p"
2102by(auto simp add: lossless_iff_pmf_None pmf_eq_0_set_pmf)
2103
2104lemma lossless_bind_spmf [simp]:
2105  "lossless_spmf (p \<bind> f) \<longleftrightarrow> lossless_spmf p \<and> (\<forall>x\<in>set_spmf p. lossless_spmf (f x))"
2106by(simp add: lossless_iff_pmf_None pmf_bind_spmf_None add_nonneg_eq_0_iff integral_nonneg_AE integral_nonneg_eq_0_iff_AE measure_spmf.integrable_const_bound[where B=1] pmf_le_1)
2107
2108lemma lossless_weight_spmfD: "lossless_spmf p \<Longrightarrow> weight_spmf p = 1"
2109by(simp add: lossless_spmf_def)
2110
2111lemma lossless_iff_set_pmf_None:
2112  "lossless_spmf p \<longleftrightarrow> None \<notin> set_pmf p"
2113by (simp add: lossless_iff_pmf_None pmf_eq_0_set_pmf)
2114
2115lemma lossless_spmf_of_set [simp]: "lossless_spmf (spmf_of_set A) \<longleftrightarrow> finite A \<and> A \<noteq> {}"
2116by(auto simp add: lossless_spmf_def weight_spmf_of_set)
2117
2118lemma lossless_spmf_spmf_of_spmf [simp]: "lossless_spmf (spmf_of_pmf p)"
2119by(simp add: lossless_spmf_def)
2120
2121lemma lossless_spmf_bind_pmf [simp]:
2122  "lossless_spmf (bind_pmf p f) \<longleftrightarrow> (\<forall>x\<in>set_pmf p. lossless_spmf (f x))"
2123by(simp add: lossless_iff_pmf_None pmf_bind integral_nonneg_AE integral_nonneg_eq_0_iff_AE measure_pmf.integrable_const_bound[where B=1] AE_measure_pmf_iff pmf_le_1)
2124
2125lemma lossless_spmf_conv_spmf_of_pmf: "lossless_spmf p \<longleftrightarrow> (\<exists>p'. p = spmf_of_pmf p')"
2126proof
2127  assume "lossless_spmf p"
2128  hence *: "\<And>y. y \<in> set_pmf p \<Longrightarrow> \<exists>x. y = Some x"
2129    by(case_tac y)(simp_all add: lossless_iff_set_pmf_None)
2130
2131  let ?p = "map_pmf the p"
2132  have "p = spmf_of_pmf ?p"
2133  proof(rule spmf_eqI)
2134    fix i
2135    have "ennreal (pmf (map_pmf the p) i) = \<integral>\<^sup>+ x. indicator (the -` {i}) x \<partial>p" by(simp add: ennreal_pmf_map)
2136    also have "\<dots> = \<integral>\<^sup>+ x. indicator {i} x \<partial>measure_spmf p" unfolding measure_spmf_def
2137      by(subst nn_integral_distr)(auto simp add: nn_integral_restrict_space AE_measure_pmf_iff simp del: nn_integral_indicator intro!: nn_integral_cong_AE split: split_indicator dest!: * )
2138    also have "\<dots> = spmf p i" by(simp add: emeasure_spmf_single)
2139    finally show "spmf p i = spmf (spmf_of_pmf ?p) i" by simp
2140  qed
2141  thus "\<exists>p'. p = spmf_of_pmf p'" ..
2142qed auto
2143
2144lemma spmf_False_conv_True: "lossless_spmf p \<Longrightarrow> spmf p False = 1 - spmf p True"
2145by(clarsimp simp add: lossless_spmf_conv_spmf_of_pmf pmf_False_conv_True)
2146
2147lemma spmf_True_conv_False: "lossless_spmf p \<Longrightarrow> spmf p True = 1 - spmf p False"
2148by(simp add: spmf_False_conv_True)
2149
2150lemma bind_eq_return_spmf:
2151  "bind_spmf p f = return_spmf x \<longleftrightarrow> (\<forall>y\<in>set_spmf p. f y = return_spmf x) \<and> lossless_spmf p"
2152by(auto simp add: bind_spmf_def bind_eq_return_pmf in_set_spmf lossless_iff_pmf_None pmf_eq_0_set_pmf iff del: not_None_eq split: option.split)
2153
2154lemma rel_spmf_return_spmf2:
2155  "rel_spmf R p (return_spmf x) \<longleftrightarrow> lossless_spmf p \<and> (\<forall>a\<in>set_spmf p. R a x)"
2156by(auto simp add: lossless_iff_set_pmf_None rel_pmf_return_pmf2 option_rel_Some2 in_set_spmf, metis in_set_spmf not_None_eq)
2157
2158lemma rel_spmf_return_spmf1:
2159  "rel_spmf R (return_spmf x) p \<longleftrightarrow> lossless_spmf p \<and> (\<forall>a\<in>set_spmf p. R x a)"
2160using rel_spmf_return_spmf2[of "R\<inverse>\<inverse>"] by(simp add: spmf_rel_conversep)
2161
2162lemma rel_spmf_bindI1:
2163  assumes f: "\<And>x. x \<in> set_spmf p \<Longrightarrow> rel_spmf R (f x) q"
2164  and p: "lossless_spmf p"
2165  shows "rel_spmf R (bind_spmf p f) q"
2166proof -
2167  fix x :: 'a
2168  have "rel_spmf R (bind_spmf p f) (bind_spmf (return_spmf x) (\<lambda>_. q))"
2169    by(rule rel_spmf_bindI[where R="\<lambda>x _. x \<in> set_spmf p"])(simp_all add: rel_spmf_return_spmf2 p f)
2170  then show ?thesis by simp
2171qed
2172
2173lemma rel_spmf_bindI2:
2174  "\<lbrakk> \<And>x. x \<in> set_spmf q \<Longrightarrow> rel_spmf R p (f x); lossless_spmf q \<rbrakk>
2175  \<Longrightarrow> rel_spmf R p (bind_spmf q f)"
2176using rel_spmf_bindI1[of q "conversep R" f p] by(simp add: spmf_rel_conversep)
2177
2178subsection \<open>Scaling\<close>
2179
2180definition scale_spmf :: "real \<Rightarrow> 'a spmf \<Rightarrow> 'a spmf"
2181where
2182  "scale_spmf r p = embed_spmf (\<lambda>x. min (inverse (weight_spmf p)) (max 0 r) * spmf p x)"
2183
2184lemma scale_spmf_le_1:
2185  "(\<integral>\<^sup>+ x. min (inverse (weight_spmf p)) (max 0 r) * spmf p x \<partial>count_space UNIV) \<le> 1" (is "?lhs \<le> _")
2186proof -
2187  have "?lhs = min (inverse (weight_spmf p)) (max 0 r) * \<integral>\<^sup>+ x. spmf p x \<partial>count_space UNIV"
2188    by(subst nn_integral_cmult[symmetric])(simp_all add: weight_spmf_nonneg max_def min_def ennreal_mult)
2189  also have "\<dots> \<le> 1" unfolding weight_spmf_eq_nn_integral_spmf[symmetric]
2190    by(simp add: min_def max_def weight_spmf_nonneg order.strict_iff_order field_simps ennreal_mult[symmetric])
2191  finally show ?thesis .
2192qed
2193
2194lemma spmf_scale_spmf: "spmf (scale_spmf r p) x = max 0 (min (inverse (weight_spmf p)) r) * spmf p x" (is "?lhs = ?rhs")
2195unfolding scale_spmf_def
2196apply(subst spmf_embed_spmf[OF scale_spmf_le_1])
2197apply(simp add: max_def min_def weight_spmf_le_0 field_simps weight_spmf_nonneg not_le order.strict_iff_order)
2198apply(metis antisym_conv order_trans weight_spmf_nonneg zero_le_mult_iff zero_le_one)
2199done
2200
2201lemma real_inverse_le_1_iff: fixes x :: real
2202  shows "\<lbrakk> 0 \<le> x; x \<le> 1 \<rbrakk> \<Longrightarrow> 1 / x \<le> 1 \<longleftrightarrow> x = 1 \<or> x = 0"
2203by auto
2204
2205lemma spmf_scale_spmf': "r \<le> 1 \<Longrightarrow> spmf (scale_spmf r p) x = max 0 r * spmf p x"
2206using real_inverse_le_1_iff[OF weight_spmf_nonneg weight_spmf_le_1, of p]
2207by(auto simp add: spmf_scale_spmf max_def min_def field_simps)(metis pmf_le_0_iff spmf_le_weight)
2208
2209lemma scale_spmf_neg: "r \<le> 0 \<Longrightarrow> scale_spmf r p = return_pmf None"
2210by(rule spmf_eqI)(simp add: spmf_scale_spmf' max_def)
2211
2212lemma scale_spmf_return_None [simp]: "scale_spmf r (return_pmf None) = return_pmf None"
2213by(rule spmf_eqI)(simp add: spmf_scale_spmf)
2214
2215lemma scale_spmf_conv_bind_bernoulli:
2216  assumes "r \<le> 1"
2217  shows "scale_spmf r p = bind_pmf (bernoulli_pmf r) (\<lambda>b. if b then p else return_pmf None)" (is "?lhs = ?rhs")
2218proof(rule spmf_eqI)
2219  fix x
2220  have "ennreal (spmf ?lhs x) = ennreal (spmf ?rhs x)" using assms
2221    unfolding spmf_scale_spmf ennreal_pmf_bind nn_integral_measure_pmf UNIV_bool bernoulli_pmf.rep_eq
2222    apply(auto simp add: nn_integral_count_space_finite max_def min_def field_simps real_inverse_le_1_iff[OF weight_spmf_nonneg weight_spmf_le_1] weight_spmf_lt_0 not_le ennreal_mult[symmetric])
2223    apply (metis pmf_le_0_iff spmf_le_weight)
2224    apply (metis pmf_le_0_iff spmf_le_weight)
2225    apply (meson le_divide_eq_1_pos measure_spmf.subprob_measure_le_1 not_less order_trans weight_spmf_le_0)
2226    by (meson divide_le_0_1_iff less_imp_le order_trans weight_spmf_le_0)
2227  thus "spmf ?lhs x = spmf ?rhs x" by simp
2228qed
2229
2230lemma nn_integral_spmf: "(\<integral>\<^sup>+ x. spmf p x \<partial>count_space A) = emeasure (measure_spmf p) A"
2231apply(simp add: measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space nn_integral_pmf[symmetric])
2232apply(rule nn_integral_bij_count_space[where g=Some])
2233apply(auto simp add: bij_betw_def)
2234done
2235
2236lemma measure_spmf_scale_spmf: "measure_spmf (scale_spmf r p) = scale_measure (min (inverse (weight_spmf p)) r) (measure_spmf p)"
2237apply(rule measure_eqI)
2238 apply simp
2239apply(simp add: nn_integral_spmf[symmetric] spmf_scale_spmf)
2240apply(subst nn_integral_cmult[symmetric])
2241apply(auto simp add: max_def min_def ennreal_mult[symmetric] not_le ennreal_lt_0)
2242done
2243
2244lemma measure_spmf_scale_spmf':
2245  "r \<le> 1 \<Longrightarrow> measure_spmf (scale_spmf r p) = scale_measure r (measure_spmf p)"
2246unfolding measure_spmf_scale_spmf
2247apply(cases "weight_spmf p > 0")
2248 apply(simp add: min.absorb2 field_simps weight_spmf_le_1 mult_le_one)
2249apply(clarsimp simp add: weight_spmf_le_0 min_def scale_spmf_neg weight_spmf_eq_0 not_less)
2250done
2251
2252lemma scale_spmf_1 [simp]: "scale_spmf 1 p = p"
2253apply(rule spmf_eqI)
2254apply(simp add: spmf_scale_spmf max_def min_def order.strict_iff_order field_simps weight_spmf_nonneg)
2255apply(metis antisym_conv divide_le_eq_1 less_imp_le pmf_nonneg spmf_le_weight weight_spmf_nonneg weight_spmf_le_1)
2256done
2257
2258lemma scale_spmf_0 [simp]: "scale_spmf 0 p = return_pmf None"
2259by(rule spmf_eqI)(simp add: spmf_scale_spmf min_def max_def weight_spmf_le_0)
2260
2261lemma bind_scale_spmf:
2262  assumes r: "r \<le> 1"
2263  shows "bind_spmf (scale_spmf r p) f = bind_spmf p (\<lambda>x. scale_spmf r (f x))"
2264  (is "?lhs = ?rhs")
2265proof(rule spmf_eqI)
2266  fix x
2267  have "ennreal (spmf ?lhs x) = ennreal (spmf ?rhs x)" using r
2268    by(simp add: ennreal_spmf_bind measure_spmf_scale_spmf' nn_integral_scale_measure spmf_scale_spmf')
2269      (simp add: ennreal_mult ennreal_lt_0 nn_integral_cmult max_def min_def)
2270  thus "spmf ?lhs x = spmf ?rhs x" by simp
2271qed
2272
2273lemma scale_bind_spmf:
2274  assumes "r \<le> 1"
2275  shows "scale_spmf r (bind_spmf p f) = bind_spmf p (\<lambda>x. scale_spmf r (f x))"
2276  (is "?lhs = ?rhs")
2277proof(rule spmf_eqI)
2278  fix x
2279  have "ennreal (spmf ?lhs x) = ennreal (spmf ?rhs x)" using assms
2280    unfolding spmf_scale_spmf'[OF assms]
2281    by(simp add: ennreal_mult ennreal_spmf_bind spmf_scale_spmf' nn_integral_cmult max_def min_def)
2282  thus "spmf ?lhs x = spmf ?rhs x" by simp
2283qed
2284
2285lemma bind_spmf_const: "bind_spmf p (\<lambda>x. q) = scale_spmf (weight_spmf p) q" (is "?lhs = ?rhs")
2286proof(rule spmf_eqI)
2287  fix x
2288  have "ennreal (spmf ?lhs x) = ennreal (spmf ?rhs x)"
2289    using measure_spmf.subprob_measure_le_1[of p "space (measure_spmf p)"]
2290    by(subst ennreal_spmf_bind)(simp add: spmf_scale_spmf' weight_spmf_le_1 ennreal_mult mult.commute max_def min_def measure_spmf.emeasure_eq_measure)
2291  thus "spmf ?lhs x = spmf ?rhs x" by simp
2292qed
2293
2294lemma map_scale_spmf: "map_spmf f (scale_spmf r p) = scale_spmf r (map_spmf f p)" (is "?lhs = ?rhs")
2295proof(rule spmf_eqI)
2296  fix i
2297  show "spmf ?lhs i = spmf ?rhs i" unfolding spmf_scale_spmf
2298    by(subst (1 2) spmf_map)(auto simp add: measure_spmf_scale_spmf max_def min_def ennreal_lt_0)
2299qed
2300
2301lemma set_scale_spmf: "set_spmf (scale_spmf r p) = (if r > 0 then set_spmf p else {})"
2302apply(auto simp add: in_set_spmf_iff_spmf spmf_scale_spmf)
2303apply(simp add: max_def min_def not_le weight_spmf_lt_0 weight_spmf_eq_0 split: if_split_asm)
2304done
2305
2306lemma set_scale_spmf' [simp]: "0 < r \<Longrightarrow> set_spmf (scale_spmf r p) = set_spmf p"
2307by(simp add: set_scale_spmf)
2308
2309lemma rel_spmf_scaleI:
2310  assumes "r > 0 \<Longrightarrow> rel_spmf A p q"
2311  shows "rel_spmf A (scale_spmf r p) (scale_spmf r q)"
2312proof(cases "r > 0")
2313  case True
2314  from assms[OF this] show ?thesis
2315    by(rule rel_spmfE)(auto simp add: map_scale_spmf[symmetric] spmf_rel_map True intro: rel_spmf_reflI)
2316qed(simp add: not_less scale_spmf_neg)
2317
2318lemma weight_scale_spmf: "weight_spmf (scale_spmf r p) = min 1 (max 0 r * weight_spmf p)"
2319proof -
2320  have "ennreal (weight_spmf (scale_spmf r p)) = min 1 (max 0 r * ennreal (weight_spmf p))"
2321    unfolding weight_spmf_eq_nn_integral_spmf
2322    apply(simp add: spmf_scale_spmf ennreal_mult zero_ereal_def[symmetric] nn_integral_cmult)
2323    apply(auto simp add: weight_spmf_eq_nn_integral_spmf[symmetric] field_simps min_def max_def not_le weight_spmf_lt_0 ennreal_mult[symmetric])
2324    subgoal by(subst (asm) ennreal_mult[symmetric], meson divide_less_0_1_iff le_less_trans not_le weight_spmf_lt_0, simp+, meson not_le pos_divide_le_eq weight_spmf_le_0)
2325    subgoal by(cases "r \<ge> 0")(simp_all add: ennreal_mult[symmetric] weight_spmf_nonneg ennreal_lt_0, meson le_less_trans not_le pos_divide_le_eq zero_less_divide_1_iff)
2326    done
2327  thus ?thesis by(auto simp add: min_def max_def ennreal_mult[symmetric] split: if_split_asm)
2328qed
2329
2330lemma weight_scale_spmf' [simp]:
2331  "\<lbrakk> 0 \<le> r; r \<le> 1 \<rbrakk> \<Longrightarrow> weight_spmf (scale_spmf r p) = r * weight_spmf p"
2332by(simp add: weight_scale_spmf max_def min_def)(metis antisym_conv mult_left_le order_trans weight_spmf_le_1)
2333
2334lemma pmf_scale_spmf_None:
2335  "pmf (scale_spmf k p) None = 1 - min 1 (max 0 k * (1 - pmf p None))"
2336unfolding pmf_None_eq_weight_spmf by(simp add: weight_scale_spmf)
2337
2338lemma scale_scale_spmf:
2339  "scale_spmf r (scale_spmf r' p) = scale_spmf (r * max 0 (min (inverse (weight_spmf p)) r')) p"
2340  (is "?lhs = ?rhs")
2341proof(rule spmf_eqI)
2342  fix i
2343  have "max 0 (min (1 / weight_spmf p) r') *
2344    max 0 (min (1 / min 1 (weight_spmf p * max 0 r')) r) =
2345    max 0 (min (1 / weight_spmf p) (r * max 0 (min (1 / weight_spmf p) r')))"
2346  proof(cases "weight_spmf p > 0")
2347    case False
2348    thus ?thesis by(simp add: not_less weight_spmf_le_0)
2349  next
2350    case True
2351    thus ?thesis by(simp add: field_simps max_def min.absorb_iff2[symmetric])(auto simp add: min_def field_simps zero_le_mult_iff)
2352  qed
2353  then show "spmf ?lhs i = spmf ?rhs i"
2354    by(simp add: spmf_scale_spmf field_simps weight_scale_spmf)
2355qed
2356
2357lemma scale_scale_spmf' [simp]:
2358  "\<lbrakk> 0 \<le> r; r \<le> 1; 0 \<le> r'; r' \<le> 1 \<rbrakk>
2359  \<Longrightarrow> scale_spmf r (scale_spmf r' p) = scale_spmf (r * r') p"
2360apply(cases "weight_spmf p > 0")
2361apply(auto simp add: scale_scale_spmf min_def max_def field_simps not_le weight_spmf_lt_0 weight_spmf_eq_0 not_less weight_spmf_le_0)
2362apply(subgoal_tac "1 = r'")
2363 apply (metis (no_types) div_by_1 eq_iff measure_spmf.subprob_measure_le_1 mult.commute mult_cancel_right1)
2364apply(meson eq_iff le_divide_eq_1_pos measure_spmf.subprob_measure_le_1 mult_imp_div_pos_le order.trans)
2365done
2366
2367lemma scale_spmf_eq_same: "scale_spmf r p = p \<longleftrightarrow> weight_spmf p = 0 \<or> r = 1 \<or> r \<ge> 1 \<and> weight_spmf p = 1"
2368  (is "?lhs \<longleftrightarrow> ?rhs")
2369proof
2370  assume ?lhs
2371  hence "weight_spmf (scale_spmf r p) = weight_spmf p" by simp
2372  hence *: "min 1 (max 0 r * weight_spmf p) = weight_spmf p" by(simp add: weight_scale_spmf)
2373  hence **: "weight_spmf p = 0 \<or> r \<ge> 1" by(auto simp add: min_def max_def split: if_split_asm)
2374  show ?rhs
2375  proof(cases "weight_spmf p = 0")
2376    case False
2377    with ** have "r \<ge> 1" by simp
2378    with * False have "r = 1 \<or> weight_spmf p = 1" by(simp add: max_def min_def not_le split: if_split_asm)
2379    with \<open>r \<ge> 1\<close> show ?thesis by simp
2380  qed simp
2381qed(auto intro!: spmf_eqI simp add: spmf_scale_spmf, metis pmf_le_0_iff spmf_le_weight)
2382
2383lemma map_const_spmf_of_set:
2384  "\<lbrakk> finite A; A \<noteq> {} \<rbrakk> \<Longrightarrow> map_spmf (\<lambda>_. c) (spmf_of_set A) = return_spmf c"
2385by(simp add: map_spmf_conv_bind_spmf bind_spmf_const)
2386
2387subsection \<open>Conditional spmfs\<close>
2388
2389lemma set_pmf_Int_Some: "set_pmf p \<inter> Some ` A = {} \<longleftrightarrow> set_spmf p \<inter> A = {}"
2390by(auto simp add: in_set_spmf)
2391
2392lemma measure_spmf_zero_iff: "measure (measure_spmf p) A = 0 \<longleftrightarrow> set_spmf p \<inter> A = {}"
2393unfolding measure_measure_spmf_conv_measure_pmf by(simp add: measure_pmf_zero_iff set_pmf_Int_Some)
2394
2395definition cond_spmf :: "'a spmf \<Rightarrow> 'a set \<Rightarrow> 'a spmf"
2396where "cond_spmf p A = (if set_spmf p \<inter> A = {} then return_pmf None else cond_pmf p (Some ` A))"
2397
2398lemma set_cond_spmf [simp]: "set_spmf (cond_spmf p A) = set_spmf p \<inter> A"
2399by(auto 4 4 simp add: cond_spmf_def in_set_spmf iff: set_cond_pmf[THEN set_eq_iff[THEN iffD1], THEN spec, rotated])
2400
2401lemma cond_map_spmf [simp]: "cond_spmf (map_spmf f p) A = map_spmf f (cond_spmf p (f -` A))"
2402proof -
2403  have "map_option f -` Some ` A = Some ` f -` A" by auto
2404  moreover have "set_pmf p \<inter> map_option f -` Some ` A \<noteq> {}" if "Some x \<in> set_pmf p" "f x \<in> A" for x
2405    using that by auto
2406  ultimately show ?thesis by(auto simp add: cond_spmf_def in_set_spmf cond_map_pmf)
2407qed
2408
2409lemma spmf_cond_spmf [simp]:
2410  "spmf (cond_spmf p A) x = (if x \<in> A then spmf p x / measure (measure_spmf p) A else 0)"
2411by(auto simp add: cond_spmf_def pmf_cond set_pmf_Int_Some[symmetric] measure_measure_spmf_conv_measure_pmf measure_pmf_zero_iff)
2412
2413lemma bind_eq_return_pmf_None:
2414  "bind_spmf p f = return_pmf None \<longleftrightarrow> (\<forall>x\<in>set_spmf p. f x = return_pmf None)"
2415by(auto simp add: bind_spmf_def bind_eq_return_pmf in_set_spmf split: option.splits)
2416
2417lemma return_pmf_None_eq_bind:
2418  "return_pmf None = bind_spmf p f \<longleftrightarrow> (\<forall>x\<in>set_spmf p. f x = return_pmf None)"
2419using bind_eq_return_pmf_None[of p f] by auto
2420
2421(* Conditional probabilities do not seem to interact nicely with bind. *)
2422
2423subsection \<open>Product spmf\<close>
2424
2425definition pair_spmf :: "'a spmf \<Rightarrow> 'b spmf \<Rightarrow> ('a \<times> 'b) spmf"
2426where "pair_spmf p q = bind_pmf (pair_pmf p q) (\<lambda>xy. case xy of (Some x, Some y) \<Rightarrow> return_spmf (x, y) | _ \<Rightarrow> return_pmf None)"
2427
2428lemma map_fst_pair_spmf [simp]: "map_spmf fst (pair_spmf p q) = scale_spmf (weight_spmf q) p"
2429unfolding bind_spmf_const[symmetric]
2430apply(simp add: pair_spmf_def map_bind_pmf pair_pmf_def bind_assoc_pmf option.case_distrib)
2431apply(subst bind_commute_pmf)
2432apply(auto intro!: bind_pmf_cong[OF refl] simp add: bind_return_pmf bind_spmf_def bind_return_pmf' case_option_collapse option.case_distrib[where h="map_spmf _"] option.case_distrib[symmetric] case_option_id split: option.split cong del: option.case_cong_weak)
2433done
2434
2435lemma map_snd_pair_spmf [simp]: "map_spmf snd (pair_spmf p q) = scale_spmf (weight_spmf p) q"
2436unfolding bind_spmf_const[symmetric]
2437  apply(simp add: pair_spmf_def map_bind_pmf pair_pmf_def bind_assoc_pmf option.case_distrib
2438    cong del: option.case_cong_weak)
2439apply(auto intro!: bind_pmf_cong[OF refl] simp add: bind_return_pmf bind_spmf_def bind_return_pmf' case_option_collapse option.case_distrib[where h="map_spmf _"] option.case_distrib[symmetric] case_option_id split: option.split cong del: option.case_cong_weak)
2440done
2441
2442lemma set_pair_spmf [simp]: "set_spmf (pair_spmf p q) = set_spmf p \<times> set_spmf q"
2443by(auto 4 3 simp add: pair_spmf_def set_spmf_bind_pmf bind_UNION in_set_spmf intro: rev_bexI split: option.splits)
2444
2445lemma spmf_pair [simp]: "spmf (pair_spmf p q) (x, y) = spmf p x * spmf q y" (is "?lhs = ?rhs")
2446proof -
2447  have "ennreal ?lhs = \<integral>\<^sup>+ a. \<integral>\<^sup>+ b. indicator {(x, y)} (a, b) \<partial>measure_spmf q \<partial>measure_spmf p"
2448    unfolding measure_spmf_def pair_spmf_def ennreal_pmf_bind nn_integral_pair_pmf'
2449    by(auto simp add: zero_ereal_def[symmetric] nn_integral_distr nn_integral_restrict_space nn_integral_multc[symmetric] intro!: nn_integral_cong split: option.split split_indicator)
2450  also have "\<dots> = \<integral>\<^sup>+ a. (\<integral>\<^sup>+ b. indicator {y} b \<partial>measure_spmf q) * indicator {x} a \<partial>measure_spmf p"
2451    by(subst nn_integral_multc[symmetric])(auto intro!: nn_integral_cong split: split_indicator)
2452  also have "\<dots> = ennreal ?rhs" by(simp add: emeasure_spmf_single max_def ennreal_mult mult.commute)
2453  finally show ?thesis by simp
2454qed
2455
2456lemma pair_map_spmf2: "pair_spmf p (map_spmf f q) = map_spmf (apsnd f) (pair_spmf p q)"
2457by(auto simp add: pair_spmf_def pair_map_pmf2 bind_map_pmf map_bind_pmf intro: bind_pmf_cong split: option.split)
2458
2459lemma pair_map_spmf1: "pair_spmf (map_spmf f p) q = map_spmf (apfst f) (pair_spmf p q)"
2460by(auto simp add: pair_spmf_def pair_map_pmf1 bind_map_pmf map_bind_pmf intro: bind_pmf_cong split: option.split)
2461
2462lemma pair_map_spmf: "pair_spmf (map_spmf f p) (map_spmf g q) = map_spmf (map_prod f g) (pair_spmf p q)"
2463unfolding pair_map_spmf2 pair_map_spmf1 spmf.map_comp by(simp add: apfst_def apsnd_def o_def prod.map_comp)
2464
2465lemma pair_spmf_alt_def: "pair_spmf p q = bind_spmf p (\<lambda>x. bind_spmf q (\<lambda>y. return_spmf (x, y)))"
2466by(auto simp add: pair_spmf_def pair_pmf_def bind_spmf_def bind_assoc_pmf bind_return_pmf split: option.split intro: bind_pmf_cong)
2467
2468lemma weight_pair_spmf [simp]: "weight_spmf (pair_spmf p q) = weight_spmf p * weight_spmf q"
2469unfolding pair_spmf_alt_def by(simp add: weight_bind_spmf o_def)
2470
2471lemma pair_scale_spmf1: (* FIXME: generalise to arbitrary r *)
2472  "r \<le> 1 \<Longrightarrow> pair_spmf (scale_spmf r p) q = scale_spmf r (pair_spmf p q)"
2473by(simp add: pair_spmf_alt_def scale_bind_spmf bind_scale_spmf)
2474
2475lemma pair_scale_spmf2: (* FIXME: generalise to arbitrary r *)
2476  "r \<le> 1 \<Longrightarrow> pair_spmf p (scale_spmf r q) = scale_spmf r (pair_spmf p q)"
2477by(simp add: pair_spmf_alt_def scale_bind_spmf bind_scale_spmf)
2478
2479lemma pair_spmf_return_None1 [simp]: "pair_spmf (return_pmf None) p = return_pmf None"
2480by(rule spmf_eqI)(clarsimp)
2481
2482lemma pair_spmf_return_None2 [simp]: "pair_spmf p (return_pmf None) = return_pmf None"
2483by(rule spmf_eqI)(clarsimp)
2484
2485lemma pair_spmf_return_spmf1: "pair_spmf (return_spmf x) q = map_spmf (Pair x) q"
2486by(rule spmf_eqI)(auto split: split_indicator simp add: spmf_map_inj' inj_on_def intro: spmf_map_outside)
2487
2488lemma pair_spmf_return_spmf2: "pair_spmf p (return_spmf y) = map_spmf (\<lambda>x. (x, y)) p"
2489by(rule spmf_eqI)(auto split: split_indicator simp add: inj_on_def intro!: spmf_map_outside spmf_map_inj'[symmetric])
2490
2491lemma pair_spmf_return_spmf [simp]: "pair_spmf (return_spmf x) (return_spmf y) = return_spmf (x, y)"
2492by(simp add: pair_spmf_return_spmf1)
2493
2494lemma rel_pair_spmf_prod:
2495  "rel_spmf (rel_prod A B) (pair_spmf p q) (pair_spmf p' q') \<longleftrightarrow>
2496   rel_spmf A (scale_spmf (weight_spmf q) p) (scale_spmf (weight_spmf q') p') \<and>
2497   rel_spmf B (scale_spmf (weight_spmf p) q) (scale_spmf (weight_spmf p') q')"
2498  (is "?lhs \<longleftrightarrow> ?rhs" is "_ \<longleftrightarrow> ?A \<and> ?B" is "_ \<longleftrightarrow> rel_spmf _ ?p ?p' \<and> rel_spmf _ ?q ?q'")
2499proof(intro iffI conjI)
2500  assume ?rhs
2501  then obtain pq pq' where p: "map_spmf fst pq = ?p" and p': "map_spmf snd pq = ?p'"
2502    and q: "map_spmf fst pq' = ?q" and q': "map_spmf snd pq' = ?q'"
2503    and *: "\<And>x x'. (x, x') \<in> set_spmf pq \<Longrightarrow> A x x'"
2504    and **: "\<And>y y'. (y, y') \<in> set_spmf pq' \<Longrightarrow> B y y'" by(auto elim!: rel_spmfE)
2505  let ?f = "\<lambda>((x, x'), (y, y')). ((x, y), (x', y'))"
2506  let ?r = "1 / (weight_spmf p * weight_spmf q)"
2507  let ?pq = "scale_spmf ?r (map_spmf ?f (pair_spmf pq pq'))"
2508
2509  { fix p :: "'x spmf" and q :: "'y spmf"
2510    assume "weight_spmf q \<noteq> 0"
2511      and "weight_spmf p \<noteq> 0"
2512      and "1 / (weight_spmf p * weight_spmf q) \<le> weight_spmf p * weight_spmf q"
2513    hence "1 \<le> (weight_spmf p * weight_spmf q) * (weight_spmf p * weight_spmf q)"
2514      by(simp add: pos_divide_le_eq order.strict_iff_order weight_spmf_nonneg)
2515    moreover have "(weight_spmf p * weight_spmf q) * (weight_spmf p * weight_spmf q) \<le> (1 * 1) * (1 * 1)"
2516      by(intro mult_mono)(simp_all add: weight_spmf_nonneg weight_spmf_le_1)
2517    ultimately have "(weight_spmf p * weight_spmf q) * (weight_spmf p * weight_spmf q) = 1" by simp
2518    hence *: "weight_spmf p * weight_spmf q = 1"
2519      by(metis antisym_conv less_le mult_less_cancel_left1 weight_pair_spmf weight_spmf_le_1 weight_spmf_nonneg)
2520    hence **: "weight_spmf p = 1" by(metis antisym_conv mult_left_le weight_spmf_le_1 weight_spmf_nonneg)
2521    moreover from * ** have "weight_spmf q = 1" by simp
2522    moreover note calculation }
2523  note full = this
2524
2525  show ?lhs
2526  proof
2527    have [simp]: "fst \<circ> ?f = map_prod fst fst" by(simp add: fun_eq_iff)
2528    have "map_spmf fst ?pq = scale_spmf ?r (pair_spmf ?p ?q)"
2529      by(simp add: pair_map_spmf[symmetric] p q map_scale_spmf spmf.map_comp)
2530    also have "\<dots> = pair_spmf p q" using full[of p q]
2531      by(simp add: pair_scale_spmf1 pair_scale_spmf2 weight_spmf_le_1 weight_spmf_nonneg)
2532        (auto simp add: scale_scale_spmf max_def min_def field_simps weight_spmf_nonneg weight_spmf_eq_0)
2533    finally show "map_spmf fst ?pq = \<dots>" .
2534
2535    have [simp]: "snd \<circ> ?f = map_prod snd snd" by(simp add: fun_eq_iff)
2536    from \<open>?rhs\<close> have eq: "weight_spmf p * weight_spmf q = weight_spmf p' * weight_spmf q'"
2537      by(auto dest!: rel_spmf_weightD simp add: weight_spmf_le_1 weight_spmf_nonneg)
2538
2539    have "map_spmf snd ?pq = scale_spmf ?r (pair_spmf ?p' ?q')"
2540      by(simp add: pair_map_spmf[symmetric] p' q' map_scale_spmf spmf.map_comp)
2541    also have "\<dots> = pair_spmf p' q'" using full[of p' q'] eq
2542      by(simp add: pair_scale_spmf1 pair_scale_spmf2 weight_spmf_le_1 weight_spmf_nonneg)
2543        (auto simp add: scale_scale_spmf max_def min_def field_simps weight_spmf_nonneg weight_spmf_eq_0)
2544    finally show "map_spmf snd ?pq = \<dots>" .
2545  qed(auto simp add: set_scale_spmf split: if_split_asm dest: * ** )
2546next
2547  assume ?lhs
2548  then obtain pq where pq: "map_spmf fst pq = pair_spmf p q"
2549    and pq': "map_spmf snd pq = pair_spmf p' q'"
2550    and *: "\<And>x y x' y'. ((x, y), (x', y')) \<in> set_spmf pq \<Longrightarrow> A x x' \<and> B y y'"
2551    by(auto elim: rel_spmfE)
2552
2553  show ?A
2554  proof
2555    let ?f = "(\<lambda>((x, y), (x', y')). (x, x'))"
2556    let ?pq = "map_spmf ?f pq"
2557    have [simp]: "fst \<circ> ?f = fst \<circ> fst" by(simp add: split_def o_def)
2558    show "map_spmf fst ?pq = scale_spmf (weight_spmf q) p" using pq
2559      by(simp add: spmf.map_comp)(simp add: spmf.map_comp[symmetric])
2560
2561    have [simp]: "snd \<circ> ?f = fst \<circ> snd" by(simp add: split_def o_def)
2562    show "map_spmf snd ?pq = scale_spmf (weight_spmf q') p'" using pq'
2563      by(simp add: spmf.map_comp)(simp add: spmf.map_comp[symmetric])
2564  qed(auto dest: * )
2565
2566  show ?B
2567  proof
2568    let ?f = "(\<lambda>((x, y), (x', y')). (y, y'))"
2569    let ?pq = "map_spmf ?f pq"
2570    have [simp]: "fst \<circ> ?f = snd \<circ> fst" by(simp add: split_def o_def)
2571    show "map_spmf fst ?pq = scale_spmf (weight_spmf p) q" using pq
2572      by(simp add: spmf.map_comp)(simp add: spmf.map_comp[symmetric])
2573
2574    have [simp]: "snd \<circ> ?f = snd \<circ> snd" by(simp add: split_def o_def)
2575    show "map_spmf snd ?pq = scale_spmf (weight_spmf p') q'" using pq'
2576      by(simp add: spmf.map_comp)(simp add: spmf.map_comp[symmetric])
2577  qed(auto dest: * )
2578qed
2579
2580lemma pair_pair_spmf:
2581  "pair_spmf (pair_spmf p q) r = map_spmf (\<lambda>(x, (y, z)). ((x, y), z)) (pair_spmf p (pair_spmf q r))"
2582by(simp add: pair_spmf_alt_def map_spmf_conv_bind_spmf)
2583
2584lemma pair_commute_spmf:
2585  "pair_spmf p q = map_spmf (\<lambda>(y, x). (x, y)) (pair_spmf q p)"
2586unfolding pair_spmf_alt_def by(subst bind_commute_spmf)(simp add: map_spmf_conv_bind_spmf)
2587
2588subsection \<open>Assertions\<close>
2589
2590definition assert_spmf :: "bool \<Rightarrow> unit spmf"
2591where "assert_spmf b = (if b then return_spmf () else return_pmf None)"
2592
2593lemma assert_spmf_simps [simp]:
2594  "assert_spmf True = return_spmf ()"
2595  "assert_spmf False = return_pmf None"
2596by(simp_all add: assert_spmf_def)
2597
2598lemma in_set_assert_spmf [simp]: "x \<in> set_spmf (assert_spmf p) \<longleftrightarrow> p"
2599by(cases p) simp_all
2600
2601lemma set_spmf_assert_spmf_eq_empty [simp]: "set_spmf (assert_spmf b) = {} \<longleftrightarrow> \<not> b"
2602by(cases b) simp_all
2603
2604lemma lossless_assert_spmf [iff]: "lossless_spmf (assert_spmf b) \<longleftrightarrow> b"
2605by(cases b) simp_all
2606
2607subsection \<open>Try\<close>
2608
2609definition try_spmf :: "'a spmf \<Rightarrow> 'a spmf \<Rightarrow> 'a spmf" ("TRY _ ELSE _" [0,60] 59)
2610where "try_spmf p q = bind_pmf p (\<lambda>x. case x of None \<Rightarrow> q | Some y \<Rightarrow> return_spmf y)"
2611
2612lemma try_spmf_lossless [simp]:
2613  assumes "lossless_spmf p"
2614  shows "TRY p ELSE q = p"
2615proof -
2616  have "TRY p ELSE q = bind_pmf p return_pmf" unfolding try_spmf_def using assms
2617    by(auto simp add: lossless_iff_set_pmf_None split: option.split intro: bind_pmf_cong)
2618  thus ?thesis by(simp add: bind_return_pmf')
2619qed
2620
2621lemma try_spmf_return_spmf1: "TRY return_spmf x ELSE q = return_spmf x"
2622by(simp add: try_spmf_def bind_return_pmf)
2623
2624lemma try_spmf_return_None [simp]: "TRY return_pmf None ELSE q = q"
2625by(simp add: try_spmf_def bind_return_pmf)
2626
2627lemma try_spmf_return_pmf_None2 [simp]: "TRY p ELSE return_pmf None = p"
2628by(simp add: try_spmf_def option.case_distrib[symmetric] bind_return_pmf' case_option_id)
2629
2630lemma map_try_spmf: "map_spmf f (try_spmf p q) = try_spmf (map_spmf f p) (map_spmf f q)"
2631by(simp add: try_spmf_def map_bind_pmf bind_map_pmf option.case_distrib[where h="map_spmf f"] o_def cong del: option.case_cong_weak)
2632
2633lemma try_spmf_bind_pmf: "TRY (bind_pmf p f) ELSE q = bind_pmf p (\<lambda>x. TRY (f x) ELSE q)"
2634by(simp add: try_spmf_def bind_assoc_pmf)
2635
2636lemma try_spmf_bind_spmf_lossless:
2637  "lossless_spmf p \<Longrightarrow> TRY (bind_spmf p f) ELSE q = bind_spmf p (\<lambda>x. TRY (f x) ELSE q)"
2638by(auto simp add: try_spmf_def bind_spmf_def bind_assoc_pmf bind_return_pmf lossless_iff_set_pmf_None intro!: bind_pmf_cong split: option.split)
2639
2640lemma try_spmf_bind_out:
2641  "lossless_spmf p \<Longrightarrow> bind_spmf p (\<lambda>x. TRY (f x) ELSE q) = TRY (bind_spmf p f) ELSE q"
2642by(simp add: try_spmf_bind_spmf_lossless)
2643
2644lemma lossless_try_spmf [simp]:
2645  "lossless_spmf (TRY p ELSE q) \<longleftrightarrow> lossless_spmf p \<or> lossless_spmf q"
2646by(auto simp add: try_spmf_def in_set_spmf lossless_iff_set_pmf_None split: option.splits)
2647
2648context includes lifting_syntax
2649begin
2650
2651lemma try_spmf_parametric [transfer_rule]:
2652  "(rel_spmf A ===> rel_spmf A ===> rel_spmf A) try_spmf try_spmf"
2653unfolding try_spmf_def[abs_def] by transfer_prover
2654
2655end
2656
2657lemma try_spmf_cong:
2658  "\<lbrakk> p = p'; \<not> lossless_spmf p' \<Longrightarrow> q = q' \<rbrakk> \<Longrightarrow> TRY p ELSE q = TRY p' ELSE q'"
2659unfolding try_spmf_def
2660by(rule bind_pmf_cong)(auto split: option.split simp add: lossless_iff_set_pmf_None)
2661
2662lemma rel_spmf_try_spmf:
2663  "\<lbrakk> rel_spmf R p p'; \<not> lossless_spmf p' \<Longrightarrow> rel_spmf R q q' \<rbrakk>
2664  \<Longrightarrow> rel_spmf R (TRY p ELSE q) (TRY p' ELSE q')"
2665unfolding try_spmf_def
2666apply(rule rel_pmf_bindI[where R="\<lambda>x y. rel_option R x y \<and> x \<in> set_pmf p \<and> y \<in> set_pmf p'"])
2667 apply(erule pmf.rel_mono_strong; simp)
2668apply(auto split: option.split simp add: lossless_iff_set_pmf_None)
2669done
2670
2671lemma spmf_try_spmf:
2672  "spmf (TRY p ELSE q) x = spmf p x + pmf p None * spmf q x"
2673proof -
2674  have "ennreal (spmf (TRY p ELSE q) x) = \<integral>\<^sup>+ y. ennreal (spmf q x) * indicator {None} y + indicator {Some x} y \<partial>measure_pmf p"
2675    unfolding try_spmf_def ennreal_pmf_bind by(rule nn_integral_cong)(simp split: option.split split_indicator)
2676  also have "\<dots> = (\<integral>\<^sup>+ y. ennreal (spmf q x) * indicator {None} y \<partial>measure_pmf p) + \<integral>\<^sup>+ y. indicator {Some x} y \<partial>measure_pmf p"
2677    by(simp add: nn_integral_add)
2678  also have "\<dots> = ennreal (spmf q x) * pmf p None + spmf p x" by(simp add: emeasure_pmf_single)
2679  finally show ?thesis by(simp add: ennreal_mult[symmetric] ennreal_plus[symmetric] del: ennreal_plus)
2680qed
2681
2682lemma try_scale_spmf_same [simp]: "lossless_spmf p \<Longrightarrow> TRY scale_spmf k p ELSE p = p"
2683by(rule spmf_eqI)(auto simp add: spmf_try_spmf spmf_scale_spmf pmf_scale_spmf_None lossless_iff_pmf_None weight_spmf_conv_pmf_None min_def max_def field_simps)
2684
2685lemma pmf_try_spmf_None [simp]: "pmf (TRY p ELSE q) None = pmf p None * pmf q None" (is "?lhs = ?rhs")
2686proof -
2687  have "?lhs = \<integral> x. pmf q None * indicator {None} x \<partial>measure_pmf p"
2688    unfolding try_spmf_def pmf_bind by(rule Bochner_Integration.integral_cong)(simp_all split: option.split)
2689  also have "\<dots> = ?rhs" by(simp add: measure_pmf_single)
2690  finally show ?thesis .
2691qed
2692
2693lemma try_bind_spmf_lossless2:
2694  "lossless_spmf q \<Longrightarrow> TRY (bind_spmf p f) ELSE q = TRY (p \<bind> (\<lambda>x. TRY (f x) ELSE q)) ELSE q"
2695by(rule spmf_eqI)(simp add: spmf_try_spmf pmf_bind_spmf_None spmf_bind field_simps measure_spmf.integrable_const_bound[where B=1] pmf_le_1 lossless_iff_pmf_None)
2696
2697lemma try_bind_spmf_lossless2':
2698  fixes f :: "'a \<Rightarrow> 'b spmf" shows
2699  "\<lbrakk> NO_MATCH (\<lambda>x :: 'a. try_spmf (g x :: 'b spmf) (h x)) f; lossless_spmf q \<rbrakk>
2700  \<Longrightarrow> TRY (bind_spmf p f) ELSE q = TRY (p \<bind> (\<lambda>x :: 'a. TRY (f x) ELSE q)) ELSE q"
2701by(rule try_bind_spmf_lossless2)
2702
2703lemma try_bind_assert_spmf:
2704  "TRY (assert_spmf b \<bind> f) ELSE q = (if b then TRY (f ()) ELSE q else q)"
2705by simp
2706
2707subsection \<open>Miscellaneous\<close>
2708
2709lemma assumes "rel_spmf (\<lambda>x y. bad1 x = bad2 y \<and> (\<not> bad2 y \<longrightarrow> A x \<longleftrightarrow> B y)) p q" (is "rel_spmf ?A _ _")
2710  shows fundamental_lemma_bad: "measure (measure_spmf p) {x. bad1 x} = measure (measure_spmf q) {y. bad2 y}" (is "?bad")
2711  and fundamental_lemma: "\<bar>measure (measure_spmf p) {x. A x} - measure (measure_spmf q) {y. B y}\<bar> \<le>
2712    measure (measure_spmf p) {x. bad1 x}" (is ?fundamental)
2713proof -
2714  have good: "rel_fun ?A (=) (\<lambda>x. A x \<and> \<not> bad1 x) (\<lambda>y. B y \<and> \<not> bad2 y)" by(auto simp add: rel_fun_def)
2715  from assms have 1: "measure (measure_spmf p) {x. A x \<and> \<not> bad1 x} = measure (measure_spmf q) {y. B y \<and> \<not> bad2 y}"
2716    by(rule measure_spmf_parametric[THEN rel_funD, THEN rel_funD])(rule Collect_parametric[THEN rel_funD, OF good])
2717
2718  have bad: "rel_fun ?A (=) bad1 bad2" by(simp add: rel_fun_def)
2719  show 2: ?bad using assms
2720    by(rule measure_spmf_parametric[THEN rel_funD, THEN rel_funD])(rule Collect_parametric[THEN rel_funD, OF bad])
2721
2722  let ?\<mu>p = "measure (measure_spmf p)" and ?\<mu>q = "measure (measure_spmf q)"
2723  have "{x. A x \<and> bad1 x} \<union> {x. A x \<and> \<not> bad1 x} = {x. A x}"
2724    and "{y. B y \<and> bad2 y} \<union> {y. B y \<and> \<not> bad2 y} = {y. B y}" by auto
2725  then have "\<bar>?\<mu>p {x. A x} - ?\<mu>q {x. B x}\<bar> = \<bar>?\<mu>p ({x. A x \<and> bad1 x} \<union> {x. A x \<and> \<not> bad1 x}) - ?\<mu>q ({y. B y \<and> bad2 y} \<union> {y. B y \<and> \<not> bad2 y})\<bar>"
2726    by simp
2727  also have "\<dots> = \<bar>?\<mu>p {x. A x \<and> bad1 x} + ?\<mu>p {x. A x \<and> \<not> bad1 x} - ?\<mu>q {y. B y \<and> bad2 y} - ?\<mu>q {y. B y \<and> \<not> bad2 y}\<bar>"
2728    by(subst (1 2) measure_Union)(auto)
2729  also have "\<dots> = \<bar>?\<mu>p {x. A x \<and> bad1 x} - ?\<mu>q {y. B y \<and> bad2 y}\<bar>" using 1 by simp
2730  also have "\<dots> \<le> max (?\<mu>p {x. A x \<and> bad1 x}) (?\<mu>q {y. B y \<and> bad2 y})"
2731    by(rule abs_leI)(auto simp add: max_def not_le, simp_all only: add_increasing measure_nonneg mult_2)
2732  also have "\<dots> \<le> max (?\<mu>p {x. bad1 x}) (?\<mu>q {y. bad2 y})"
2733    by(rule max.mono; rule measure_spmf.finite_measure_mono; auto)
2734  also note 2[symmetric]
2735  finally show ?fundamental by simp
2736qed
2737
2738end
2739