1(* Code to accompany Proof Tools chapter of the Manual *)
2(* Ultimate entry-point is DPLL_TAUT *)
3open HolKernel Parse boolLib
4
5datatype result = Unsat of thm | Sat of term -> term
6
7fun count_vars ds acc =
8    case ds of
9      [] => acc
10    | lit::lits => let
11        val v = dest_neg lit handle HOL_ERR _ => lit
12      in
13        case Binarymap.peek (acc, v) of
14          NONE => count_vars lits (Binarymap.insert(acc,v,1))
15        | SOME n => count_vars lits (Binarymap.insert(acc,v,n + 1))
16      end
17
18fun getBiggest acc =
19    #1 (Binarymap.foldl(fn (v,cnt,a as (bestv,bestcnt)) =>
20                           if cnt > bestcnt then (v,cnt) else a)
21                       (boolSyntax.T, 0)
22                       acc)
23
24fun find_splitting_var phi = let
25  fun recurse acc [] = getBiggest acc
26    | recurse acc (c::cs) = let
27        val ds = strip_disj c
28      in
29        case ds of
30          [lit] => (dest_neg lit handle HOL_ERR _ => lit)
31        | _ => recurse (count_vars ds acc) cs
32      end
33in
34  recurse (Binarymap.mkDict Term.compare) (strip_conj phi)
35end
36
37fun casesplit v th = let
38  val eqT = ASSUME (mk_eq(v, boolSyntax.T))
39  val eqF = ASSUME (mk_eq(v, boolSyntax.F))
40in
41  (REWRITE_RULE [eqT] th, REWRITE_RULE [eqF] th)
42end
43
44fun mk_satmap th = let
45  val hyps = hypset th
46  fun foldthis (t,acc) = let
47    val (l,r) = dest_eq t
48  in
49    Binarymap.insert(acc,l,r)
50  end handle HOL_ERR _ => acc
51  val fmap = HOLset.foldl foldthis (Binarymap.mkDict Term.compare) hyps
52in
53  Sat (fn v => Binarymap.find(fmap,v)
54          handle Binarymap.NotFound => boolSyntax.T)
55end
56
57
58fun CoreDPLL form = let
59  val initial_th = ASSUME form
60  fun recurse th = let
61    val c = concl th
62  in
63    if c = boolSyntax.T then
64      mk_satmap th
65    else if c = boolSyntax.F then
66      Unsat th
67    else let
68        val v = find_splitting_var c
69        val (l,r) = casesplit v th
70      in
71        case recurse l of
72          Unsat l_false => let
73          in
74            case recurse r of
75              Unsat r_false =>
76                Unsat (DISJ_CASES (SPEC v BOOL_CASES_AX) l_false r_false)
77            | x => x
78          end
79        | x => x
80      end
81  end
82in
83  case (recurse initial_th) of
84    Unsat th => Unsat (CONV_RULE (REWR_CONV IMP_F_EQ_F) (DISCH form th))
85  | x => x
86end
87   fun DPLL t = let
88     val (transform, body) = let
89       val (vector, body) = dest_exists t
90       fun transform body_eq_F = let
91         val body_imp_F = CONV_RULE (REWR_CONV (GSYM IMP_F_EQ_F)) body_eq_F
92         val fa_body_imp_F = GEN vector body_imp_F
93         val ex_body_imp_F = CONV_RULE FORALL_IMP_CONV fa_body_imp_F
94       in
95         CONV_RULE (REWR_CONV IMP_F_EQ_F) ex_body_imp_F
96       end
97     in
98       (transform, body)
99     end handle HOL_ERR _ => (I, t)
100   in
101     case CoreDPLL body of
102       Unsat body_eq_F => Unsat (transform body_eq_F)
103     | x => x
104   end
105val NEG_EQ_F = prove(``(~p = F) = p``, REWRITE_TAC []);
106val toCNF = defCNF.DEF_CNF_VECTOR_CONV
107fun DPLL_UNIV t = let
108  val (vs, phi) = strip_forall t
109  val cnf_eqn = toCNF (mk_neg phi)
110  val phi' = rhs (concl cnf_eqn)
111in
112  case DPLL phi' of
113    Unsat phi'_eq_F => let
114      val negphi_eq_F = TRANS cnf_eqn phi'_eq_F
115      val phi_thm = CONV_RULE (REWR_CONV NEG_EQ_F) negphi_eq_F
116    in
117      EQT_INTRO (GENL vs phi_thm)
118    end
119  | Sat f => let
120      val t_assumed = ASSUME t
121      fun spec th =
122          spec (SPEC (f (#1 (dest_forall (concl th)))) th)
123          handle HOL_ERR _ => REWRITE_RULE [] th
124    in
125      CONV_RULE (REWR_CONV IMP_F_EQ_F) (DISCH t (spec t_assumed))
126    end
127end
128
129fun dest_bool_eq t = let
130  val (l,r) = dest_eq t
131  val _ = type_of l = bool orelse
132          raise mk_HOL_ERR "dpll" "dest_bool_eq" "Eq not on bools"
133in
134  (l,r)
135end
136fun var_leaves acc t = let
137  val (l,r) = dest_conj t handle HOL_ERR _ =>
138              dest_disj t handle HOL_ERR _ =>
139              dest_imp t handle HOL_ERR _ =>
140              dest_bool_eq t
141in
142  var_leaves (var_leaves acc l) r
143end handle HOL_ERR _ =>
144           if type_of t <> bool then
145             raise mk_HOL_ERR "dpll" "var_leaves" "Term not boolean"
146           else if t = boolSyntax.T then acc
147           else if t = boolSyntax.F then acc
148           else HOLset.add(acc, t)
149
150fun DPLL_TAUT tm =
151    let val (univs,tm') = strip_forall tm
152        val insts = HOLset.listItems (var_leaves empty_tmset tm')
153        val vars = map (fn t => genvar bool) insts
154        val theta = map2 (curry (op |->)) insts vars
155        val tm'' = list_mk_forall (vars,subst theta tm')
156    in
157      EQT_INTRO (GENL univs
158                      (SPECL insts (EQT_ELIM (DPLL_UNIV tm''))))
159    end
160
161
162
163(* implementation of DPLL ends *)
164
165(* ----------------------------------------------------------------------
166    Code below, due to John Harrison, generates tautologies stating that
167    two different implementations of binary addition are equivalent
168   ---------------------------------------------------------------------- *)
169
170fun halfsum x y = mk_eq(x,mk_neg y)
171fun halfcarry x y = mk_conj(x,y)
172fun ha x y s c = mk_conj(mk_eq(s,halfsum x y), mk_eq(c,halfcarry x y))
173
174fun carry x y z = mk_disj(mk_conj(x,y), mk_conj(mk_disj(x,y), z))
175fun sum x y z = halfsum (halfsum x y) z;
176fun fa x y z s c = mk_conj(mk_eq(s,sum x y z), mk_eq(c,carry x y z))
177
178fun list_conj cs = list_mk_conj cs handle HOL_ERR _ => boolSyntax.T
179
180fun ripplecarry x y c out n =
181    list_conj
182      (List.tabulate(n, (fn i => fa (x i) (y i) (c i) (out i) (c (i + 1)))))
183
184fun mk_index s i = mk_var(s ^ "_" ^ Int.toString i, bool)
185
186val [x,y,out,c] = map mk_index ["X", "Y", "OUT", "C"]
187val twobit_adder = ripplecarry x y c out 2
188
189fun simp t =
190    rhs (concl (QCONV (REWRITE_CONV [GSYM CONJ_ASSOC, GSYM DISJ_ASSOC]) t))
191
192fun ripplecarry0 x y c out n =
193    simp (ripplecarry x y (fn i => if i = 0 then boolSyntax.F
194                                   else c i) out n)
195
196fun ripplecarry1 x y c out n =
197    simp (ripplecarry x y (fn i => if i = 0 then boolSyntax.T
198                                   else c i) out n)
199
200fun mux sel in0 in1 = mk_disj(mk_conj(mk_neg sel,in0), mk_conj(sel,in1))
201
202fun offset n x i = x (n + i)
203fun carryselect x y c0 c1 s0 s1 c s n k = let
204  val k' = Int.min(n,k)
205  val fm =
206      mk_conj(mk_conj(ripplecarry0 x y c0 s0 k', ripplecarry1 x y c1 s1 k'),
207              mk_conj(mk_eq(c k', mux (c 0) (c0 k') (c1 k')),
208                      list_conj
209                      (List.tabulate
210                       (k',
211                        (fn i => mk_eq(s i, mux (c 0) (s0 i) (s1 i)))))))
212in
213  if k' < k then fm
214  else mk_conj(fm, carryselect (offset k x) (offset k y)
215                               (offset k c0) (offset k c1)
216                               (offset k s0) (offset k s1)
217                               (offset k c) (offset k s)
218                               (n - k) k)
219end
220
221(* call with positive n and k to generate tautologies *)
222fun mk_adder_test n k = let
223  val [x,y,c,s,c0,s0,c1,s1,c2,s2] =
224      map mk_index ["x", "y", "c", "s", "c0", "s0", "c1", "s1", "c2", "s2"]
225in
226  simp
227    (mk_imp(mk_conj(mk_conj(carryselect x y c0 c1 s0 s1 c s n k, mk_neg (c 0)),
228                    ripplecarry0 x y c2 s2 n),
229            mk_conj(mk_eq(c n, c2 n),
230                    list_conj(List.tabulate(n, (fn i => mk_eq(s i, s2 i)))))))
231end
232
233(* example in tutorial is *)
234
235val example = gen_all (mk_adder_test 3 2)
236
237(* test them here:
238time DPLL_UNIV example;
239time tautLib.TAUT_PROVE example;
240*)
241