1(*  Title:       HOL/Tools/Function/scnp_reconstruct.ML
2    Author:      Armin Heller, TU Muenchen
3    Author:      Alexander Krauss, TU Muenchen
4
5Proof reconstruction for SCNP termination.
6*)
7
8signature SCNP_RECONSTRUCT =
9sig
10  val sizechange_tac : Proof.context -> tactic -> tactic
11
12  val decomp_scnp_tac : ScnpSolve.label list -> Proof.context -> tactic
13
14  datatype multiset_setup =
15    Multiset of
16    {
17     msetT : typ -> typ,
18     mk_mset : typ -> term list -> term,
19     mset_regroup_conv : Proof.context -> int list -> conv,
20     mset_member_tac : Proof.context -> int -> int -> tactic,
21     mset_nonempty_tac : Proof.context -> int -> tactic,
22     mset_pwleq_tac : Proof.context -> int -> tactic,
23     set_of_simps : thm list,
24     smsI' : thm,
25     wmsI2'' : thm,
26     wmsI1 : thm,
27     reduction_pair : thm
28    }
29
30  val multiset_setup : multiset_setup -> theory -> theory
31end
32
33structure ScnpReconstruct : SCNP_RECONSTRUCT =
34struct
35
36val PROFILE = Function_Common.PROFILE
37
38open ScnpSolve
39
40val natT = HOLogic.natT
41val nat_pairT = HOLogic.mk_prodT (natT, natT)
42
43
44(* Theory dependencies *)
45
46datatype multiset_setup =
47  Multiset of
48  {
49   msetT : typ -> typ,
50   mk_mset : typ -> term list -> term,
51   mset_regroup_conv : Proof.context -> int list -> conv,
52   mset_member_tac : Proof.context -> int -> int -> tactic,
53   mset_nonempty_tac : Proof.context -> int -> tactic,
54   mset_pwleq_tac : Proof.context -> int -> tactic,
55   set_of_simps : thm list,
56   smsI' : thm,
57   wmsI2'' : thm,
58   wmsI1 : thm,
59   reduction_pair : thm
60  }
61
62structure Multiset_Setup = Theory_Data
63(
64  type T = multiset_setup option
65  val empty = NONE
66  val extend = I;
67  val merge = merge_options
68)
69
70val multiset_setup = Multiset_Setup.put o SOME
71
72fun undef _ = error "undef"
73
74fun get_multiset_setup ctxt = Multiset_Setup.get (Proof_Context.theory_of ctxt)
75  |> the_default (Multiset
76    { msetT = undef, mk_mset=undef,
77      mset_regroup_conv=undef, mset_member_tac = undef,
78      mset_nonempty_tac = undef, mset_pwleq_tac = undef,
79      set_of_simps = [],reduction_pair = refl,
80      smsI'=refl, wmsI2''=refl, wmsI1=refl })
81
82fun order_rpair _ MAX = @{thm max_rpair_set}
83  | order_rpair msrp MS  = msrp
84  | order_rpair _ MIN = @{thm min_rpair_set}
85
86fun ord_intros_max true = (@{thm smax_emptyI}, @{thm smax_insertI})
87  | ord_intros_max false = (@{thm wmax_emptyI}, @{thm wmax_insertI})
88
89fun ord_intros_min true = (@{thm smin_emptyI}, @{thm smin_insertI})
90  | ord_intros_min false = (@{thm wmin_emptyI}, @{thm wmin_insertI})
91
92fun gen_probl D cs =
93  let
94    val n = Termination.get_num_points D
95    val arity = length o Termination.get_measures D
96    fun measure p i = nth (Termination.get_measures D p) i
97
98    fun mk_graph c =
99      let
100        val (_, p, _, q, _, _) = Termination.dest_call D c
101
102        fun add_edge i j =
103          case Termination.get_descent D c (measure p i) (measure q j)
104           of SOME (Termination.Less _) => cons (i, GTR, j)
105            | SOME (Termination.LessEq _) => cons (i, GEQ, j)
106            | _ => I
107
108        val edges =
109          fold_product add_edge (0 upto arity p - 1) (0 upto arity q - 1) []
110      in
111        G (p, q, edges)
112      end
113  in
114    GP (map_range arity n, map mk_graph cs)
115  end
116
117(* General reduction pair application *)
118fun rem_inv_img ctxt =
119  resolve_tac ctxt @{thms subsetI} 1
120  THEN eresolve_tac ctxt @{thms CollectE} 1
121  THEN REPEAT (eresolve_tac ctxt @{thms exE} 1)
122  THEN Local_Defs.unfold0_tac ctxt @{thms inv_image_def}
123  THEN resolve_tac ctxt @{thms CollectI} 1
124  THEN eresolve_tac ctxt @{thms conjE} 1
125  THEN eresolve_tac ctxt @{thms ssubst} 1
126  THEN Local_Defs.unfold0_tac ctxt @{thms split_conv triv_forall_equality sum.case}
127
128
129(* Sets *)
130
131val setT = HOLogic.mk_setT
132
133fun set_member_tac ctxt m i =
134  if m = 0 then resolve_tac ctxt @{thms insertI1} i
135  else resolve_tac ctxt @{thms insertI2} i THEN set_member_tac ctxt (m - 1) i
136
137fun set_nonempty_tac ctxt = resolve_tac ctxt @{thms insert_not_empty}
138
139fun set_finite_tac ctxt i =
140  resolve_tac ctxt @{thms finite.emptyI} i
141  ORELSE (resolve_tac ctxt @{thms finite.insertI} i THEN (fn st => set_finite_tac ctxt i st))
142
143
144(* Reconstruction *)
145
146fun reconstruct_tac ctxt D cs (GP (_, gs)) certificate =
147  let
148    val Multiset
149          { msetT, mk_mset,
150            mset_regroup_conv, mset_pwleq_tac, set_of_simps,
151            smsI', wmsI2'', wmsI1, reduction_pair=ms_rp, ...}
152        = get_multiset_setup ctxt
153
154    fun measure_fn p = nth (Termination.get_measures D p)
155
156    fun get_desc_thm cidx m1 m2 bStrict =
157      (case Termination.get_descent D (nth cs cidx) m1 m2 of
158        SOME (Termination.Less thm) =>
159          if bStrict then thm
160          else (thm COMP (Thm.lift_rule (Thm.cprop_of thm) @{thm less_imp_le}))
161      | SOME (Termination.LessEq (thm, _))  =>
162          if not bStrict then thm
163          else raise Fail "get_desc_thm"
164      | _ => raise Fail "get_desc_thm")
165
166    val (label, lev, sl, covering) = certificate
167
168    fun prove_lev strict g =
169      let
170        val G (p, q, _) = nth gs g
171
172        fun less_proof strict (j, b) (i, a) =
173          let
174            val tag_flag = b < a orelse (not strict andalso b <= a)
175
176            val stored_thm =
177              get_desc_thm g (measure_fn p i) (measure_fn q j)
178                             (not tag_flag)
179              |> Conv.fconv_rule (Thm.beta_conversion true)
180
181            val rule =
182              if strict
183              then if b < a then @{thm pair_lessI2} else @{thm pair_lessI1}
184              else if b <= a then @{thm pair_leqI2} else @{thm pair_leqI1}
185          in
186            resolve_tac ctxt [rule] 1 THEN PRIMITIVE (Thm.elim_implies stored_thm)
187            THEN (if tag_flag then Arith_Data.arith_tac ctxt 1 else all_tac)
188          end
189
190        fun steps_tac MAX strict lq lp =
191              let
192                val (empty, step) = ord_intros_max strict
193              in
194                if length lq = 0
195                then resolve_tac ctxt [empty] 1 THEN set_finite_tac ctxt 1
196                     THEN (if strict then set_nonempty_tac ctxt 1 else all_tac)
197                else
198                  let
199                    val (j, b) :: rest = lq
200                    val (i, a) = the (covering g strict j)
201                    fun choose xs = set_member_tac ctxt (find_index (curry op = (i, a)) xs) 1
202                    val solve_tac = choose lp THEN less_proof strict (j, b) (i, a)
203                  in
204                    resolve_tac ctxt [step] 1 THEN solve_tac THEN steps_tac MAX strict rest lp
205                  end
206              end
207          | steps_tac MIN strict lq lp =
208              let
209                val (empty, step) = ord_intros_min strict
210              in
211                if length lp = 0
212                then resolve_tac ctxt [empty] 1
213                     THEN (if strict then set_nonempty_tac ctxt 1 else all_tac)
214                else
215                  let
216                    val (i, a) :: rest = lp
217                    val (j, b) = the (covering g strict i)
218                    fun choose xs = set_member_tac ctxt (find_index (curry op = (j, b)) xs) 1
219                    val solve_tac = choose lq THEN less_proof strict (j, b) (i, a)
220                  in
221                    resolve_tac ctxt [step] 1 THEN solve_tac THEN steps_tac MIN strict lq rest
222                  end
223              end
224          | steps_tac MS strict lq lp =
225              let
226                fun get_str_cover (j, b) =
227                  if is_some (covering g true j) then SOME (j, b) else NONE
228                fun get_wk_cover (j, b) = the (covering g false j)
229
230                val qs = subtract (op =) (map_filter get_str_cover lq) lq
231                val ps = map get_wk_cover qs
232
233                fun indices xs ys = map (fn y => find_index (curry op = y) xs) ys
234                val iqs = indices lq qs
235                val ips = indices lp ps
236
237                local open Conv in
238                fun t_conv a C =
239                  params_conv ~1 (K ((concl_conv ~1 o arg_conv o arg1_conv o a) C)) ctxt
240                val goal_rewrite =
241                    t_conv arg1_conv (mset_regroup_conv ctxt iqs)
242                    then_conv t_conv arg_conv (mset_regroup_conv ctxt ips)
243                end
244              in
245                CONVERSION goal_rewrite 1
246                THEN (if strict then resolve_tac ctxt [smsI'] 1
247                      else if qs = lq then resolve_tac ctxt [wmsI2''] 1
248                      else resolve_tac ctxt [wmsI1] 1)
249                THEN mset_pwleq_tac ctxt 1
250                THEN EVERY (map2 (less_proof false) qs ps)
251                THEN (if strict orelse qs <> lq
252                      then Local_Defs.unfold0_tac ctxt set_of_simps
253                           THEN steps_tac MAX true
254                           (subtract (op =) qs lq) (subtract (op =) ps lp)
255                      else all_tac)
256              end
257      in
258        rem_inv_img ctxt
259        THEN steps_tac label strict (nth lev q) (nth lev p)
260      end
261
262    val (mk_set, setT) = if label = MS then (mk_mset, msetT) else (HOLogic.mk_set, setT)
263
264    fun tag_pair p (i, tag) =
265      HOLogic.pair_const natT natT $
266        (measure_fn p i $ Bound 0) $ HOLogic.mk_number natT tag
267
268    fun pt_lev (p, lm) =
269      Abs ("x", Termination.get_types D p, mk_set nat_pairT (map (tag_pair p) lm))
270
271    val level_mapping =
272      map_index pt_lev lev
273        |> Termination.mk_sumcases D (setT nat_pairT)
274        |> Thm.cterm_of ctxt
275    in
276      PROFILE "Proof Reconstruction"
277        (CONVERSION (Conv.arg_conv (Conv.arg_conv (Function_Lib.regroup_union_conv ctxt sl))) 1
278         THEN (resolve_tac ctxt @{thms reduction_pair_lemma} 1)
279         THEN (resolve_tac ctxt @{thms rp_inv_image_rp} 1)
280         THEN (resolve_tac ctxt [order_rpair ms_rp label] 1)
281         THEN PRIMITIVE (Thm.instantiate' [] [SOME level_mapping])
282         THEN unfold_tac ctxt @{thms rp_inv_image_def}
283         THEN Local_Defs.unfold0_tac ctxt @{thms split_conv fst_conv snd_conv}
284         THEN REPEAT (SOMEGOAL (resolve_tac ctxt [@{thm Un_least}, @{thm empty_subsetI}]))
285         THEN EVERY (map (prove_lev true) sl)
286         THEN EVERY (map (prove_lev false) (subtract (op =) sl (0 upto length cs - 1))))
287    end
288
289
290fun single_scnp_tac use_tags orders ctxt D = Termination.CALLS (fn (cs, i) =>
291  let
292    val ms_configured = is_some (Multiset_Setup.get (Proof_Context.theory_of ctxt))
293    val orders' =
294      if ms_configured then orders
295      else filter_out (curry op = MS) orders
296    val gp = gen_probl D cs
297    val certificate = generate_certificate use_tags orders' gp
298  in
299    (case certificate of
300      NONE => no_tac
301    | SOME cert =>
302        SELECT_GOAL (reconstruct_tac ctxt D cs gp cert) i
303        THEN TRY (resolve_tac ctxt @{thms wf_empty} i))
304  end)
305
306fun gen_decomp_scnp_tac orders autom_tac ctxt =
307  Termination.TERMINATION ctxt autom_tac (fn D =>
308    let
309      val decompose = Termination.decompose_tac ctxt D
310      val scnp_full = single_scnp_tac true orders ctxt D
311    in
312      REPEAT_ALL_NEW (scnp_full ORELSE' decompose)
313    end)
314
315fun gen_sizechange_tac orders autom_tac ctxt =
316  TRY (Function_Common.termination_rule_tac ctxt 1)
317  THEN TRY (Termination.wf_union_tac ctxt)
318  THEN (resolve_tac ctxt @{thms wf_empty} 1 ORELSE gen_decomp_scnp_tac orders autom_tac ctxt 1)
319
320fun sizechange_tac ctxt autom_tac =
321  gen_sizechange_tac [MAX, MS, MIN] autom_tac ctxt
322
323fun decomp_scnp_tac orders ctxt =
324  let
325    val extra_simps = Named_Theorems.get ctxt @{named_theorems termination_simp}
326    val autom_tac = auto_tac (ctxt addsimps extra_simps)
327  in
328     gen_sizechange_tac orders autom_tac ctxt
329  end
330
331
332(* Method setup *)
333
334val orders =
335  Scan.repeat1
336    ((Args.$$$ "max" >> K MAX) ||
337     (Args.$$$ "min" >> K MIN) ||
338     (Args.$$$ "ms" >> K MS))
339  || Scan.succeed [MAX, MS, MIN]
340
341val _ =
342  Theory.setup
343    (Method.setup @{binding size_change}
344      (Scan.lift orders --| Method.sections clasimp_modifiers >>
345        (fn orders => SIMPLE_METHOD o decomp_scnp_tac orders))
346      "termination prover with graph decomposition and the NP subset of size change termination")
347
348end
349