1(*
2 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3 *
4 * SPDX-License-Identifier: BSD-2-Clause
5 *)
6
7(* Methods for managing subgoals collectively. *)
8
9theory Subgoal_Methods
10imports Main
11begin
12ML \<open>
13signature SUBGOAL_METHODS =
14sig
15  val fold_subgoals: Proof.context -> bool -> thm -> thm
16  val unfold_subgoals_tac: Proof.context -> tactic
17  val distinct_subgoals: Proof.context -> thm -> thm
18end;
19
20structure Subgoal_Methods: SUBGOAL_METHODS =
21struct
22
23fun max_common_prefix eq (ls :: lss) =
24      let
25        val ls' = tag_list 0 ls;
26        fun all_prefix (i,a) =
27          forall (fn ls' => if length ls' > i then eq (a, nth ls' i) else false) lss
28        val ls'' = take_prefix all_prefix ls'
29      in map snd ls'' end
30  | max_common_prefix _ [] = [];
31
32fun push_outer_params ctxt th =
33  let
34    val ctxt' = ctxt
35      |> Simplifier.empty_simpset
36      |> Simplifier.add_simp Drule.norm_hhf_eq;
37  in
38    Conv.fconv_rule
39      (Raw_Simplifier.rewrite_cterm (true, false, false) (K (K NONE)) ctxt') th
40  end;
41
42fun fix_schematics ctxt raw_st =
43  let
44    val ((schematic_types, [st']), ctxt1) = Variable.importT [raw_st] ctxt;
45    val ((_, inst), ctxt2) =
46      Variable.import_inst true [Thm.prop_of st'] ctxt1;
47
48    val schematic_terms = map (apsnd (Thm.cterm_of ctxt2)) inst;
49    val schematics = (schematic_types, schematic_terms);
50
51  in (Thm.instantiate schematics st', ctxt2) end
52
53val strip_params = Term.strip_all_vars;
54val strip_prems = Logic.strip_imp_prems o Term.strip_all_body;
55val strip_concl = Logic.strip_imp_concl o Term.strip_all_body;
56
57
58
59fun fold_subgoals ctxt prefix raw_st =
60  if Thm.nprems_of raw_st < 2 then raw_st
61  else
62    let
63      val (st, inner_ctxt) = fix_schematics ctxt raw_st;
64
65      val subgoals = Thm.prems_of st;
66      val paramss = map strip_params subgoals;
67      val common_params = max_common_prefix (eq_snd (op =)) paramss;
68
69      fun strip_shift subgoal =
70        let
71          val params = strip_params subgoal;
72          val diff = length common_params - length params;
73          val prems = strip_prems subgoal;
74        in map (Term.incr_boundvars diff) prems end;
75
76      val premss = map (strip_shift) subgoals;
77
78      val common_prems = max_common_prefix (op aconv) premss;
79
80      val common_params = if prefix then common_params else [];
81      val common_prems = if prefix then common_prems else [];
82
83      fun mk_concl subgoal =
84        let
85          val params = Term.strip_all_vars subgoal;
86          val local_params = drop (length common_params) params;
87          val prems = strip_prems subgoal;
88          val local_prems = drop (length common_prems) prems;
89          val concl = strip_concl subgoal;
90        in Logic.list_all (local_params, Logic.list_implies (local_prems, concl)) end;
91
92      val goal =
93        Logic.list_all (common_params,
94          (Logic.list_implies (common_prems,Logic.mk_conjunction_list (map mk_concl subgoals))));
95
96      val chyp = Thm.cterm_of inner_ctxt goal;
97
98      val (common_params',inner_ctxt') =
99        Variable.add_fixes (map fst common_params) inner_ctxt
100        |>> map2 (fn (_, T) => fn x => Thm.cterm_of inner_ctxt (Free (x, T))) common_params;
101
102      fun try_dest rule =
103        try (fn () => (@{thm conjunctionD1} OF [rule], @{thm conjunctionD2} OF [rule])) ();
104
105      fun solve_headgoal rule =
106        let
107          val rule' = rule
108            |> Drule.forall_intr_list common_params'
109            |> push_outer_params inner_ctxt';
110        in
111          (fn st => Thm.implies_elim st rule')
112        end;
113
114      fun solve_subgoals rule' st =
115        (case try_dest rule' of
116          SOME (this, rest) => solve_subgoals rest (solve_headgoal this st)
117        | NONE => solve_headgoal rule' st);
118
119      val rule = Drule.forall_elim_list common_params' (Thm.assume chyp);
120    in
121      st
122      |> push_outer_params inner_ctxt
123      |> solve_subgoals rule
124      |> Thm.implies_intr chyp
125      |> singleton (Variable.export inner_ctxt' ctxt)
126    end;
127
128fun distinct_subgoals ctxt raw_st =
129  let
130    val (st, inner_ctxt) = fix_schematics ctxt raw_st;
131    val subgoals = Drule.cprems_of st;
132    val atomize = Conv.fconv_rule (Object_Logic.atomize_prems inner_ctxt);
133
134    val rules =
135      map (atomize o Raw_Simplifier.norm_hhf inner_ctxt o Thm.assume) subgoals
136      |> sort (int_ord o apply2 Thm.nprems_of);
137
138    val st' = st
139      |> ALLGOALS (fn i =>
140        Object_Logic.atomize_prems_tac inner_ctxt i THEN solve_tac inner_ctxt rules i)
141      |> Seq.hd;
142
143    val subgoals' = subgoals
144      |> inter (op aconvc) (Thm.chyps_of st')
145      |> distinct (op aconvc);
146  in
147    Drule.implies_intr_list subgoals' st'
148    |> singleton (Variable.export inner_ctxt ctxt)
149  end;
150
151(* Variant of filter_prems_tac that recovers premise order *)
152fun filter_prems_tac' ctxt pred =
153  let
154    fun Then NONE tac = SOME tac
155      | Then (SOME tac) tac' = SOME (tac THEN' tac');
156    fun thins H (tac, n, i) =
157      (if pred H then (tac, n + 1, i)
158       else (Then tac (rotate_tac n THEN' eresolve_tac ctxt [thin_rl]), 0, i + n));
159  in
160    SUBGOAL (fn (goal, i) =>
161      let val Hs = Logic.strip_assums_hyp goal in
162        (case fold thins Hs (NONE, 0, 0) of
163          (NONE, _, _) => no_tac
164        | (SOME tac, _, n) => tac i THEN rotate_tac (~ n) i)
165      end)
166  end;
167
168fun trim_prems_tac ctxt rules =
169let
170  fun matches (prem,rule) =
171  let
172    val ((_,prem'),ctxt') = Variable.focus NONE prem ctxt;
173    val rule_prop = Thm.prop_of rule;
174  in Unify.matches_list (Context.Proof ctxt') [rule_prop] [prem'] end;
175
176in filter_prems_tac' ctxt (not o member matches rules) end;
177
178val adhoc_conjunction_tac = REPEAT_ALL_NEW
179  (SUBGOAL (fn (goal, i) =>
180    if can Logic.dest_conjunction (Logic.strip_imp_concl goal)
181    then resolve0_tac [Conjunction.conjunctionI] i
182    else no_tac));
183
184fun unfold_subgoals_tac ctxt =
185  TRY (adhoc_conjunction_tac 1)
186  THEN (PRIMITIVE (Raw_Simplifier.norm_hhf ctxt));
187
188val _ =
189  Theory.setup
190   (Method.setup @{binding fold_subgoals}
191      (Scan.lift (Args.mode "prefix") >> (fn prefix => fn ctxt =>
192         SIMPLE_METHOD (PRIMITIVE (fold_subgoals ctxt prefix))))
193      "lift all subgoals over common premises/params" #>
194    Method.setup @{binding unfold_subgoals}
195      (Scan.succeed (fn ctxt => SIMPLE_METHOD (unfold_subgoals_tac ctxt)))
196      "recover subgoals after folding" #>
197    Method.setup @{binding distinct_subgoals}
198      (Scan.succeed (fn ctxt => SIMPLE_METHOD (PRIMITIVE (distinct_subgoals ctxt))))
199     "trim all subgoals to be (logically) distinct" #>
200    Method.setup @{binding trim}
201      (Attrib.thms >> (fn thms => fn ctxt =>
202         SIMPLE_METHOD (HEADGOAL (trim_prems_tac ctxt thms))))
203     "trim all premises that match the given rules");
204
205end;
206\<close>
207
208end
209