1(*  Title:      HOL/Tools/Sledgehammer/sledgehammer_mash.ML
2    Author:     Jasmin Blanchette, TU Muenchen
3    Author:     Cezary Kaliszyk, University of Innsbruck
4
5Sledgehammer's machine-learning-based relevance filter (MaSh).
6*)
7
8signature SLEDGEHAMMER_MASH =
9sig
10  type stature = ATP_Problem_Generate.stature
11  type raw_fact = Sledgehammer_Fact.raw_fact
12  type fact = Sledgehammer_Fact.fact
13  type fact_override = Sledgehammer_Fact.fact_override
14  type params = Sledgehammer_Prover.params
15  type prover_result = Sledgehammer_Prover.prover_result
16
17  val trace : bool Config.T
18  val duplicates : bool Config.T
19  val MePoN : string
20  val MaShN : string
21  val MeShN : string
22  val mepoN : string
23  val mashN : string
24  val meshN : string
25  val unlearnN : string
26  val learn_isarN : string
27  val learn_proverN : string
28  val relearn_isarN : string
29  val relearn_proverN : string
30  val fact_filters : string list
31  val encode_str : string -> string
32  val encode_strs : string list -> string
33  val decode_str : string -> string
34  val decode_strs : string -> string list
35
36  datatype mash_algorithm =
37    MaSh_NB
38  | MaSh_kNN
39  | MaSh_NB_kNN
40  | MaSh_NB_Ext
41  | MaSh_kNN_Ext
42
43  val is_mash_enabled : unit -> bool
44  val the_mash_algorithm : unit -> mash_algorithm
45  val str_of_mash_algorithm : mash_algorithm -> string
46
47  val mesh_facts : ('a list -> 'a list) -> ('a * 'a -> bool) -> int ->
48    (real * (('a * real) list * 'a list)) list -> 'a list
49  val nickname_of_thm : thm -> string
50  val find_suggested_facts : Proof.context -> ('b * thm) list -> string list -> ('b * thm) list
51  val crude_thm_ord : Proof.context -> thm ord
52  val thm_less : thm * thm -> bool
53  val goal_of_thm : theory -> thm -> thm
54  val run_prover_for_mash : Proof.context -> params -> string -> string -> fact list -> thm ->
55    prover_result
56  val features_of : Proof.context -> string -> stature -> term list -> string list
57  val trim_dependencies : string list -> string list option
58  val isar_dependencies_of : string Symtab.table * string Symtab.table -> thm -> string list option
59  val prover_dependencies_of : Proof.context -> params -> string -> int -> raw_fact list ->
60    string Symtab.table * string Symtab.table -> thm -> bool * string list
61  val attach_parents_to_facts : ('a * thm) list -> ('a * thm) list ->
62    (string list * ('a * thm)) list
63  val num_extra_feature_facts : int
64  val extra_feature_factor : real
65  val weight_facts_smoothly : 'a list -> ('a * real) list
66  val weight_facts_steeply : 'a list -> ('a * real) list
67  val find_mash_suggestions : Proof.context -> int -> string list -> ('a * thm) list ->
68    ('a * thm) list -> ('a * thm) list -> ('a * thm) list * ('a * thm) list
69  val mash_suggested_facts : Proof.context -> string -> params -> int -> term list -> term ->
70    raw_fact list -> fact list * fact list
71
72  val mash_unlearn : Proof.context -> unit
73  val mash_learn_proof : Proof.context -> params -> term -> thm list -> unit
74  val mash_learn_facts : Proof.context -> params -> string -> int -> bool -> Time.time ->
75    raw_fact list -> string
76  val mash_learn : Proof.context -> params -> fact_override -> thm list -> bool -> unit
77  val mash_can_suggest_facts : Proof.context -> bool
78  val mash_can_suggest_facts_fast : Proof.context -> bool
79
80  val generous_max_suggestions : int -> int
81  val mepo_weight : real
82  val mash_weight : real
83  val relevant_facts : Proof.context -> params -> string -> int -> fact_override -> term list ->
84    term -> raw_fact list -> (string * fact list) list
85end;
86
87structure Sledgehammer_MaSh : SLEDGEHAMMER_MASH =
88struct
89
90open ATP_Util
91open ATP_Problem_Generate
92open Sledgehammer_Util
93open Sledgehammer_Fact
94open Sledgehammer_Prover
95open Sledgehammer_Prover_Minimize
96open Sledgehammer_MePo
97
98val anonymous_proof_prefix = "."
99
100val trace = Attrib.setup_config_bool \<^binding>\<open>sledgehammer_mash_trace\<close> (K false)
101val duplicates = Attrib.setup_config_bool \<^binding>\<open>sledgehammer_fact_duplicates\<close> (K false)
102
103fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
104
105fun gen_eq_thm ctxt = if Config.get ctxt duplicates then Thm.eq_thm_strict else Thm.eq_thm_prop
106
107val MePoN = "MePo"
108val MaShN = "MaSh"
109val MeShN = "MeSh"
110
111val mepoN = "mepo"
112val mashN = "mash"
113val meshN = "mesh"
114
115val fact_filters = [meshN, mepoN, mashN]
116
117val unlearnN = "unlearn"
118val learn_isarN = "learn_isar"
119val learn_proverN = "learn_prover"
120val relearn_isarN = "relearn_isar"
121val relearn_proverN = "relearn_prover"
122
123fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
124
125type xtab = int * int Symtab.table
126
127val empty_xtab = (0, Symtab.empty)
128
129fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab)
130fun maybe_add_to_xtab key = perhaps (try (add_to_xtab key))
131
132fun state_file () = Path.expand (Path.explode "$ISABELLE_HOME_USER/mash_state")
133val remove_state_file = try File.rm o state_file
134
135datatype mash_algorithm =
136  MaSh_NB
137| MaSh_kNN
138| MaSh_NB_kNN
139| MaSh_NB_Ext
140| MaSh_kNN_Ext
141
142fun mash_algorithm () =
143  (case Options.default_string \<^system_option>\<open>MaSh\<close> of
144    "yes" => SOME MaSh_NB_kNN
145  | "sml" => SOME MaSh_NB_kNN
146  | "nb" => SOME MaSh_NB
147  | "knn" => SOME MaSh_kNN
148  | "nb_knn" => SOME MaSh_NB_kNN
149  | "nb_ext" => SOME MaSh_NB_Ext
150  | "knn_ext" => SOME MaSh_kNN_Ext
151  | "none" => NONE
152  | "" => NONE
153  | algorithm => (warning ("Unknown MaSh algorithm: " ^ quote algorithm); NONE))
154
155val is_mash_enabled = is_some o mash_algorithm
156val the_mash_algorithm = the_default MaSh_NB_kNN o mash_algorithm
157
158fun str_of_mash_algorithm MaSh_NB = "nb"
159  | str_of_mash_algorithm MaSh_kNN = "knn"
160  | str_of_mash_algorithm MaSh_NB_kNN = "nb_knn"
161  | str_of_mash_algorithm MaSh_NB_Ext = "nb_ext"
162  | str_of_mash_algorithm MaSh_kNN_Ext = "knn_ext"
163
164fun scaled_avg [] = 0
165  | scaled_avg xs = Real.ceil (100000000.0 * fold (curry (op +)) xs 0.0) div length xs
166
167fun avg [] = 0.0
168  | avg xs = fold (curry (op +)) xs 0.0 / Real.fromInt (length xs)
169
170fun normalize_scores _ [] = []
171  | normalize_scores max_facts xs =
172    map (apsnd (curry (op *) (1.0 / avg (map snd (take max_facts xs))))) xs
173
174fun mesh_facts maybe_distinct _ max_facts [(_, (sels, unks))] =
175    map fst (take max_facts sels) @ take (max_facts - length sels) unks
176    |> maybe_distinct
177  | mesh_facts _ fact_eq max_facts mess =
178    let
179      val mess = mess |> map (apsnd (apfst (normalize_scores max_facts)))
180
181      fun score_in fact (global_weight, (sels, unks)) =
182        let val score_at = try (nth sels) #> Option.map (fn (_, score) => global_weight * score) in
183          (case find_index (curry fact_eq fact o fst) sels of
184            ~1 => if member fact_eq unks fact then NONE else SOME 0.0
185          | rank => score_at rank)
186        end
187
188      fun weight_of fact = mess |> map_filter (score_in fact) |> scaled_avg
189    in
190      fold (union fact_eq o map fst o take max_facts o fst o snd) mess []
191      |> map (`weight_of) |> sort (int_ord o apply2 fst o swap)
192      |> map snd |> take max_facts
193    end
194
195fun smooth_weight_of_fact rank = Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0 (* FUDGE *)
196fun steep_weight_of_fact rank = Math.pow (0.62, log2 (Real.fromInt (rank + 1))) (* FUDGE *)
197
198fun weight_facts_smoothly facts = facts ~~ map smooth_weight_of_fact (0 upto length facts - 1)
199fun weight_facts_steeply facts = facts ~~ map steep_weight_of_fact (0 upto length facts - 1)
200
201fun sort_array_suffix cmp needed a =
202  let
203    exception BOTTOM of int
204
205    val al = Array.length a
206
207    fun maxson l i =
208      let val i31 = i + i + i + 1 in
209        if i31 + 2 < l then
210          let val x = Unsynchronized.ref i31 in
211            if is_less (cmp (Array.sub (a, i31), Array.sub (a, i31 + 1))) then x := i31 + 1 else ();
212            if is_less (cmp (Array.sub (a, !x), Array.sub (a, i31 + 2))) then x := i31 + 2 else ();
213            !x
214          end
215        else
216          if i31 + 1 < l andalso is_less (cmp (Array.sub (a, i31), Array.sub (a, i31 + 1)))
217          then i31 + 1 else if i31 < l then i31 else raise BOTTOM i
218      end
219
220    fun trickledown l i e =
221      let val j = maxson l i in
222        if is_greater (cmp (Array.sub (a, j), e)) then
223          (Array.update (a, i, Array.sub (a, j)); trickledown l j e)
224        else
225          Array.update (a, i, e)
226      end
227
228    fun trickle l i e = trickledown l i e handle BOTTOM i => Array.update (a, i, e)
229
230    fun bubbledown l i =
231      let val j = maxson l i in
232        Array.update (a, i, Array.sub (a, j));
233        bubbledown l j
234      end
235
236    fun bubble l i = bubbledown l i handle BOTTOM i => i
237
238    fun trickleup i e =
239      let val father = (i - 1) div 3 in
240        if is_less (cmp (Array.sub (a, father), e)) then
241          (Array.update (a, i, Array.sub (a, father));
242           if father > 0 then trickleup father e else Array.update (a, 0, e))
243        else
244          Array.update (a, i, e)
245      end
246
247    fun for i = if i < 0 then () else (trickle al i (Array.sub (a, i)); for (i - 1))
248
249    fun for2 i =
250      if i < Integer.max 2 (al - needed) then
251        ()
252      else
253        let val e = Array.sub (a, i) in
254          Array.update (a, i, Array.sub (a, 0));
255          trickleup (bubble i 0) e;
256          for2 (i - 1)
257        end
258  in
259    for (((al + 1) div 3) - 1);
260    for2 (al - 1);
261    if al > 1 then
262      let val e = Array.sub (a, 1) in
263        Array.update (a, 1, Array.sub (a, 0));
264        Array.update (a, 0, e)
265      end
266    else
267      ()
268  end
269
270fun rev_sort_list_prefix cmp needed xs =
271  let val ary = Array.fromList xs in
272    sort_array_suffix cmp needed ary;
273    Array.foldl (op ::) [] ary
274  end
275
276
277(*** Convenience functions for synchronized access ***)
278
279fun synchronized_timed_value var time_limit =
280  Synchronized.timed_access var time_limit (fn value => SOME (value, value))
281fun synchronized_timed_change_result var time_limit f =
282  Synchronized.timed_access var time_limit (SOME o f)
283fun synchronized_timed_change var time_limit f =
284  synchronized_timed_change_result var time_limit (fn x => ((), f x))
285
286fun mash_time_limit _ = SOME (seconds 0.1)
287
288
289(*** Isabelle-agnostic machine learning ***)
290
291structure MaSh =
292struct
293
294fun select_fact_idxs (big_number : real) recommends =
295  List.app (fn at =>
296    let val (j, ov) = Array.sub (recommends, at) in
297      Array.update (recommends, at, (j, big_number + ov))
298    end)
299
300fun wider_array_of_vector init vec =
301  let val ary = Array.array init in
302    Array.copyVec {src = vec, dst = ary, di = 0};
303    ary
304  end
305
306val nb_def_prior_weight = 1000 (* FUDGE *)
307
308fun learn_facts (tfreq0, sfreq0, dffreq0) num_facts0 num_facts num_feats depss featss =
309  let
310    val tfreq = wider_array_of_vector (num_facts, 0) tfreq0
311    val sfreq = wider_array_of_vector (num_facts, Inttab.empty) sfreq0
312    val dffreq = wider_array_of_vector (num_feats, 0) dffreq0
313
314    fun learn_one th feats deps =
315      let
316        fun add_th weight t =
317          let
318            val im = Array.sub (sfreq, t)
319            fun fold_fn s = Inttab.map_default (s, 0) (Integer.add weight)
320          in
321            map_array_at tfreq (Integer.add weight) t;
322            Array.update (sfreq, t, fold fold_fn feats im)
323          end
324
325        val add_sym = map_array_at dffreq (Integer.add 1)
326      in
327        add_th nb_def_prior_weight th;
328        List.app (add_th 1) deps;
329        List.app add_sym feats
330      end
331
332    fun for i =
333      if i = num_facts then ()
334      else (learn_one i (Vector.sub (featss, i)) (Vector.sub (depss, i)); for (i + 1))
335  in
336    for num_facts0;
337    (Array.vector tfreq, Array.vector sfreq, Array.vector dffreq)
338  end
339
340fun naive_bayes (tfreq, sfreq, dffreq) num_facts max_suggs fact_idxs goal_feats =
341  let
342    val tau = 0.2 (* FUDGE *)
343    val pos_weight = 5.0 (* FUDGE *)
344    val def_val = ~18.0 (* FUDGE *)
345    val init_val = 30.0 (* FUDGE *)
346
347    val ln_afreq = Math.ln (Real.fromInt num_facts)
348    val idf = Vector.map (fn i => ln_afreq - Math.ln (Real.fromInt i)) dffreq
349
350    fun tfidf feat = Vector.sub (idf, feat)
351
352    fun log_posterior i =
353      let
354        val tfreq = Real.fromInt (Vector.sub (tfreq, i))
355
356        fun add_feat (f, fw0) (res, sfh) =
357          (case Inttab.lookup sfh f of
358            SOME sf =>
359            (res + fw0 * tfidf f * Math.ln (pos_weight * Real.fromInt sf / tfreq),
360             Inttab.delete f sfh)
361          | NONE => (res + fw0 * tfidf f * def_val, sfh))
362
363        val (res, sfh) = fold add_feat goal_feats (init_val * Math.ln tfreq, Vector.sub (sfreq, i))
364
365        fun fold_sfh (f, sf) sow =
366          sow + tfidf f * Math.ln (1.0 - Real.fromInt (sf - 1) / tfreq)
367
368        val sum_of_weights = Inttab.fold fold_sfh sfh 0.0
369      in
370        res + tau * sum_of_weights
371      end
372
373    val posterior = Array.tabulate (num_facts, (fn j => (j, log_posterior j)))
374
375    fun ret at acc =
376      if at = num_facts then acc else ret (at + 1) (Array.sub (posterior, at) :: acc)
377  in
378    select_fact_idxs 100000.0 posterior fact_idxs;
379    sort_array_suffix (Real.compare o apply2 snd) max_suggs posterior;
380    ret (Integer.max 0 (num_facts - max_suggs)) []
381  end
382
383val initial_k = 0
384
385fun k_nearest_neighbors dffreq num_facts num_feats depss featss max_suggs fact_idxs goal_feats =
386  let
387    exception EXIT of unit
388
389    val ln_afreq = Math.ln (Real.fromInt num_facts)
390    fun tfidf feat = ln_afreq - Math.ln (Real.fromInt (Vector.sub (dffreq, feat)))
391
392    val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0)
393
394    val feat_facts = Array.array (num_feats, [])
395    val _ = Vector.foldl (fn (feats, fact) =>
396      (List.app (map_array_at feat_facts (cons fact)) feats; fact + 1)) 0 featss
397
398    fun do_feat (s, sw0) =
399      let
400        val sw = sw0 * tfidf s
401        val w6 = Math.pow (sw, 6.0 (* FUDGE *))
402
403        fun inc_overlap j =
404          let val (_, ov) = Array.sub (overlaps_sqr, j) in
405            Array.update (overlaps_sqr, j, (j, w6 + ov))
406          end
407      in
408        List.app inc_overlap (Array.sub (feat_facts, s))
409      end
410
411    val _ = List.app do_feat goal_feats
412    val _ = sort_array_suffix (Real.compare o apply2 snd) num_facts overlaps_sqr
413    val no_recommends = Unsynchronized.ref 0
414    val recommends = Array.tabulate (num_facts, rpair 0.0)
415    val age = Unsynchronized.ref 500000000.0
416
417    fun inc_recommend v j =
418      let val (_, ov) = Array.sub (recommends, j) in
419        if ov <= 0.0 then
420          (no_recommends := !no_recommends + 1; Array.update (recommends, j, (j, !age + ov)))
421        else
422          Array.update (recommends, j, (j, v + ov))
423      end
424
425    val k = Unsynchronized.ref 0
426    fun do_k k =
427      if k >= num_facts then
428        raise EXIT ()
429      else
430        let
431          val deps_factor = 2.7 (* FUDGE *)
432          val (j, o2) = Array.sub (overlaps_sqr, num_facts - k - 1)
433          val _ = inc_recommend o2 j
434          val ds = Vector.sub (depss, j)
435          val l = Real.fromInt (length ds)
436        in
437          List.app (inc_recommend (deps_factor * o2 / l)) ds
438        end
439
440    fun while1 () =
441      if !k = initial_k + 1 then () else (do_k (!k); k := !k + 1; while1 ())
442      handle EXIT () => ()
443
444    fun while2 () =
445      if !no_recommends >= max_suggs then ()
446      else (do_k (!k); k := !k + 1; age := !age - 10000.0; while2 ())
447      handle EXIT () => ()
448
449    fun ret acc at =
450      if at = num_facts then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
451  in
452    while1 ();
453    while2 ();
454    select_fact_idxs 1000000000.0 recommends fact_idxs;
455    sort_array_suffix (Real.compare o apply2 snd) max_suggs recommends;
456    ret [] (Integer.max 0 (num_facts - max_suggs))
457  end
458
459(* experimental *)
460fun external_tool tool max_suggs learns goal_feats =
461  let
462    val ser = string_of_int (serial ()) (* poor person's attempt at thread-safety *)
463    val ocs = TextIO.openOut ("adv_syms" ^ ser)
464    val ocd = TextIO.openOut ("adv_deps" ^ ser)
465    val ocq = TextIO.openOut ("adv_seq" ^ ser)
466    val occ = TextIO.openOut ("adv_conj" ^ ser)
467
468    fun os oc s = TextIO.output (oc, s)
469
470    fun ol _ _ _ [] = ()
471      | ol _ f _ [e] = f e
472      | ol oc f sep (h :: t) = (f h; os oc sep; ol oc f sep t)
473
474    fun do_learn (name, feats, deps) =
475      (os ocs name; os ocs ":"; ol ocs (os ocs o quote) ", " feats; os ocs "\n";
476       os ocd name; os ocd ":"; ol ocd (os ocd) " " deps; os ocd "\n"; os ocq name; os ocq "\n")
477
478    fun forkexec no =
479      let
480        val cmd =
481          "~/misc/" ^ tool ^ " adv_syms" ^ ser ^ " adv_deps" ^ ser ^ " " ^ string_of_int no ^
482          " adv_seq" ^ ser ^ " < adv_conj" ^ ser
483      in
484        fst (Isabelle_System.bash_output cmd)
485        |> space_explode " "
486        |> filter_out (curry (op =) "")
487      end
488  in
489    (List.app do_learn learns; ol occ (os occ o quote) ", " (map fst goal_feats);
490     TextIO.closeOut ocs; TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ;
491     forkexec max_suggs)
492  end
493
494fun k_nearest_neighbors_ext max_suggs =
495  external_tool ("newknn/knn" ^ " " ^ string_of_int initial_k) max_suggs
496fun naive_bayes_ext max_suggs = external_tool "predict/nbayes" max_suggs
497
498fun query_external ctxt algorithm max_suggs learns goal_feats =
499  (trace_msg ctxt (fn () => "MaSh query external " ^ commas (map fst goal_feats));
500   (case algorithm of
501     MaSh_NB_Ext => naive_bayes_ext max_suggs learns goal_feats
502   | MaSh_kNN_Ext => k_nearest_neighbors_ext max_suggs learns goal_feats))
503
504fun query_internal ctxt algorithm num_facts num_feats (fact_names, featss, depss)
505    (freqs as (_, _, dffreq)) fact_idxs max_suggs goal_feats int_goal_feats =
506  let
507    fun nb () =
508      naive_bayes freqs num_facts max_suggs fact_idxs int_goal_feats
509      |> map fst
510    fun knn () =
511      k_nearest_neighbors dffreq num_facts num_feats depss featss max_suggs fact_idxs int_goal_feats
512      |> map fst
513  in
514    (trace_msg ctxt (fn () => "MaSh query internal " ^ commas (map fst goal_feats) ^ " from {" ^
515       elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] fact_names)) ^ "}");
516     (case algorithm of
517       MaSh_NB => nb ()
518     | MaSh_kNN => knn ()
519     | MaSh_NB_kNN =>
520       mesh_facts I (op =) max_suggs
521         [(0.5 (* FUDGE *), (weight_facts_steeply (nb ()), [])),
522          (0.5 (* FUDGE *), (weight_facts_steeply (knn ()), []))])
523     |> map (curry Vector.sub fact_names))
524   end
525
526end;
527
528
529(*** Persistent, stringly-typed state ***)
530
531fun meta_char c =
532  if Char.isAlphaNum c orelse c = #"_" orelse c = #"." orelse c = #"(" orelse c = #")" orelse
533     c = #"," orelse c = #"'" then
534    String.str c
535  else
536    (* fixed width, in case more digits follow *)
537    "%" ^ stringN_of_int 3 (Char.ord c)
538
539fun unmeta_chars accum [] = String.implode (rev accum)
540  | unmeta_chars accum (#"%" :: d1 :: d2 :: d3 :: cs) =
541    (case Int.fromString (String.implode [d1, d2, d3]) of
542      SOME n => unmeta_chars (Char.chr n :: accum) cs
543    | NONE => "" (* error *))
544  | unmeta_chars _ (#"%" :: _) = "" (* error *)
545  | unmeta_chars accum (c :: cs) = unmeta_chars (c :: accum) cs
546
547val encode_str = String.translate meta_char
548val encode_strs = map encode_str #> space_implode " "
549
550fun decode_str s =
551  if String.isSubstring "%" s then unmeta_chars [] (String.explode s) else s;
552
553fun decode_strs s =
554  space_explode " " s |> String.isSubstring "%" s ? map decode_str;
555
556datatype proof_kind = Isar_Proof | Automatic_Proof | Isar_Proof_wegen_Prover_Flop
557
558fun str_of_proof_kind Isar_Proof = "i"
559  | str_of_proof_kind Automatic_Proof = "a"
560  | str_of_proof_kind Isar_Proof_wegen_Prover_Flop = "x"
561
562fun proof_kind_of_str "a" = Automatic_Proof
563  | proof_kind_of_str "x" = Isar_Proof_wegen_Prover_Flop
564  | proof_kind_of_str _ (* "i" *) = Isar_Proof
565
566fun add_edge_to name parent =
567  Graph.default_node (parent, (Isar_Proof, [], []))
568  #> Graph.add_edge (parent, name)
569
570fun add_node kind name parents feats deps (accum as (access_G, (fact_xtab, feat_xtab), learns)) =
571  let val fact_xtab' = add_to_xtab name fact_xtab in
572    ((Graph.new_node (name, (kind, feats, deps)) access_G
573      handle Graph.DUP _ => Graph.map_node name (K (kind, feats, deps)) access_G)
574     |> fold (add_edge_to name) parents,
575     (fact_xtab', fold maybe_add_to_xtab feats feat_xtab),
576     (name, feats, deps) :: learns)
577  end
578  handle Symtab.DUP _ => accum (* robustness (in case the state file violates the invariant) *)
579
580fun try_graph ctxt when def f =
581  f ()
582  handle
583    Graph.CYCLES (cycle :: _) =>
584    (trace_msg ctxt (fn () => "Cycle involving " ^ commas cycle ^ " when " ^ when); def)
585  | Graph.DUP name =>
586    (trace_msg ctxt (fn () => "Duplicate fact " ^ quote name ^ " when " ^ when); def)
587  | Graph.UNDEF name =>
588    (trace_msg ctxt (fn () => "Unknown fact " ^ quote name ^ " when " ^ when); def)
589  | exn =>
590    if Exn.is_interrupt exn then
591      Exn.reraise exn
592    else
593      (trace_msg ctxt (fn () => "Internal error when " ^ when ^ ":\n" ^ Runtime.exn_message exn);
594       def)
595
596fun graph_info G =
597  string_of_int (length (Graph.keys G)) ^ " node(s), " ^
598  string_of_int (fold (Integer.add o length o snd) (Graph.dest G) 0) ^ " edge(s), " ^
599  string_of_int (length (Graph.maximals G)) ^ " maximal"
600
601type ffds = string vector * int list vector * int list vector
602type freqs = int vector * int Inttab.table vector * int vector
603
604type mash_state =
605  {access_G : (proof_kind * string list * string list) Graph.T,
606   xtabs : xtab * xtab,
607   ffds : ffds,
608   freqs : freqs,
609   dirty_facts : string list option}
610
611val empty_xtabs = (empty_xtab, empty_xtab)
612val empty_ffds = (Vector.fromList [], Vector.fromList [], Vector.fromList []) : ffds
613val empty_freqs = (Vector.fromList [], Vector.fromList [], Vector.fromList []) : freqs
614
615val empty_state =
616  {access_G = Graph.empty,
617   xtabs = empty_xtabs,
618   ffds = empty_ffds,
619   freqs = empty_freqs,
620   dirty_facts = SOME []} : mash_state
621
622fun recompute_ffds_freqs_from_learns (learns : (string * string list * string list) list)
623    ((num_facts, fact_tab), (num_feats, feat_tab)) num_facts0 (fact_names0, featss0, depss0) freqs0 =
624  let
625    val fact_names = Vector.concat [fact_names0, Vector.fromList (map #1 learns)]
626    val featss = Vector.concat [featss0,
627      Vector.fromList (map (map_filter (Symtab.lookup feat_tab) o #2) learns)]
628    val depss = Vector.concat [depss0,
629      Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)]
630  in
631    ((fact_names, featss, depss),
632     MaSh.learn_facts freqs0 num_facts0 num_facts num_feats depss featss)
633  end
634
635fun reorder_learns (num_facts, fact_tab) learns =
636  let val ary = Array.array (num_facts, ("", [], [])) in
637    List.app (fn learn as (fact, _, _) =>
638        Array.update (ary, the (Symtab.lookup fact_tab fact), learn))
639      learns;
640    Array.foldr (op ::) [] ary
641  end
642
643fun recompute_ffds_freqs_from_access_G access_G (xtabs as (fact_xtab, _)) =
644  let
645    val learns =
646      Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
647      |> reorder_learns fact_xtab
648  in
649    recompute_ffds_freqs_from_learns learns xtabs 0 empty_ffds empty_freqs
650  end
651
652local
653
654val version = "*** MaSh version 20190121 ***"
655
656exception FILE_VERSION_TOO_NEW of unit
657
658fun extract_node line =
659  (case space_explode ":" line of
660    [head, tail] =>
661    (case (space_explode " " head, map (unprefix " ") (space_explode ";" tail)) of
662      ([kind, name], [parents, feats, deps]) =>
663      SOME (proof_kind_of_str kind, decode_str name, decode_strs parents, decode_strs feats,
664        decode_strs deps)
665    | _ => NONE)
666  | _ => NONE)
667
668fun would_load_state (memory_time, _) =
669  let val path = state_file () in
670    (case try OS.FileSys.modTime (File.platform_path path) of
671      NONE => false
672    | SOME disk_time => memory_time < disk_time)
673  end;
674
675fun load_state ctxt (time_state as (memory_time, _)) =
676  let val path = state_file () in
677    (case try OS.FileSys.modTime (File.platform_path path) of
678      NONE => time_state
679    | SOME disk_time =>
680      if memory_time >= disk_time then
681        time_state
682      else
683        (disk_time,
684         (case try File.read_lines path of
685           SOME (version' :: node_lines) =>
686           let
687             fun extract_line_and_add_node line =
688               (case extract_node line of
689                 NONE => I (* should not happen *)
690               | SOME (kind, name, parents, feats, deps) => add_node kind name parents feats deps)
691
692             val empty_G_etc = (Graph.empty, empty_xtabs, [])
693
694             val (access_G, xtabs, rev_learns) =
695               (case string_ord (version', version) of
696                 EQUAL =>
697                 try_graph ctxt "loading state" empty_G_etc
698                   (fn () => fold extract_line_and_add_node node_lines empty_G_etc)
699               | LESS => (remove_state_file (); empty_G_etc) (* cannot parse old file *)
700               | GREATER => raise FILE_VERSION_TOO_NEW ())
701
702             val (ffds, freqs) =
703               recompute_ffds_freqs_from_learns (rev rev_learns) xtabs 0 empty_ffds empty_freqs
704           in
705             trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")");
706             {access_G = access_G, xtabs = xtabs, ffds = ffds, freqs = freqs, dirty_facts = SOME []}
707           end
708         | _ => empty_state)))
709  end
710
711fun str_of_entry (kind, name, parents, feats, deps) =
712  str_of_proof_kind kind ^ " " ^ encode_str name ^ ": " ^ encode_strs parents ^ "; " ^
713  encode_strs feats ^ "; " ^ encode_strs deps ^ "\n"
714
715fun save_state _ (time_state as (_, {dirty_facts = SOME [], ...})) = time_state
716  | save_state ctxt (memory_time, {access_G, xtabs, ffds, freqs, dirty_facts}) =
717    let
718      fun append_entry (name, ((kind, feats, deps), (parents, _))) =
719        cons (kind, name, Graph.Keys.dest parents, feats, deps)
720
721      val path = state_file ()
722      val dirty_facts' =
723        (case try OS.FileSys.modTime (File.platform_path path) of
724          NONE => NONE
725        | SOME disk_time => if disk_time <= memory_time then dirty_facts else NONE)
726      val (banner, entries) =
727        (case dirty_facts' of
728          SOME names => (NONE, fold (append_entry o Graph.get_entry access_G) names [])
729        | NONE => (SOME (version ^ "\n"), Graph.fold append_entry access_G []))
730    in
731      (case banner of SOME s => File.write path s | NONE => ();
732       entries |> chunk_list 500 |> List.app (File.append path o implode o map str_of_entry))
733      handle IO.Io _ => ();
734      trace_msg ctxt (fn () =>
735        "Saved fact graph (" ^ graph_info access_G ^
736        (case dirty_facts of
737          SOME dirty_facts => "; " ^ string_of_int (length dirty_facts) ^ " dirty fact(s)"
738        | _ => "") ^  ")");
739      (Time.now (),
740       {access_G = access_G, xtabs = xtabs, ffds = ffds, freqs = freqs, dirty_facts = SOME []})
741    end
742
743val global_state = Synchronized.var "Sledgehammer_MaSh.global_state" (Time.zeroTime, empty_state)
744
745in
746
747fun map_state ctxt f =
748  (trace_msg ctxt (fn () => "Changing MaSh state");
749   synchronized_timed_change global_state mash_time_limit
750     (load_state ctxt ##> f #> save_state ctxt))
751  |> ignore
752  handle FILE_VERSION_TOO_NEW () => ()
753
754fun peek_state ctxt =
755  (trace_msg ctxt (fn () => "Peeking at MaSh state");
756   (case synchronized_timed_value global_state mash_time_limit of
757     NONE => NONE
758   | SOME state => if would_load_state state then NONE else SOME state))
759
760fun get_state ctxt =
761  (trace_msg ctxt (fn () => "Retrieving MaSh state");
762   synchronized_timed_change_result global_state mash_time_limit
763     (perhaps (try (load_state ctxt)) #> `snd))
764
765fun clear_state ctxt =
766  (trace_msg ctxt (fn () => "Clearing MaSh state");
767   Synchronized.change global_state (fn _ => (remove_state_file (); (Time.zeroTime, empty_state))))
768
769end
770
771
772(*** Isabelle helpers ***)
773
774fun crude_printed_term size t =
775  let
776    fun term _ (res, 0) = (res, 0)
777      | term (t $ u) (res, size) =
778        let
779          val (res, size) = term t (res ^ "(", size)
780          val (res, size) = term u (res ^ " ", size)
781        in
782          (res ^ ")", size)
783        end
784      | term (Abs (s, _, t)) (res, size) = term t (res ^ "%" ^ s ^ ".", size - 1)
785      | term (Bound n) (res, size) = (res ^ "#" ^ string_of_int n, size - 1)
786      | term (Const (s, _)) (res, size) = (res ^ Long_Name.base_name s, size - 1)
787      | term (Free (s, _)) (res, size) = (res ^ s, size - 1)
788      | term (Var ((s, _), _)) (res, size) = (res ^ s, size - 1)
789  in
790    fst (term t ("", size))
791  end
792
793fun nickname_of_thm th =
794  if Thm.has_name_hint th then
795    let val hint = Thm.get_name_hint th in
796      (* There must be a better way to detect local facts. *)
797      (case Long_Name.dest_local hint of
798        SOME suf =>
799        Long_Name.implode [Thm.theory_name th, suf, crude_printed_term 25 (Thm.prop_of th)]
800      | NONE => hint)
801    end
802  else
803    crude_printed_term 50 (Thm.prop_of th)
804
805fun find_suggested_facts ctxt facts =
806  let
807    fun add (fact as (_, th)) = Symtab.default (nickname_of_thm th, fact)
808    val tab = fold add facts Symtab.empty
809    fun lookup nick =
810      Symtab.lookup tab nick
811      |> tap (fn NONE => trace_msg ctxt (fn () => "Cannot find " ^ quote nick) | _ => ())
812  in map_filter lookup end
813
814fun free_feature_of s = "f" ^ s
815fun thy_feature_of s = "y" ^ s
816fun type_feature_of s = "t" ^ s
817fun class_feature_of s = "s" ^ s
818val local_feature = "local"
819
820fun crude_thm_ord ctxt =
821  let
822    val ancestor_lengths =
823      fold (fn thy => Symtab.update (Context.theory_name thy, length (Context.ancestors_of thy)))
824        (Theory.nodes_of (Proof_Context.theory_of ctxt)) Symtab.empty
825    val ancestor_length = Symtab.lookup ancestor_lengths o Context.theory_id_name
826
827    fun crude_theory_ord p =
828      if Context.eq_thy_id p then EQUAL
829      else if Context.proper_subthy_id p then LESS
830      else if Context.proper_subthy_id (swap p) then GREATER
831      else
832        (case apply2 ancestor_length p of
833          (SOME m, SOME n) =>
834            (case int_ord (m, n) of
835              EQUAL => string_ord (apply2 Context.theory_id_name p)
836            | ord => ord)
837        | _ => string_ord (apply2 Context.theory_id_name p))
838  in
839    fn p =>
840      (case crude_theory_ord (apply2 Thm.theory_id p) of
841        EQUAL =>
842        (* The hack below is necessary because of odd dependencies that are not reflected in the theory
843           comparison. *)
844        let val q = apply2 nickname_of_thm p in
845          (* Hack to put "xxx_def" before "xxxI" and "xxxE" *)
846          (case bool_ord (apply2 (String.isSuffix "_def") (swap q)) of
847            EQUAL => string_ord q
848          | ord => ord)
849        end
850      | ord => ord)
851  end;
852
853val thm_less_eq = Context.subthy_id o apply2 Thm.theory_id
854fun thm_less p = thm_less_eq p andalso not (thm_less_eq (swap p))
855
856val freezeT = Type.legacy_freeze_type
857
858fun freeze (t $ u) = freeze t $ freeze u
859  | freeze (Abs (s, T, t)) = Abs (s, freezeT T, freeze t)
860  | freeze (Var ((s, _), T)) = Free (s, freezeT T)
861  | freeze (Const (s, T)) = Const (s, freezeT T)
862  | freeze (Free (s, T)) = Free (s, freezeT T)
863  | freeze t = t
864
865fun goal_of_thm thy = Thm.prop_of #> freeze #> Thm.global_cterm_of thy #> Goal.init
866
867fun run_prover_for_mash ctxt params prover goal_name facts goal =
868  let
869    val problem =
870      {comment = "Goal: " ^ goal_name, state = Proof.init ctxt, goal = goal, subgoal = 1,
871       subgoal_count = 1, factss = [("", facts)], found_proof = I}
872  in
873    get_minimizing_prover ctxt MaSh (K ()) prover params problem
874  end
875
876val bad_types = [\<^type_name>\<open>prop\<close>, \<^type_name>\<open>bool\<close>, \<^type_name>\<open>fun\<close>]
877
878val crude_str_of_sort = space_implode "," o map Long_Name.base_name o subtract (op =) \<^sort>\<open>type\<close>
879
880fun crude_str_of_typ (Type (s, [])) = Long_Name.base_name s
881  | crude_str_of_typ (Type (s, Ts)) = Long_Name.base_name s ^ implode (map crude_str_of_typ Ts)
882  | crude_str_of_typ (TFree (_, S)) = crude_str_of_sort S
883  | crude_str_of_typ (TVar (_, S)) = crude_str_of_sort S
884
885fun maybe_singleton_str "" = []
886  | maybe_singleton_str s = [s]
887
888val max_pat_breadth = 5 (* FUDGE *)
889
890fun term_features_of ctxt thy_name term_max_depth type_max_depth ts =
891  let
892    val thy = Proof_Context.theory_of ctxt
893
894    val fixes = map snd (Variable.dest_fixes ctxt)
895    val classes = Sign.classes_of thy
896
897    fun add_classes \<^sort>\<open>type\<close> = I
898      | add_classes S =
899        fold (`(Sorts.super_classes classes)
900          #> swap #> op ::
901          #> subtract (op =) \<^sort>\<open>type\<close>
902          #> map class_feature_of
903          #> union (op =)) S
904
905    fun pattify_type 0 _ = []
906      | pattify_type _ (Type (s, [])) = if member (op =) bad_types s then [] else [s]
907      | pattify_type depth (Type (s, U :: Ts)) =
908        let
909          val T = Type (s, Ts)
910          val ps = take max_pat_breadth (pattify_type depth T)
911          val qs = take max_pat_breadth ("" :: pattify_type (depth - 1) U)
912        in
913          map_product (fn p => fn "" => p | q => p ^ "(" ^ q ^ ")") ps qs
914        end
915      | pattify_type _ (TFree (_, S)) = maybe_singleton_str (crude_str_of_sort S)
916      | pattify_type _ (TVar (_, S)) = maybe_singleton_str (crude_str_of_sort S)
917
918    fun add_type_pat depth T =
919      union (op =) (map type_feature_of (pattify_type depth T))
920
921    fun add_type_pats 0 _ = I
922      | add_type_pats depth t = add_type_pat depth t #> add_type_pats (depth - 1) t
923
924    fun add_type T =
925      add_type_pats type_max_depth T
926      #> fold_atyps_sorts (add_classes o snd) T
927
928    fun add_subtypes (T as Type (_, Ts)) = add_type T #> fold add_subtypes Ts
929      | add_subtypes T = add_type T
930
931    fun pattify_term _ 0 _ = []
932      | pattify_term _ _ (Const (s, _)) =
933        if is_widely_irrelevant_const s then [] else [s]
934      | pattify_term _ _ (Free (s, T)) =
935        maybe_singleton_str (crude_str_of_typ T)
936        |> (if member (op =) fixes s then cons (free_feature_of (Long_Name.append thy_name s))
937            else I)
938      | pattify_term _ _ (Var (_, T)) =
939        maybe_singleton_str (crude_str_of_typ T)
940      | pattify_term Ts _ (Bound j) =
941        maybe_singleton_str (crude_str_of_typ (nth Ts j))
942      | pattify_term Ts depth (t $ u) =
943        let
944          val ps = take max_pat_breadth (pattify_term Ts depth t)
945          val qs = take max_pat_breadth ("" :: pattify_term Ts (depth - 1) u)
946        in
947          map_product (fn p => fn "" => p | q => p ^ "(" ^ q ^ ")") ps qs
948        end
949      | pattify_term _ _ _ = []
950
951    fun add_term_pat Ts = union (op =) oo pattify_term Ts
952
953    fun add_term_pats _ 0 _ = I
954      | add_term_pats Ts depth t = add_term_pat Ts depth t #> add_term_pats Ts (depth - 1) t
955
956    fun add_term Ts = add_term_pats Ts term_max_depth
957
958    fun add_subterms Ts t =
959      (case strip_comb t of
960        (Const (s, T), args) =>
961        (not (is_widely_irrelevant_const s) ? add_term Ts t)
962        #> add_subtypes T #> fold (add_subterms Ts) args
963      | (head, args) =>
964        (case head of
965           Free (_, T) => add_term Ts t #> add_subtypes T
966         | Var (_, T) => add_subtypes T
967         | Abs (_, T, body) => add_subtypes T #> add_subterms (T :: Ts) body
968         | _ => I)
969        #> fold (add_subterms Ts) args)
970  in
971    fold (add_subterms []) ts []
972  end
973
974val term_max_depth = 2
975val type_max_depth = 1
976
977(* TODO: Generate type classes for types? *)
978fun features_of ctxt thy_name (scope, _) ts =
979  thy_feature_of thy_name ::
980  term_features_of ctxt thy_name term_max_depth type_max_depth ts
981  |> scope <> Global ? cons local_feature
982
983(* Too many dependencies is a sign that a decision procedure is at work. There is not much to learn
984   from such proofs. *)
985val max_dependencies = 20 (* FUDGE *)
986
987val prover_default_max_facts = 25 (* FUDGE *)
988
989(* "type_definition_xxx" facts are characterized by their use of "CollectI". *)
990val typedef_dep = nickname_of_thm @{thm CollectI}
991(* Mysterious parts of the class machinery create lots of proofs that refer exclusively to
992   "someI_ex" (and to some internal constructions). *)
993val class_some_dep = nickname_of_thm @{thm someI_ex}
994
995val fundef_ths =
996  @{thms fundef_ex1_existence fundef_ex1_uniqueness fundef_ex1_iff fundef_default_value}
997  |> map nickname_of_thm
998
999(* "Rep_xxx_inject", "Abs_xxx_inverse", etc., are derived using these facts. *)
1000val typedef_ths =
1001  @{thms type_definition.Abs_inverse type_definition.Rep_inverse type_definition.Rep
1002      type_definition.Rep_inject type_definition.Abs_inject type_definition.Rep_cases
1003      type_definition.Abs_cases type_definition.Rep_induct type_definition.Abs_induct
1004      type_definition.Rep_range type_definition.Abs_image}
1005  |> map nickname_of_thm
1006
1007fun is_size_def [dep] th =
1008    (case first_field ".rec" dep of
1009      SOME (pref, _) =>
1010      (case first_field ".size" (nickname_of_thm th) of
1011        SOME (pref', _) => pref = pref'
1012      | NONE => false)
1013    | NONE => false)
1014  | is_size_def _ _ = false
1015
1016fun trim_dependencies deps =
1017  if length deps > max_dependencies then NONE else SOME deps
1018
1019fun isar_dependencies_of name_tabs th =
1020  thms_in_proof max_dependencies (SOME name_tabs) th
1021  |> Option.map (fn deps =>
1022    if deps = [typedef_dep] orelse deps = [class_some_dep] orelse
1023        exists (member (op =) fundef_ths) deps orelse exists (member (op =) typedef_ths) deps orelse
1024        is_size_def deps th then
1025      []
1026    else
1027      deps)
1028
1029fun prover_dependencies_of ctxt (params as {verbose, max_facts, ...}) prover auto_level facts
1030    name_tabs th =
1031  (case isar_dependencies_of name_tabs th of
1032    SOME [] => (false, [])
1033  | isar_deps0 =>
1034    let
1035      val isar_deps = these isar_deps0
1036      val thy = Proof_Context.theory_of ctxt
1037      val goal = goal_of_thm thy th
1038      val name = nickname_of_thm th
1039      val (_, hyp_ts, concl_t) = ATP_Util.strip_subgoal goal 1 ctxt
1040      val facts = facts |> filter (fn (_, th') => thm_less (th', th))
1041
1042      fun nickify ((_, stature), th) = ((nickname_of_thm th, stature), th)
1043
1044      fun is_dep dep (_, th) = (nickname_of_thm th = dep)
1045
1046      fun add_isar_dep facts dep accum =
1047        if exists (is_dep dep) accum then
1048          accum
1049        else
1050          (case find_first (is_dep dep) facts of
1051            SOME ((_, status), th) => accum @ [(("", status), th)]
1052          | NONE => accum (* should not happen *))
1053
1054      val mepo_facts =
1055        facts
1056        |> mepo_suggested_facts ctxt params (max_facts |> the_default prover_default_max_facts) NONE
1057             hyp_ts concl_t
1058      val facts =
1059        mepo_facts
1060        |> fold (add_isar_dep facts) isar_deps
1061        |> map nickify
1062      val num_isar_deps = length isar_deps
1063    in
1064      if verbose andalso auto_level = 0 then
1065        writeln ("MaSh: " ^ quote prover ^ " on " ^ quote name ^ " with " ^
1066          string_of_int num_isar_deps ^ " + " ^ string_of_int (length facts - num_isar_deps) ^
1067          " facts")
1068      else
1069        ();
1070      (case run_prover_for_mash ctxt params prover name facts goal of
1071        {outcome = NONE, used_facts, ...} =>
1072        (if verbose andalso auto_level = 0 then
1073           let val num_facts = length used_facts in
1074             writeln ("Found proof with " ^ string_of_int num_facts ^ " fact" ^
1075               plural_s num_facts)
1076           end
1077         else
1078           ();
1079         (true, map fst used_facts))
1080      | _ => (false, isar_deps))
1081    end)
1082
1083
1084(*** High-level communication with MaSh ***)
1085
1086(* In the following functions, chunks are risers w.r.t. "thm_less_eq". *)
1087
1088fun chunks_and_parents_for chunks th =
1089  let
1090    fun insert_parent new parents =
1091      let val parents = parents |> filter_out (fn p => thm_less_eq (p, new)) in
1092        parents |> forall (fn p => not (thm_less_eq (new, p))) parents ? cons new
1093      end
1094
1095    fun rechunk seen (rest as th' :: ths) =
1096      if thm_less_eq (th', th) then (rev seen, rest)
1097      else rechunk (th' :: seen) ths
1098
1099    fun do_chunk [] accum = accum
1100      | do_chunk (chunk as hd_chunk :: _) (chunks, parents) =
1101        if thm_less_eq (hd_chunk, th) then
1102          (chunk :: chunks, insert_parent hd_chunk parents)
1103        else if thm_less_eq (List.last chunk, th) then
1104          let val (front, back as hd_back :: _) = rechunk [] chunk in
1105            (front :: back :: chunks, insert_parent hd_back parents)
1106          end
1107        else
1108          (chunk :: chunks, parents)
1109  in
1110    fold_rev do_chunk chunks ([], [])
1111    |>> cons []
1112    ||> map nickname_of_thm
1113  end
1114
1115fun attach_parents_to_facts _ [] = []
1116  | attach_parents_to_facts old_facts (facts as (_, th) :: _) =
1117    let
1118      fun do_facts _ [] = []
1119        | do_facts (_, parents) [fact] = [(parents, fact)]
1120        | do_facts (chunks, parents)
1121                   ((fact as (_, th)) :: (facts as (_, th') :: _)) =
1122          let
1123            val chunks = app_hd (cons th) chunks
1124            val chunks_and_parents' =
1125              if thm_less_eq (th, th') andalso
1126                Thm.theory_name th = Thm.theory_name th'
1127              then (chunks, [nickname_of_thm th])
1128              else chunks_and_parents_for chunks th'
1129          in
1130            (parents, fact) :: do_facts chunks_and_parents' facts
1131          end
1132    in
1133      old_facts @ facts
1134      |> do_facts (chunks_and_parents_for [[]] th)
1135      |> drop (length old_facts)
1136    end
1137
1138fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm
1139
1140val chained_feature_factor = 0.5 (* FUDGE *)
1141val extra_feature_factor = 0.1 (* FUDGE *)
1142val num_extra_feature_facts = 10 (* FUDGE *)
1143
1144val max_proximity_facts = 100 (* FUDGE *)
1145
1146fun find_mash_suggestions ctxt max_facts suggs facts chained raw_unknown =
1147  let
1148    val inter_fact = inter (eq_snd Thm.eq_thm_prop)
1149    val raw_mash = find_suggested_facts ctxt facts suggs
1150    val proximate = take max_proximity_facts facts
1151    val unknown_chained = inter_fact raw_unknown chained
1152    val unknown_proximate = inter_fact raw_unknown proximate
1153    val mess =
1154      [(0.9 (* FUDGE *), (map (rpair 1.0) unknown_chained, [])),
1155       (0.4 (* FUDGE *), (weight_facts_smoothly unknown_proximate, [])),
1156       (0.1 (* FUDGE *), (weight_facts_steeply raw_mash, raw_unknown))]
1157    val unknown = raw_unknown
1158      |> fold (subtract (eq_snd Thm.eq_thm_prop)) [unknown_chained, unknown_proximate]
1159  in
1160    (mesh_facts (fact_distinct (op aconv)) (eq_snd (gen_eq_thm ctxt)) max_facts mess, unknown)
1161  end
1162
1163fun mash_suggested_facts ctxt thy_name ({debug, ...} : params) max_suggs hyp_ts concl_t facts =
1164  let
1165    val algorithm = the_mash_algorithm ()
1166
1167    val facts = facts
1168      |> rev_sort_list_prefix (crude_thm_ord ctxt o apply2 snd)
1169        (Int.max (num_extra_feature_facts, max_proximity_facts))
1170
1171    val chained = filter (fn ((_, (scope, _)), _) => scope = Chained) facts
1172
1173    fun fact_has_right_theory (_, th) = thy_name = Thm.theory_name th
1174
1175    fun chained_or_extra_features_of factor (((_, stature), th), weight) =
1176      [Thm.prop_of th]
1177      |> features_of ctxt (Thm.theory_name th) stature
1178      |> map (rpair (weight * factor))
1179  in
1180    (case get_state ctxt of
1181      NONE => ([], [])
1182    | SOME {access_G, xtabs = ((num_facts, fact_tab), (num_feats, feat_tab)), ffds, freqs, ...} =>
1183      let
1184        val goal_feats0 =
1185          features_of ctxt thy_name (Local, General) (concl_t :: hyp_ts)
1186        val chained_feats = chained
1187          |> map (rpair 1.0)
1188          |> map (chained_or_extra_features_of chained_feature_factor)
1189          |> rpair [] |-> fold (union (eq_fst (op =)))
1190        val extra_feats = facts
1191          |> take (Int.max (0, num_extra_feature_facts - length chained))
1192          |> filter fact_has_right_theory
1193          |> weight_facts_steeply
1194          |> map (chained_or_extra_features_of extra_feature_factor)
1195          |> rpair [] |-> fold (union (eq_fst (op =)))
1196
1197        val goal_feats =
1198          fold (union (eq_fst (op =))) [chained_feats, extra_feats] (map (rpair 1.0) goal_feats0)
1199          |> debug ? sort (Real.compare o swap o apply2 snd)
1200
1201        val fact_idxs = map_filter (Symtab.lookup fact_tab o nickname_of_thm o snd) facts
1202
1203        val suggs =
1204          if algorithm = MaSh_NB_Ext orelse algorithm = MaSh_kNN_Ext then
1205            let
1206              val learns =
1207                Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps))
1208                  access_G
1209            in
1210              MaSh.query_external ctxt algorithm max_suggs learns goal_feats
1211            end
1212          else
1213            let
1214              val int_goal_feats =
1215                map_filter (fn (s, w) => Option.map (rpair w) (Symtab.lookup feat_tab s)) goal_feats
1216            in
1217              MaSh.query_internal ctxt algorithm num_facts num_feats ffds freqs fact_idxs max_suggs
1218                goal_feats int_goal_feats
1219            end
1220
1221        val unknown = filter_out (is_fact_in_graph access_G o snd) facts
1222      in
1223        find_mash_suggestions ctxt max_suggs suggs facts chained unknown
1224        |> apply2 (map fact_of_raw_fact)
1225      end)
1226  end
1227
1228fun mash_unlearn ctxt = (clear_state ctxt; writeln "Reset MaSh")
1229
1230fun learn_wrt_access_graph ctxt (name, parents, feats, deps)
1231    (accum as (access_G, (fact_xtab, feat_xtab))) =
1232  let
1233    fun maybe_learn_from from (accum as (parents, access_G)) =
1234      try_graph ctxt "updating graph" accum (fn () =>
1235        (from :: parents, Graph.add_edge_acyclic (from, name) access_G))
1236
1237    val access_G = access_G |> Graph.default_node (name, (Isar_Proof, feats, deps))
1238    val (parents, access_G) = ([], access_G) |> fold maybe_learn_from parents
1239    val (deps, _) = ([], access_G) |> fold maybe_learn_from deps
1240
1241    val fact_xtab = add_to_xtab name fact_xtab
1242    val feat_xtab = fold maybe_add_to_xtab feats feat_xtab
1243  in
1244    (SOME (name, parents, feats, deps), (access_G, (fact_xtab, feat_xtab)))
1245  end
1246  handle Symtab.DUP _ => (NONE, accum) (* facts sometimes have the same name, confusingly *)
1247
1248fun relearn_wrt_access_graph ctxt (name, deps) access_G =
1249  let
1250    fun maybe_relearn_from from (accum as (parents, access_G)) =
1251      try_graph ctxt "updating graph" accum (fn () =>
1252        (from :: parents, Graph.add_edge_acyclic (from, name) access_G))
1253    val access_G =
1254      access_G |> Graph.map_node name (fn (_, feats, _) => (Automatic_Proof, feats, deps))
1255    val (deps, _) = ([], access_G) |> fold maybe_relearn_from deps
1256  in
1257    ((name, deps), access_G)
1258  end
1259
1260fun flop_wrt_access_graph name =
1261  Graph.map_node name (fn (_, feats, deps) => (Isar_Proof_wegen_Prover_Flop, feats, deps))
1262
1263val learn_timeout_slack = 20.0
1264
1265fun launch_thread timeout task =
1266  let
1267    val hard_timeout = time_mult learn_timeout_slack timeout
1268    val birth_time = Time.now ()
1269    val death_time = birth_time + hard_timeout
1270    val desc = ("Machine learner for Sledgehammer", "")
1271  in
1272    Async_Manager_Legacy.thread MaShN birth_time death_time desc task
1273  end
1274
1275fun anonymous_proof_name () =
1276  Date.fmt (anonymous_proof_prefix ^ "%Y%m%d.%H%M%S.") (Date.fromTimeLocal (Time.now ())) ^
1277  serial_string ()
1278
1279fun mash_learn_proof ctxt ({timeout, ...} : params) t used_ths =
1280  if not (null used_ths) andalso is_mash_enabled () then
1281    launch_thread timeout (fn () =>
1282      let
1283        val thy = Proof_Context.theory_of ctxt
1284        val feats = features_of ctxt (Context.theory_name thy) (Local, General) [t]
1285      in
1286        map_state ctxt
1287          (fn {access_G, xtabs as ((num_facts0, _), _), ffds, freqs, dirty_facts} =>
1288             let
1289               val deps = used_ths
1290                 |> filter (is_fact_in_graph access_G)
1291                 |> map nickname_of_thm
1292
1293               val name = anonymous_proof_name ()
1294               val (access_G', xtabs', rev_learns) =
1295                 add_node Automatic_Proof name [] (* ignore parents *) feats deps
1296                   (access_G, xtabs, [])
1297
1298               val (ffds', freqs') =
1299                 recompute_ffds_freqs_from_learns (rev rev_learns) xtabs' num_facts0 ffds freqs
1300             in
1301               {access_G = access_G', xtabs = xtabs', ffds = ffds', freqs = freqs',
1302                dirty_facts = Option.map (cons name) dirty_facts}
1303             end);
1304        (true, "")
1305      end)
1306  else
1307    ()
1308
1309fun sendback sub = Active.sendback_markup_command (sledgehammerN ^ " " ^ sub)
1310
1311val commit_timeout = seconds 30.0
1312
1313(* The timeout is understood in a very relaxed fashion. *)
1314fun mash_learn_facts ctxt (params as {debug, verbose, ...}) prover auto_level run_prover
1315    learn_timeout facts =
1316  let
1317    val timer = Timer.startRealTimer ()
1318    fun next_commit_time () = Timer.checkRealTimer timer + commit_timeout
1319  in
1320    (case get_state ctxt of
1321      NONE => "MaSh is busy\nPlease try again later"
1322    | SOME {access_G, ...} =>
1323      let
1324        val is_in_access_G = is_fact_in_graph access_G o snd
1325        val no_new_facts = forall is_in_access_G facts
1326      in
1327        if no_new_facts andalso not run_prover then
1328          if auto_level < 2 then
1329            "No new " ^ (if run_prover then "automatic" else "Isar") ^ " proofs to learn" ^
1330            (if auto_level = 0 andalso not run_prover then
1331               "\n\nHint: Try " ^ sendback learn_proverN ^ " to learn from an automatic prover"
1332             else
1333               "")
1334          else
1335            ""
1336        else
1337          let
1338            val name_tabs = build_name_tables nickname_of_thm facts
1339
1340            fun deps_of status th =
1341              if status = Non_Rec_Def orelse status = Rec_Def then
1342                SOME []
1343              else if run_prover then
1344                prover_dependencies_of ctxt params prover auto_level facts name_tabs th
1345                |> (fn (false, _) => NONE | (true, deps) => trim_dependencies deps)
1346              else
1347                isar_dependencies_of name_tabs th
1348
1349            fun do_commit [] [] [] state = state
1350              | do_commit learns relearns flops
1351                  {access_G, xtabs as ((num_facts0, _), _), ffds, freqs, dirty_facts} =
1352                let
1353                  val was_empty = Graph.is_empty access_G
1354
1355                  val (learns, (access_G', xtabs')) =
1356                    fold_map (learn_wrt_access_graph ctxt) learns (access_G, xtabs)
1357                    |>> map_filter I
1358                  val (relearns, access_G'') =
1359                    fold_map (relearn_wrt_access_graph ctxt) relearns access_G'
1360
1361                  val access_G''' = access_G'' |> fold flop_wrt_access_graph flops
1362                  val dirty_facts' =
1363                    (case (was_empty, dirty_facts) of
1364                      (false, SOME names) => SOME (map #1 learns @ map #1 relearns @ names)
1365                    | _ => NONE)
1366
1367                  val (ffds', freqs') =
1368                    if null relearns then
1369                      recompute_ffds_freqs_from_learns
1370                        (map (fn (name, _, feats, deps) => (name, feats, deps)) learns) xtabs'
1371                        num_facts0 ffds freqs
1372                    else
1373                      recompute_ffds_freqs_from_access_G access_G''' xtabs'
1374                in
1375                  {access_G = access_G''', xtabs = xtabs', ffds = ffds', freqs = freqs',
1376                   dirty_facts = dirty_facts'}
1377                end
1378
1379            fun commit last learns relearns flops =
1380              (if debug andalso auto_level = 0 then writeln "Committing..." else ();
1381               map_state ctxt (do_commit (rev learns) relearns flops);
1382               if not last andalso auto_level = 0 then
1383                 let val num_proofs = length learns + length relearns in
1384                   writeln ("Learned " ^ string_of_int num_proofs ^ " " ^
1385                     (if run_prover then "automatic" else "Isar") ^ " proof" ^
1386                     plural_s num_proofs ^ " in the last " ^ string_of_time commit_timeout)
1387                 end
1388               else
1389                 ())
1390
1391            fun learn_new_fact _ (accum as (_, (_, _, true))) = accum
1392              | learn_new_fact (parents, ((_, stature as (_, status)), th))
1393                  (learns, (num_nontrivial, next_commit, _)) =
1394                let
1395                  val name = nickname_of_thm th
1396                  val feats = features_of ctxt (Thm.theory_name th) stature [Thm.prop_of th]
1397                  val deps = these (deps_of status th)
1398                  val num_nontrivial = num_nontrivial |> not (null deps) ? Integer.add 1
1399                  val learns = (name, parents, feats, deps) :: learns
1400                  val (learns, next_commit) =
1401                    if Timer.checkRealTimer timer > next_commit then
1402                      (commit false learns [] []; ([], next_commit_time ()))
1403                    else
1404                      (learns, next_commit)
1405                  val timed_out = Timer.checkRealTimer timer > learn_timeout
1406                in
1407                  (learns, (num_nontrivial, next_commit, timed_out))
1408                end
1409
1410            val (num_new_facts, num_nontrivial) =
1411              if no_new_facts then
1412                (0, 0)
1413              else
1414                let
1415                  val new_facts = facts
1416                    |> sort (crude_thm_ord ctxt o apply2 snd)
1417                    |> map (pair []) (* ignore parents *)
1418                    |> filter_out (is_in_access_G o snd)
1419                  val (learns, (num_nontrivial, _, _)) =
1420                    ([], (0, next_commit_time (), false))
1421                    |> fold learn_new_fact new_facts
1422                in
1423                  commit true learns [] []; (length new_facts, num_nontrivial)
1424                end
1425
1426            fun relearn_old_fact _ (accum as (_, (_, _, true))) = accum
1427              | relearn_old_fact ((_, (_, status)), th)
1428                  ((relearns, flops), (num_nontrivial, next_commit, _)) =
1429                let
1430                  val name = nickname_of_thm th
1431                  val (num_nontrivial, relearns, flops) =
1432                    (case deps_of status th of
1433                      SOME deps => (num_nontrivial + 1, (name, deps) :: relearns, flops)
1434                    | NONE => (num_nontrivial, relearns, name :: flops))
1435                  val (relearns, flops, next_commit) =
1436                    if Timer.checkRealTimer timer > next_commit then
1437                      (commit false [] relearns flops; ([], [], next_commit_time ()))
1438                    else
1439                      (relearns, flops, next_commit)
1440                  val timed_out = Timer.checkRealTimer timer > learn_timeout
1441                in
1442                  ((relearns, flops), (num_nontrivial, next_commit, timed_out))
1443                end
1444
1445            val num_nontrivial =
1446              if not run_prover then
1447                num_nontrivial
1448              else
1449                let
1450                  val max_isar = 1000 * max_dependencies
1451
1452                  fun priority_of th =
1453                    Random.random_range 0 max_isar +
1454                    (case try (Graph.get_node access_G) (nickname_of_thm th) of
1455                      SOME (Isar_Proof, _, deps) => ~100 * length deps
1456                    | SOME (Automatic_Proof, _, _) => 2 * max_isar
1457                    | SOME (Isar_Proof_wegen_Prover_Flop, _, _) => max_isar
1458                    | NONE => 0)
1459
1460                  val old_facts = facts
1461                    |> filter is_in_access_G
1462                    |> map (`(priority_of o snd))
1463                    |> sort (int_ord o apply2 fst)
1464                    |> map snd
1465                  val ((relearns, flops), (num_nontrivial, _, _)) =
1466                    (([], []), (num_nontrivial, next_commit_time (), false))
1467                    |> fold relearn_old_fact old_facts
1468                in
1469                  commit true [] relearns flops; num_nontrivial
1470                end
1471          in
1472            if verbose orelse auto_level < 2 then
1473              "Learned " ^ string_of_int num_new_facts ^ " fact" ^ plural_s num_new_facts ^
1474              " and " ^ string_of_int num_nontrivial ^ " nontrivial " ^
1475              (if run_prover then "automatic and " else "") ^ "Isar proof" ^
1476              plural_s num_nontrivial ^
1477              (if verbose then " in " ^ string_of_time (Timer.checkRealTimer timer) else "")
1478            else
1479              ""
1480          end
1481      end)
1482  end
1483
1484fun mash_learn ctxt (params as {provers, timeout, ...}) fact_override chained run_prover =
1485  let
1486    val css = Sledgehammer_Fact.clasimpset_rule_table_of ctxt
1487    val ctxt = ctxt |> Config.put instantiate_inducts false
1488    val facts =
1489      nearly_all_facts ctxt false fact_override Keyword.empty_keywords css chained [] \<^prop>\<open>True\<close>
1490      |> sort (crude_thm_ord ctxt o apply2 snd o swap)
1491    val num_facts = length facts
1492    val prover = hd provers
1493
1494    fun learn auto_level run_prover =
1495      mash_learn_facts ctxt params prover auto_level run_prover one_year facts
1496      |> writeln
1497  in
1498    if run_prover then
1499      (writeln ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
1500         plural_s num_facts ^ " for automatic proofs (" ^ quote prover ^ " timeout: " ^
1501         string_of_time timeout ^ ").\n\nCollecting Isar proofs first...");
1502       learn 1 false;
1503       writeln "Now collecting automatic proofs\n\
1504         \This may take several hours; you can safely stop the learning process at any point";
1505       learn 0 true)
1506    else
1507      (writeln ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
1508         plural_s num_facts ^ " for Isar proofs...");
1509       learn 0 false)
1510  end
1511
1512fun mash_can_suggest_facts ctxt =
1513  (case get_state ctxt of
1514    NONE => false
1515  | SOME {access_G, ...} => not (Graph.is_empty access_G))
1516
1517fun mash_can_suggest_facts_fast ctxt =
1518  (case peek_state ctxt of
1519    NONE => false
1520  | SOME (_, {access_G, ...}) => not (Graph.is_empty access_G))
1521
1522(* Generate more suggestions than requested, because some might be thrown out later for various
1523   reasons (e.g., duplicates). *)
1524fun generous_max_suggestions max_facts = 2 * max_facts + 25 (* FUDGE *)
1525
1526val mepo_weight = 0.5 (* FUDGE *)
1527val mash_weight = 0.5 (* FUDGE *)
1528
1529val max_facts_to_learn_before_query = 100 (* FUDGE *)
1530
1531(* The threshold should be large enough so that MaSh does not get activated for Auto Sledgehammer. *)
1532val min_secs_for_learning = 10
1533
1534fun relevant_facts ctxt (params as {verbose, learn, fact_filter, timeout, ...}) prover
1535    max_facts ({add, only, ...} : fact_override) hyp_ts concl_t facts =
1536  if not (subset (op =) (the_list fact_filter, fact_filters)) then
1537    error ("Unknown fact filter: " ^ quote (the fact_filter))
1538  else if only then
1539    [("", map fact_of_raw_fact facts)]
1540  else if max_facts <= 0 orelse null facts then
1541    [("", [])]
1542  else
1543    let
1544      val thy_name = Context.theory_name (Proof_Context.theory_of ctxt)
1545
1546      fun maybe_launch_thread exact min_num_facts_to_learn =
1547        if not (Async_Manager_Legacy.has_running_threads MaShN) andalso
1548           Time.toSeconds timeout >= min_secs_for_learning then
1549          let val timeout = time_mult learn_timeout_slack timeout in
1550            (if verbose then
1551               writeln ("Started MaShing through " ^
1552                 (if exact then "" else "up to ") ^ string_of_int min_num_facts_to_learn ^
1553                 " fact" ^ plural_s min_num_facts_to_learn ^ " in the background")
1554             else
1555               ());
1556            launch_thread timeout
1557              (fn () => (true, mash_learn_facts ctxt params prover 2 false timeout facts))
1558          end
1559        else
1560          ()
1561
1562      val mash_enabled = is_mash_enabled ()
1563      val mash_fast = mash_can_suggest_facts_fast ctxt
1564
1565      fun please_learn () =
1566        if mash_fast then
1567          (case get_state ctxt of
1568            NONE => maybe_launch_thread false (length facts)
1569          | SOME {access_G, xtabs = ((num_facts0, _), _), ...} =>
1570            let
1571              val is_in_access_G = is_fact_in_graph access_G o snd
1572              val min_num_facts_to_learn = length facts - num_facts0
1573            in
1574              if min_num_facts_to_learn <= max_facts_to_learn_before_query then
1575                (case length (filter_out is_in_access_G facts) of
1576                  0 => ()
1577                | num_facts_to_learn =>
1578                  if num_facts_to_learn <= max_facts_to_learn_before_query then
1579                    mash_learn_facts ctxt params prover 2 false timeout facts
1580                    |> (fn "" => () | s => writeln (MaShN ^ ": " ^ s))
1581                  else
1582                    maybe_launch_thread true num_facts_to_learn)
1583              else
1584                maybe_launch_thread false min_num_facts_to_learn
1585            end)
1586        else
1587          maybe_launch_thread false (length facts)
1588
1589      val _ =
1590        if learn andalso mash_enabled andalso fact_filter <> SOME mepoN then please_learn () else ()
1591
1592      val effective_fact_filter =
1593        (case fact_filter of
1594          SOME ff => ff
1595        | NONE => if mash_enabled andalso mash_fast then meshN else mepoN)
1596
1597      val unique_facts = drop_duplicate_facts facts
1598      val add_ths = Attrib.eval_thms ctxt add
1599
1600      fun in_add (_, th) = member Thm.eq_thm_prop add_ths th
1601
1602      fun add_and_take accepts =
1603        (case add_ths of
1604           [] => accepts
1605         | _ =>
1606           (unique_facts |> filter in_add |> map fact_of_raw_fact) @ (accepts |> filter_out in_add))
1607        |> take max_facts
1608
1609      fun mepo () =
1610        (mepo_suggested_facts ctxt params max_facts NONE hyp_ts concl_t unique_facts
1611         |> weight_facts_steeply, [])
1612
1613      fun mash () =
1614        mash_suggested_facts ctxt thy_name params (generous_max_suggestions max_facts) hyp_ts
1615          concl_t facts
1616        |>> weight_facts_steeply
1617
1618      val mess =
1619        (* the order is important for the "case" expression below *)
1620        [] |> effective_fact_filter <> mepoN ? cons (mash_weight, mash)
1621           |> effective_fact_filter <> mashN ? cons (mepo_weight, mepo)
1622           |> Par_List.map (apsnd (fn f => f ()))
1623      val mesh =
1624        mesh_facts (fact_distinct (op aconv)) (eq_snd (gen_eq_thm ctxt)) max_facts mess
1625        |> add_and_take
1626    in
1627      (case (fact_filter, mess) of
1628        (NONE, [(_, (mepo, _)), (_, (mash, _))]) =>
1629        [(meshN, mesh),
1630         (mepoN, mepo |> map fst |> add_and_take),
1631         (mashN, mash |> map fst |> add_and_take)]
1632      | _ => [(effective_fact_filter, mesh)])
1633    end
1634
1635end;
1636