1(*  Title:      Pure/Isar/calculation.ML
2    Author:     Markus Wenzel, TU Muenchen
3
4Generic calculational proofs.
5*)
6
7signature CALCULATION =
8sig
9  val print_rules: Proof.context -> unit
10  val check: Proof.state -> thm list option
11  val trans_add: attribute
12  val trans_del: attribute
13  val sym_add: attribute
14  val sym_del: attribute
15  val symmetric: attribute
16  val also: thm list option -> bool -> Proof.state -> Proof.state Seq.result Seq.seq
17  val also_cmd: (Facts.ref * Token.src list) list option ->
18    bool -> Proof.state -> Proof.state Seq.result Seq.seq
19  val finally: thm list option -> bool -> Proof.state -> Proof.state Seq.result Seq.seq
20  val finally_cmd: (Facts.ref * Token.src list) list option -> bool ->
21    Proof.state -> Proof.state Seq.result Seq.seq
22  val moreover: bool -> Proof.state -> Proof.state
23  val ultimately: bool -> Proof.state -> Proof.state
24end;
25
26structure Calculation: CALCULATION =
27struct
28
29(** calculation data **)
30
31type calculation = {result: thm list, level: int, serial: serial, pos: Position.T};
32
33structure Data = Generic_Data
34(
35  type T = (thm Item_Net.T * thm list) * calculation option;
36  val empty = ((Thm.elim_rules, []), NONE);
37  val extend = I;
38  fun merge (((trans1, sym1), _), ((trans2, sym2), _)) =
39    ((Item_Net.merge (trans1, trans2), Thm.merge_thms (sym1, sym2)), NONE);
40);
41
42val get_rules = #1 o Data.get o Context.Proof;
43val get_calculation = #2 o Data.get o Context.Proof;
44
45fun print_rules ctxt =
46  let
47    val pretty_thm = Thm.pretty_thm_item ctxt;
48    val (trans, sym) = get_rules ctxt;
49  in
50   [Pretty.big_list "transitivity rules:" (map pretty_thm (Item_Net.content trans)),
51    Pretty.big_list "symmetry rules:" (map pretty_thm sym)]
52  end |> Pretty.writeln_chunks;
53
54
55(* access calculation *)
56
57fun check_calculation state =
58  (case get_calculation (Proof.context_of state) of
59    NONE => NONE
60  | SOME calculation =>
61      if #level calculation = Proof.level state then SOME calculation else NONE);
62
63val check = Option.map #result o check_calculation;
64
65val calculationN = "calculation";
66
67fun update_calculation calc state =
68  let
69    fun report def serial pos =
70      Context_Position.report (Proof.context_of state)
71        (Position.thread_data ())
72          (Markup.entity calculationN ""
73            |> Markup.properties (Position.entity_properties_of def serial pos));
74    val calculation =
75      (case calc of
76        NONE => NONE
77      | SOME result =>
78          (case check_calculation state of
79            NONE =>
80              let
81                val level = Proof.level state;
82                val serial = serial ();
83                val pos = Position.thread_data ();
84                val _ = report true serial pos;
85              in SOME {result = result, level = level, serial = serial, pos = pos} end
86          | SOME {level, serial, pos, ...} =>
87              (report false serial pos;
88                SOME {result = result, level = level, serial = serial, pos = pos})));
89  in
90    state
91    |> (Proof.map_context o Context.proof_map o Data.map o apsnd) (K calculation)
92    |> Proof.map_context (Proof_Context.put_thms false (calculationN, calc))
93  end;
94
95
96
97(** attributes **)
98
99(* add/del rules *)
100
101val trans_add =
102  Thm.declaration_attribute (Data.map o apfst o apfst o Item_Net.update o Thm.trim_context);
103
104val trans_del =
105  Thm.declaration_attribute (Data.map o apfst o apfst o Item_Net.remove);
106
107val sym_add =
108  Thm.declaration_attribute (fn th =>
109    (Data.map o apfst o apsnd) (Thm.add_thm (Thm.trim_context th)) #>
110    Thm.attribute_declaration (Context_Rules.elim_query NONE) th);
111
112val sym_del =
113  Thm.declaration_attribute (fn th =>
114    (Data.map o apfst o apsnd) (Thm.del_thm th) #>
115    Thm.attribute_declaration Context_Rules.rule_del th);
116
117
118(* symmetric *)
119
120val symmetric =
121  Thm.rule_attribute [] (fn context => fn th =>
122    (case Seq.chop 2
123        (Drule.multi_resolves (SOME (Context.proof_of context)) [th] (#2 (#1 (Data.get context)))) of
124      ([th'], _) => Drule.zero_var_indexes th'
125    | ([], _) => raise THM ("symmetric: no unifiers", 1, [th])
126    | _ => raise THM ("symmetric: multiple unifiers", 1, [th])));
127
128
129(* concrete syntax *)
130
131val _ = Theory.setup
132 (Attrib.setup \<^binding>\<open>trans\<close> (Attrib.add_del trans_add trans_del)
133    "declaration of transitivity rule" #>
134  Attrib.setup \<^binding>\<open>sym\<close> (Attrib.add_del sym_add sym_del)
135    "declaration of symmetry rule" #>
136  Attrib.setup \<^binding>\<open>symmetric\<close> (Scan.succeed symmetric)
137    "resolution with symmetry rule" #>
138  Global_Theory.add_thms
139   [((Binding.empty, transitive_thm), [trans_add]),
140    ((Binding.empty, symmetric_thm), [sym_add])] #> snd);
141
142
143
144(** proof commands **)
145
146fun assert_sane final =
147  if final then Proof.assert_forward
148  else
149    Proof.assert_forward_or_chain #>
150    tap (fn state =>
151      if can Proof.assert_chain state then
152        Context_Position.report (Proof.context_of state) (Position.thread_data ()) Markup.improper
153      else ());
154
155fun maintain_calculation int final calc state =
156  let
157    val state' = state
158      |> update_calculation (SOME calc)
159      |> Proof.improper_reset_facts;
160    val ctxt' = Proof.context_of state';
161    val _ =
162      if int then
163        Proof_Context.pretty_fact ctxt'
164          (Proof_Context.full_name ctxt' (Binding.name calculationN), calc)
165        |> Pretty.string_of |> writeln
166      else ();
167  in state' |> final ? (update_calculation NONE #> Proof.chain_facts calc) end;
168
169
170(* also and finally *)
171
172fun calculate prep_rules final raw_rules int state =
173  let
174    val ctxt = Proof.context_of state;
175    val pretty_thm = Thm.pretty_thm ctxt;
176    val pretty_thm_item = Thm.pretty_thm_item ctxt;
177
178    val strip_assums_concl = Logic.strip_assums_concl o Thm.prop_of;
179    val eq_prop = op aconv o apply2 (Envir.beta_eta_contract o strip_assums_concl);
180    fun check_projection ths th =
181      (case find_index (curry eq_prop th) ths of
182        ~1 => Seq.Result [th]
183      | i =>
184          Seq.Error (fn () =>
185            (Pretty.string_of o Pretty.chunks)
186             [Pretty.block [Pretty.str "Vacuous calculation result:", Pretty.brk 1, pretty_thm th],
187              (Pretty.block o Pretty.fbreaks)
188                (Pretty.str ("derived as projection (" ^ string_of_int (i + 1) ^ ") from:") ::
189                  map pretty_thm_item ths)]));
190
191    val opt_rules = Option.map (prep_rules ctxt) raw_rules;
192    fun combine ths =
193      Seq.append
194        ((case opt_rules of
195          SOME rules => rules
196        | NONE =>
197            (case ths of
198              [] => Item_Net.content (#1 (get_rules ctxt))
199            | th :: _ => Item_Net.retrieve (#1 (get_rules ctxt)) (strip_assums_concl th)))
200        |> Seq.of_list |> Seq.maps (Drule.multi_resolve (SOME ctxt) ths)
201        |> Seq.map (check_projection ths))
202        (Seq.single (Seq.Error (fn () =>
203          (Pretty.string_of o Pretty.block o Pretty.fbreaks)
204            (Pretty.str "No matching trans rules for calculation:" ::
205              map pretty_thm_item ths))));
206
207    val facts = Proof.the_facts (assert_sane final state);
208    val (initial, calculations) =
209      (case check state of
210        NONE => (true, Seq.single (Seq.Result facts))
211      | SOME calc => (false, combine (calc @ facts)));
212
213    val _ = initial andalso final andalso error "No calculation yet";
214    val _ = initial andalso is_some opt_rules andalso
215      error "Initial calculation -- no rules to be given";
216  in
217    calculations |> Seq.map_result (fn calc => maintain_calculation int final calc state)
218  end;
219
220val also = calculate (K I) false;
221val also_cmd = calculate Attrib.eval_thms false;
222val finally = calculate (K I) true;
223val finally_cmd = calculate Attrib.eval_thms true;
224
225
226(* moreover and ultimately *)
227
228fun collect final int state =
229  let
230    val facts = Proof.the_facts (assert_sane final state);
231    val (initial, thms) =
232      (case check state of
233        NONE => (true, [])
234      | SOME thms => (false, thms));
235    val calc = thms @ facts;
236    val _ = initial andalso final andalso error "No calculation yet";
237  in maintain_calculation int final calc state end;
238
239val moreover = collect false;
240val ultimately = collect true;
241
242end;
243