1(* Code to accompany Proof Tools chapter of the Manual *) 2(* Ultimate entry-point is DPLL_TAUT *) 3open HolKernel Parse boolLib 4 5datatype result = Unsat of thm | Sat of term -> term 6 7fun count_vars ds acc = 8 case ds of 9 [] => acc 10 | lit::lits => let 11 val v = dest_neg lit handle HOL_ERR _ => lit 12 in 13 case Binarymap.peek (acc, v) of 14 NONE => count_vars lits (Binarymap.insert(acc,v,1)) 15 | SOME n => count_vars lits (Binarymap.insert(acc,v,n + 1)) 16 end 17 18fun getBiggest acc = 19 #1 (Binarymap.foldl(fn (v,cnt,a as (bestv,bestcnt)) => 20 if cnt > bestcnt then (v,cnt) else a) 21 (boolSyntax.T, 0) 22 acc) 23 24fun find_splitting_var phi = let 25 fun recurse acc [] = getBiggest acc 26 | recurse acc (c::cs) = let 27 val ds = strip_disj c 28 in 29 case ds of 30 [lit] => (dest_neg lit handle HOL_ERR _ => lit) 31 | _ => recurse (count_vars ds acc) cs 32 end 33in 34 recurse (Binarymap.mkDict Term.compare) (strip_conj phi) 35end 36 37fun casesplit v th = let 38 val eqT = ASSUME (mk_eq(v, boolSyntax.T)) 39 val eqF = ASSUME (mk_eq(v, boolSyntax.F)) 40in 41 (REWRITE_RULE [eqT] th, REWRITE_RULE [eqF] th) 42end 43 44fun mk_satmap th = let 45 val hyps = hypset th 46 fun foldthis (t,acc) = let 47 val (l,r) = dest_eq t 48 in 49 Binarymap.insert(acc,l,r) 50 end handle HOL_ERR _ => acc 51 val fmap = HOLset.foldl foldthis (Binarymap.mkDict Term.compare) hyps 52in 53 Sat (fn v => Binarymap.find(fmap,v) 54 handle Binarymap.NotFound => boolSyntax.T) 55end 56 57 58fun CoreDPLL form = let 59 val initial_th = ASSUME form 60 fun recurse th = let 61 val c = concl th 62 in 63 if c = boolSyntax.T then 64 mk_satmap th 65 else if c = boolSyntax.F then 66 Unsat th 67 else let 68 val v = find_splitting_var c 69 val (l,r) = casesplit v th 70 in 71 case recurse l of 72 Unsat l_false => let 73 in 74 case recurse r of 75 Unsat r_false => 76 Unsat (DISJ_CASES (SPEC v BOOL_CASES_AX) l_false r_false) 77 | x => x 78 end 79 | x => x 80 end 81 end 82in 83 case (recurse initial_th) of 84 Unsat th => Unsat (CONV_RULE (REWR_CONV IMP_F_EQ_F) (DISCH form th)) 85 | x => x 86end 87 fun DPLL t = let 88 val (transform, body) = let 89 val (vector, body) = dest_exists t 90 fun transform body_eq_F = let 91 val body_imp_F = CONV_RULE (REWR_CONV (GSYM IMP_F_EQ_F)) body_eq_F 92 val fa_body_imp_F = GEN vector body_imp_F 93 val ex_body_imp_F = CONV_RULE FORALL_IMP_CONV fa_body_imp_F 94 in 95 CONV_RULE (REWR_CONV IMP_F_EQ_F) ex_body_imp_F 96 end 97 in 98 (transform, body) 99 end handle HOL_ERR _ => (I, t) 100 in 101 case CoreDPLL body of 102 Unsat body_eq_F => Unsat (transform body_eq_F) 103 | x => x 104 end 105val NEG_EQ_F = prove(``(~p = F) = p``, REWRITE_TAC []); 106val toCNF = defCNF.DEF_CNF_VECTOR_CONV 107fun DPLL_UNIV t = let 108 val (vs, phi) = strip_forall t 109 val cnf_eqn = toCNF (mk_neg phi) 110 val phi' = rhs (concl cnf_eqn) 111in 112 case DPLL phi' of 113 Unsat phi'_eq_F => let 114 val negphi_eq_F = TRANS cnf_eqn phi'_eq_F 115 val phi_thm = CONV_RULE (REWR_CONV NEG_EQ_F) negphi_eq_F 116 in 117 EQT_INTRO (GENL vs phi_thm) 118 end 119 | Sat f => let 120 val t_assumed = ASSUME t 121 fun spec th = 122 spec (SPEC (f (#1 (dest_forall (concl th)))) th) 123 handle HOL_ERR _ => REWRITE_RULE [] th 124 in 125 CONV_RULE (REWR_CONV IMP_F_EQ_F) (DISCH t (spec t_assumed)) 126 end 127end 128 129fun dest_bool_eq t = let 130 val (l,r) = dest_eq t 131 val _ = type_of l = bool orelse 132 raise mk_HOL_ERR "dpll" "dest_bool_eq" "Eq not on bools" 133in 134 (l,r) 135end 136fun var_leaves acc t = let 137 val (l,r) = dest_conj t handle HOL_ERR _ => 138 dest_disj t handle HOL_ERR _ => 139 dest_imp t handle HOL_ERR _ => 140 dest_bool_eq t 141in 142 var_leaves (var_leaves acc l) r 143end handle HOL_ERR _ => 144 if type_of t <> bool then 145 raise mk_HOL_ERR "dpll" "var_leaves" "Term not boolean" 146 else if t = boolSyntax.T then acc 147 else if t = boolSyntax.F then acc 148 else HOLset.add(acc, t) 149 150fun DPLL_TAUT tm = 151 let val (univs,tm') = strip_forall tm 152 val insts = HOLset.listItems (var_leaves empty_tmset tm') 153 val vars = map (fn t => genvar bool) insts 154 val theta = map2 (curry (op |->)) insts vars 155 val tm'' = list_mk_forall (vars,subst theta tm') 156 in 157 EQT_INTRO (GENL univs 158 (SPECL insts (EQT_ELIM (DPLL_UNIV tm'')))) 159 end 160 161 162 163(* implementation of DPLL ends *) 164 165(* ---------------------------------------------------------------------- 166 Code below, due to John Harrison, generates tautologies stating that 167 two different implementations of binary addition are equivalent 168 ---------------------------------------------------------------------- *) 169 170fun halfsum x y = mk_eq(x,mk_neg y) 171fun halfcarry x y = mk_conj(x,y) 172fun ha x y s c = mk_conj(mk_eq(s,halfsum x y), mk_eq(c,halfcarry x y)) 173 174fun carry x y z = mk_disj(mk_conj(x,y), mk_conj(mk_disj(x,y), z)) 175fun sum x y z = halfsum (halfsum x y) z; 176fun fa x y z s c = mk_conj(mk_eq(s,sum x y z), mk_eq(c,carry x y z)) 177 178fun list_conj cs = list_mk_conj cs handle HOL_ERR _ => boolSyntax.T 179 180fun ripplecarry x y c out n = 181 list_conj 182 (List.tabulate(n, (fn i => fa (x i) (y i) (c i) (out i) (c (i + 1))))) 183 184fun mk_index s i = mk_var(s ^ "_" ^ Int.toString i, bool) 185 186val [x,y,out,c] = map mk_index ["X", "Y", "OUT", "C"] 187val twobit_adder = ripplecarry x y c out 2 188 189fun simp t = 190 rhs (concl (QCONV (REWRITE_CONV [GSYM CONJ_ASSOC, GSYM DISJ_ASSOC]) t)) 191 192fun ripplecarry0 x y c out n = 193 simp (ripplecarry x y (fn i => if i = 0 then boolSyntax.F 194 else c i) out n) 195 196fun ripplecarry1 x y c out n = 197 simp (ripplecarry x y (fn i => if i = 0 then boolSyntax.T 198 else c i) out n) 199 200fun mux sel in0 in1 = mk_disj(mk_conj(mk_neg sel,in0), mk_conj(sel,in1)) 201 202fun offset n x i = x (n + i) 203fun carryselect x y c0 c1 s0 s1 c s n k = let 204 val k' = Int.min(n,k) 205 val fm = 206 mk_conj(mk_conj(ripplecarry0 x y c0 s0 k', ripplecarry1 x y c1 s1 k'), 207 mk_conj(mk_eq(c k', mux (c 0) (c0 k') (c1 k')), 208 list_conj 209 (List.tabulate 210 (k', 211 (fn i => mk_eq(s i, mux (c 0) (s0 i) (s1 i))))))) 212in 213 if k' < k then fm 214 else mk_conj(fm, carryselect (offset k x) (offset k y) 215 (offset k c0) (offset k c1) 216 (offset k s0) (offset k s1) 217 (offset k c) (offset k s) 218 (n - k) k) 219end 220 221(* call with positive n and k to generate tautologies *) 222fun mk_adder_test n k = let 223 val [x,y,c,s,c0,s0,c1,s1,c2,s2] = 224 map mk_index ["x", "y", "c", "s", "c0", "s0", "c1", "s1", "c2", "s2"] 225in 226 simp 227 (mk_imp(mk_conj(mk_conj(carryselect x y c0 c1 s0 s1 c s n k, mk_neg (c 0)), 228 ripplecarry0 x y c2 s2 n), 229 mk_conj(mk_eq(c n, c2 n), 230 list_conj(List.tabulate(n, (fn i => mk_eq(s i, s2 i))))))) 231end 232 233(* example in tutorial is *) 234 235val example = gen_all (mk_adder_test 3 2) 236 237(* test them here: 238time DPLL_UNIV example; 239time tautLib.TAUT_PROVE example; 240*) 241