1(* Taken from code by Michael Norrish to *) 2(* accompany Proof Tools chapter of the Manual. *) 3(* Acts as a failsafe for SAT_TAUT_PROVE *) 4(* on trivial problems. *) 5(* Ultimate entry-point is DPLL_TAUT *) 6 7structure dpll = 8struct 9 10open HolKernel Parse boolLib def_cnf satCommonTools satTools 11 12datatype result = Unsat of thm | Sat of term -> term 13 14fun count_vars ds acc = 15 case ds of 16 [] => acc 17 | lit::lits => let 18 val v = dest_neg lit handle HOL_ERR _ => lit 19 in 20 case Binarymap.peek (acc, v) of 21 NONE => count_vars lits (Binarymap.insert(acc,v,1)) 22 | SOME n => count_vars lits (Binarymap.insert(acc,v,n + 1)) 23 end 24 25fun getBiggest acc = 26 #1 (Binarymap.foldl(fn (v,cnt,a as (bestv,bestcnt)) => 27 if cnt > bestcnt then (v,cnt) else a) 28 (boolSyntax.T, 0) 29 acc) 30 31(* The first unit we see, or the var that occurs most often *) 32fun find_splitting_var phi = let 33 fun recurse acc [] = getBiggest acc 34 | recurse acc (c::cs) = let 35 val ds = strip_disj c 36 in 37 case ds of 38 [lit] => (dest_neg lit handle HOL_ERR _ => lit) 39 | _ => recurse (count_vars ds acc) cs 40 end 41in 42 recurse (Binarymap.mkDict Term.compare) (strip_conj phi) 43end 44 45fun casesplit v th = let (*th is [assignments, cnf] |- current *) 46 val eqT = ASSUME (mk_eq(v, boolSyntax.T)) (* v = T |- v = T *) 47 val eqF = ASSUME (mk_eq(v, boolSyntax.F)) (* v = F |- v = F *) 48in 49 (REWRITE_RULE [eqT] th, REWRITE_RULE [eqF] th) (* [assignments,v=T,cnf] |- cnf[T/v] ... *) 50end 51 52fun mk_satmap th = let 53 val hyps = hypset th 54 fun foldthis (t,acc) = let 55 val (l,r) = dest_eq t 56 in 57 Binarymap.insert(acc,l,r) 58 end handle HOL_ERR _ => acc 59 val fmap = HOLset.foldl foldthis (Binarymap.mkDict Term.compare) hyps 60in 61 Sat (fn v => Binarymap.find(fmap,v) 62 handle Binarymap.NotFound => boolSyntax.T) 63end 64 65fun CoreDPLL initial_th = let (* [ci] |- cnf *) 66 fun recurse th = let (* [assigns, ci] |- curr *) 67 val c = concl th (* current *) 68 in 69 if c = boolSyntax.T then 70 mk_satmap th 71 else if c = boolSyntax.F then 72 Unsat th 73 else let 74 val v = find_splitting_var c 75 val (l,r) = casesplit v th (*[assigns,v=T,ci]|- curr[T/v],[assigns,v=F,ci]|- curr[F/v]*) 76 in 77 case recurse l of 78 Unsat l_false => 79 (case recurse r of 80 Unsat r_false => 81 Unsat (DISJ_CASES (SPEC v BOOL_CASES_AX) l_false r_false) 82 (* [assignsr\v,assignsl\v,ci] |- F *) 83 | x => x) 84 | x => x 85 end 86 end 87in 88 recurse initial_th (* [ci] |- F *) 89end 90 91fun doCNF neg_tm = (* clauses is (ci,[~t] |- ci') pairs, where ci' is expanded ci *) 92 let val (cnfv,vc,lfn,clauseth) = to_cnf false neg_tm 93 val clauses = Array.foldr (fn ((c,th),l) => (c,th)::l) [] clauseth 94 val cnf_thm = List.foldl (fn ((c,_),cnf) => CONJ (ASSUME c) cnf)(*[ci] |- cnf*) 95 (ASSUME (fst (hd clauses))) (tl clauses) 96 in (cnfv,cnf_thm,lfn,clauses) end 97 98fun undoCNF lfn clauses th = (* th is [ci] |- F *) 99 let val insts = RBM.foldl (fn (v,t,insts) => (v |-> t)::insts) [] lfn 100 val inst_th = INST insts th 101 val th0 = List.foldl (fn ((_,cth),th) => PROVE_HYP cth th) inst_th clauses (* ~t |- F *) 102 in th0 end 103 104fun mk_model_thm cnfv lfn t f = 105 if isSome cnfv then let 106 val fvs = List.map fst (RBM.listItems lfn) 107 val model = List.map (fn v => if is_T (f v) then v else mk_neg v) fvs 108 val model2 = mapfilter (fn l => let val x = hd(free_vars l) 109 val y = rbapply lfn x 110 in if is_var y then subst [x|->y] l 111 else failwith"" end) model 112 in satCheck model2 (mk_neg t) end else 113 let val fvs = free_vars t 114 val model = List.map (fn v => if is_T (f v) then v else mk_neg v) fvs 115 in satCheck model (mk_neg t) end 116 117fun DPLL_TAUT t = let 118 val (cnfv,cnf_thm,lfn,clauses) = doCNF (mk_neg t) (* cnf_thm is [ci] |- dCNF(~t) *) 119in 120 case CoreDPLL cnf_thm of 121 Unsat cnf_entails_F => (* [ci] |- F *) 122 undoCNF lfn clauses cnf_entails_F (* [~t] |- F *) 123 | Sat f => mk_model_thm cnfv lfn t f (* |- model ==> ~t *) 124end 125 126(* implementation of DPLL ends *) 127 128(* ---------------------------------------------------------------------- 129 Code below, due to John Harrison, generates tautologies stating that 130 two different implementations of binary addition are equivalent 131 ---------------------------------------------------------------------- *) 132 133(* 134fun halfsum x y = mk_eq(x,mk_neg y) 135fun halfcarry x y = mk_conj(x,y) 136fun ha x y s c = mk_conj(mk_eq(s,halfsum x y), mk_eq(c,halfcarry x y)) 137 138fun carry x y z = mk_disj(mk_conj(x,y), mk_conj(mk_disj(x,y), z)) 139fun sum x y z = halfsum (halfsum x y) z; 140fun fa x y z s c = mk_conj(mk_eq(s,sum x y z), mk_eq(c,carry x y z)) 141 142fun list_conj cs = list_mk_conj cs handle HOL_ERR _ => boolSyntax.T 143 144fun ripplecarry x y c out n = 145 list_conj 146 (List.tabulate(n, (fn i => fa (x i) (y i) (c i) (out i) (c (i + 1))))) 147 148fun mk_index s i = mk_var(s ^ "_" ^ Int.toString i, bool) 149 150val [x,y,out,c] = map mk_index ["X", "Y", "OUT", "C"] 151val twobit_adder = ripplecarry x y c out 2 152 153fun simp t = 154 rhs (concl (QCONV (REWRITE_CONV [GSYM CONJ_ASSOC, GSYM DISJ_ASSOC]) t)) 155 156fun ripplecarry0 x y c out n = 157 simp (ripplecarry x y (fn i => if i = 0 then boolSyntax.F 158 else c i) out n) 159 160fun ripplecarry1 x y c out n = 161 simp (ripplecarry x y (fn i => if i = 0 then boolSyntax.T 162 else c i) out n) 163 164fun mux sel in0 in1 = mk_disj(mk_conj(mk_neg sel,in0), mk_conj(sel,in1)) 165 166fun offset n x i = x (n + i) 167fun carryselect x y c0 c1 s0 s1 c s n k = let 168 val k' = Int.min(n,k) 169 val fm = 170 mk_conj(mk_conj(ripplecarry0 x y c0 s0 k', ripplecarry1 x y c1 s1 k'), 171 mk_conj(mk_eq(c k', mux (c 0) (c0 k') (c1 k')), 172 list_conj 173 (List.tabulate 174 (k', 175 (fn i => mk_eq(s i, mux (c 0) (s0 i) (s1 i))))))) 176in 177 if k' < k then fm 178 else mk_conj(fm, carryselect (offset k x) (offset k y) 179 (offset k c0) (offset k c1) 180 (offset k s0) (offset k s1) 181 (offset k c) (offset k s) 182 (n - k) k) 183end 184 185(* call with positive n and k to generate tautologies *) 186fun mk_adder_test n k = let 187 val [x,y,c,s,c0,s0,c1,s1,c2,s2] = 188 map mk_index ["x", "y", "c", "s", "c0", "s0", "c1", "s1", "c2", "s2"] 189in 190 simp 191 (mk_imp(mk_conj(mk_conj(carryselect x y c0 c1 s0 s1 c s n k, mk_neg (c 0)), 192 ripplecarry0 x y c2 s2 n), 193 mk_conj(mk_eq(c n, c2 n), 194 list_conj(List.tabulate(n, (fn i => mk_eq(s i, s2 i))))))) 195end 196 197(* example in tutorial is *) 198 199val example = gen_all (mk_adder_test 3 2) 200 201*) 202end 203