1(* Taken from code by Michael Norrish to        *)
2(* accompany Proof Tools chapter of the Manual. *)
3(* Acts as a failsafe for SAT_TAUT_PROVE        *)
4(* on trivial problems.                         *)
5(* Ultimate entry-point is DPLL_TAUT            *)
6
7structure dpll =
8struct
9
10open HolKernel Parse boolLib def_cnf satCommonTools satTools
11
12datatype result = Unsat of thm | Sat of term -> term
13
14fun count_vars ds acc =
15    case ds of
16      [] => acc
17    | lit::lits => let
18        val v = dest_neg lit handle HOL_ERR _ => lit
19      in
20        case Binarymap.peek (acc, v) of
21          NONE => count_vars lits (Binarymap.insert(acc,v,1))
22        | SOME n => count_vars lits (Binarymap.insert(acc,v,n + 1))
23      end
24
25fun getBiggest acc =
26    #1 (Binarymap.foldl(fn (v,cnt,a as (bestv,bestcnt)) =>
27                           if cnt > bestcnt then (v,cnt) else a)
28                       (boolSyntax.T, 0)
29                       acc)
30
31(* The first unit we see, or the var that occurs most often *)
32fun find_splitting_var phi = let
33  fun recurse acc [] = getBiggest acc
34    | recurse acc (c::cs) = let
35        val ds = strip_disj c
36      in
37        case ds of
38          [lit] => (dest_neg lit handle HOL_ERR _ => lit)
39        | _ => recurse (count_vars ds acc) cs
40      end
41in
42  recurse (Binarymap.mkDict Term.compare) (strip_conj phi)
43end
44
45fun casesplit v th = let (*th is [assignments, cnf] |- current *)
46  val eqT = ASSUME (mk_eq(v, boolSyntax.T)) (* v = T |- v = T *)
47  val eqF = ASSUME (mk_eq(v, boolSyntax.F)) (* v = F |- v = F *)
48in
49  (REWRITE_RULE [eqT] th, REWRITE_RULE [eqF] th)
50    (* [assignments,v=T,cnf] |- cnf[T/v] ... *)
51end
52
53fun mk_satmap th = let
54  val hyps = hypset th
55  fun foldthis (t,acc) = let
56    val (l,r) = dest_eq t
57  in
58    Binarymap.insert(acc,l,r)
59  end handle HOL_ERR _ => acc
60  val fmap = HOLset.foldl foldthis (Binarymap.mkDict Term.compare) hyps
61in
62  Sat (fn v => Binarymap.find(fmap,v)
63          handle Binarymap.NotFound => boolSyntax.T)
64end
65
66fun CoreDPLL initial_th = let (* [ci] |- cnf *)
67   fun recurse th = let (* [assigns, ci] |- curr *)
68    val c = concl th (* current *)
69  in
70    if aconv c boolSyntax.T then
71      mk_satmap th
72    else if aconv c boolSyntax.F then
73      Unsat th
74    else let
75        val v = find_splitting_var c
76        val (l,r) = casesplit v th
77          (*[assigns,v=T,ci]|- curr[T/v],[assigns,v=F,ci]|- curr[F/v]*)
78      in
79        case recurse l of
80          Unsat l_false =>
81          (case recurse r of
82               Unsat r_false =>
83               Unsat (DISJ_CASES (SPEC v BOOL_CASES_AX) l_false r_false)
84               (* [assignsr\v,assignsl\v,ci] |- F *)
85            | x => x)
86        | x => x
87      end
88  end
89in
90   recurse initial_th (* [ci] |-  F *)
91end
92
93fun doCNF neg_tm =
94    (* clauses is (ci,[~t] |- ci') pairs, where ci' is expanded ci *)
95    let val (cnfv,vc,lfn,clauseth) = to_cnf false neg_tm
96        val clauses = Array.foldr (fn ((c,th),l) => (c,th)::l) [] clauseth
97        val cnf_thm = List.foldl (fn ((c,_),cnf) => CONJ (ASSUME c) cnf)
98                                 (*[ci] |- cnf*)
99                                 (ASSUME (fst (hd clauses))) (tl clauses)
100    in (cnfv,cnf_thm,lfn,clauses) end
101
102fun undoCNF lfn clauses th = (* th is [ci] |-  F *)
103    let val insts = RBM.foldl (fn (v,t,insts) => (v |-> t)::insts) [] lfn
104        val inst_th = INST insts th
105        val th0 = List.foldl (fn ((_,cth),th) => PROVE_HYP cth th)
106                             inst_th clauses (* ~t |- F *)
107    in th0 end
108
109fun mk_model_thm cnfv lfn t f =
110    if isSome cnfv then
111      let
112        val fvs = List.map fst (RBM.listItems lfn)
113        val model = List.map (fn v => if is_T (f v) then v else mk_neg v) fvs
114        val model2 = mapfilter (fn l => let val x = hd(free_vars l)
115                                            val y = rbapply lfn x
116                                        in if is_var y then subst [x|->y] l
117                                           else failwith"" end) model
118      in satCheck model2 (mk_neg t) end else
119    let val fvs = free_vars t
120        val model = List.map (fn v => if is_T (f v) then v else mk_neg v) fvs
121    in satCheck model (mk_neg t) end
122
123fun DPLL_TAUT t = let
124  val (cnfv,cnf_thm,lfn,clauses) = doCNF (mk_neg t)
125                                         (* cnf_thm is [ci] |- dCNF(~t) *)
126in
127  case CoreDPLL cnf_thm of
128      Unsat cnf_entails_F =>  (* [ci] |- F *)
129        undoCNF lfn clauses cnf_entails_F (* [~t] |- F *)
130    | Sat f => mk_model_thm cnfv lfn t f (* |- model ==> ~t *)
131end
132
133(* implementation of DPLL ends *)
134
135(* ----------------------------------------------------------------------
136    Code below, due to John Harrison, generates tautologies stating that
137    two different implementations of binary addition are equivalent
138   ---------------------------------------------------------------------- *)
139
140(*
141fun halfsum x y = mk_eq(x,mk_neg y)
142fun halfcarry x y = mk_conj(x,y)
143fun ha x y s c = mk_conj(mk_eq(s,halfsum x y), mk_eq(c,halfcarry x y))
144
145fun carry x y z = mk_disj(mk_conj(x,y), mk_conj(mk_disj(x,y), z))
146fun sum x y z = halfsum (halfsum x y) z;
147fun fa x y z s c = mk_conj(mk_eq(s,sum x y z), mk_eq(c,carry x y z))
148
149fun list_conj cs = list_mk_conj cs handle HOL_ERR _ => boolSyntax.T
150
151fun ripplecarry x y c out n =
152    list_conj
153      (List.tabulate(n, (fn i => fa (x i) (y i) (c i) (out i) (c (i + 1)))))
154
155fun mk_index s i = mk_var(s ^ "_" ^ Int.toString i, bool)
156
157val [x,y,out,c] = map mk_index ["X", "Y", "OUT", "C"]
158val twobit_adder = ripplecarry x y c out 2
159
160fun simp t =
161    rhs (concl (QCONV (REWRITE_CONV [GSYM CONJ_ASSOC, GSYM DISJ_ASSOC]) t))
162
163fun ripplecarry0 x y c out n =
164    simp (ripplecarry x y (fn i => if i = 0 then boolSyntax.F
165                                   else c i) out n)
166
167fun ripplecarry1 x y c out n =
168    simp (ripplecarry x y (fn i => if i = 0 then boolSyntax.T
169                                   else c i) out n)
170
171fun mux sel in0 in1 = mk_disj(mk_conj(mk_neg sel,in0), mk_conj(sel,in1))
172
173fun offset n x i = x (n + i)
174fun carryselect x y c0 c1 s0 s1 c s n k = let
175  val k' = Int.min(n,k)
176  val fm =
177      mk_conj(mk_conj(ripplecarry0 x y c0 s0 k', ripplecarry1 x y c1 s1 k'),
178              mk_conj(mk_eq(c k', mux (c 0) (c0 k') (c1 k')),
179                      list_conj
180                      (List.tabulate
181                       (k',
182                        (fn i => mk_eq(s i, mux (c 0) (s0 i) (s1 i)))))))
183in
184  if k' < k then fm
185  else mk_conj(fm, carryselect (offset k x) (offset k y)
186                               (offset k c0) (offset k c1)
187                               (offset k s0) (offset k s1)
188                               (offset k c) (offset k s)
189                               (n - k) k)
190end
191
192(* call with positive n and k to generate tautologies *)
193fun mk_adder_test n k = let
194  val [x,y,c,s,c0,s0,c1,s1,c2,s2] =
195      map mk_index ["x", "y", "c", "s", "c0", "s0", "c1", "s1", "c2", "s2"]
196in
197  simp
198    (mk_imp(mk_conj(mk_conj(carryselect x y c0 c1 s0 s1 c s n k, mk_neg (c 0)),
199                    ripplecarry0 x y c2 s2 n),
200            mk_conj(mk_eq(c n, c2 n),
201                    list_conj(List.tabulate(n, (fn i => mk_eq(s i, s2 i)))))))
202end
203
204(* example in tutorial is *)
205
206val example = gen_all (mk_adder_test 3 2)
207
208*)
209end
210