1(*  Title:      Tools/Argo/argo_cdcl.ML
2    Author:     Sascha Boehme
3
4Propositional satisfiability solver in the style of conflict-driven
5clause-learning (CDCL). It features:
6
7 * conflict analysis and clause learning based on the first unique implication point
8 * nonchronological backtracking
9 * dynamic variable ordering (VSIDS)
10 * restarting
11 * polarity caching
12 * propagation via two watched literals
13 * special propagation of binary clauses 
14 * minimizing learned clauses
15 * support for external knowledge
16
17These features might be added:
18
19 * pruning of unnecessary learned clauses
20 * rebuilding the variable heap
21 * aligning the restart level with the decision heuristics: keep decisions that would
22   be recovered instead of backjumping to level 0
23
24The implementation is inspired by:
25
26  Niklas E'en and Niklas S"orensson. An Extensible SAT-solver. In Enrico
27  Giunchiglia and Armando Tacchella, editors, Theory and Applications of
28  Satisfiability Testing. Volume 2919 of Lecture Notes in Computer
29  Science, pages 502-518. Springer, 2003.
30
31  Niklas S"orensson and Armin Biere. Minimizing Learned Clauses. In
32  Oliver Kullmann, editor, Theory and Applications of Satisfiability
33  Testing. Volume 5584 of Lecture Notes in Computer Science,
34  pages 237-243. Springer, 2009.
35*)
36
37signature ARGO_CDCL =
38sig
39  (* types *)
40  type 'a explain = Argo_Lit.literal -> 'a -> Argo_Cls.clause * 'a
41
42  (* context *)
43  type context
44  val context: context
45  val assignment_of: context -> Argo_Lit.literal -> bool option
46
47  (* enriching the context *)
48  val add_atom: Argo_Term.term -> context -> context
49  val add_axiom: Argo_Cls.clause -> context -> int * context
50
51  (* main operations *)
52  val assume: 'a explain -> Argo_Lit.literal -> context -> 'a ->
53    Argo_Cls.clause option * context * 'a
54  val propagate: context -> Argo_Common.literal Argo_Common.implied * context
55  val decide: context -> context option
56  val analyze: 'a explain -> Argo_Cls.clause -> context -> 'a -> int * context * 'a
57  val restart: context -> int * context
58end
59
60structure Argo_Cdcl: ARGO_CDCL =
61struct
62
63(* basic types and operations *)
64
65type 'a explain = Argo_Lit.literal -> 'a -> Argo_Cls.clause * 'a
66
67datatype reason =
68  Level0 of Argo_Proof.proof |
69  Decided of int * int * (bool * reason) Argo_Termtab.table |
70  Implied of int * int * (Argo_Lit.literal * reason) list * Argo_Proof.proof |
71  External of int
72
73fun level_of (Level0 _) = 0
74  | level_of (Decided (l, _, _)) = l
75  | level_of (Implied (l, _, _, _)) = l
76  | level_of (External l) = l
77
78type justified = Argo_Lit.literal * reason
79
80type watches = Argo_Cls.clause list * Argo_Cls.clause list
81
82fun get_watches wts t = Argo_Termtab.lookup wts t
83fun map_watches f t wts = Argo_Termtab.map_default (t, ([], [])) f wts
84
85fun map_lit_watches f (Argo_Lit.Pos t) = map_watches (apsnd f) t
86  | map_lit_watches f (Argo_Lit.Neg t) = map_watches (apfst f) t
87
88fun watches_of wts (Argo_Lit.Pos t) = (case get_watches wts t of SOME (ws, _) => ws | NONE => [])
89  | watches_of wts (Argo_Lit.Neg t) = (case get_watches wts t of SOME (_, ws) => ws | NONE => [])
90
91fun attach cls lit = map_lit_watches (cons cls) lit
92fun detach cls lit = map_lit_watches (remove Argo_Cls.eq_clause cls) lit
93
94
95(* literal values *)
96
97fun raw_val_of vals lit = Argo_Termtab.lookup vals (Argo_Lit.term_of lit)
98
99fun val_of vals (Argo_Lit.Pos t) = Argo_Termtab.lookup vals t
100  | val_of vals (Argo_Lit.Neg t) = Option.map (apfst not) (Argo_Termtab.lookup vals t)
101
102fun value_of vals (Argo_Lit.Pos t) = Option.map fst (Argo_Termtab.lookup vals t)
103  | value_of vals (Argo_Lit.Neg t) = Option.map (not o fst) (Argo_Termtab.lookup vals t)
104
105fun justified vals lit = Option.map (pair lit o snd) (raw_val_of vals lit)
106fun the_reason_of vals lit = snd (the (raw_val_of vals lit))
107
108fun assign (Argo_Lit.Pos t) r = Argo_Termtab.update (t, (true, r))
109  | assign (Argo_Lit.Neg t) r = Argo_Termtab.update (t, (false, r))
110
111
112(* context *)
113
114type trail = int * justified list (* the trail height and the sequence of assigned literals *)
115
116type context = {
117  units: Argo_Common.literal list, (* the literals that await propagation *)
118  level: int, (* the decision level *)
119  trail: int * justified list, (* the trail height and the sequence of assigned literals *)
120  vals: (bool * reason) Argo_Termtab.table, (* mapping of terms to polarity and reason *)
121  wts: watches Argo_Termtab.table, (* clauses watched by terms *)
122  heap: Argo_Heap.heap, (* max-priority heap for decision heuristics *)
123  clss: Argo_Cls.table, (* information about clauses *)
124  prf: Argo_Proof.context} (* the proof context *)
125
126fun mk_context units level trail vals wts heap clss prf: context =
127  {units=units, level=level, trail=trail, vals=vals, wts=wts, heap=heap, clss=clss, prf=prf}
128
129val context =
130  mk_context [] 0 (0, []) Argo_Termtab.empty Argo_Termtab.empty Argo_Heap.heap
131    Argo_Cls.table Argo_Proof.cdcl_context
132
133fun drop_levels n (Decided (l, h, vals)) trail heap =
134      if l = n + 1 then ((h, trail), vals, heap) else drop_literal n trail heap
135  | drop_levels n _ tr heap = drop_literal n tr heap
136
137and drop_literal n ((lit, r) :: trail) heap = drop_levels n r trail (Argo_Heap.insert lit heap)
138  | drop_literal _ [] _ = raise Fail "bad trail"
139
140fun backjump_to new_level (cx as {level, trail=(_, tr), wts, heap, clss, prf, ...}: context) =
141  if new_level >= level then (0, cx)
142  else
143    let val (trail, vals, heap) = drop_literal (Integer.max 0 new_level) tr heap
144    in (level - new_level, mk_context [] new_level trail vals wts heap clss prf) end
145
146
147(* proofs *)
148
149fun tag_clause (lits, p) prf = Argo_Proof.mk_clause lits p prf |>> pair lits
150
151fun level0_unit_proof (lit, Level0 p') (p, prf) = Argo_Proof.mk_unit_res lit p p' prf
152  | level0_unit_proof _ _ = raise Fail "bad reason"
153
154fun level0_unit_proofs lrs p prf = fold level0_unit_proof lrs (p, prf)
155
156fun unsat ({vals, prf, ...}: context) (lits, p) =
157  let val lrs = map (fn lit => (lit, the_reason_of vals lit)) lits
158  in Argo_Proof.unsat (fst (level0_unit_proofs lrs p prf)) end
159
160
161(* literal operations *)
162
163fun push lit p reason prf ({units, level, trail=(h, tr), vals, wts, heap, clss, ...}: context) =
164  let val vals = assign lit reason vals
165  in mk_context ((lit, p) :: units) level (h + 1, (lit, reason) :: tr) vals wts heap clss prf end
166
167fun push_level0 lit p lrs (cx as {prf, ...}: context) =
168  let val (p, prf) = level0_unit_proofs lrs p prf
169  in push lit (SOME p) (Level0 p) prf cx end
170
171fun push_implied lit p lrs (cx as {level, trail=(h, _), prf, ...}: context) =
172  if level > 0 then push lit NONE (Implied (level, h, lrs, p)) prf cx
173  else push_level0 lit p lrs cx
174
175fun push_decided lit (cx as {level, trail=(h, _), vals, prf, ...}: context) =
176  push lit NONE (Decided (level, h, vals)) prf cx
177
178fun assignment_of ({vals, ...}: context) = value_of vals
179
180fun replace_watches old new cls ({units, level, trail, vals, wts, heap, clss, prf}: context) =
181  mk_context units level trail vals (attach cls new (detach cls old wts)) heap clss prf
182
183
184(* clause operations *)
185
186fun as_clause cls ({units, level, trail, vals, wts, heap, clss, prf}: context) =
187  let val (cls, prf) = tag_clause cls prf
188  in (cls, mk_context units level trail vals wts heap clss prf) end
189
190fun note_watches ([_, _], _) _ clss = clss
191  | note_watches cls lp clss = Argo_Cls.put_watches cls lp clss
192
193fun attach_clause lit1 lit2 (cls as (lits, _)) cx =
194  let
195    val {units, level, trail, vals, wts, heap, clss, prf}: context = cx
196    val wts = attach cls lit1 (attach cls lit2 wts)
197    val clss = note_watches cls (lit1, lit2) clss
198  in mk_context units level trail vals wts (fold Argo_Heap.count lits heap) clss prf end
199
200fun change_watches _ (false, _, _) cx = cx
201  | change_watches cls (true, l1, l2) ({units, level, trail, vals, wts, heap, clss, prf}: context) =
202      mk_context units level trail vals wts heap (Argo_Cls.put_watches cls (l1, l2) clss) prf
203
204fun add_asserting lit lit' (cls as (_, p)) lrs cx =
205  attach_clause lit lit' cls (push_implied lit p lrs cx)
206
207(*
208  When learning a non-unit clause, the context is backtracked to the highest decision level
209  of the assigned literals.
210*)
211
212fun learn_clause _ ([lit], p) cx = backjump_to 0 cx ||> push_level0 lit p []
213  | learn_clause lrs (cls as (lits, _)) cx =
214      let
215        fun max_level (l, r) (ll as (_, lvl)) = if level_of r > lvl then (l, level_of r) else ll
216        val (lit, lvl) = fold max_level lrs (hd lits, 0)
217      in backjump_to lvl cx ||> add_asserting (hd lits) lit cls lrs end
218
219(*
220  An axiom with one unassigned literal and all remaining literals being assigned to
221  false is asserting. An axiom with all literals assigned to false on level 0 makes the
222  context unsatisfiable. An axiom with all literals assigned to false on higher levels
223  causes backjumping before the highest level, and then the axiom might be asserting if
224  only one literal is unassigned on that level.
225*)
226
227fun min lit i NONE = SOME (lit, i)
228  | min lit i (SOME (lj as (_, j))) = SOME (if i < j then (lit, i) else lj)
229
230fun level_ord ((_, r1), (_, r2)) = int_ord (level_of r2, level_of r1)
231fun add_max lr lrs = Ord_List.insert level_ord lr lrs
232
233fun part [] [] t us fs = (t, us, fs)
234  | part (NONE :: vs) (l :: ls) t us fs = part vs ls t (l :: us) fs
235  | part (SOME (true, r) :: vs) (l :: ls) t us fs = part vs ls (min l (level_of r) t) us fs
236  | part (SOME (false, r) :: vs) (l :: ls) t us fs = part vs ls t us (add_max (l, r) fs)
237  | part _ _ _ _ _ = raise Fail "mismatch between values and literals"
238
239fun backjump_add (lit, r) (lit', r') cls lrs cx =
240  let
241    val add =
242      if level_of r = level_of r' then attach_clause lit lit' cls
243      else add_asserting lit lit' cls lrs
244  in backjump_to (level_of r - 1) cx ||> add end
245
246fun analyze_axiom vs (cls as (lits, p), cx) =
247  (case part vs lits NONE [] [] of
248    (SOME (lit, lvl), [], []) =>
249      if lvl > 0 then backjump_to 0 cx ||> push_implied lit p [] else (0, cx)
250  | (SOME (lit, lvl), [], (lit', _) :: _) => (0, cx |> (lvl > 0) ? attach_clause lit lit' cls)
251  | (SOME (lit, lvl), lit' :: _, _) => (0, cx |> (lvl > 0) ? attach_clause lit lit' cls)
252  | (NONE, [], (_, Level0 _) :: _) => unsat cx cls
253  | (NONE, [], [(lit, _)]) => backjump_to 0 cx ||> push_implied lit p []
254  | (NONE, [], lrs as (lr :: lr' :: _)) => backjump_add lr lr' cls lrs cx
255  | (NONE, [lit], []) => backjump_to 0 cx ||> push_implied lit p []
256  | (NONE, [lit], lrs as (lit', _) :: _) => (0, add_asserting lit lit' cls lrs cx)
257  | (NONE, lit1 :: lit2 :: _, _) => (0, attach_clause lit1 lit2 cls cx)
258  | _ => raise Fail "bad clause")
259
260
261(* enriching the context *)
262
263fun add_atom t ({units, level, trail, vals, wts, heap, clss, prf}: context) =
264  let val heap = Argo_Heap.insert (Argo_Lit.Pos t) heap
265  in mk_context units level trail vals wts heap clss prf end
266
267fun add_axiom ([], p) _ = Argo_Proof.unsat p
268  | add_axiom (cls as (lits, _)) (cx as {vals, ...}: context) =
269      if has_duplicates Argo_Lit.eq_lit lits then raise Fail "clause with duplicate literals"
270      else if has_duplicates Argo_Lit.dual_lit lits then (0, cx)
271      else analyze_axiom (map (val_of vals) lits) (as_clause cls cx)
272
273
274(* external knowledge *)
275
276fun assume explain lit (cx as {level, vals, prf, ...}: context) x =
277  (case value_of vals lit of
278    SOME true => (NONE, cx, x)
279  | SOME false => 
280      let val (cls, x) = explain lit x
281      in if level = 0 then unsat cx cls else (SOME cls, cx, x) end
282  | NONE =>
283      if level = 0 then
284        let val ((lits, p), x) = explain lit x
285        in (NONE, push_level0 lit p (map_filter (justified vals) lits) cx, x) end
286      else (NONE, push lit NONE (External level) prf cx, x))
287
288
289(* propagation *)
290
291exception CONFLICT of Argo_Cls.clause * context
292
293fun order_lits_by lit (l1, l2) =
294  if Argo_Lit.eq_id (l1, lit) then (true, l2, l1) else (false, l1, l2)
295
296fun prop_binary (_, implied_lit, other_lit) (cls as (_, p)) (cx as {level, vals, ...}: context) =
297  (case value_of vals implied_lit of
298    NONE => push_implied implied_lit p [(other_lit, the_reason_of vals other_lit)] cx
299  | SOME true => cx
300  | SOME false => if level = 0 then unsat cx cls else raise CONFLICT (cls, cx))
301
302datatype next = Lit of Argo_Lit.literal | None of justified list
303
304fun with_non_false f l (SOME (false, r)) lrs = f ((l, r) :: lrs)
305  | with_non_false _ l _ _ = Lit l
306
307fun first_non_false _ _ [] lrs = None lrs
308  | first_non_false vals lit (l :: ls) lrs =
309      if Argo_Lit.eq_lit (l, lit) then first_non_false vals lit ls lrs
310      else with_non_false (first_non_false vals lit ls) l (val_of vals l) lrs
311
312fun prop_nary (lp as (_, lit1, lit2)) (cls as (lits, p)) (cx as {level, vals, ...}: context) =
313  let val v = value_of vals lit1
314  in
315    if v = SOME true then change_watches cls lp cx
316    else
317      (case first_non_false vals lit1 lits [] of
318        Lit lit2' => change_watches cls (true, lit1, lit2') (replace_watches lit2 lit2' cls cx)
319      | None lrs =>
320          if v = NONE then push_implied lit1 p lrs (change_watches cls lp cx)
321          else if level = 0 then unsat cx cls
322          else raise CONFLICT (cls, change_watches cls lp cx))
323  end
324
325fun prop_cls lit (cls as ([l1, l2], _)) cx = prop_binary (order_lits_by lit (l1, l2)) cls cx
326  | prop_cls lit cls (cx as {clss, ...}: context) =
327      prop_nary (order_lits_by lit (Argo_Cls.get_watches clss cls)) cls cx
328
329fun prop_lit (lp as (lit, _)) (lps, cx as {wts, ...}: context) =
330  (lp :: lps, fold (prop_cls lit) (watches_of wts lit) cx)
331
332fun prop lps (cx as {units=[], ...}: context) = (Argo_Common.Implied (rev lps), cx)
333  | prop lps ({units, level, trail, vals, wts, heap, clss, prf}: context) =
334      fold_rev prop_lit units (lps, mk_context [] level trail vals wts heap clss prf) |-> prop
335
336fun propagate cx = prop [] cx
337  handle CONFLICT (cls, cx) => (Argo_Common.Conflict cls, cx)
338
339
340(* decisions *)
341
342(*
343  Decisions are based on an activity heuristics. The most active variable that is
344  still unassigned is chosen.
345*)
346
347fun decide ({units, level, trail, vals, wts, heap, clss, prf}: context) =
348  let
349    fun check NONE = NONE
350      | check (SOME (lit, heap)) =
351          if Argo_Termtab.defined vals (Argo_Lit.term_of lit) then check (Argo_Heap.extract heap)
352          else SOME (push_decided lit (mk_context units (level + 1) trail vals wts heap clss prf))
353  in check (Argo_Heap.extract heap) end
354
355
356(* conflict analysis and clause learning *)
357
358(*
359  Learned clauses often contain literals that are redundant, because they are
360  subsumed by other literals of the clause. By analyzing the implication graph beyond
361  the unique implication point, such redundant literals can be identified and hence
362  removed from the learned clause. Only literals occurring in the learned clause and
363  their reasons need to be analyzed.
364*)
365
366exception ESSENTIAL of unit
367
368fun history_ord ((h1, lit1, _), (h2, lit2, _)) =
369  if h1 < 0 andalso h2 < 0 then int_ord (apply2 Argo_Lit.signed_id_of (lit1, lit2))
370  else int_ord (h2, h1)
371
372fun rec_redundant stop (lit, Implied (lvl, h, lrs, p)) lps =
373      if stop lit lvl then lps
374      else fold (rec_redundant stop) lrs ((h, lit, p) :: lps)
375  | rec_redundant stop (lit, Decided (lvl, _, _)) lps =
376      if stop lit lvl then lps
377      else raise ESSENTIAL ()
378  | rec_redundant _ (lit, Level0 p) lps = ((~1, lit, p) :: lps)
379  | rec_redundant _ _ _ = raise ESSENTIAL ()
380
381fun redundant stop (lr as (lit, Implied (_, h, lrs, p))) (lps, essential_lrs) = (
382      (fold (rec_redundant stop) lrs ((h, lit, p) :: lps), essential_lrs)
383      handle ESSENTIAL () => (lps, lr :: essential_lrs))
384  | redundant _ lr (lps, essential_lrs) = (lps, lr :: essential_lrs)
385
386fun resolve_step (_, l, p') (p, prf) = Argo_Proof.mk_unit_res l p p' prf
387
388fun reduce lrs p prf =
389  let
390    val lits = map fst lrs
391    val levels = fold (insert (op =) o level_of o snd) lrs []
392    fun stop lit level =
393      if member Argo_Lit.eq_lit lits lit then true
394      else if member (op =) levels level then false
395      else raise ESSENTIAL ()
396
397    val (lps, lrs) = fold (redundant stop) lrs ([], [])
398  in (lrs, fold resolve_step (sort_distinct history_ord lps) (p, prf)) end
399
400(*
401  Literals that are candidates for the learned lemma are marked and unmarked while
402  traversing backwards through the trail. The last remaining marked literal is the first
403  unique implication point.
404*)
405
406fun unmark lit ms = remove Argo_Lit.eq_id lit ms
407fun marked ms lit = member Argo_Lit.eq_id ms lit
408
409(*
410  Whenever an implication is recorded, the reason for the false literals of the
411  asserting clause are known. It is reasonable to store this justification list as part
412  of the implication reason. Consequently, the implementation of conflict analysis can
413  benefit from this information, which does not need to be re-computed here.
414*)
415
416fun justification_for _ _ _ (Implied (_, _, lrs, p)) x = (lrs, p, x)
417  | justification_for explain vals lit (External _) x =
418      let val ((lits, p), x) = explain lit x
419      in (map_filter (justified vals) lits, p, x) end
420  | justification_for _ _ _ _ _ = raise Fail "bad reason"
421
422fun first_lit pred ((lr as (lit, _)) :: lrs) = if pred lit then (lr, lrs) else first_lit pred lrs
423  | first_lit _ _ = raise Empty
424
425(*
426  Beginning from the conflicting clause, the implication graph is traversed to the first
427  unique implication point. This breadth-first search is controlled by the topological order of
428  the trail, which is traversed backwards. While traversing through the trail, the conflict
429  literals of lower levels are collected to form the conflict lemma together with the unique
430  implication point. Conflict literals assigned on level 0 are excluded from the conflict lemma.
431  Conflict literals assigned on the current level are candidates for the first unique
432  implication point.
433*)
434
435fun analyze explain cls (cx as {level, trail, vals, wts, heap, clss, prf, ...}: context) x =
436  let
437    fun from_clause [] trail ms lrs h p prf x =
438          from_trail (first_lit (marked ms) trail) ms lrs h p prf x
439      | from_clause ((lit, r) :: clause_lrs) trail ms lrs h p prf x =
440          from_reason r lit clause_lrs trail ms lrs h p prf x
441 
442    and from_reason (Level0 p') lit clause_lrs trail ms lrs h p prf x =
443          let val (p, prf) = Argo_Proof.mk_unit_res lit p p' prf
444          in from_clause clause_lrs trail ms lrs h p prf x end
445      | from_reason r lit clause_lrs trail ms lrs h p prf x =
446          if level_of r = level then
447            if marked ms lit then from_clause clause_lrs trail ms lrs h p prf x
448            else from_clause clause_lrs trail (lit :: ms) lrs (Argo_Heap.increase lit h) p prf x
449          else
450            let
451              val (lrs, h) =
452                if AList.defined Argo_Lit.eq_id lrs lit then (lrs, h)
453                else ((lit, r) :: lrs, Argo_Heap.increase lit h)
454            in from_clause clause_lrs trail ms lrs h p prf x end
455
456    and from_trail ((lit, _), _) [_] lrs h p prf x =
457          let val (lrs, (p, prf)) = reduce lrs p prf
458          in (Argo_Lit.negate lit :: map fst lrs, lrs, h, p, prf, x) end
459      | from_trail ((lit, r), trail) ms lrs h p prf x =
460          let
461            val (clause_lrs, p', x) = justification_for explain vals lit r x
462            val (p, prf) = Argo_Proof.mk_unit_res lit p' p prf
463          in from_clause clause_lrs trail (unmark lit ms) lrs h p prf x end
464
465    val (ls, p) = cls
466    val lrs = if level = 0 then unsat cx cls else map (fn l => (l, the_reason_of vals l)) ls
467    val (lits, lrs, heap, p, prf, x) = from_clause lrs (snd trail) [] [] heap p prf x
468    val heap = Argo_Heap.decay heap
469    val (levels, cx) = learn_clause lrs (lits, p) (mk_context [] level trail vals wts heap clss prf)
470  in (levels, cx, x) end
471
472
473(* restarting *)
474
475fun restart cx = backjump_to 0 cx
476
477end
478