1(* ========================================================================= *)
2(* PREDICATE SUBTYPE PROVER                                                  *)
3(* ========================================================================= *)
4
5structure subtypeTools :> subtypeTools =
6struct
7
8open HolKernel Parse boolLib bossLib res_quanTools;
9
10val ERR = mk_HOL_ERR "subtypeTools";
11val Bug = mlibUseful.Bug;
12val Error = ERR "";
13
14(* ------------------------------------------------------------------------- *)
15(* Helper proof tools.                                                       *)
16(* ------------------------------------------------------------------------- *)
17
18fun bool_compare (true,false) = LESS
19  | bool_compare (false,true) = GREATER
20  | bool_compare _ = EQUAL;
21
22val dest_in = dest_binop pred_setSyntax.in_tm (ERR "dest_in" "");
23
24val is_in = can dest_in;
25
26val abbrev_tm = ``Abbrev``;
27
28fun dest_abbrev tm =
29    let
30      val (c,t) = dest_comb tm
31      val () = if same_const c abbrev_tm then () else raise ERR "dest_abbrev" ""
32    in
33      dest_eq t
34    end;
35
36val is_abbrev = can dest_abbrev;
37
38val norm_rule =
39    SIMP_RULE (simpLib.++ (pureSimps.pure_ss, resq_SS))
40      [GSYM LEFT_FORALL_IMP_THM, GSYM RIGHT_FORALL_IMP_THM,
41       AND_IMP_INTRO, GSYM CONJ_ASSOC];
42
43fun match_tac th =
44    let
45      val th = norm_rule th
46      val (_,tm) = strip_forall (concl th)
47    in
48      (if is_imp tm then MATCH_MP_TAC else MATCH_ACCEPT_TAC) th
49    end;
50
51fun flexible_solver solver cond =
52    let
53      val cond_th = solver cond
54      val cond_tm = concl cond_th
55    in
56      if cond_tm = cond then cond_th
57      else if cond_tm = mk_eq (cond,T) then EQT_ELIM cond_th
58      else raise Bug "flexible_solver: solver didn't prove condition"
59    end;
60
61fun cond_rewr_conv rewr =
62    let
63      val rewr = SPEC_ALL (norm_rule rewr)
64      val rewr_tm = concl rewr
65      val (no_cond,eq) =
66          case total dest_imp_only rewr_tm of
67            NONE => (true,rewr_tm)
68          | SOME (_,eq) => (false,eq)
69      val pat = lhs eq
70    in
71      fn solver => fn tm =>
72      let
73        val sub = match_term pat tm
74        val th = INST_TY_TERM sub rewr
75      in
76        if no_cond then th
77        else MP th (flexible_solver solver (rand (rator (concl th))))
78      end
79    end;
80
81fun cond_rewrs_conv ths =
82    let
83      val solver_convs = map cond_rewr_conv ths
84      fun mk_conv solver solver_conv = solver_conv solver
85    in
86      fn solver => FIRST_CONV (map (mk_conv solver) solver_convs)
87    end;
88
89local
90  type cache = (term,thm) Binarymap.dict ref;
91
92  fun in_cache cache (asl,g) =
93      case Binarymap.peek (cache,g) of
94        NONE => NONE
95      | SOME th =>
96        if List.all (fn h => mem h asl) (hyp th) then SOME th else NONE;
97in
98  fun cache_new () = ref (Binarymap.mkDict compare);
99
100  fun cache_tac (cache : cache) (goal as (_,g)) =
101      case in_cache (!cache) goal of
102        SOME th => ([], fn [] => th | _ => raise Fail "cache_tac: hit")
103      | NONE =>
104        ([goal],
105         fn [th] => (cache := Binarymap.insert (!cache, g, th); th)
106          | _ => raise Fail "cache_tac: miss");
107end;
108
109fun print_tac s goal = (print s; ALL_TAC goal);
110
111(* ------------------------------------------------------------------------- *)
112(* Solver conversions.                                                       *)
113(* ------------------------------------------------------------------------- *)
114
115type solver_conv = Conv.conv -> Conv.conv;
116
117fun binop_ac_conv info =
118    let
119      val {term_compare,
120           dest_binop,
121           dest_inv,
122           dest_exp,
123           assoc_th,
124           comm_th,
125           comm_th',
126           id_ths,
127           simplify_ths,
128           combine_ths,
129           combine_ths'} = info
130
131      val is_binop = can dest_binop
132      and is_inv = can dest_inv
133      and is_exp = can dest_exp
134
135      fun dest tm =
136          let
137            val (pos,tm) =
138                case total dest_inv tm of
139                  NONE => (true,tm)
140                | SOME (_ : term, tm) => (false,tm)
141            val (sing,tm) =
142                case total dest_exp tm of
143                  NONE => (true,tm)
144                | SOME (_ : term, tm, _ : term) => (false,tm)
145          in
146            (tm,pos,sing)
147          end
148
149      fun cmp (x,y) =
150          let
151            val (xt,xp,xs) = dest x
152            and (yt,yp,ys) = dest y
153          in
154            case term_compare (xt,yt) of
155              LESS => (true,false)
156            | EQUAL =>
157              (case bool_compare (xp,yp) of
158                 LESS => (true,true)
159               | EQUAL =>
160                 (case bool_compare (xs,ys) of
161                    LESS => (true,true)
162                  | EQUAL => (true,true)
163                  | GREATER => (false,true))
164               | GREATER => (false,true))
165            | GREATER => (false,false)
166          end
167
168      val assoc_conv = cond_rewr_conv assoc_th
169
170      val comm_conv = cond_rewr_conv comm_th
171
172      val comm_conv' = cond_rewr_conv comm_th'
173
174      val id_conv = cond_rewrs_conv id_ths
175
176      val term_simplify_conv = cond_rewrs_conv simplify_ths
177
178      val term_combine_conv =
179          let
180            val conv = cond_rewrs_conv combine_ths
181          in
182            fn solver =>
183               conv solver THENC
184               reduceLib.REDUCE_CONV THENC
185               TRY_CONV (term_simplify_conv solver)
186          end
187
188      val term_combine_conv' =
189          let
190            val conv = cond_rewrs_conv combine_ths'
191          in
192            fn solver =>
193               conv solver THENC
194               LAND_CONV
195                 (reduceLib.REDUCE_CONV THENC
196                  TRY_CONV (term_simplify_conv solver)) THENC
197               TRY_CONV (id_conv solver)
198          end
199
200      fun push_conv solver tm =
201          TRY_CONV
202          let
203            val (_,a,b) = dest_binop tm
204          in
205            case total dest_binop b of
206              NONE =>
207              let
208                val (ok,eq) = cmp (a,b)
209              in
210                (if ok then ALL_CONV else comm_conv solver) THENC
211                (if eq then TRY_CONV (term_combine_conv solver) else ALL_CONV)
212              end
213            | SOME (_,b,_) =>
214              let
215                val (ok,eq) = cmp (a,b)
216              in
217                (if ok then ALL_CONV else comm_conv' solver) THENC
218                ((if eq then term_combine_conv' solver else NO_CONV) ORELSEC
219                 (if ok then ALL_CONV else push_conv' solver))
220              end
221          end tm
222      and push_conv' solver =
223          RAND_CONV (push_conv solver) THENC TRY_CONV (id_conv solver)
224
225      (* Does not raise an exception *)
226      fun ac_conv solver tm =
227          (case total dest_binop tm of
228             NONE => TRY_CONV (term_simplify_conv solver THENC ac_conv solver)
229           | SOME (_,a,b) =>
230             if is_binop a then
231               TRY_CONV (assoc_conv solver THENC ac_conv solver)
232             else
233               ((id_conv solver ORELSEC
234                 LAND_CONV (term_simplify_conv solver)) THENC
235                ac_conv solver) ORELSEC
236               (if is_binop b then
237                  RAND_CONV (ac_conv solver) THENC push_conv solver
238                else
239                  (RAND_CONV (term_simplify_conv solver) THENC
240                   ac_conv solver) ORELSEC
241                  push_conv solver)) tm
242    in
243      (***trace_conv "alg_binop_ac_conv" o***) CHANGED_CONV o ac_conv
244    end;
245
246(* ------------------------------------------------------------------------- *)
247(* Named conversions.                                                        *)
248(* ------------------------------------------------------------------------- *)
249
250type named_conv = {name : string, key : Term.term, conv : solver_conv};
251
252fun named_conv_to_simpset_conv solver_conv =
253    let
254      val {name : string, key : term, conv : conv -> conv} = solver_conv
255      val key = SOME ([] : term list, key)
256      and conv = fn c => fn tms : term list => conv (c tms)
257      and trace = 2
258    in
259      {name = name, key = key, conv = conv, trace = trace}
260    end;
261
262(* ------------------------------------------------------------------------- *)
263(* Subtype contexts.                                                         *)
264(* ------------------------------------------------------------------------- *)
265
266val ORACLE = ref false;
267
268fun ORACLE_solver goal =
269    EQT_INTRO (mk_oracle_thm "algebra_dproc" ([],goal));
270
271type named_conv = {name : string, key : term, conv : conv -> conv};
272
273datatype context =
274    Context of {rewrites : thm list,
275                conversions :  named_conv list,
276                reductions : thm list,
277                judgements : thm list,
278                dproc_cache : (term,thm) Binarymap.dict ref};
279
280fun pp p context =
281    let
282      val Context {rewrites,conversions,reductions,judgements,...} = context
283      val rewrites = length rewrites
284      and conversions = length conversions
285      and reductions = length reductions
286      and judgements = length judgements
287    in
288      PP.begin_block p PP.INCONSISTENT 1;
289      PP.add_string p ("<" ^ int_to_string rewrites ^ "r" ^ ",");
290      PP.add_break p (1,0);
291      PP.add_string p (int_to_string conversions ^ "c" ^ ",");
292      PP.add_break p (1,0);
293      PP.add_string p (int_to_string reductions ^ "r" ^ ",");
294      PP.add_break p (1,0);
295      PP.add_string p (int_to_string judgements ^ "j>");
296      PP.end_block p
297    end;
298
299fun to_string context = PP.pp_to_string (!Globals.linewidth) pp context;
300
301val empty =
302    Context {rewrites = [], conversions = [],
303             reductions = [], judgements = [],
304             dproc_cache = cache_new ()};
305
306fun add_rewrite x context =
307    let
308      val Context {rewrites = r, conversions = c, reductions = d,
309                   judgements = j, dproc_cache = m} = context
310    in
311      Context {rewrites = r @ [x], conversions = c, reductions = d,
312               judgements = j, dproc_cache = ref (!m)}
313    end;
314
315fun add_conversion x context =
316    let
317      val Context {rewrites = r, conversions = c, reductions = d,
318                   judgements = j, dproc_cache = m} = context
319    in
320      Context {rewrites = r, conversions = c @ [x], reductions = d,
321               judgements = j, dproc_cache = ref (!m)}
322    end;
323
324fun add_reduction x context =
325    let
326      val Context {rewrites = r, conversions = c, reductions = d,
327                   judgements = j, dproc_cache = m} = context
328    in
329      Context {rewrites = r, conversions = c, reductions = d @ [x],
330               judgements = j, dproc_cache = ref (!m)}
331    end;
332
333fun add_judgement x context =
334    let
335      val Context {rewrites = r, conversions = c,reductions = d,
336                   judgements = j, dproc_cache = m} = context
337    in
338      Context {rewrites = r, conversions = c, reductions = d,
339               judgements = j @ [x], dproc_cache = ref (!m)}
340    end;
341
342local
343  exception State of
344    {assumptions : term list,
345     reductions : tactic list,
346     judgements : tactic list};
347
348  local
349    val abbrev_rule = prove
350        (``!v t. Abbrev (v = t) ==> (!s. t IN s ==> v IN s)``,
351         RW_TAC std_ss [markerTheory.Abbrev_def]);
352
353    fun reduce_tac th = match_tac th THEN REPEAT CONJ_TAC;
354
355    fun assume_reduction th (State {assumptions,reductions,judgements}) =
356        let
357(***
358          val () = (print "assume_reduction: "; print_thm th; print "\n")
359***)
360        in
361          State {assumptions = concl th :: assumptions,
362                 reductions = reduce_tac th :: reductions,
363                 judgements = judgements}
364        end
365      | assume_reduction _ _ = raise Fail "assume_reduction";
366
367    fun assume_judgement th (State {assumptions,reductions,judgements}) =
368        let
369(***
370          val () = (print "assume_judgement: "; print_thm th; print "\n")
371***)
372        in
373          State {assumptions = concl th :: assumptions,
374                 reductions = reductions,
375                 judgements = reduce_tac th :: judgements}
376        end
377      | assume_judgement _ _ = raise Fail "assume_judgement";
378  in
379    fun initial_state reductions judgements =
380        State {assumptions = [],
381               reductions = map reduce_tac reductions,
382               judgements = map reduce_tac judgements};
383
384    fun state_add (s,[]) = s
385      | state_add (s, th :: ths) =
386        let
387          val tm = concl th
388        in
389          if is_in tm then state_add (assume_reduction th s, ths)
390          else if is_abbrev tm then
391            state_add (assume_judgement (MATCH_MP abbrev_rule th) s, ths)
392          else if is_conj tm then state_add (s, CONJUNCTS th @ ths)
393          else state_add (s,ths)
394        end;
395  end;
396
397  fun state_apply_dproc dproc_cache dproc_context goal =
398      if not (is_in goal) then
399        raise ERR "algebra_dproc" "not of form X IN Y"
400      else if !ORACLE then ORACLE_solver goal
401      else
402        let
403          val {context, solver = _, conv = _, relation = _, stack = _} = dproc_context
404          val {assumptions,reductions,judgements} =
405              case context of
406                State state => state
407              | _ => raise Bug "state_apply_dproc: wrong exception type"
408
409          fun dproc_tac goal =
410              (REPEAT (cache_tac dproc_cache
411                       THEN print_tac "-"
412                       THEN FIRST reductions)
413               THEN (FIRST (map (fn tac => tac THEN dproc_tac) judgements)
414                     ORELSE reduceLib.REDUCE_TAC)
415               THEN NO_TAC) goal
416
417(***
418          val _ = (print "algebra_dproc: "; print_term goal; print "\n")
419***)
420          val th = TAC_PROOF ((assumptions,goal), dproc_tac)
421        in
422          EQT_INTRO th
423        end;
424
425  fun algebra_dproc reductions judgements dproc_cache =
426      Traverse.REDUCER {name = NONE,
427                        initial = initial_state reductions judgements,
428                        addcontext = state_add,
429                        apply = state_apply_dproc dproc_cache};
430in
431  fun simpset_frag context =
432      let
433        val Context {rewrites, conversions, reductions,
434                     judgements, dproc_cache} = context
435        val convs = map named_conv_to_simpset_conv conversions
436        val dproc = algebra_dproc reductions judgements dproc_cache
437      in
438        simpLib.SSFRAG
439          {name = NONE, ac = [], congs = [], convs = convs, rewrs = rewrites,
440           dprocs = [dproc], filter = NONE}
441      end;
442
443  fun simpset context = simpLib.++ (std_ss, simpset_frag context);
444end;
445
446(* ------------------------------------------------------------------------- *)
447(* Subtype context pairs: one for simplification, the other for              *)
448(* normalization.                                                            *)
449(*                                                                           *)
450(* By convention add_X2 adds to both contexts, add_X2' adds to just the      *)
451(* simplify context, and add_X2'' adds to just the normalize context.        *)
452(* ------------------------------------------------------------------------- *)
453
454datatype context2 = Context2 of {simplify : context, normalize : context};
455
456fun pp2 pp alg =
457    let
458      val Context2 {simplify,normalize} = alg
459    in
460      PP.begin_block pp PP.INCONSISTENT 1;
461      PP.add_string pp ("{simplify = " ^ to_string simplify ^ ",");
462      PP.add_break pp (1,0);
463      PP.add_string pp ("normalize = " ^ to_string normalize ^ "}");
464      PP.end_block pp
465    end;
466
467fun to_string2 context2 = PP.pp_to_string (!Globals.linewidth) pp2 context2;
468
469fun dest2 (Context2 info) = info;
470
471val empty2 =
472    Context2 {simplify = empty, normalize = empty};
473
474fun add_rewrite2' r (Context2 {simplify,normalize}) =
475    Context2 {simplify = add_rewrite r simplify, normalize = normalize};
476
477fun add_rewrite2'' r (Context2 {simplify,normalize}) =
478    Context2 {simplify = simplify, normalize = add_rewrite r normalize};
479
480fun add_rewrite2 r = add_rewrite2' r o add_rewrite2'' r;
481
482fun add_conversion2' r (Context2 {simplify,normalize}) =
483    Context2 {simplify = add_conversion r simplify, normalize = normalize};
484
485fun add_conversion2'' r (Context2 {simplify,normalize}) =
486    Context2 {simplify = simplify, normalize = add_conversion r normalize};
487
488fun add_conversion2 c = add_conversion2' c o add_conversion2'' c;
489
490fun add_reduction2' d (Context2 {simplify,normalize}) =
491    Context2 {simplify = add_reduction d simplify, normalize = normalize};
492
493fun add_reduction2'' d (Context2 {simplify,normalize}) =
494    Context2 {simplify = simplify, normalize = add_reduction d normalize};
495
496fun add_reduction2 d = add_reduction2' d o add_reduction2'' d;
497
498fun add_judgement2' r (Context2 {simplify,normalize}) =
499    Context2 {simplify = add_judgement r simplify, normalize = normalize};
500
501fun add_judgement2'' r (Context2 {simplify,normalize}) =
502    Context2 {simplify = simplify, normalize = add_judgement r normalize};
503
504fun add_judgement2 j = add_judgement2' j o add_judgement2'' j;
505
506fun simpset_frag2 (Context2 {simplify,normalize}) =
507    {simplify = simpset_frag simplify,
508     normalize = simpset_frag normalize};
509
510fun simpset2 (Context2 {simplify,normalize}) =
511    {simplify = simpset simplify, normalize = simpset normalize};
512
513end
514