1(* ========================================================================= *)
2(* Tools for reasoning about wp and wlp.                                     *)
3(* ========================================================================= *)
4
5structure wpTools :> wpTools =
6struct
7
8open HolKernel Parse boolLib bossLib simpLib metisLib intLib
9     integerTheory stringTheory combinTheory listTheory
10     posrealTheory posrealLib expectationTheory syntaxTheory wpTheory;
11
12(* ------------------------------------------------------------------------- *)
13(* Automatic tool for calculating wlp verification conditions.               *)
14(* ------------------------------------------------------------------------- *)
15
16val ERR = mk_HOL_ERR "wpTools";
17
18local
19  open folTools;
20
21  val prolog_parm =
22    {equality = false,
23     boolean = false,
24     combinator = false,
25     mapping_parm = {higher_order = false, with_types = true}};
26in
27  val prolog = fn ths =>
28    let
29      val (cs,ths) = FOL_NORM ths
30      val lmap = build_map (prolog_parm, cs, ths)
31      val lmap = add_thm (([``x:bool``], []), ASSUME ``x:bool``) lmap
32    in
33      FOL_FIND mlibMeson.prolog lmap mlibMeter.unlimited
34    end;
35end;
36
37fun find_subterm p =
38  let
39    val f = find_term p
40    fun g t =
41      if is_comb t then
42        g (f (rator t))
43        handle HOL_ERR _ => (g (f (rand t)) handle HOL_ERR _ => t)
44      else if is_abs t then
45        g (f (body t)) handle HOL_ERR _ => t
46      else t
47  in
48    fn t => g (f t)
49  end;
50
51fun abbrev_tac n p (asl,g) =
52  let
53    val t = find_subterm p g
54    val v =
55      with_flag (Globals.priming, SOME "")
56      (variant (free_varsl (g :: asl))) (mk_var (n, type_of t))
57    val th = EXISTS (mk_exists (v, mk_eq (t, v)), t) (REFL t)
58  in
59    MP_TAC th THEN CONV_TAC LEFT_IMP_EXISTS_CONV THEN GEN_TAC
60    THEN DISCH_THEN (fn th => PURE_REWRITE_TAC [th] THEN ASSUME_TAC th)
61  end (asl,g);
62
63local
64  val mp = prove
65    (``!p q. ((\v. p v) = q) ==> !v. q v = p v``,
66     CONV_TAC (DEPTH_CONV ETA_CONV)
67     THEN REWRITE_TAC [GSYM FUN_EQ_THM]
68     THEN MATCH_ACCEPT_TAC EQ_SYM);
69in
70  val lam_abbrev_tac =
71    abbrev_tac "lam" is_abs
72    THEN POP_ASSUM (ASSUME_TAC o HO_MATCH_MP mp);
73end;
74
75local
76  fun gconj tm = (TRY_CONV (RAND_CONV gconj) THENC REWR_CONV AND_IMP_INTRO) tm;
77  val tconj = REWR_CONV (GSYM (CONJUNCT1 (SPEC_ALL IMP_CLAUSES)));
78  fun sel [] = tconj | sel [_] = ALL_CONV | sel (_ :: _ :: _) = gconj;
79in
80  fun DISCH_CONJ th = CONV_RULE (sel (hyp th)) (DISCH_ALL th);
81end;
82
83local
84  val leq_tm = ``Leq : 'a state expect -> 'a state expect -> bool``;
85
86  val dest_leq = dest_binop leq_tm (ERR "dest_leq" "");
87
88  fun mk_leq (a,b) =
89      let
90        val (_,ty) = dom_rng (hd (snd (dest_type (type_of a))))
91        val tm = inst [alpha |-> ty] leq_tm
92      in
93        mk_comb (mk_comb (tm,a), b)
94      end;
95
96  val vc_solve = prolog
97    [wlp_abort_vc, wlp_consume_vc, wlp_assign_vc, wlp_seq_vc, wlp_nondet_vc,
98     wlp_prob_vc, wlp_while_vc, wlp_skip_vc, wlp_if_vc, wlp_assert_vc];
99in
100  fun vc_tac (asl,goal) =
101    let
102      val (pre,wlp_post) = dest_leq goal
103      val var = genvar (type_of wlp_post)
104      val query = mk_leq (var, wlp_post)
105      val result = hd (snd (vc_solve (([var],[]),[query])))
106      val (mid,_) = dest_leq (concl result)
107      val mp_th = DISCH_CONJ result
108    in
109      MATCH_MP_TAC leq_trans
110      THEN EXISTS_TAC mid
111      THEN CONJ_TAC
112      THENL [TRY (MATCH_ACCEPT_TAC leq_refl),
113             MATCH_MP_TAC mp_th THEN TRY (ACCEPT_TAC TRUTH)]
114      THEN REPEAT CONJ_TAC
115    end (asl,goal);
116end;
117
118fun dest_lam tm =
119  let
120    val (v,b) = dest_forall tm
121    val (l,r) = dest_eq b
122    val (la,lb) = dest_comb l
123    val _ = is_var la orelse raise ERR "dest_lam" "lhs rator not var"
124    val _ = lb = v orelse raise ERR "dest_lam" "lhs rand not bvar"
125    val _ = not (free_in la r) orelse raise ERR "dest_lam" "recursive"
126  in
127    (v,la,r)
128  end;
129
130fun dest_triv_lam tm =
131  let
132    val (v,la,r) = dest_lam tm
133    val _ = not (free_in v r) orelse raise ERR "dest_triv_lam" ""
134  in
135    (v,la,r)
136  end;
137
138val is_triv_lam = can dest_triv_lam;
139
140local
141  fun f acc _ [] = EVERY (map (fn x => UNDISCH_THEN x (K ALL_TAC)) acc)
142    | f acc ys (x :: xs) =
143    let
144      val p =
145        case total dest_lam x of NONE => true
146        | SOME (_,l,_) => List.exists (free_in l) (ys @ xs)
147    in
148      if p then f acc (x :: ys) xs else f (x :: acc) ys xs
149    end;
150in
151  fun elim_redundant_lam_tac (asl,g) = f [] [g] asl (asl,g);
152end;
153
154local
155  val if_t = (GEN_ALL o CONJUNCT1 o SPEC_ALL) COND_CLAUSES;
156  val if_f = (GEN_ALL o CONJUNCT2 o SPEC_ALL) COND_CLAUSES;
157  val if_ss = simpLib.++
158    (pureSimps.pure_ss,
159     simpLib.SSFRAG {ac = [], congs = [], convs = [], dprocs = [],
160                      filter = NONE, rewrs = [if_t, if_f]});
161in
162  fun if_cases_tac thtac (asl,g) =
163    let
164      val (c,_,_) = dest_cond (find_subterm is_cond g)
165    in
166      MP_TAC (SPEC c BOOL_CASES_AX)
167      THEN STRIP_TAC
168      THEN POP_ASSUM (fn th => SIMP_TAC if_ss [th] THEN thtac th)
169    end (asl,g);
170end;
171
172local
173  val int_simps =
174    [INT_LE_REFL, INT_SUB_ADD, INT_SUB_RZERO, INT_EQ_LADD, INT_EQ_RADD,
175     INT_EQ_SUB_RADD, INT_EQ_SUB_LADD,
176     INT_LE_SUB_RADD, INT_LE_SUB_LADD, INT_ADD_RID, INT_ADD_LID,
177     INT_ADD_RID_UNIQ, INT_ADD_LID_UNIQ, GSYM INT_ADD_ASSOC,
178     int_arithTheory.move_sub, int_arithTheory.less_to_leq_samer,
179     INT_NOT_LE]
180
181  fun assume_tac th ths =
182    let
183      (*val () = print ("\nthm = " ^ thm_to_string th ^ "\n");*)
184      val th1 = SIMP_RULE arith_ss ([STRING_11, CHR_11, assign_def] @ ths) th
185      val th2 = SIMP_RULE int_ss int_simps th1
186      val th3 = SIMP_RULE posreal_reduce_ss [] th2
187      val th4 = SIMP_RULE (simpLib.++ (bool_ss, boolSimps.COND_elim_ss)) [] th3
188    in
189      STRIP_ASSUME_TAC th4
190    end;
191in
192  val wlp_assume_tac = (ASSUM_LIST o assume_tac);
193end;
194
195local
196  fun simps ths =
197      ths @ [min_alt, Lin_def, Cond_def, o_THM, magic_alt, assign_def];
198in
199  val leq_tac =
200    CONV_TAC (REWR_CONV Leq_def)
201    THEN GEN_TAC
202    THEN ASSUM_LIST
203         (SIMP_TAC pureSimps.pure_ss o simps o filter (is_triv_lam o concl))
204    THEN REPEAT (if_cases_tac wlp_assume_tac)
205    THEN ASSUM_LIST (SIMP_TAC pureSimps.pure_ss o simps)
206    THEN REPEAT (if_cases_tac wlp_assume_tac)
207    THEN elim_redundant_lam_tac
208    THEN RW_TAC posreal_ss []
209    THEN Q.UNABBREV_ALL_TAC
210    THEN RW_TAC posreal_reduce_ss [bound1_def];
211end;
212
213local
214  val expand =
215    [MAP, LENGTH,
216     Nondets_def, NondetAssign_def, Probs_def, ProbAssign_def,
217     Program_def, Guards_def, guards_def];
218in
219  val pure_wlp_tac =
220    REPEAT lam_abbrev_tac
221    THEN vc_tac;
222
223  val wlp_tac =
224    RW_TAC posreal_ss expand
225    THEN FULL_SIMP_TAC posreal_reduce_ss []
226    THEN pure_wlp_tac
227    THEN leq_tac;
228end;
229
230end
231