1(* Taken from code by Michael Norrish to        *)
2(* accompany Proof Tools chapter of the Manual. *)
3(* Acts as a failsafe for SAT_TAUT_PROVE        *)
4(* on trivial problems.                         *)
5(* Ultimate entry-point is DPLL_TAUT            *)
6
7structure dpll =
8struct
9
10open HolKernel Parse boolLib def_cnf satCommonTools satTools
11
12datatype result = Unsat of thm | Sat of term -> term
13
14fun count_vars ds acc =
15    case ds of
16      [] => acc
17    | lit::lits => let
18        val v = dest_neg lit handle HOL_ERR _ => lit
19      in
20        case Binarymap.peek (acc, v) of
21          NONE => count_vars lits (Binarymap.insert(acc,v,1))
22        | SOME n => count_vars lits (Binarymap.insert(acc,v,n + 1))
23      end
24
25fun getBiggest acc =
26    #1 (Binarymap.foldl(fn (v,cnt,a as (bestv,bestcnt)) =>
27                           if cnt > bestcnt then (v,cnt) else a)
28                       (boolSyntax.T, 0)
29                       acc)
30
31(* The first unit we see, or the var that occurs most often *)
32fun find_splitting_var phi = let
33  fun recurse acc [] = getBiggest acc
34    | recurse acc (c::cs) = let
35        val ds = strip_disj c
36      in
37        case ds of
38          [lit] => (dest_neg lit handle HOL_ERR _ => lit)
39        | _ => recurse (count_vars ds acc) cs
40      end
41in
42  recurse (Binarymap.mkDict Term.compare) (strip_conj phi)
43end
44
45fun casesplit v th = let (*th is [assignments, cnf] |- current *)
46  val eqT = ASSUME (mk_eq(v, boolSyntax.T)) (* v = T |- v = T *)
47  val eqF = ASSUME (mk_eq(v, boolSyntax.F)) (* v = F |- v = F *)
48in
49  (REWRITE_RULE [eqT] th, REWRITE_RULE [eqF] th) (* [assignments,v=T,cnf] |- cnf[T/v] ... *)
50end
51
52fun mk_satmap th = let
53  val hyps = hypset th
54  fun foldthis (t,acc) = let
55    val (l,r) = dest_eq t
56  in
57    Binarymap.insert(acc,l,r)
58  end handle HOL_ERR _ => acc
59  val fmap = HOLset.foldl foldthis (Binarymap.mkDict Term.compare) hyps
60in
61  Sat (fn v => Binarymap.find(fmap,v)
62          handle Binarymap.NotFound => boolSyntax.T)
63end
64
65fun CoreDPLL initial_th = let (* [ci] |- cnf *)
66   fun recurse th = let (* [assigns, ci] |- curr *)
67    val c = concl th (* current *)
68  in
69    if c = boolSyntax.T then
70      mk_satmap th
71    else if c = boolSyntax.F then
72      Unsat th
73    else let
74        val v = find_splitting_var c
75        val (l,r) = casesplit v th (*[assigns,v=T,ci]|- curr[T/v],[assigns,v=F,ci]|- curr[F/v]*)
76      in
77        case recurse l of
78          Unsat l_false =>
79          (case recurse r of
80               Unsat r_false =>
81               Unsat (DISJ_CASES (SPEC v BOOL_CASES_AX) l_false r_false)
82               (* [assignsr\v,assignsl\v,ci] |- F *)
83            | x => x)
84        | x => x
85      end
86  end
87in
88   recurse initial_th (* [ci] |-  F *)
89end
90
91fun doCNF neg_tm =  (* clauses is (ci,[~t] |- ci') pairs, where ci' is expanded ci *)
92    let val (cnfv,vc,lfn,clauseth) = to_cnf false neg_tm
93        val clauses = Array.foldr (fn ((c,th),l) => (c,th)::l) [] clauseth
94        val cnf_thm = List.foldl (fn ((c,_),cnf) => CONJ (ASSUME c) cnf)(*[ci] |- cnf*)
95                                  (ASSUME (fst (hd clauses))) (tl clauses)
96    in (cnfv,cnf_thm,lfn,clauses) end
97
98fun undoCNF lfn clauses th = (* th is [ci] |-  F *)
99    let val insts = RBM.foldl (fn (v,t,insts) => (v |-> t)::insts) [] lfn
100        val inst_th = INST insts th
101        val th0 = List.foldl (fn ((_,cth),th) => PROVE_HYP cth th) inst_th clauses (* ~t |- F *)
102    in th0 end
103
104fun mk_model_thm cnfv lfn t f =
105    if isSome cnfv then let
106            val fvs = List.map fst (RBM.listItems lfn)
107            val model = List.map (fn v => if is_T (f v) then v else mk_neg v) fvs
108            val model2 = mapfilter (fn l => let val x = hd(free_vars l)
109                                                val y = rbapply lfn x
110                                            in if is_var y then subst [x|->y] l
111                                               else failwith"" end) model
112        in satCheck model2 (mk_neg t) end else
113    let val fvs = free_vars t
114        val model = List.map (fn v => if is_T (f v) then v else mk_neg v) fvs
115    in satCheck model (mk_neg t) end
116
117fun DPLL_TAUT t = let
118  val (cnfv,cnf_thm,lfn,clauses) = doCNF (mk_neg t) (* cnf_thm is [ci] |- dCNF(~t) *)
119in
120  case CoreDPLL cnf_thm of
121      Unsat cnf_entails_F =>  (* [ci] |- F *)
122        undoCNF lfn clauses cnf_entails_F (* [~t] |- F *)
123    | Sat f => mk_model_thm cnfv lfn t f (* |- model ==> ~t *)
124end
125
126(* implementation of DPLL ends *)
127
128(* ----------------------------------------------------------------------
129    Code below, due to John Harrison, generates tautologies stating that
130    two different implementations of binary addition are equivalent
131   ---------------------------------------------------------------------- *)
132
133(*
134fun halfsum x y = mk_eq(x,mk_neg y)
135fun halfcarry x y = mk_conj(x,y)
136fun ha x y s c = mk_conj(mk_eq(s,halfsum x y), mk_eq(c,halfcarry x y))
137
138fun carry x y z = mk_disj(mk_conj(x,y), mk_conj(mk_disj(x,y), z))
139fun sum x y z = halfsum (halfsum x y) z;
140fun fa x y z s c = mk_conj(mk_eq(s,sum x y z), mk_eq(c,carry x y z))
141
142fun list_conj cs = list_mk_conj cs handle HOL_ERR _ => boolSyntax.T
143
144fun ripplecarry x y c out n =
145    list_conj
146      (List.tabulate(n, (fn i => fa (x i) (y i) (c i) (out i) (c (i + 1)))))
147
148fun mk_index s i = mk_var(s ^ "_" ^ Int.toString i, bool)
149
150val [x,y,out,c] = map mk_index ["X", "Y", "OUT", "C"]
151val twobit_adder = ripplecarry x y c out 2
152
153fun simp t =
154    rhs (concl (QCONV (REWRITE_CONV [GSYM CONJ_ASSOC, GSYM DISJ_ASSOC]) t))
155
156fun ripplecarry0 x y c out n =
157    simp (ripplecarry x y (fn i => if i = 0 then boolSyntax.F
158                                   else c i) out n)
159
160fun ripplecarry1 x y c out n =
161    simp (ripplecarry x y (fn i => if i = 0 then boolSyntax.T
162                                   else c i) out n)
163
164fun mux sel in0 in1 = mk_disj(mk_conj(mk_neg sel,in0), mk_conj(sel,in1))
165
166fun offset n x i = x (n + i)
167fun carryselect x y c0 c1 s0 s1 c s n k = let
168  val k' = Int.min(n,k)
169  val fm =
170      mk_conj(mk_conj(ripplecarry0 x y c0 s0 k', ripplecarry1 x y c1 s1 k'),
171              mk_conj(mk_eq(c k', mux (c 0) (c0 k') (c1 k')),
172                      list_conj
173                      (List.tabulate
174                       (k',
175                        (fn i => mk_eq(s i, mux (c 0) (s0 i) (s1 i)))))))
176in
177  if k' < k then fm
178  else mk_conj(fm, carryselect (offset k x) (offset k y)
179                               (offset k c0) (offset k c1)
180                               (offset k s0) (offset k s1)
181                               (offset k c) (offset k s)
182                               (n - k) k)
183end
184
185(* call with positive n and k to generate tautologies *)
186fun mk_adder_test n k = let
187  val [x,y,c,s,c0,s0,c1,s1,c2,s2] =
188      map mk_index ["x", "y", "c", "s", "c0", "s0", "c1", "s1", "c2", "s2"]
189in
190  simp
191    (mk_imp(mk_conj(mk_conj(carryselect x y c0 c1 s0 s1 c s n k, mk_neg (c 0)),
192                    ripplecarry0 x y c2 s2 n),
193            mk_conj(mk_eq(c n, c2 n),
194                    list_conj(List.tabulate(n, (fn i => mk_eq(s i, s2 i)))))))
195end
196
197(* example in tutorial is *)
198
199val example = gen_all (mk_adder_test 3 2)
200
201*)
202end
203