1(*  Title:      Provers/Arith/fast_lin_arith.ML
2    Author:     Tobias Nipkow and Tjark Weber and Sascha Boehme
3
4A generic linear arithmetic package.
5
6Only take premises and conclusions into account that are already
7(negated) (in)equations. lin_arith_simproc tries to prove or disprove
8the term.
9*)
10
11(*** Data needed for setting up the linear arithmetic package ***)
12
13signature LIN_ARITH_LOGIC =
14sig
15  val conjI       : thm  (* P ==> Q ==> P & Q *)
16  val ccontr      : thm  (* (~ P ==> False) ==> P *)
17  val notI        : thm  (* (P ==> False) ==> ~ P *)
18  val not_lessD   : thm  (* ~(m < n) ==> n <= m *)
19  val not_leD     : thm  (* ~(m <= n) ==> n < m *)
20  val sym         : thm  (* x = y ==> y = x *)
21  val trueI       : thm  (* True *)
22  val mk_Eq       : thm -> thm
23  val atomize     : thm -> thm list
24  val mk_Trueprop : term -> term
25  val neg_prop    : term -> term
26  val is_False    : thm -> bool
27  val is_nat      : typ list * term -> bool
28  val mk_nat_thm  : theory -> term -> thm
29end;
30(*
31mk_Eq(~in) = `in == False'
32mk_Eq(in) = `in == True'
33where `in' is an (in)equality.
34
35neg_prop(t) = neg  if t is wrapped up in Trueprop and neg is the
36  (logically) negated version of t (again wrapped up in Trueprop),
37  where the negation of a negative term is the term itself (no
38  double negation!); raises TERM ("neg_prop", [t]) if t is not of
39  the form 'Trueprop $ _'
40
41is_nat(parameter-types,t) =  t:nat
42mk_nat_thm(t) = "0 <= t"
43*)
44
45signature LIN_ARITH_DATA =
46sig
47  (*internal representation of linear (in-)equations:*)
48  type decomp = (term * Rat.rat) list * Rat.rat * string * (term * Rat.rat) list * Rat.rat * bool
49  val decomp: Proof.context -> term -> decomp option
50  val domain_is_nat: term -> bool
51
52  (*abstraction for proof replay*)
53  val abstract_arith: term -> (term * term) list * Proof.context ->
54    term * ((term * term) list * Proof.context)
55  val abstract: term -> (term * term) list * Proof.context ->
56    term * ((term * term) list * Proof.context)
57
58  (*preprocessing, performed on a representation of subgoals as list of premises:*)
59  val pre_decomp: Proof.context -> typ list * term list -> (typ list * term list) list
60
61  (*preprocessing, performed on the goal -- must do the same as 'pre_decomp':*)
62  val pre_tac: Proof.context -> int -> tactic
63
64  (*the limit on the number of ~= allowed; because each ~= is split
65    into two cases, this can lead to an explosion*)
66  val neq_limit: int Config.T
67
68  val trace: bool Config.T
69end;
70(*
71decomp(`x Rel y') should yield (p,i,Rel,q,j,d)
72   where Rel is one of "<", "~<", "<=", "~<=" and "=" and
73         p (q, respectively) is the decomposition of the sum term x
74         (y, respectively) into a list of summand * multiplicity
75         pairs and a constant summand and d indicates if the domain
76         is discrete.
77
78domain_is_nat(`x Rel y') t should yield true iff x is of type "nat".
79
80The relationship between pre_decomp and pre_tac is somewhat tricky.  The
81internal representation of a subgoal and the corresponding theorem must
82be modified by pre_decomp (pre_tac, resp.) in a corresponding way.  See
83the comment for split_items below.  (This is even necessary for eta- and
84beta-equivalent modifications, as some of the lin. arith. code is not
85insensitive to them.)
86
87Simpset must reduce contradictory <= to False.
88   It should also cancel common summands to keep <= reduced;
89   otherwise <= can grow to massive proportions.
90*)
91
92signature FAST_LIN_ARITH =
93sig
94  val prems_lin_arith_tac: Proof.context -> int -> tactic
95  val lin_arith_tac: Proof.context -> int -> tactic
96  val lin_arith_simproc: Proof.context -> cterm -> thm option
97  val map_data:
98    ({add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
99      lessD: thm list, neqE: thm list, simpset: simpset,
100      number_of: (Proof.context -> typ -> int -> cterm) option} ->
101     {add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
102      lessD: thm list, neqE: thm list, simpset: simpset,
103      number_of: (Proof.context -> typ -> int -> cterm) option}) ->
104      Context.generic -> Context.generic
105  val add_inj_thms: thm list -> Context.generic -> Context.generic
106  val add_lessD: thm -> Context.generic -> Context.generic
107  val add_simps: thm list -> Context.generic -> Context.generic
108  val add_simprocs: simproc list -> Context.generic -> Context.generic
109  val set_number_of: (Proof.context -> typ -> int -> cterm) -> Context.generic -> Context.generic
110end;
111
112functor Fast_Lin_Arith
113  (structure LA_Logic: LIN_ARITH_LOGIC and LA_Data: LIN_ARITH_DATA): FAST_LIN_ARITH =
114struct
115
116
117(** theory data **)
118
119structure Data = Generic_Data
120(
121  type T =
122   {add_mono_thms: thm list,
123    mult_mono_thms: thm list,
124    inj_thms: thm list,
125    lessD: thm list,
126    neqE: thm list,
127    simpset: simpset,
128    number_of: (Proof.context -> typ -> int -> cterm) option};
129
130  val empty : T =
131   {add_mono_thms = [], mult_mono_thms = [], inj_thms = [],
132    lessD = [], neqE = [], simpset = empty_ss,
133    number_of = NONE};
134  val extend = I;
135  fun merge
136    ({add_mono_thms = add_mono_thms1, mult_mono_thms = mult_mono_thms1, inj_thms = inj_thms1,
137      lessD = lessD1, neqE = neqE1, simpset = simpset1, number_of = number_of1},
138     {add_mono_thms = add_mono_thms2, mult_mono_thms = mult_mono_thms2, inj_thms = inj_thms2,
139      lessD = lessD2, neqE = neqE2, simpset = simpset2, number_of = number_of2}) : T =
140    {add_mono_thms = Thm.merge_thms (add_mono_thms1, add_mono_thms2),
141     mult_mono_thms = Thm.merge_thms (mult_mono_thms1, mult_mono_thms2),
142     inj_thms = Thm.merge_thms (inj_thms1, inj_thms2),
143     lessD = Thm.merge_thms (lessD1, lessD2),
144     neqE = Thm.merge_thms (neqE1, neqE2),
145     simpset = merge_ss (simpset1, simpset2),
146     number_of = merge_options (number_of1, number_of2)};
147);
148
149val map_data = Data.map;
150val get_data = Data.get o Context.Proof;
151
152fun get_neqE ctxt = map (Thm.transfer' ctxt) (#neqE (get_data ctxt));
153
154fun map_inj_thms f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} =
155  {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = f inj_thms,
156    lessD = lessD, neqE = neqE, simpset = simpset, number_of = number_of};
157
158fun map_lessD f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} =
159  {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = inj_thms,
160    lessD = f lessD, neqE = neqE, simpset = simpset, number_of = number_of};
161
162fun map_simpset f context =
163  map_data (fn {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} =>
164    {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = inj_thms,
165      lessD = lessD, neqE = neqE, simpset = simpset_map (Context.proof_of context) f simpset,
166      number_of = number_of}) context;
167
168fun add_inj_thms thms = map_data (map_inj_thms (append (map Thm.trim_context thms)));
169fun add_lessD thm = map_data (map_lessD (fn thms => thms @ [Thm.trim_context thm]));
170fun add_simps thms = map_simpset (fn ctxt => ctxt addsimps thms);
171fun add_simprocs procs = map_simpset (fn ctxt => ctxt addsimprocs procs);
172
173fun set_number_of f =
174  map_data (fn {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, ...} =>
175   {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = inj_thms,
176    lessD = lessD, neqE = neqE, simpset = simpset, number_of = SOME f});
177
178fun number_of ctxt =
179  (case get_data ctxt of
180    {number_of = SOME f, ...} => f ctxt
181  | _ => fn _ => fn _ => raise CTERM ("number_of", []));
182
183
184
185(*** A fast decision procedure ***)
186(*** Code ported from HOL Light ***)
187(* possible optimizations:
188   use (var,coeff) rep or vector rep  tp save space;
189   treat non-negative atoms separately rather than adding 0 <= atom
190*)
191
192datatype lineq_type = Eq | Le | Lt;
193
194datatype injust = Asm of int
195                | Nat of int (* index of atom *)
196                | LessD of injust
197                | NotLessD of injust
198                | NotLeD of injust
199                | NotLeDD of injust
200                | Multiplied of int * injust
201                | Added of injust * injust;
202
203datatype lineq = Lineq of int * lineq_type * int list * injust;
204
205(* ------------------------------------------------------------------------- *)
206(* Finding a (counter) example from the trace of a failed elimination        *)
207(* ------------------------------------------------------------------------- *)
208(* Examples are represented as rational numbers,                             *)
209(* Dont blame John Harrison for this code - it is entirely mine. TN          *)
210
211exception NoEx;
212
213(* Coding: (i,true,cs) means i <= cs and (i,false,cs) means i < cs.
214   In general, true means the bound is included, false means it is excluded.
215   Need to know if it is a lower or upper bound for unambiguous interpretation!
216*)
217
218(* ------------------------------------------------------------------------- *)
219(* End of counterexample finder. The actual decision procedure starts here.  *)
220(* ------------------------------------------------------------------------- *)
221
222(* ------------------------------------------------------------------------- *)
223(* Calculate new (in)equality type after addition.                           *)
224(* ------------------------------------------------------------------------- *)
225
226fun find_add_type(Eq,x) = x
227  | find_add_type(x,Eq) = x
228  | find_add_type(_,Lt) = Lt
229  | find_add_type(Lt,_) = Lt
230  | find_add_type(Le,Le) = Le;
231
232(* ------------------------------------------------------------------------- *)
233(* Multiply out an (in)equation.                                             *)
234(* ------------------------------------------------------------------------- *)
235
236fun multiply_ineq n (i as Lineq(k,ty,l,just)) =
237  if n = 1 then i
238  else if n = 0 andalso ty = Lt then raise Fail "multiply_ineq"
239  else if n < 0 andalso (ty=Le orelse ty=Lt) then raise Fail "multiply_ineq"
240  else Lineq (n * k, ty, map (Integer.mult n) l, Multiplied (n, just));
241
242(* ------------------------------------------------------------------------- *)
243(* Add together (in)equations.                                               *)
244(* ------------------------------------------------------------------------- *)
245
246fun add_ineq (Lineq (k1,ty1,l1,just1)) (Lineq (k2,ty2,l2,just2)) =
247  let val l = map2 Integer.add l1 l2
248  in Lineq(k1+k2,find_add_type(ty1,ty2),l,Added(just1,just2)) end;
249
250(* ------------------------------------------------------------------------- *)
251(* Elimination of variable between a single pair of (in)equations.           *)
252(* If they're both inequalities, 1st coefficient must be +ve, 2nd -ve.       *)
253(* ------------------------------------------------------------------------- *)
254
255fun elim_var v (i1 as Lineq(_,ty1,l1,_)) (i2 as Lineq(_,ty2,l2,_)) =
256  let val c1 = nth l1 v and c2 = nth l2 v
257      val m = Integer.lcm c1 c2
258      val m1 = m div (abs c1) and m2 = m div (abs c2)
259      val (n1,n2) =
260        if (c1 >= 0) = (c2 >= 0)
261        then if ty1 = Eq then (~m1,m2)
262             else if ty2 = Eq then (m1,~m2)
263                  else raise Fail "elim_var"
264        else (m1,m2)
265      val (p1,p2) = if ty1=Eq andalso ty2=Eq andalso (n1 = ~1 orelse n2 = ~1)
266                    then (~n1,~n2) else (n1,n2)
267  in add_ineq (multiply_ineq p1 i1) (multiply_ineq p2 i2) end;
268
269(* ------------------------------------------------------------------------- *)
270(* The main refutation-finding code.                                         *)
271(* ------------------------------------------------------------------------- *)
272
273fun is_trivial (Lineq(_,_,l,_)) = forall (fn i => i=0) l;
274
275fun is_contradictory (Lineq(k,ty,_,_)) =
276  case ty  of Eq => k <> 0 | Le => k > 0 | Lt => k >= 0;
277
278fun calc_blowup l =
279  let val (p,n) = List.partition (curry (op <) 0) (filter (curry (op <>) 0) l)
280  in length p * length n end;
281
282(* ------------------------------------------------------------------------- *)
283(* Main elimination code:                                                    *)
284(*                                                                           *)
285(* (1) Looks for immediate solutions (false assertions with no variables).   *)
286(*                                                                           *)
287(* (2) If there are any equations, picks a variable with the lowest absolute *)
288(* coefficient in any of them, and uses it to eliminate.                     *)
289(*                                                                           *)
290(* (3) Otherwise, chooses a variable in the inequality to minimize the       *)
291(* blowup (number of consequences generated) and eliminates it.              *)
292(* ------------------------------------------------------------------------- *)
293
294fun extract_first p =
295  let
296    fun extract xs (y::ys) = if p y then (y, xs @ ys) else extract (y::xs) ys
297      | extract _ [] = raise List.Empty
298  in extract [] end;
299
300fun print_ineqs ctxt ineqs =
301  if Config.get ctxt LA_Data.trace then
302     tracing(cat_lines(""::map (fn Lineq(c,t,l,_) =>
303       string_of_int c ^
304       (case t of Eq => " =  " | Lt=> " <  " | Le => " <= ") ^
305       commas(map string_of_int l)) ineqs))
306  else ();
307
308type history = (int * lineq list) list;
309datatype result = Success of injust | Failure of history;
310
311fun elim ctxt (ineqs, hist) =
312  let val _ = print_ineqs ctxt ineqs
313      val (triv, nontriv) = List.partition is_trivial ineqs in
314  if not (null triv)
315  then case find_first is_contradictory triv of
316         NONE => elim ctxt (nontriv, hist)
317       | SOME(Lineq(_,_,_,j)) => Success j
318  else
319  if null nontriv then Failure hist
320  else
321  let val (eqs, noneqs) = List.partition (fn (Lineq(_,ty,_,_)) => ty=Eq) nontriv in
322  if not (null eqs) then
323     let val c =
324           fold (fn Lineq(_,_,l,_) => fn cs => union (op =) l cs) eqs []
325           |> filter (fn i => i <> 0)
326           |> sort (int_ord o apply2 abs)
327           |> hd
328         val (eq as Lineq(_,_,ceq,_),othereqs) =
329               extract_first (fn Lineq(_,_,l,_) => member (op =) l c) eqs
330         val v = find_index (fn v => v = c) ceq
331         val (ioth,roth) = List.partition (fn (Lineq(_,_,l,_)) => nth l v = 0)
332                                     (othereqs @ noneqs)
333         val others = map (elim_var v eq) roth @ ioth
334     in elim ctxt (others,(v,nontriv)::hist) end
335  else
336  let val lists = map (fn (Lineq(_,_,l,_)) => l) noneqs
337      val numlist = 0 upto (length (hd lists) - 1)
338      val coeffs = map (fn i => map (fn xs => nth xs i) lists) numlist
339      val blows = map calc_blowup coeffs
340      val iblows = blows ~~ numlist
341      val nziblows = filter_out (fn (i, _) => i = 0) iblows
342  in if null nziblows then Failure((~1,nontriv)::hist)
343     else
344     let val (_,v) = hd(sort (fn (x,y) => int_ord(fst(x),fst(y))) nziblows)
345         val (no,yes) = List.partition (fn (Lineq(_,_,l,_)) => nth l v = 0) ineqs
346         val (pos,neg) = List.partition(fn (Lineq(_,_,l,_)) => nth l v > 0) yes
347     in elim ctxt (no @ map_product (elim_var v) pos neg, (v,nontriv)::hist) end
348  end
349  end
350  end;
351
352(* ------------------------------------------------------------------------- *)
353(* Translate back a proof.                                                   *)
354(* ------------------------------------------------------------------------- *)
355
356fun trace_thm ctxt msgs th =
357 (if Config.get ctxt LA_Data.trace
358  then tracing (cat_lines (msgs @ [Thm.string_of_thm ctxt th]))
359  else (); th);
360
361fun trace_term ctxt msgs t =
362 (if Config.get ctxt LA_Data.trace
363  then tracing (cat_lines (msgs @ [Syntax.string_of_term ctxt t]))
364  else (); t);
365
366fun trace_msg ctxt msg =
367  if Config.get ctxt LA_Data.trace then tracing msg else ();
368
369val union_term = union Envir.aeconv;
370
371fun add_atoms (lhs, _, _, rhs, _, _) =
372  union_term (map fst lhs) o union_term (map fst rhs);
373
374fun atoms_of ds = fold add_atoms ds [];
375
376(*
377Simplification may detect a contradiction 'prematurely' due to type
378information: n+1 <= 0 is simplified to False and does not need to be crossed
379with 0 <= n.
380*)
381local
382  exception FalseE of thm * (int * cterm) list * Proof.context
383in
384
385fun mkthm ctxt asms (just: injust) =
386  let
387    val thy = Proof_Context.theory_of ctxt;
388    val {add_mono_thms = add_mono_thms0, mult_mono_thms = mult_mono_thms0,
389      inj_thms = inj_thms0, lessD = lessD0, simpset, ...} = get_data ctxt;
390    val add_mono_thms = map (Thm.transfer thy) add_mono_thms0;
391    val mult_mono_thms = map (Thm.transfer thy) mult_mono_thms0;
392    val inj_thms = map (Thm.transfer thy) inj_thms0;
393    val lessD = map (Thm.transfer thy) lessD0;
394
395    val number_of = number_of ctxt;
396    val simpset_ctxt = put_simpset simpset ctxt;
397    fun only_concl f thm =
398      if Thm.no_prems thm then f (Thm.concl_of thm) else NONE;
399    val atoms = atoms_of (map_filter (only_concl (LA_Data.decomp ctxt)) asms);
400
401    fun use_first rules thm =
402      get_first (fn th => SOME (thm RS th) handle THM _ => NONE) rules
403
404    fun add2 thm1 thm2 =
405      use_first add_mono_thms (thm1 RS (thm2 RS LA_Logic.conjI));
406    fun try_add thms thm = get_first (fn th => add2 th thm) thms;
407
408    fun add_thms thm1 thm2 =
409      (case add2 thm1 thm2 of
410        NONE =>
411          (case try_add ([thm1] RL inj_thms) thm2 of
412            NONE =>
413              (the (try_add ([thm2] RL inj_thms) thm1)
414                handle Option.Option =>
415                  (trace_thm ctxt [] thm1; trace_thm ctxt [] thm2;
416                   raise Fail "Linear arithmetic: failed to add thms"))
417          | SOME thm => thm)
418      | SOME thm => thm);
419
420    fun mult_by_add n thm =
421      let fun mul i th = if i = 1 then th else mul (i - 1) (add_thms thm th)
422      in mul n thm end;
423
424    val rewr = Simplifier.rewrite simpset_ctxt;
425    val rewrite_concl = Conv.fconv_rule (Conv.concl_conv ~1 (Conv.arg_conv
426      (Conv.binop_conv rewr)));
427    fun discharge_prem thm = if Thm.nprems_of thm = 0 then thm else
428      let val cv = Conv.arg1_conv (Conv.arg_conv rewr)
429      in Thm.implies_elim (Conv.fconv_rule cv thm) LA_Logic.trueI end
430
431    fun mult n thm =
432      (case use_first mult_mono_thms thm of
433        NONE => mult_by_add n thm
434      | SOME mth =>
435          let
436            val cv = mth |> Thm.cprop_of |> Drule.strip_imp_concl
437              |> Thm.dest_arg |> Thm.dest_arg1 |> Thm.dest_arg1
438            val T = Thm.typ_of_cterm cv
439          in
440            mth
441            |> Thm.instantiate ([], [(dest_Var (Thm.term_of cv), number_of T n)])
442            |> rewrite_concl
443            |> discharge_prem
444            handle CTERM _ => mult_by_add n thm
445                 | THM _ => mult_by_add n thm
446          end);
447
448    fun mult_thm n thm =
449      if n = ~1 then thm RS LA_Logic.sym
450      else if n < 0 then mult (~n) thm RS LA_Logic.sym
451      else mult n thm;
452
453    fun simp thm (cx as (_, hyps, ctxt')) =
454      let val thm' = trace_thm ctxt ["Simplified:"] (full_simplify simpset_ctxt thm)
455      in if LA_Logic.is_False thm' then raise FalseE (thm', hyps, ctxt') else (thm', cx) end;
456
457    fun abs_thm i (cx as (terms, hyps, ctxt)) =
458      (case AList.lookup (op =) hyps i of
459        SOME ct => (Thm.assume ct, cx)
460      | NONE =>
461          let
462            val thm = nth asms i
463            val (t, (terms', ctxt')) = LA_Data.abstract (Thm.prop_of thm) (terms, ctxt)
464            val ct = Thm.cterm_of ctxt' t
465          in (Thm.assume ct, (terms', (i, ct) :: hyps, ctxt')) end);
466
467    fun nat_thm t (terms, hyps, ctxt) =
468      let val (t', (terms', ctxt')) = LA_Data.abstract_arith t (terms, ctxt)
469      in (LA_Logic.mk_nat_thm thy t', (terms', hyps, ctxt')) end;
470
471    fun step0 msg (thm, cx) = (trace_thm ctxt [msg] thm, cx)
472    fun step1 msg j f cx = mk j cx |>> f |>> trace_thm ctxt [msg]
473    and step2 msg j1 j2 f cx = mk j1 cx ||>> mk j2 |>> f |>> trace_thm ctxt [msg]
474
475    and mk (Asm i) cx = step0 ("Asm " ^ string_of_int i) (abs_thm i cx)
476      | mk (Nat i) cx = step0 ("Nat " ^ string_of_int i) (nat_thm (nth atoms i) cx)
477      | mk (LessD j) cx = step1 "L" j (fn thm => hd ([thm] RL lessD)) cx
478      | mk (NotLeD j) cx = step1 "NLe" j (fn thm => thm RS LA_Logic.not_leD) cx
479      | mk (NotLeDD j) cx = step1 "NLeD" j (fn thm => hd ([thm RS LA_Logic.not_leD] RL lessD)) cx
480      | mk (NotLessD j) cx = step1 "NL" j (fn thm => thm RS LA_Logic.not_lessD) cx
481      | mk (Added (j1, j2)) cx = step2 "+" j1 j2 (uncurry add_thms) cx |-> simp
482      | mk (Multiplied (n, j)) cx =
483          (trace_msg ctxt ("*" ^ string_of_int n); step1 "*" j (mult_thm n) cx)
484
485    fun finish ctxt' hyps thm =
486      thm
487      |> fold_rev (Thm.implies_intr o snd) hyps
488      |> singleton (Variable.export ctxt' ctxt)
489      |> fold (fn (i, _) => fn thm => nth asms i RS thm) hyps
490  in
491    let
492      val _ = trace_msg ctxt "mkthm";
493      val (thm, (_, hyps, ctxt')) = mk just ([], [], ctxt);
494      val _ = trace_thm ctxt ["Final thm:"] thm;
495      val fls = simplify simpset_ctxt thm;
496      val _ = trace_thm ctxt ["After simplification:"] fls;
497      val _ =
498        if LA_Logic.is_False fls then ()
499        else
500         (tracing (cat_lines
501           (["Assumptions:"] @ map (Thm.string_of_thm ctxt) asms @ [""] @
502            ["Proved:", Thm.string_of_thm ctxt fls, ""]));
503          warning "Linear arithmetic should have refuted the assumptions.\n\
504            \Please inform Tobias Nipkow.")
505    in finish ctxt' hyps fls end
506    handle FalseE (thm, hyps, ctxt') =>
507      trace_thm ctxt ["False reached early:"] (finish ctxt' hyps thm)
508  end;
509
510end;
511
512fun coeff poly atom =
513  AList.lookup Envir.aeconv poly atom |> the_default 0;
514
515fun integ(rlhs,r,rel,rrhs,s,d) =
516let val (rn,rd) = Rat.dest r and (sn,sd) = Rat.dest s
517    val m = Integer.lcms(map (snd o Rat.dest) (r :: s :: map snd rlhs @ map snd rrhs))
518    fun mult(t,r) =
519        let val (i,j) = Rat.dest r
520        in (t,i * (m div j)) end
521in (m,(map mult rlhs, rn*(m div rd), rel, map mult rrhs, sn*(m div sd), d)) end
522
523fun mklineq atoms =
524  fn (item, k) =>
525  let val (m, (lhs,i,rel,rhs,j,discrete)) = integ item
526      val lhsa = map (coeff lhs) atoms
527      and rhsa = map (coeff rhs) atoms
528      val diff = map2 (curry (op -)) rhsa lhsa
529      val c = i-j
530      val just = Asm k
531      fun lineq(c,le,cs,j) = Lineq(c,le,cs, if m=1 then j else Multiplied(m,j))
532  in case rel of
533      "<="   => lineq(c,Le,diff,just)
534     | "~<=" => if discrete
535                then lineq(1-c,Le,map (op ~) diff,NotLeDD(just))
536                else lineq(~c,Lt,map (op ~) diff,NotLeD(just))
537     | "<"   => if discrete
538                then lineq(c+1,Le,diff,LessD(just))
539                else lineq(c,Lt,diff,just)
540     | "~<"  => lineq(~c,Le,map (op~) diff,NotLessD(just))
541     | "="   => lineq(c,Eq,diff,just)
542     | _     => raise Fail ("mklineq" ^ rel)
543  end;
544
545(* ------------------------------------------------------------------------- *)
546(* Print (counter) example                                                   *)
547(* ------------------------------------------------------------------------- *)
548
549(* ------------------------------------------------------------------------- *)
550
551fun mknat (pTs : typ list) (ixs : int list) (atom : term, i : int) : lineq option =
552  if LA_Logic.is_nat (pTs, atom)
553  then let val l = map (fn j => if j=i then 1 else 0) ixs
554       in SOME (Lineq (0, Le, l, Nat i)) end
555  else NONE;
556
557(* This code is tricky. It takes a list of premises in the order they occur
558in the subgoal. Numerical premises are coded as SOME(tuple), non-numerical
559ones as NONE. Going through the premises, each numeric one is converted into
560a Lineq. The tricky bit is to convert ~= which is split into two cases < and
561>. Thus split_items returns a list of equation systems. This may blow up if
562there are many ~=, but in practice it does not seem to happen. The really
563tricky bit is to arrange the order of the cases such that they coincide with
564the order in which the cases are in the end generated by the tactic that
565applies the generated refutation thms (see function 'refute_tac').
566
567For variables n of type nat, a constraint 0 <= n is added.
568*)
569
570(* FIXME: To optimize, the splitting of cases and the search for refutations *)
571(*        could be intertwined: separate the first (fully split) case,       *)
572(*        refute it, continue with splitting and refuting.  Terminate with   *)
573(*        failure as soon as a case could not be refuted; i.e. delay further *)
574(*        splitting until after a refutation for other cases has been found. *)
575
576fun split_items ctxt do_pre split_neq (Ts, terms) : (typ list * (LA_Data.decomp * int) list) list =
577let
578  (* splits inequalities '~=' into '<' and '>'; this corresponds to *)
579  (* 'REPEAT_DETERM (eresolve_tac neqE i)' at the theorem/tactic    *)
580  (* level                                                          *)
581  (* FIXME: this is currently sensitive to the order of theorems in *)
582  (*        neqE:  The theorem for type "nat" must come first.  A   *)
583  (*        better (i.e. less likely to break when neqE changes)    *)
584  (*        implementation should *test* which theorem from neqE    *)
585  (*        can be applied, and split the premise accordingly.      *)
586  fun elim_neq (ineqs : (LA_Data.decomp option * bool) list) :
587               (LA_Data.decomp option * bool) list list =
588  let
589    fun elim_neq' _ ([] : (LA_Data.decomp option * bool) list) :
590                  (LA_Data.decomp option * bool) list list =
591          [[]]
592      | elim_neq' nat_only ((NONE, is_nat) :: ineqs) =
593          map (cons (NONE, is_nat)) (elim_neq' nat_only ineqs)
594      | elim_neq' nat_only ((ineq as (SOME (l, i, rel, r, j, d), is_nat)) :: ineqs) =
595          if rel = "~=" andalso (not nat_only orelse is_nat) then
596            (* [| ?l ~= ?r; ?l < ?r ==> ?R; ?r < ?l ==> ?R |] ==> ?R *)
597            elim_neq' nat_only (ineqs @ [(SOME (l, i, "<", r, j, d), is_nat)]) @
598            elim_neq' nat_only (ineqs @ [(SOME (r, j, "<", l, i, d), is_nat)])
599          else
600            map (cons ineq) (elim_neq' nat_only ineqs)
601  in
602    ineqs |> elim_neq' true
603          |> maps (elim_neq' false)
604  end
605
606  fun ignore_neq (NONE, bool) = (NONE, bool)
607    | ignore_neq (ineq as SOME (_, _, rel, _, _, _), bool) =
608      if rel = "~=" then (NONE, bool) else (ineq, bool)
609
610  fun number_hyps _ []             = []
611    | number_hyps n (NONE::xs)     = number_hyps (n+1) xs
612    | number_hyps n ((SOME x)::xs) = (x, n) :: number_hyps (n+1) xs
613
614  val result = (Ts, terms)
615    |> (* user-defined preprocessing of the subgoal *)
616       (if do_pre then LA_Data.pre_decomp ctxt else Library.single)
617    |> tap (fn subgoals => trace_msg ctxt ("Preprocessing yields " ^
618         string_of_int (length subgoals) ^ " subgoal(s) total."))
619    |> (* produce the internal encoding of (in-)equalities *)
620       map (apsnd (map (fn t => (LA_Data.decomp ctxt t, LA_Data.domain_is_nat t))))
621    |> (* splitting of inequalities *)
622       map (apsnd (if split_neq then elim_neq else
623                     Library.single o map ignore_neq))
624    |> maps (fn (Ts, subgoals) => map (pair Ts o map fst) subgoals)
625    |> (* numbering of hypotheses, ignoring irrelevant ones *)
626       map (apsnd (number_hyps 0))
627in
628  trace_msg ctxt ("Splitting of inequalities yields " ^
629    string_of_int (length result) ^ " subgoal(s) total.");
630  result
631end;
632
633fun refutes ctxt :
634    (typ list * (LA_Data.decomp * int) list) list -> injust list -> injust list option =
635  let
636    fun refute ((Ts, initems : (LA_Data.decomp * int) list) :: initemss) (js: injust list) =
637          let
638            val atoms = atoms_of (map fst initems)
639            val n = length atoms
640            val mkleq = mklineq atoms
641            val ixs = 0 upto (n - 1)
642            val iatoms = atoms ~~ ixs
643            val natlineqs = map_filter (mknat Ts ixs) iatoms
644            val ineqs = map mkleq initems @ natlineqs
645          in
646            (case elim ctxt (ineqs, []) of
647               Success j =>
648                 (trace_msg ctxt ("Contradiction! (" ^ string_of_int (length js + 1) ^ ")");
649                  refute initemss (js @ [j]))
650             | Failure _ => NONE)
651          end
652      | refute [] js = SOME js
653  in refute end;
654
655fun refute ctxt params do_pre split_neq terms : injust list option =
656  refutes ctxt (split_items ctxt do_pre split_neq (map snd params, terms)) [];
657
658fun count P xs = length (filter P xs);
659
660fun prove ctxt params do_pre Hs concl : bool * injust list option =
661  let
662    val _ = trace_msg ctxt "prove:"
663    (* append the negated conclusion to 'Hs' -- this corresponds to     *)
664    (* 'DETERM (resolve_tac [LA_Logic.notI, LA_Logic.ccontr] i)' at the *)
665    (* theorem/tactic level                                             *)
666    val Hs' = Hs @ [LA_Logic.neg_prop concl]
667    fun is_neq NONE                 = false
668      | is_neq (SOME (_,_,r,_,_,_)) = (r = "~=")
669    val neq_limit = Config.get ctxt LA_Data.neq_limit
670    val split_neq = count is_neq (map (LA_Data.decomp ctxt) Hs') <= neq_limit
671  in
672    if split_neq then ()
673    else
674      trace_msg ctxt ("neq_limit exceeded (current value is " ^
675        string_of_int neq_limit ^ "), ignoring all inequalities");
676    (split_neq, refute ctxt params do_pre split_neq Hs')
677  end handle TERM ("neg_prop", _) =>
678    (* since no meta-logic negation is available, we can only fail if   *)
679    (* the conclusion is not of the form 'Trueprop $ _' (simply         *)
680    (* dropping the conclusion doesn't work either, because even        *)
681    (* 'False' does not imply arbitrary 'concl::prop')                  *)
682    (trace_msg ctxt "prove failed (cannot negate conclusion).";
683      (false, NONE));
684
685fun refute_tac ctxt (i, split_neq, justs) =
686  fn state =>
687    let
688      val _ = trace_thm ctxt
689        ["refute_tac (on subgoal " ^ string_of_int i ^ ", with " ^
690          string_of_int (length justs) ^ " justification(s)):"] state
691      val neqE = get_neqE ctxt;
692      fun just1 j =
693        (* eliminate inequalities *)
694        (if split_neq then
695          REPEAT_DETERM (eresolve_tac ctxt neqE i)
696        else
697          all_tac) THEN
698          PRIMITIVE (trace_thm ctxt ["State after neqE:"]) THEN
699          (* use theorems generated from the actual justifications *)
700          Subgoal.FOCUS (fn {prems, ...} => resolve_tac ctxt [mkthm ctxt prems j] 1) ctxt i
701    in
702      (* rewrite "[| A1; ...; An |] ==> B" to "[| A1; ...; An; ~B |] ==> False" *)
703      DETERM (resolve_tac ctxt [LA_Logic.notI, LA_Logic.ccontr] i) THEN
704      (* user-defined preprocessing of the subgoal *)
705      DETERM (LA_Data.pre_tac ctxt i) THEN
706      PRIMITIVE (trace_thm ctxt ["State after pre_tac:"]) THEN
707      (* prove every resulting subgoal, using its justification *)
708      EVERY (map just1 justs)
709    end  state;
710
711(*
712Fast but very incomplete decider. Only premises and conclusions
713that are already (negated) (in)equations are taken into account.
714*)
715fun simpset_lin_arith_tac ctxt = SUBGOAL (fn (A, i) =>
716  let
717    val params = rev (Logic.strip_params A)
718    val Hs = Logic.strip_assums_hyp A
719    val concl = Logic.strip_assums_concl A
720    val _ = trace_term ctxt ["Trying to refute subgoal " ^ string_of_int i] A
721  in
722    case prove ctxt params true Hs concl of
723      (_, NONE) => (trace_msg ctxt "Refutation failed."; no_tac)
724    | (split_neq, SOME js) => (trace_msg ctxt "Refutation succeeded.";
725                               refute_tac ctxt (i, split_neq, js))
726  end);
727
728fun prems_lin_arith_tac ctxt =
729  Method.insert_tac ctxt (Simplifier.prems_of ctxt) THEN'
730  simpset_lin_arith_tac ctxt;
731
732fun lin_arith_tac ctxt =
733  simpset_lin_arith_tac (empty_simpset ctxt);
734
735
736
737(** Forward proof from theorems **)
738
739(* More tricky code. Needs to arrange the proofs of the multiple cases (due
740to splits of ~= premises) such that it coincides with the order of the cases
741generated by function split_items. *)
742
743datatype splittree = Tip of thm list
744                   | Spl of thm * cterm * splittree * cterm * splittree;
745
746(* "(ct1 ==> ?R) ==> (ct2 ==> ?R) ==> ?R" is taken to (ct1, ct2) *)
747
748fun extract (imp : cterm) : cterm * cterm =
749let val (Il, r)    = Thm.dest_comb imp
750    val (_, imp1)  = Thm.dest_comb Il
751    val (Ict1, _)  = Thm.dest_comb imp1
752    val (_, ct1)   = Thm.dest_comb Ict1
753    val (Ir, _)    = Thm.dest_comb r
754    val (_, Ict2r) = Thm.dest_comb Ir
755    val (Ict2, _)  = Thm.dest_comb Ict2r
756    val (_, ct2)   = Thm.dest_comb Ict2
757in (ct1, ct2) end;
758
759fun splitasms ctxt (asms : thm list) : splittree =
760let val neqE = get_neqE ctxt
761    fun elim_neq [] (asms', []) = Tip (rev asms')
762      | elim_neq [] (asms', asms) = Tip (rev asms' @ asms)
763      | elim_neq (_ :: neqs) (asms', []) = elim_neq neqs ([],rev asms')
764      | elim_neq (neqs as (neq :: _)) (asms', asm::asms) =
765      (case get_first (fn th => SOME (asm COMP th) handle THM _ => NONE) [neq] of
766        SOME spl =>
767          let val (ct1, ct2) = extract (Thm.cprop_of spl)
768              val thm1 = Thm.assume ct1
769              val thm2 = Thm.assume ct2
770          in Spl (spl, ct1, elim_neq neqs (asms', asms@[thm1]),
771            ct2, elim_neq neqs (asms', asms@[thm2]))
772          end
773      | NONE => elim_neq neqs (asm::asms', asms))
774in elim_neq neqE ([], asms) end;
775
776fun fwdproof ctxt (Tip asms : splittree) (j::js : injust list) = (mkthm ctxt asms j, js)
777  | fwdproof ctxt (Spl (thm, ct1, tree1, ct2, tree2)) js =
778      let
779        val (thm1, js1) = fwdproof ctxt tree1 js
780        val (thm2, js2) = fwdproof ctxt tree2 js1
781        val thm1' = Thm.implies_intr ct1 thm1
782        val thm2' = Thm.implies_intr ct2 thm2
783      in (thm2' COMP (thm1' COMP thm), js2) end;
784      (* FIXME needs handle THM _ => NONE ? *)
785
786fun prover ctxt thms Tconcl (js : injust list) split_neq pos : thm option =
787  let
788    val nTconcl = LA_Logic.neg_prop Tconcl
789    val cnTconcl = Thm.cterm_of ctxt nTconcl
790    val nTconclthm = Thm.assume cnTconcl
791    val tree = (if split_neq then splitasms ctxt else Tip) (thms @ [nTconclthm])
792    val (Falsethm, _) = fwdproof ctxt tree js
793    val contr = if pos then LA_Logic.ccontr else LA_Logic.notI
794    val concl = Thm.implies_intr cnTconcl Falsethm COMP contr
795  in SOME (trace_thm ctxt ["Proved by lin. arith. prover:"] (LA_Logic.mk_Eq concl)) end
796  (*in case concl contains ?-var, which makes assume fail:*)   (* FIXME Variable.import_terms *)
797  handle THM _ => NONE;
798
799(* PRE: concl is not negated!
800   This assumption is OK because
801   1. lin_arith_simproc tries both to prove and disprove concl and
802   2. lin_arith_simproc is applied by the Simplifier which
803      dives into terms and will thus try the non-negated concl anyway.
804*)
805fun lin_arith_simproc ctxt concl =
806  let
807    val thms = maps LA_Logic.atomize (Simplifier.prems_of ctxt)
808    val Hs = map Thm.prop_of thms
809    val Tconcl = LA_Logic.mk_Trueprop (Thm.term_of concl)
810  in
811    case prove ctxt [] false Hs Tconcl of (* concl provable? *)
812      (split_neq, SOME js) => prover ctxt thms Tconcl js split_neq true
813    | (_, NONE) =>
814        let val nTconcl = LA_Logic.neg_prop Tconcl in
815          case prove ctxt [] false Hs nTconcl of (* ~concl provable? *)
816            (split_neq, SOME js) => prover ctxt thms nTconcl js split_neq false
817          | (_, NONE) => NONE
818        end
819  end;
820
821end;
822