1structure IntDP_Munge :> IntDP_Munge =
2struct
3
4structure Parse = struct
5  open Parse
6  val (Type,Term) = parse_from_grammars int_arithTheory.int_arith_grammars
7end
8open Parse
9
10open HolKernel boolLib intSyntax boolSyntax CooperSyntax integerTheory
11     int_arithTheory intReduce
12
13val ERR = mk_HOL_ERR "IntDP_Munge";
14
15val normalise_mult = OmegaMath.NORMALISE_MULT
16
17(* this draws on similar code in Richard Boulton's natural number
18   arithmetic decision procedure *)
19
20fun contains_var tm =
21    if numSyntax.is_numeral tm then false
22    else
23      case dest_term tm of
24        COMB(f,x) => contains_var f orelse contains_var x
25      | LAMB(v,b) => contains_var b
26      | VAR _ => true
27      | CONST{Ty, ...} => Ty = numSyntax.num orelse Ty = int_ty
28fun is_linear_mult tm =
29  is_mult tm andalso
30  not (contains_var (rand tm) andalso contains_var (rand (rator tm)))
31fun land tm = rand (rator tm)
32
33fun non_zero tm =
34    if is_negated tm then non_zero (rand tm)
35    else tm <> zero_tm
36
37(* returns a list of pairs, where the first element of each pair is a non-
38   Presburger term that occurs in tm, and where the second is a boolean
39   that is true if none of the variables that occur in the term are
40   bound by a quantifier. *)
41fun non_presburger_subterms0 ctxt tm =
42  if
43    (is_forall tm orelse is_exists1 tm orelse is_exists tm) andalso
44    (type_of (bvar (rand tm)) = int_ty)
45  then let
46    val abst = rand tm
47  in
48    non_presburger_subterms0 (Lib.union [bvar abst] ctxt) (body abst)
49  end
50  else if is_neg tm orelse is_absval tm orelse is_negated tm then
51    non_presburger_subterms0 ctxt (rand tm)
52  else if (is_cond tm) then let
53    val (b, t1, t2) = dest_cond tm
54  in
55    Lib.U [non_presburger_subterms0 ctxt b, non_presburger_subterms0 ctxt t1,
56           non_presburger_subterms0 ctxt t2]
57  end
58  else if (is_great tm orelse is_geq tm orelse is_eq tm orelse
59           is_less tm orelse is_leq tm orelse is_conj tm orelse
60           is_disj tm orelse is_imp tm orelse is_plus tm orelse
61           is_minus tm orelse is_linear_mult tm) then
62    Lib.union (non_presburger_subterms0 ctxt (land tm))
63              (non_presburger_subterms0 ctxt (rand tm))
64  else if (is_divides tm andalso is_int_literal (land tm)) then
65    non_presburger_subterms0 ctxt (rand tm)
66  else if ((is_div tm orelse is_mod tm) andalso
67           is_int_literal (rand tm) andalso
68           non_zero (rand tm)) then
69    non_presburger_subterms0 ctxt (land tm)
70  else if is_int_literal tm then []
71  else if is_var tm andalso type_of tm = int_ty then []
72  else if (tm = true_tm orelse tm = false_tm) then []
73  else [(tm, not (List.exists (fn v => free_in v tm) ctxt))]
74
75val is_presburger = null o non_presburger_subterms0 []
76val non_presburger_subterms = map #1 o non_presburger_subterms0 []
77
78fun is_natlin_mult tm =
79    numSyntax.is_mult tm andalso
80    not (contains_var (land tm) andalso contains_var (rand tm))
81
82fun nat_nonpresburgers tm =
83    if is_forall tm orelse is_exists tm orelse is_exists1 tm then
84      nat_nonpresburgers (body (rand tm))
85    else if is_conj tm orelse is_disj tm orelse
86            (is_imp tm andalso not (is_neg tm)) orelse
87            is_great tm orelse is_leq tm orelse is_eq tm orelse
88            is_minus tm orelse is_less tm orelse is_geq tm orelse
89            is_linear_mult tm
90    then
91      HOLset.union (nat_nonpresburgers (land tm), nat_nonpresburgers (rand tm))
92    else if is_neg tm orelse is_injected tm orelse is_Num tm orelse
93            numSyntax.is_suc tm
94    then
95      nat_nonpresburgers (rand tm)
96    else if is_cond tm then
97      HOLset.union
98      (HOLset.union (nat_nonpresburgers (rand (rator (rator tm))),
99                     nat_nonpresburgers (land tm)),
100       nat_nonpresburgers (rand tm))
101    else
102      let open numSyntax
103      in
104        if is_greater tm orelse is_geq tm orelse is_less tm orelse
105           is_leq tm orelse is_plus tm orelse is_minus tm orelse
106           is_natlin_mult tm
107        then
108          HOLset.union (nat_nonpresburgers (land tm),
109                        nat_nonpresburgers (rand tm))
110        else if is_numeral tm then empty_tmset
111        else if is_var tm then empty_tmset
112        else HOLset.add(empty_tmset, tm)
113      end
114
115val x_var = mk_var("x", int_ty)
116val c_var = mk_var("c", int_ty)
117fun elim_div_mod0 exp t = let
118  val divmods =
119      HOLset.listItems (find_free_terms (fn t => is_mod t orelse is_div t) t)
120  fun elim_t to_elim = let
121    val ((num,divisor), c1, c2, thm) = let
122      val (c1, c2) = if exp then (RAND_CONV o LAND_CONV, RAND_CONV o RAND_CONV)
123                     else (LAND_CONV o RAND_CONV, RAND_CONV)
124    in
125      (dest_div to_elim, c1, c2, if exp then INT_DIV_P else INT_DIV_FORALL_P)
126      handle HOL_ERR _ => (dest_mod to_elim, c1, c2,
127                           if exp then INT_MOD_P else INT_MOD_FORALL_P)
128    end
129    val div_nzero = EQT_ELIM (REDUCE_CONV (mk_neg(mk_eq(divisor, zero_tm))))
130    val abs_div = REDUCE_CONV (mk_absval divisor)
131    val rwt = MP (Thm.INST [x_var |-> num, c_var |-> divisor] (SPEC_ALL thm))
132                 div_nzero
133  in
134    UNBETA_CONV to_elim THENC REWR_CONV rwt THENC
135    STRIP_QUANT_CONV (c1 REDUCE_CONV THENC c2 BETA_CONV)
136  end
137in
138  case divmods of
139    [] => ALL_CONV
140  | _ => FIRST_CONV (map elim_t divmods) THENC elim_div_mod0 exp
141end t
142
143fun elim_div_mod t = let
144  (* can't just apply elim_div_mod to a term with quantifiers because the
145     elimination of x/c relies on x being free.  So we need to traverse
146     the term underneath the quantifiers.  It may also help to get the
147     quantifiers to have scope over as little of the term as possible. *)
148  val exp = goal_qtype t = qsEXISTS
149  fun recurse passed_a_binder tm = let
150  in
151    if is_exists tm orelse is_forall tm orelse is_exists1 tm then
152      BINDER_CONV (recurse true)
153    else if is_abs tm then ABS_CONV (recurse true)
154    else
155      (if passed_a_binder then TRY_CONV (elim_div_mod0 exp)
156       else ALL_CONV) THENC
157      SUB_CONV (recurse false)
158  end tm
159in
160  recurse true t
161end
162
163
164fun decide_fv_presburger DPname DP tm = let
165  fun is_int_const tm = type_of tm = int_ty andalso is_const tm
166  val fvs = free_vars tm @ (Lib.mk_set (find_terms is_int_const tm))
167  fun dest_atom tm = dest_const tm handle HOL_ERR _ => dest_var tm
168  fun gen(bv, t) =
169    if is_var bv then mk_forall(bv, t)
170    else let
171      val gv = genvar int_ty
172    in
173      mk_forall(gv, subst [bv |-> gv] t)
174    end
175  val preprocess = elim_div_mod THENC REWRITE_CONV [INT_ABS]
176  val doit = preprocess THENC DP
177in
178  if null fvs then doit tm
179  else let
180    val newtm = List.foldr gen tm fvs   (* as there are no non-presburger
181                                           sub-terms, all these variables
182                                           will be of integer type *)
183  in
184    EQT_INTRO (SPECL fvs (EQT_ELIM (doit newtm)))
185  end handle HOL_ERR _ =>
186    raise ERR DPname
187      ("Tried to prove generalised goal (generalising "^
188       #1 (dest_atom (hd fvs))^"...) but it was false")
189end
190
191
192fun abs_inj inj_n tm = let
193  val gv = genvar int_ty
194  val tm1 = subst [inj_n |-> gv] tm
195in
196  GSYM (BETA_CONV (mk_comb(mk_abs(gv,tm1), inj_n)))
197end
198
199fun eliminate_nat_quants tm = let
200in
201  if is_forall tm orelse is_exists tm orelse is_exists1 tm then let
202    val (bvar, body) = dest_abs (rand tm)
203  in
204    if type_of bvar = num_ty then let
205      val inj_bvar = mk_comb(int_injection, bvar)
206      val rewrite_qaway =
207        REWR_CONV (if is_forall tm then INT_NUM_FORALL
208                   else if is_exists tm then INT_NUM_EXISTS
209                   else INT_NUM_UEXISTS) THENC
210        BINDER_CONV (RAND_CONV BETA_CONV)
211    in
212      BINDER_CONV (abs_inj inj_bvar) THENC rewrite_qaway THENC
213      RENAME_VARS_CONV [#1 (dest_var bvar)] THENC
214      BINDER_CONV eliminate_nat_quants
215    end
216    else
217      BINDER_CONV eliminate_nat_quants
218  end
219    else if is_neg tm then (* must test for is_neg before is_imp *)
220      RAND_CONV eliminate_nat_quants
221    else if (is_conj tm orelse is_disj tm orelse is_eq tm orelse
222             is_imp tm) then
223      BINOP_CONV eliminate_nat_quants
224    else if is_cond tm then
225      RAND_CONV eliminate_nat_quants THENC
226      LAND_CONV eliminate_nat_quants THENC
227      RATOR_CONV (RATOR_CONV (RAND_CONV eliminate_nat_quants))
228    else ALL_CONV
229end tm handle HOL_ERR {origin_function = "REWR_CONV", ...} =>
230  raise ERR "IntDP_Munge" "Uneliminable natural number term remains"
231
232
233fun tacTHEN t1 t2 tm = let
234  val (g1, v1) = t1 tm
235  val (g2, v2) = t2 g1
236in
237  (g2, v1 o v2)
238end
239fun tacALL tm = (tm, I)
240fun tacMAP_EVERY tlist =
241    case tlist of
242      [] => tacALL
243    | (t1::ts) => tacTHEN t1 (tacMAP_EVERY ts)
244fun tacCONV c tm = let
245  val thm = c tm
246in
247  (rhs (concl thm), TRANS thm)
248end handle UNCHANGED => (tm, I)
249fun tacRGEN t = let
250  val (fvs, body) = strip_forall t
251  val prove_it = EQT_INTRO o GENL fvs o EQT_ELIM
252in
253  (body, prove_it)
254end
255val op tTHEN = fn (t1, t2) => tacTHEN t1 t2
256infix tTHEN
257
258
259fun subtm_rel (t1, t2) =
260    case Int.compare(term_size t1, term_size t2) of
261      LESS => GREATER
262    | EQUAL => EQUAL
263    | GREATER => LESS
264
265local
266  open arithmeticTheory numSyntax
267  val Num_lemma = prove(
268    ``&(Num i) = if 0 <= i then i else & ((Num o I) i)``,
269    COND_CASES_TAC THEN
270    ASM_REWRITE_TAC [combinTheory.o_THM, integerTheory.INT_OF_NUM,
271                     combinTheory.I_THM])
272
273  val rewrites = [GSYM INT_INJ, GSYM INT_LT, GSYM INT_LE,
274                  GREATER_DEF, GREATER_EQ, GSYM INT_ADD,
275                  GSYM INT_MUL, INT, INT_NUM_COND, Num_lemma]
276  val p_var = mk_var("p", num)
277  val q_var = mk_var("q", num)
278  fun elim_div_mod0 exp t = let
279    val divmods =
280        HOLset.listItems (find_free_terms (fn t => is_mod t orelse is_div t) t)
281    fun elim_t to_elim = let
282      val ((num,divisor), (thm, c)) =
283          (dest_div to_elim, if exp then (DIV_P, RAND_CONV)
284                             else (DIV_P_UNIV, I))
285          handle HOL_ERR _ => (dest_mod to_elim, if exp then (MOD_P, RAND_CONV)
286                                                 else (MOD_P_UNIV, I))
287      val div_nzero = EQT_ELIM (REDUCE_CONV (mk_less(zero_tm, divisor)))
288      fun findinst thm =
289          Thm.INST (#1 (match_term (rand (lhs (#2 (dest_imp (concl thm)))))
290                                   to_elim))
291                   thm
292      val rwt = MP (findinst (SPEC_ALL thm)) div_nzero
293    in
294      UNBETA_CONV to_elim THENC REWR_CONV rwt THENC
295      STRIP_QUANT_CONV (RAND_CONV (c BETA_CONV))
296    end
297  in
298    case divmods of
299      [] => ALL_CONV
300    | _ => FIRST_CONV (map elim_t divmods) THENC elim_div_mod0 exp
301  end t
302  fun elim_div_mod t = let
303    val exp = goal_qtype t = qsEXISTS andalso
304              HOLset.isEmpty (nat_nonpresburgers t)
305    fun recurse passed_a_binder tm = let
306    in
307      if is_exists tm orelse is_forall tm orelse is_exists1 tm then
308        BINDER_CONV (recurse true)
309      else if is_abs tm then
310        ABS_CONV (recurse true)
311      else
312        (if passed_a_binder then TRY_CONV (elim_div_mod0 exp)
313         else ALL_CONV) THENC
314        SUB_CONV (recurse false)
315    end tm
316  in
317    recurse true t
318  end
319  fun term_size t = let
320    val (f,x) = dest_comb t
321  in
322    term_size f + term_size x
323  end handle HOL_ERR _ => term_size (body t) + 1
324      handle HOL_ERR _ => 1
325
326  (* two functions below derived from RJB's Sub_and_cond.sml *)
327  fun op_of_app tm = op_of_app (rator tm) handle _ => tm
328in
329  fun COND_ABS_CONV tm = let
330    open Rsyntax
331    val {Bvar=v,Body=bdy} = dest_abs tm
332    val {cond,larm=x,rarm=y} = Rsyntax.dest_cond bdy
333    val b = assert (not o Lib.mem v o free_vars) cond
334    val _ = assert (fn t => type_of t <> bool) x
335    val xf = mk_abs{Bvar=v,Body=x}
336    val yf = mk_abs{Bvar=v,Body=y}
337    val th1 = INST_TYPE [alpha |-> type_of v, beta |-> type_of x] COND_ABS
338    val th2 = SPECL [b,xf,yf] th1
339  in
340    CONV_RULE (RATOR_CONV
341                 (RAND_CONV (ABS_CONV
342                               (RATOR_CONV (RAND_CONV BETA_CONV) THENC
343                                RAND_CONV BETA_CONV) THENC
344                               ALPHA_CONV v))) th2
345  end handle HOL_ERR _ => failwith "COND_ABS_CONV"
346  val NBOOL_COND_RATOR_CONV = REWR_CONV COND_RATOR
347  fun NBOOL_COND_RAND_CONV tm = let
348    val (f, cnd) = dest_comb tm
349  in
350    if same_const f conditional orelse
351       (type_of (rand cnd) <> bool andalso
352        not (same_const (op_of_app f) conditional))
353    then
354      (* guard above allows rewrite of
355           COND (COND p q r) x y
356         which will go to
357           (COND p (COND q) (COND r)) x y
358         COND q and COND r will get exposed to x and y ; term duplicates
359         x and y; hope this doens't happen too often. *)
360      REWR_CONV COND_RAND tm
361    else
362      NO_CONV tm
363  end
364
365val nat_rewrites =
366    [arithmeticTheory.LEFT_ADD_DISTRIB, arithmeticTheory.RIGHT_ADD_DISTRIB,
367     arithmeticTheory.MAX_DEF, arithmeticTheory.MIN_DEF,
368     arithmeticTheory.ODD_EXISTS, arithmeticTheory.EVEN_EXISTS]
369
370val dealwith_nats = let
371  val phase1 =
372      tacCONV (PURE_REWRITE_CONV nat_rewrites THENC
373               ONCE_DEPTH_CONV normalise_mult THENC
374               elim_div_mod THENC
375               (* eliminate nasty subtractions *)
376               TOP_DEPTH_CONV (Thm_convs.SUB_NORM_CONV ORELSEC
377                               NBOOL_COND_RATOR_CONV ORELSEC
378                               NBOOL_COND_RAND_CONV ORELSEC
379                               COND_ABS_CONV))
380  fun do_pbs tm = let
381    val non_pbs0 = HOLset.listItems (nat_nonpresburgers tm)
382    val non_pbs = Listsort.sort subtm_rel
383                                (List.filter (equal num_ty o type_of) non_pbs0)
384    val initially =
385        if null non_pbs then tacALL
386        else if goal_qtype tm = qsUNIV then
387          tacCONV move_quants_up tTHEN tacRGEN
388        else tacRGEN
389    fun tactic subtm tm = let
390      (* return both a newtm and a function that will convert a theorem
391         of the form <new term> = T into tm = T *)
392      val gv = genvar numSyntax.num
393      val newterm = mk_forall (gv, Term.subst [subtm |-> gv] tm)
394      fun prove_it thm =
395          EQT_INTRO (SPEC subtm (EQT_ELIM thm))
396          handle HOL_ERR _ =>
397                 raise ERR "COOPER_CONV"
398                           ("Tried to prove generalised goal (generalising "^
399                            Parse.term_to_string subtm^"...) but it was false")
400    in
401      (newterm, prove_it)
402    end
403  in
404    initially tTHEN tacMAP_EVERY (map tactic non_pbs)
405  end tm
406in
407 phase1 tTHEN do_pbs tTHEN
408 tacCONV (PURE_REWRITE_CONV rewrites THENC eliminate_nat_quants)
409end
410end (* local *)
411
412(* subterms is a list of subterms all of integer type *)
413fun decide_nonpbints_presburger DPname DP subterms tm = let
414  fun tactic subtm tm =
415    (* return both a new term and a function that will convert a theorem
416       of the form <new term> = T into tm = T *)
417    if is_comb subtm andalso rator subtm = int_injection then let
418      val n = rand subtm
419      val thm0 = abs_inj subtm tm (* |- tm = P subtm *)
420      val tm0 = rhs (concl thm0)
421      val gv = genvar num_ty
422      val tm1 = mk_forall(gv, mk_comb (rator tm0, mk_comb(int_injection, gv)))
423      val thm1 =  (* |- (!gv. P gv) = !x. 0 <= x ==> P x *)
424        (REWR_CONV INT_NUM_FORALL THENC
425         BINDER_CONV (RAND_CONV BETA_CONV)) tm1
426      fun prove_it thm = let
427        val without_true = EQT_ELIM thm (* |- !x. 0 <= x ==> P x *)
428        val univ_nat = EQ_MP (SYM thm1) without_true
429        val spec_nat = SPEC n univ_nat
430      in
431        EQT_INTRO (EQ_MP (SYM thm0) spec_nat)
432      end
433    in
434      (rhs (concl thm1), prove_it)
435    end
436    else let
437      val gv = genvar int_ty
438    in
439      (mk_forall(gv, subst [subtm |-> gv] tm),
440       EQT_INTRO o SPEC subtm o EQT_ELIM)
441    end
442  val (goal, vfn) = tacMAP_EVERY (map tactic subterms) tm
443  val thm = decide_fv_presburger DPname DP goal
444in
445  vfn thm handle HOL_ERR _ =>
446    raise ERR DPname
447      ("Tried to prove generalised goal (generalising "^
448       Parse.term_to_string (hd subterms)^"...) but it was false")
449end
450
451val int_rewrites =
452  [INT_LDISTRIB, INT_RDISTRIB, INT_MAX, INT_MIN]
453
454fun BASIC_CONV DPname DP tm = let
455  val (natgoal, natvalidation) = dealwith_nats tm
456  val stage1 = PURE_REWRITE_CONV int_rewrites THENC
457               ONCE_DEPTH_CONV normalise_mult
458  fun stage2 tm =
459    case non_presburger_subterms0 [] tm of
460      [] => decide_fv_presburger DPname DP tm
461    | non_pbs => let
462      in
463        case List.find (fn (t,_) => type_of t <> int_ty) non_pbs of
464          NONE => let
465            val (igoal, initvfn) =
466                case List.find (fn (_, b) => not b) non_pbs of
467                  NONE => (tm, I)
468                | SOME _ =>
469                  if goal_qtype tm = qsUNIV then
470                    (tacCONV move_quants_up tTHEN tacRGEN) tm
471                  else tacRGEN tm
472            val init_nonpbs =
473                Listsort.sort (inv_img_cmp #1 subtm_rel)
474                              (non_presburger_subterms0 [] igoal)
475          in
476            case List.find (fn (_, b) => not b) init_nonpbs of
477              NONE =>
478              initvfn (decide_nonpbints_presburger
479                       DPname
480                       DP
481                       (map #1 init_nonpbs) igoal)
482            | SOME (t, _) =>
483              raise ERR DPname
484                    ("Couldn't free quantification over non-Presburger "^
485                     "sub-term "^Parse.term_to_string t)
486          end
487        | SOME (t,_) => raise ERR DPname
488            ("Not in the allowed subset; consider "^Parse.term_to_string t)
489      end
490in
491  natvalidation ((stage1 THENC stage2) natgoal)
492end
493
494fun ok_asm th = let
495  val exists_th = goal_qtype (concl th) = qsEXISTS
496  fun check(t, free_p) =
497      mem (type_of t) [intSyntax.int_ty, numSyntax.num] andalso
498      (exists_th orelse free_p)
499  val dodgy_subterms0 = non_presburger_subterms0 [] (concl th)
500  fun ignore_nats ((t, free_p), acc) = let
501    val nat_set = nat_nonpresburgers t
502    fun foldthis (nt, acc) = HOLset.add(acc, (nt, free_p))
503  in
504    HOLset.foldl foldthis acc nat_set
505  end
506  fun bcompare (b1, b2) = if b1 = b2 then EQUAL
507                          else if b1 then GREATER
508                          else LESS
509  val empty_pairs = HOLset.empty (pair_compare(Term.compare, bcompare))
510  val dodgy_subterms = List.foldl ignore_nats empty_pairs dodgy_subterms0
511in
512  not (isSome (HOLset.find (not o check) dodgy_subterms))
513end
514
515fun conv_tac c =
516    REPEAT (FIRST_X_ASSUM (MP_TAC o assert ok_asm)) THEN
517    CONV_TAC c
518
519end; (* struct *)
520