1(*  Title:      Tools/Argo/argo_simplex.ML
2    Author:     Sascha Boehme
3
4Linear arithmetic reasoning based on the simplex algorithm. It features:
5
6 * simplification and normalization of arithmetic expressions
7 * decision procedure for reals
8
9These features might be added:
10
11 * propagating implied inequality literals while assuming external knowledge
12 * propagating equalities for fixed variables to all other theory solvers
13 * pruning the tableau after new atoms have been added: eliminate unnecessary
14   variables
15
16The implementation is inspired by:
17
18  Bruno Dutertre and Leonardo de Moura. A fast linear-arithmetic solver
19  for DPLL(T). In Computer Aided Verification, pages 81-94. Springer, 2006.
20*)
21
22signature ARGO_SIMPLEX =
23sig
24  (* context *)
25  type context
26  val context: context
27
28  (* enriching the context *)
29  val add_atom: Argo_Term.term -> context -> Argo_Lit.literal option * context
30
31  (* main operations *)
32  val prepare: context -> context
33  val assume: Argo_Common.literal -> context -> Argo_Lit.literal Argo_Common.implied * context
34  val check: context -> Argo_Lit.literal Argo_Common.implied * context
35  val explain: Argo_Lit.literal -> context -> (Argo_Cls.clause * context) option
36  val add_level: context -> context
37  val backtrack: context -> context
38end
39
40structure Argo_Simplex: ARGO_SIMPLEX =
41struct
42
43(* extended rationals *)
44
45(*
46  Extended rationals (c, k) are reals (c + k * e) where e is some small positive real number.
47  Extended rationals are used to represent a strict inequality by a non-strict inequality:
48    c < x  ~~  c + k * e <= e
49    x < c  ~~  x <= c - k * e
50*)
51
52type erat = Rat.rat * Rat.rat
53
54val erat_zero = (@0, @0)
55
56fun add (c1, k1) (c2, k2) = (c1 + c2, k1 + k2)
57fun sub (c1, k1) (c2, k2) = (c1 - c2, k1 - k2)
58fun mul n (c, k) = (n * c, n * k)
59
60val erat_ord = prod_ord Rat.ord Rat.ord
61
62fun less_eq n1 n2 = is_less_equal (erat_ord (n1, n2))
63fun less n1 n2 = is_less (erat_ord (n1, n2))
64
65
66(* term functions *)
67
68fun dest_monom (Argo_Term.T (_, Argo_Expr.Mul, [Argo_Term.T (_, Argo_Expr.Num n, _), t])) = (t, n)
69  | dest_monom t = (t, @1)
70
71datatype node = Var of Argo_Term.term | Num of Rat.rat
72datatype ineq = Lower of Argo_Term.term * erat | Upper of Argo_Term.term * erat
73
74fun dest_node (Argo_Term.T (_, Argo_Expr.Num n, _)) = Num n
75  | dest_node t = Var t
76
77fun dest_atom true (k as Argo_Expr.Le) t1 t2 = SOME (k, dest_node t1, dest_node t2)
78  | dest_atom true (k as Argo_Expr.Lt) t1 t2 = SOME (k, dest_node t1, dest_node t2)
79  | dest_atom false Argo_Expr.Le t1 t2 = SOME (Argo_Expr.Lt, dest_node t2, dest_node t1)
80  | dest_atom false Argo_Expr.Lt t1 t2 = SOME (Argo_Expr.Le, dest_node t2, dest_node t1)
81  | dest_atom _ _ _ _ = NONE
82
83fun ineq_of pol (Argo_Term.T (_, k, [t1, t2])) =
84      (case dest_atom pol k t1 t2 of
85        SOME (Argo_Expr.Le, Var x, Num n) => SOME (Upper (x, (n, @0)))
86      | SOME (Argo_Expr.Le, Num n, Var x) => SOME (Lower (x, (n, @0)))
87      | SOME (Argo_Expr.Lt, Var x, Num n) => SOME (Upper (x, (n, @~1)))
88      | SOME (Argo_Expr.Lt, Num n, Var x) => SOME (Lower (x, (n, @1)))
89      | _ => NONE)
90  | ineq_of _ _ = NONE
91
92
93(* proofs *)
94
95(*
96  comment missing
97*)
98
99fun mk_ineq is_lt = if is_lt then Argo_Expr.mk_lt else Argo_Expr.mk_le
100fun ineq_rule_of is_lt = if is_lt then Argo_Proof.Lt else Argo_Proof.Le
101
102fun rewrite_top f = Argo_Rewr.rewrite_top (f Argo_Rewr.context)
103
104fun unnegate_conv (e as Argo_Expr.E (Argo_Expr.Not, [Argo_Expr.E (Argo_Expr.Le, [e1, e2])])) =
105      Argo_Rewr.rewr (Argo_Proof.Rewr_Not_Ineq Argo_Proof.Le) (Argo_Expr.mk_lt e2 e1) e
106  | unnegate_conv (e as Argo_Expr.E (Argo_Expr.Not, [Argo_Expr.E (Argo_Expr.Lt, [e1, e2])])) =
107      Argo_Rewr.rewr (Argo_Proof.Rewr_Not_Ineq Argo_Proof.Lt) (Argo_Expr.mk_le e2 e1) e
108  | unnegate_conv e = Argo_Rewr.keep e
109
110fun scale_conv r mk n e1 e2 =
111  let
112    fun scale e = Argo_Expr.mk_mul (Argo_Expr.mk_num n) e
113    val (e1, e2) = if n > @0 then (scale e1, scale e2) else (scale e2, scale e1)
114    val conv = Argo_Rewr.rewr (Argo_Proof.Rewr_Ineq_Mul (r, n)) (mk e1 e2)
115  in Argo_Rewr.seq [conv, Argo_Rewr.args (rewrite_top Argo_Rewr.norm_mul)] end
116
117fun dest_ineq (Argo_Expr.E (Argo_Expr.Le, [e1, e2])) = SOME (false, e1, e2)
118  | dest_ineq (Argo_Expr.E (Argo_Expr.Lt, [e1, e2])) = SOME (true, e1, e2)
119  | dest_ineq _ = NONE
120
121fun scale_ineq_conv n e =
122  if n = @1 then Argo_Rewr.keep e
123  else
124    (case dest_ineq e of
125      NONE => raise Fail "bad inequality"
126    | SOME (is_lt, e1, e2) => scale_conv (ineq_rule_of is_lt) (mk_ineq is_lt) n e1 e2 e)
127
128fun simp_lit (n, (lit, p)) =
129  let val conv = Argo_Rewr.seq [unnegate_conv, scale_ineq_conv n]
130  in Argo_Rewr.with_proof conv (Argo_Lit.signed_expr_of lit, p) end
131
132val combine_conv = rewrite_top Argo_Rewr.norm_add
133fun reduce_conv r = Argo_Rewr.rewr (Argo_Proof.Rewr_Ineq_Nums (r, false)) Argo_Expr.false_expr
134
135fun simp_combine es p prf =
136  let
137    fun dest e (is_lt, (es1, es2)) =
138      let val (is_lt', e1, e2) = the (dest_ineq e)
139      in (is_lt orelse is_lt', (e1 :: es1, e2 :: es2)) end
140    val (is_lt, (es1, es2)) = fold_rev dest es (false, ([], []))
141    val e = uncurry (mk_ineq is_lt) (apply2 Argo_Expr.mk_add (es1, es2))
142    val conv = Argo_Rewr.seq [Argo_Rewr.args combine_conv, reduce_conv (ineq_rule_of is_lt)]
143  in prf |> Argo_Rewr.with_proof conv (e, p) |>> snd end
144
145fun linear_combination nlps prf =
146  let val ((es, ps), prf) = fold_map simp_lit nlps prf |>> split_list
147  in prf |> Argo_Proof.mk_linear_comb ps |-> simp_combine es |-> Argo_Proof.mk_lemma [] end
148
149fun proof_of (lit, SOME p) (ls, prf) = ((lit, p), (ls, prf))
150  | proof_of (lit, NONE) (ls, prf) =
151      let val (p, prf) = Argo_Proof.mk_hyp lit prf
152      in ((lit, p), (Argo_Lit.negate lit :: ls, prf)) end
153
154
155(* tableau *)
156
157(*
158  The tableau consists of equations x_i = a_i1 * x_1 + ... a_ik * x_k where
159  the variable on the left-hand side is called a basic variable and
160  the variables on the right-hand side are called non-basic variables.
161
162  For each basic variable, the polynom on the right-hand side is stored as a map
163  from variables to coefficients. Only variables with non-zero coefficients are stored.
164  The map is sorted by the term order of the variables for a deterministic order when
165  analyzing a polynom.
166
167  Additionally, for each basic variable a boolean flag is kept that, when false,
168  indicates that the current value of the basic variable might be outside its bounds.
169  The value of a non-basic variable is always within its bounds.
170
171  The tableau is stored as a table indexed by variables. For each variable,
172  both basic and non-basic, its current value is stored as extended rational
173  along with either the equations or the occurrences.
174*)
175
176type basic = bool * (Argo_Term.term * Rat.rat) Ord_List.T
177type entry = erat * basic option
178type tableau = entry Argo_Termtab.table
179
180fun dirty ms = SOME (false, ms)
181fun checked ms = SOME (true, ms)
182
183fun basic_entry ms = (erat_zero, dirty ms)
184val non_basic_entry: entry = (erat_zero, NONE)
185
186fun value_of tableau x =
187  (case Argo_Termtab.lookup tableau x of
188    NONE => erat_zero
189  | SOME (v, _) => v)
190
191fun first_unchecked_basic tableau =
192  Argo_Termtab.get_first (fn (y, (v, SOME (false, ms))) => SOME (y, v, ms) | _ => NONE) tableau
193
194local
195
196fun coeff_of ms x = the (AList.lookup Argo_Term.eq_term ms x)
197
198val eq_var = Argo_Term.eq_term
199fun monom_ord sp = prod_ord Argo_Term.term_ord (K EQUAL) sp
200
201fun add_monom m ms = Ord_List.insert monom_ord m ms
202fun update_monom (m as (x, a)) = if a = @0 then AList.delete eq_var x else AList.update eq_var m
203
204fun add_scaled_monom n (x, a) ms =
205  (case AList.lookup eq_var ms x of
206    NONE => add_monom (x, n * a) ms
207  | SOME b => update_monom (x, n * a + b) ms)
208
209fun replace_polynom x n ms' ms = fold (add_scaled_monom n) ms' (AList.delete eq_var x ms)
210
211fun map_basic f (v, SOME (_, ms)) = f v ms
212  | map_basic _ e = e
213
214fun map_basic_entries x f =
215  let
216    fun apply (e as (v, SOME (_, ms))) = if AList.defined eq_var ms x then f v ms else e
217      | apply ve = ve
218  in Argo_Termtab.map (K apply) end
219
220fun put_entry x e = Argo_Termtab.update (x, e)
221
222fun add_new_entry (y as Argo_Term.T (_, Argo_Expr.Add, ts)) tableau =
223      let val ms = Ord_List.make monom_ord (map dest_monom ts)
224      in fold (fn (x, _) => put_entry x non_basic_entry) ms (put_entry y (basic_entry ms) tableau) end
225  | add_new_entry x tableau = put_entry x non_basic_entry tableau
226
227fun with_non_basic update_basic x f tableau =
228  (case Argo_Termtab.lookup tableau x of
229    NONE => tableau
230  | SOME (v, NONE) => f v tableau
231  | SOME (v, SOME (_, ms)) => if update_basic then put_entry x (v, dirty ms) tableau else tableau)
232
233in
234
235fun add_entry x tableau =
236  if Argo_Termtab.defined tableau x then tableau
237  else add_new_entry x tableau
238
239fun basic_within_bounds y = Argo_Termtab.map_entry y (map_basic (fn v => fn ms => (v, checked ms)))
240
241fun eliminate _ tableau = tableau
242
243fun update_non_basic pred x v' = with_non_basic true x (fn v =>
244  let fun update_basic n v ms = (add v (mul (coeff_of ms x) n), dirty ms)
245  in pred v ? put_entry x (v', NONE) o map_basic_entries x (update_basic (sub v' v)) end)
246
247fun update_pivot y vy ms x c v = with_non_basic false x (fn vx =>
248  let
249    val a = Rat.inv c
250    val v' = mul a (sub v vy)
251
252    fun scale_or_drop (x', b) = if Argo_Term.eq_term (x', x) then NONE else SOME (x', ~ a * b)
253    val ms = add_monom (y, a) (map_filter scale_or_drop ms)
254
255    fun update_basic v ms' =
256      let val n = coeff_of ms' x
257      in (add v (mul n v'), dirty (replace_polynom x n ms ms')) end
258  in
259    put_entry x (add vx v', dirty ms) #>
260    put_entry y (v, NONE) #>
261    map_basic_entries x update_basic
262  end)
263
264end
265
266
267(* bounds *)
268
269(*
270  comment missing
271*)
272
273type bound = (erat * Argo_Common.literal) option
274type atoms = (erat * Argo_Term.term) list
275type bounds_atoms = ((bound * bound) * (atoms * atoms))
276type bounds = bounds_atoms Argo_Termtab.table
277
278val empty_bounds_atoms: bounds_atoms = ((NONE, NONE), ([], []))
279
280fun on_some pred (SOME (n, _)) = pred n
281  | on_some _ NONE = false
282
283fun none_or_some pred (SOME (n, _)) = pred n
284  | none_or_some _ NONE = true
285
286fun bound_of (SOME (n, _)) = n
287  | bound_of NONE = raise Fail "bad bound"
288
289fun reason_of (SOME (_, r)) = r
290  | reason_of NONE = raise Fail "bad reason"
291
292fun bounds_atoms_of bounds x = the_default empty_bounds_atoms (Argo_Termtab.lookup bounds x)
293fun bounds_of bounds x = fst (bounds_atoms_of bounds x)
294
295fun put_bounds x bs bounds = Argo_Termtab.map_default (x, empty_bounds_atoms) (apfst (K bs)) bounds
296
297fun has_bound_atoms bounds x =
298  (case Argo_Termtab.lookup bounds x of
299    NONE => false
300  | SOME (_, ([], [])) => false
301  | _ => true)
302
303fun add_new_atom f x n t =
304  let val ins = f (insert (eq_snd Argo_Term.eq_term) (n, t))
305  in Argo_Termtab.map_default (x, empty_bounds_atoms) (apsnd ins) end
306
307fun del_atom x t =
308  let fun eq_atom (t1, (_, t2)) = Argo_Term.eq_term (t1, t2)
309  in Argo_Termtab.map_entry x (apsnd (apply2 (remove eq_atom t))) end
310
311
312(* context *)
313
314type context = {
315  tableau: tableau, (* values of variables and tableau entries for each variable *)
316  bounds: bounds, (* bounds and unassigned atoms for each variable *)
317  prf: Argo_Proof.context, (* proof context *)
318  back: bounds list} (* stack storing previous bounds and unassigned atoms *)
319
320fun mk_context tableau bounds prf back: context =
321  {tableau=tableau, bounds=bounds, prf=prf, back=back}
322
323val context = mk_context Argo_Termtab.empty Argo_Termtab.empty Argo_Proof.simplex_context []
324
325
326(* declaring atoms *)
327
328fun add_ineq_atom f t x n ({tableau, bounds, prf, back}: context) =
329  (* TODO: check whether the atom is already known to hold *)
330  (NONE, mk_context (add_entry x tableau) (add_new_atom f x n t bounds) prf back)
331
332fun add_atom t cx =
333  (case ineq_of true t of
334    SOME (Lower (x, n)) => add_ineq_atom apfst t x n cx
335  | SOME (Upper (x, n)) => add_ineq_atom apsnd t x n cx
336  | NONE => (NONE, cx))
337
338
339(* preparing the solver after new atoms have been added *)
340
341(*
342  Variables that do not directly occur in atoms can be eliminated from the tableau
343  since no bounds will ever limit their value. This can reduce the tableau size
344  substantially.
345*)
346
347fun prepare ({tableau, bounds, prf, back}: context) =
348  let fun drop (xe as (x, _)) = not (has_bound_atoms bounds x) ? eliminate xe
349  in mk_context (Argo_Termtab.fold drop tableau tableau) bounds prf back end
350
351
352(* assuming external knowledge *)
353
354fun bounds_conflict r1 r2 ({tableau, bounds, prf, back}: context) =
355  let
356    val ((lp2, lp1), (lits, prf)) = ([], prf) |> proof_of r2 ||>> proof_of r1
357    val (p, prf) = linear_combination [(@~1, lp1), (@1, lp2)] prf
358  in (Argo_Common.Conflict (lits, p), mk_context tableau bounds prf back) end
359
360fun assume_bounds order x c bs ({tableau, bounds, prf, back}: context) =
361  let
362    val lits = []
363    val bounds = put_bounds x bs bounds
364    val tableau = update_non_basic (fn v => erat_ord (v, c) = order) x c tableau
365  in (Argo_Common.Implied lits, mk_context tableau bounds prf back) end
366
367fun assume_lower r x c (low, upp) cx =
368  if on_some (fn l => less_eq c l) low then (Argo_Common.Implied [], cx)
369  else if on_some (fn u => less u c) upp then bounds_conflict r (reason_of upp) cx
370  else assume_bounds LESS x c (SOME (c, r), upp) cx
371
372fun assume_upper r x c (low, upp) cx =
373  if on_some (fn u => less_eq u c) upp then (Argo_Common.Implied [], cx)
374  else if on_some (fn l => less c l) low then bounds_conflict (reason_of low) r cx
375  else assume_bounds GREATER x c (low, SOME (c, r)) cx
376
377fun with_bounds r t f x n ({tableau, bounds, prf, back}: context) =
378  f r x n (bounds_of bounds x) (mk_context tableau (del_atom x t bounds) prf back)
379
380fun choose f (SOME (Lower (x, n))) cx = f assume_lower x n cx
381  | choose f (SOME (Upper (x, n))) cx = f assume_upper x n cx
382  | choose _ NONE cx = (Argo_Common.Implied [], cx)
383
384fun assume (r as (lit, _)) cx =
385  let val (t, pol) = Argo_Lit.dest lit
386  in choose (with_bounds r t) (ineq_of pol t) cx end
387
388
389(* checking for consistency and pending implications *)
390
391fun basic_bounds_conflict lower y ms ({tableau, bounds, prf, back}: context) =
392  let
393    val (a, low, upp) = if lower then (@1, fst, snd) else (@~1, snd, fst)
394    fun coeff_proof f a x = apfst (pair a) o proof_of (reason_of (f (bounds_of bounds x)))
395    fun monom_proof (x, a) = coeff_proof (if a < @0 then low else upp) a x
396    val ((alp, alps), (lits, prf)) = ([], prf) |> coeff_proof low a y ||>> fold_map monom_proof ms
397    val (p, prf) = linear_combination (alp :: alps) prf
398  in (Argo_Common.Conflict (lits, p), mk_context tableau bounds prf back) end
399
400fun can_compensate ord tableau bounds (x, a) =
401  let val (low, upp) = bounds_of bounds x
402  in
403    if Rat.ord (a, @0) = ord then none_or_some (fn u => less (value_of tableau x) u) upp
404    else none_or_some (fn l => less l (value_of tableau x)) low
405  end
406
407fun check (cx as {tableau, bounds, prf, back}: context) =
408  (case first_unchecked_basic tableau of
409    NONE => (Argo_Common.Implied [], cx)
410  | SOME (y, v, ms) =>
411      let val (low, upp) = bounds_of bounds y
412      in
413        if on_some (fn l => less v l) low then adjust GREATER true y v ms (bound_of low) cx
414        else if on_some (fn u => less u v) upp then adjust LESS false y v ms (bound_of upp) cx
415        else check (mk_context (basic_within_bounds y tableau) bounds prf back)
416      end)
417
418and adjust ord lower y vy ms v (cx as {tableau, bounds, prf, back}: context) =
419  (case find_first (can_compensate ord tableau bounds) ms of
420    NONE => basic_bounds_conflict lower y ms cx
421  | SOME (x, a) => check (mk_context (update_pivot y vy ms x a v tableau) bounds prf back))
422
423
424(* explanations *)
425
426fun explain _ _ = NONE
427
428
429(* backtracking *)
430
431fun add_level ({tableau, bounds, prf, back}: context) =
432  mk_context tableau bounds prf (bounds :: back)
433
434fun backtrack ({back=[], ...}: context) = raise Empty
435  | backtrack ({tableau, prf, back=bounds :: back, ...}: context) =
436      mk_context tableau bounds prf back
437
438end
439