1(* ========================================================================= *)
2(* Tools for reasoning about wp and wlp.                                     *)
3(* ========================================================================= *)
4
5structure wpTools :> wpTools =
6struct
7
8open HolKernel Parse boolLib bossLib simpLib metisLib intLib
9     integerTheory stringTheory combinTheory listTheory
10     posrealTheory posrealLib expectationTheory syntaxTheory wpTheory;
11
12(* ------------------------------------------------------------------------- *)
13(* Automatic tool for calculating wlp verification conditions.               *)
14(* ------------------------------------------------------------------------- *)
15
16val ERR = mk_HOL_ERR "wpTools";
17
18local
19  open folTools;
20
21  val prolog_parm =
22    {equality = false,
23     boolean = false,
24     combinator = false,
25     mapping_parm = {higher_order = false, with_types = true}};
26in
27  val prolog = fn ths =>
28    let
29      val (cs,ths) = FOL_NORM ths
30      val lmap = build_map (prolog_parm, cs, ths)
31      val lmap = add_thm (([``x:bool``], []), ASSUME ``x:bool``) lmap
32    in
33      FOL_FIND mlibMeson.prolog lmap mlibMeter.unlimited
34    end;
35end;
36
37fun find_subterm p =
38  let
39    val f = find_term p
40    fun g t =
41      if is_comb t then
42        g (f (rator t))
43        handle HOL_ERR _ => (g (f (rand t)) handle HOL_ERR _ => t)
44      else if is_abs t then
45        g (f (body t)) handle HOL_ERR _ => t
46      else t
47  in
48    fn t => g (f t)
49  end;
50
51fun abbrev_tac n p (asl,g) =
52  let
53    val t = find_subterm p g
54    val v =
55        numvariant (free_varsl (g :: asl)) (mk_var (n, type_of t))
56    val th = EXISTS (mk_exists (v, mk_eq (t, v)), t) (REFL t)
57  in
58    MP_TAC th THEN CONV_TAC LEFT_IMP_EXISTS_CONV THEN GEN_TAC
59    THEN DISCH_THEN (fn th => PURE_REWRITE_TAC [th] THEN ASSUME_TAC th)
60  end (asl,g);
61
62local
63  val mp = prove
64    (``!p q. ((\v. p v) = q) ==> !v. q v = p v``,
65     CONV_TAC (DEPTH_CONV ETA_CONV)
66     THEN REWRITE_TAC [GSYM FUN_EQ_THM]
67     THEN MATCH_ACCEPT_TAC EQ_SYM);
68in
69  val lam_abbrev_tac =
70    abbrev_tac "lam" is_abs
71    THEN POP_ASSUM (ASSUME_TAC o HO_MATCH_MP mp);
72end;
73
74local
75  fun gconj tm = (TRY_CONV (RAND_CONV gconj) THENC REWR_CONV AND_IMP_INTRO) tm;
76  val tconj = REWR_CONV (GSYM (CONJUNCT1 (SPEC_ALL IMP_CLAUSES)));
77  fun sel [] = tconj | sel [_] = ALL_CONV | sel (_ :: _ :: _) = gconj;
78in
79  fun DISCH_CONJ th = CONV_RULE (sel (hyp th)) (DISCH_ALL th);
80end;
81
82local
83  val leq_tm = ``Leq : 'a state expect -> 'a state expect -> bool``;
84
85  val dest_leq = dest_binop leq_tm (ERR "dest_leq" "");
86
87  fun mk_leq (a,b) =
88      let
89        val (_,ty) = dom_rng (hd (snd (dest_type (type_of a))))
90        val tm = inst [alpha |-> ty] leq_tm
91      in
92        mk_comb (mk_comb (tm,a), b)
93      end;
94
95  val vc_solve = prolog
96    [wlp_abort_vc, wlp_consume_vc, wlp_assign_vc, wlp_seq_vc, wlp_nondet_vc,
97     wlp_prob_vc, wlp_while_vc, wlp_skip_vc, wlp_if_vc, wlp_assert_vc];
98in
99  fun vc_tac (asl,goal) =
100    let
101      val (pre,wlp_post) = dest_leq goal
102      val var = genvar (type_of wlp_post)
103      val query = mk_leq (var, wlp_post)
104      val result = hd (snd (vc_solve (([var],[]),[query])))
105      val (mid,_) = dest_leq (concl result)
106      val mp_th = DISCH_CONJ result
107    in
108      MATCH_MP_TAC leq_trans
109      THEN EXISTS_TAC mid
110      THEN CONJ_TAC
111      THENL [TRY (MATCH_ACCEPT_TAC leq_refl),
112             MATCH_MP_TAC mp_th THEN TRY (ACCEPT_TAC TRUTH)]
113      THEN REPEAT CONJ_TAC
114    end (asl,goal);
115end;
116
117fun dest_lam tm =
118  let
119    val (v,b) = dest_forall tm
120    val (l,r) = dest_eq b
121    val (la,lb) = dest_comb l
122    val _ = is_var la orelse raise ERR "dest_lam" "lhs rator not var"
123    val _ = lb = v orelse raise ERR "dest_lam" "lhs rand not bvar"
124    val _ = not (free_in la r) orelse raise ERR "dest_lam" "recursive"
125  in
126    (v,la,r)
127  end;
128
129fun dest_triv_lam tm =
130  let
131    val (v,la,r) = dest_lam tm
132    val _ = not (free_in v r) orelse raise ERR "dest_triv_lam" ""
133  in
134    (v,la,r)
135  end;
136
137val is_triv_lam = can dest_triv_lam;
138
139local
140  fun f acc _ [] = EVERY (map (fn x => UNDISCH_THEN x (K ALL_TAC)) acc)
141    | f acc ys (x :: xs) =
142    let
143      val p =
144        case total dest_lam x of NONE => true
145        | SOME (_,l,_) => List.exists (free_in l) (ys @ xs)
146    in
147      if p then f acc (x :: ys) xs else f (x :: acc) ys xs
148    end;
149in
150  fun elim_redundant_lam_tac (asl,g) = f [] [g] asl (asl,g);
151end;
152
153local
154  val if_t = (GEN_ALL o CONJUNCT1 o SPEC_ALL) COND_CLAUSES;
155  val if_f = (GEN_ALL o CONJUNCT2 o SPEC_ALL) COND_CLAUSES;
156  val if_ss = simpLib.++
157    (pureSimps.pure_ss,
158     simpLib.SSFRAG {ac = [], congs = [], convs = [], dprocs = [],
159                      filter = NONE, rewrs = [if_t, if_f]});
160in
161  fun if_cases_tac thtac (asl,g) =
162    let
163      val (c,_,_) = dest_cond (find_subterm is_cond g)
164    in
165      MP_TAC (SPEC c BOOL_CASES_AX)
166      THEN STRIP_TAC
167      THEN POP_ASSUM (fn th => SIMP_TAC if_ss [th] THEN thtac th)
168    end (asl,g);
169end;
170
171local
172  val int_simps =
173    [INT_LE_REFL, INT_SUB_ADD, INT_SUB_RZERO, INT_EQ_LADD, INT_EQ_RADD,
174     INT_EQ_SUB_RADD, INT_EQ_SUB_LADD,
175     INT_LE_SUB_RADD, INT_LE_SUB_LADD, INT_ADD_RID, INT_ADD_LID,
176     INT_ADD_RID_UNIQ, INT_ADD_LID_UNIQ, GSYM INT_ADD_ASSOC,
177     int_arithTheory.move_sub, int_arithTheory.less_to_leq_samer,
178     INT_NOT_LE]
179
180  fun assume_tac th ths =
181    let
182      (*val () = print ("\nthm = " ^ thm_to_string th ^ "\n");*)
183      val th1 = SIMP_RULE arith_ss ([STRING_11, CHR_11, assign_def] @ ths) th
184      val th2 = SIMP_RULE int_ss int_simps th1
185      val th3 = SIMP_RULE posreal_reduce_ss [] th2
186      val th4 = SIMP_RULE (simpLib.++ (bool_ss, boolSimps.COND_elim_ss)) [] th3
187    in
188      STRIP_ASSUME_TAC th4
189    end;
190in
191  val wlp_assume_tac = (ASSUM_LIST o assume_tac);
192end;
193
194local
195  fun simps ths =
196      ths @ [min_alt, Lin_def, Cond_def, o_THM, magic_alt, assign_def];
197in
198  val leq_tac =
199    CONV_TAC (REWR_CONV Leq_def)
200    THEN GEN_TAC
201    THEN ASSUM_LIST
202         (SIMP_TAC pureSimps.pure_ss o simps o filter (is_triv_lam o concl))
203    THEN REPEAT (if_cases_tac wlp_assume_tac)
204    THEN ASSUM_LIST (SIMP_TAC pureSimps.pure_ss o simps)
205    THEN REPEAT (if_cases_tac wlp_assume_tac)
206    THEN elim_redundant_lam_tac
207    THEN RW_TAC posreal_ss []
208    THEN Q.UNABBREV_ALL_TAC
209    THEN RW_TAC posreal_reduce_ss [bound1_def];
210end;
211
212local
213  val expand =
214    [MAP, LENGTH,
215     Nondets_def, NondetAssign_def, Probs_def, ProbAssign_def,
216     Program_def, Guards_def, guards_def];
217in
218  val pure_wlp_tac =
219    REPEAT lam_abbrev_tac
220    THEN vc_tac;
221
222  val wlp_tac =
223    RW_TAC posreal_ss expand
224    THEN FULL_SIMP_TAC posreal_reduce_ss []
225    THEN pure_wlp_tac
226    THEN leq_tac;
227end;
228
229end
230