1(* Taken from code by Michael Norrish to *) 2(* accompany Proof Tools chapter of the Manual. *) 3(* Acts as a failsafe for SAT_TAUT_PROVE *) 4(* on trivial problems. *) 5(* Ultimate entry-point is DPLL_TAUT *) 6 7structure dpll = 8struct 9 10open HolKernel Parse boolLib def_cnf satCommonTools satTools 11 12datatype result = Unsat of thm | Sat of term -> term 13 14fun count_vars ds acc = 15 case ds of 16 [] => acc 17 | lit::lits => let 18 val v = dest_neg lit handle HOL_ERR _ => lit 19 in 20 case Binarymap.peek (acc, v) of 21 NONE => count_vars lits (Binarymap.insert(acc,v,1)) 22 | SOME n => count_vars lits (Binarymap.insert(acc,v,n + 1)) 23 end 24 25fun getBiggest acc = 26 #1 (Binarymap.foldl(fn (v,cnt,a as (bestv,bestcnt)) => 27 if cnt > bestcnt then (v,cnt) else a) 28 (boolSyntax.T, 0) 29 acc) 30 31(* The first unit we see, or the var that occurs most often *) 32fun find_splitting_var phi = let 33 fun recurse acc [] = getBiggest acc 34 | recurse acc (c::cs) = let 35 val ds = strip_disj c 36 in 37 case ds of 38 [lit] => (dest_neg lit handle HOL_ERR _ => lit) 39 | _ => recurse (count_vars ds acc) cs 40 end 41in 42 recurse (Binarymap.mkDict Term.compare) (strip_conj phi) 43end 44 45fun casesplit v th = let (*th is [assignments, cnf] |- current *) 46 val eqT = ASSUME (mk_eq(v, boolSyntax.T)) (* v = T |- v = T *) 47 val eqF = ASSUME (mk_eq(v, boolSyntax.F)) (* v = F |- v = F *) 48in 49 (REWRITE_RULE [eqT] th, REWRITE_RULE [eqF] th) 50 (* [assignments,v=T,cnf] |- cnf[T/v] ... *) 51end 52 53fun mk_satmap th = let 54 val hyps = hypset th 55 fun foldthis (t,acc) = let 56 val (l,r) = dest_eq t 57 in 58 Binarymap.insert(acc,l,r) 59 end handle HOL_ERR _ => acc 60 val fmap = HOLset.foldl foldthis (Binarymap.mkDict Term.compare) hyps 61in 62 Sat (fn v => Binarymap.find(fmap,v) 63 handle Binarymap.NotFound => boolSyntax.T) 64end 65 66fun CoreDPLL initial_th = let (* [ci] |- cnf *) 67 fun recurse th = let (* [assigns, ci] |- curr *) 68 val c = concl th (* current *) 69 in 70 if aconv c boolSyntax.T then 71 mk_satmap th 72 else if aconv c boolSyntax.F then 73 Unsat th 74 else let 75 val v = find_splitting_var c 76 val (l,r) = casesplit v th 77 (*[assigns,v=T,ci]|- curr[T/v],[assigns,v=F,ci]|- curr[F/v]*) 78 in 79 case recurse l of 80 Unsat l_false => 81 (case recurse r of 82 Unsat r_false => 83 Unsat (DISJ_CASES (SPEC v BOOL_CASES_AX) l_false r_false) 84 (* [assignsr\v,assignsl\v,ci] |- F *) 85 | x => x) 86 | x => x 87 end 88 end 89in 90 recurse initial_th (* [ci] |- F *) 91end 92 93fun doCNF neg_tm = 94 (* clauses is (ci,[~t] |- ci') pairs, where ci' is expanded ci *) 95 let val (cnfv,vc,lfn,clauseth) = to_cnf false neg_tm 96 val clauses = Array.foldr (fn ((c,th),l) => (c,th)::l) [] clauseth 97 val cnf_thm = List.foldl (fn ((c,_),cnf) => CONJ (ASSUME c) cnf) 98 (*[ci] |- cnf*) 99 (ASSUME (fst (hd clauses))) (tl clauses) 100 in (cnfv,cnf_thm,lfn,clauses) end 101 102fun undoCNF lfn clauses th = (* th is [ci] |- F *) 103 let val insts = RBM.foldl (fn (v,t,insts) => (v |-> t)::insts) [] lfn 104 val inst_th = INST insts th 105 val th0 = List.foldl (fn ((_,cth),th) => PROVE_HYP cth th) 106 inst_th clauses (* ~t |- F *) 107 in th0 end 108 109fun mk_model_thm cnfv lfn t f = 110 if isSome cnfv then 111 let 112 val fvs = List.map fst (RBM.listItems lfn) 113 val model = List.map (fn v => if is_T (f v) then v else mk_neg v) fvs 114 val model2 = mapfilter (fn l => let val x = hd(free_vars l) 115 val y = rbapply lfn x 116 in if is_var y then subst [x|->y] l 117 else failwith"" end) model 118 in satCheck model2 (mk_neg t) end else 119 let val fvs = free_vars t 120 val model = List.map (fn v => if is_T (f v) then v else mk_neg v) fvs 121 in satCheck model (mk_neg t) end 122 123fun DPLL_TAUT t = let 124 val (cnfv,cnf_thm,lfn,clauses) = doCNF (mk_neg t) 125 (* cnf_thm is [ci] |- dCNF(~t) *) 126in 127 case CoreDPLL cnf_thm of 128 Unsat cnf_entails_F => (* [ci] |- F *) 129 undoCNF lfn clauses cnf_entails_F (* [~t] |- F *) 130 | Sat f => mk_model_thm cnfv lfn t f (* |- model ==> ~t *) 131end 132 133(* implementation of DPLL ends *) 134 135(* ---------------------------------------------------------------------- 136 Code below, due to John Harrison, generates tautologies stating that 137 two different implementations of binary addition are equivalent 138 ---------------------------------------------------------------------- *) 139 140(* 141fun halfsum x y = mk_eq(x,mk_neg y) 142fun halfcarry x y = mk_conj(x,y) 143fun ha x y s c = mk_conj(mk_eq(s,halfsum x y), mk_eq(c,halfcarry x y)) 144 145fun carry x y z = mk_disj(mk_conj(x,y), mk_conj(mk_disj(x,y), z)) 146fun sum x y z = halfsum (halfsum x y) z; 147fun fa x y z s c = mk_conj(mk_eq(s,sum x y z), mk_eq(c,carry x y z)) 148 149fun list_conj cs = list_mk_conj cs handle HOL_ERR _ => boolSyntax.T 150 151fun ripplecarry x y c out n = 152 list_conj 153 (List.tabulate(n, (fn i => fa (x i) (y i) (c i) (out i) (c (i + 1))))) 154 155fun mk_index s i = mk_var(s ^ "_" ^ Int.toString i, bool) 156 157val [x,y,out,c] = map mk_index ["X", "Y", "OUT", "C"] 158val twobit_adder = ripplecarry x y c out 2 159 160fun simp t = 161 rhs (concl (QCONV (REWRITE_CONV [GSYM CONJ_ASSOC, GSYM DISJ_ASSOC]) t)) 162 163fun ripplecarry0 x y c out n = 164 simp (ripplecarry x y (fn i => if i = 0 then boolSyntax.F 165 else c i) out n) 166 167fun ripplecarry1 x y c out n = 168 simp (ripplecarry x y (fn i => if i = 0 then boolSyntax.T 169 else c i) out n) 170 171fun mux sel in0 in1 = mk_disj(mk_conj(mk_neg sel,in0), mk_conj(sel,in1)) 172 173fun offset n x i = x (n + i) 174fun carryselect x y c0 c1 s0 s1 c s n k = let 175 val k' = Int.min(n,k) 176 val fm = 177 mk_conj(mk_conj(ripplecarry0 x y c0 s0 k', ripplecarry1 x y c1 s1 k'), 178 mk_conj(mk_eq(c k', mux (c 0) (c0 k') (c1 k')), 179 list_conj 180 (List.tabulate 181 (k', 182 (fn i => mk_eq(s i, mux (c 0) (s0 i) (s1 i))))))) 183in 184 if k' < k then fm 185 else mk_conj(fm, carryselect (offset k x) (offset k y) 186 (offset k c0) (offset k c1) 187 (offset k s0) (offset k s1) 188 (offset k c) (offset k s) 189 (n - k) k) 190end 191 192(* call with positive n and k to generate tautologies *) 193fun mk_adder_test n k = let 194 val [x,y,c,s,c0,s0,c1,s1,c2,s2] = 195 map mk_index ["x", "y", "c", "s", "c0", "s0", "c1", "s1", "c2", "s2"] 196in 197 simp 198 (mk_imp(mk_conj(mk_conj(carryselect x y c0 c1 s0 s1 c s n k, mk_neg (c 0)), 199 ripplecarry0 x y c2 s2 n), 200 mk_conj(mk_eq(c n, c2 n), 201 list_conj(List.tabulate(n, (fn i => mk_eq(s i, s2 i))))))) 202end 203 204(* example in tutorial is *) 205 206val example = gen_all (mk_adder_test 3 2) 207 208*) 209end 210