1(*
2 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3 *
4 * SPDX-License-Identifier: BSD-2-Clause
5 *)
6
7(*
8 * Term generation.
9 *
10 * Construct a term-generator at compile time.
11 *
12 *    @{mk_term "foo"} ()
13 *      ==> @{term "foo"}
14 *
15 *    @{mk_term "?f a b c" (f)}
16 *      ==> (fn t1 => t1 $ @{term a} $ @{term b} $ @{term c}
17 *
18 *    @{mk_term "a::(?'a::plus) + b" ('a)}
19 *      ==> (fn t1 => Const (@{const_name plus}, t1 --> t1 --> t1) $ Free ("a", t1) $ Free ("b", t1))
20 *)
21structure MkTermAntiquote =
22struct
23
24local
25  open ML_Syntax
26
27  (*
28   * Find the name and type of all schematic variables in the given term.
29   *
30   *   @{term "?x ?y"} ==> [("x", 'a => 'b), ("y", 'b)]
31   *)
32  fun get_schematic_types (a $ b) = get_schematic_types a @ get_schematic_types b
33    | get_schematic_types (t as (Abs _)) = get_schematic_types (snd (Term.strip_abs_eta (~1) t))
34    | get_schematic_types (Var ((name, _), T)) = [(name, T)]
35    | get_schematic_types _ = []
36
37  (*
38   * Generate ML code to perform variable capture of the given type.
39   *
40   * In particular, all type variables will be captured into ML variables. The
41   * returned dictionary indicates the mapping from type variables to ML variable
42   * names.
43   *)
44  fun capture_type prefix (Type (_, Ts)) dict =
45      let
46        val (strings, new_dict) = fold_map (capture_type prefix) Ts dict
47      in
48        if length (Symtab.dest dict) = length (Symtab.dest new_dict) then
49          ("_", dict)
50        else
51          ("Type (_,  " ^ (ML_Syntax.print_list I strings) ^ ")", new_dict)
52      end
53    | capture_type prefix (T as (TVar ((var_name, _), _))) dict =
54      let
55        val name = prefix ^ "_" ^ string_of_int (length (Symtab.dest dict))
56      in
57        case Symtab.lookup dict var_name of
58          SOME _ => ("_", dict)
59        | NONE => (name, Symtab.update_new (var_name, (name, T)) dict)
60      end
61    | capture_type _ (TFree _) dict = ("_", dict)
62
63    (* Parse a list of the form "(x, y, z)". "inner" parses each of the indivdual items. *)
64    fun comma_list inner =
65      (inner >> (fn a => [a])) ||
66          (Args.parens (inner -- (Scan.repeat (Args.$$$ "," -- inner >> snd)) >> (fn (a, b) => a :: b)))
67
68  (* Write ML code for generating the given term, replacing schematic variables
69   * with the ML code in the "replacements" dictionary. *)
70  fun write_term_constructor replacements term =
71  let
72    fun print_typ (Type arg) = "Type " ^ print_pair print_string (print_list print_typ) arg
73      | print_typ (TFree arg) = "TFree " ^ print_pair print_string print_sort arg
74      | print_typ (TVar (arg as ((name, _), _))) =
75          (case Symtab.lookup replacements name of
76             NONE => "TVar " ^ print_pair print_indexname print_sort arg
77           | SOME ml => atomic ml)
78
79    fun print_term (Const arg) = "Const " ^ print_pair print_string print_typ arg
80      | print_term (Free arg) = "Free " ^ print_pair print_string print_typ arg
81      | print_term (Var (arg as ((name, _), _))) =
82          (case Symtab.lookup replacements name of
83             NONE => "Var " ^ print_pair print_indexname print_typ arg
84           | SOME ml => atomic ml)
85      | print_term (Bound i) = "Bound " ^ print_int i
86      | print_term (Abs (s, T, t)) =
87          "Abs (" ^ print_string s ^ ", " ^ print_typ T ^ ", " ^ print_term t ^ ")"
88      | print_term (t1 $ t2) = atomic (print_term t1) ^ " $ " ^ atomic (print_term t2);
89  in
90    print_term term
91  end
92
93  (* Print ML code for rendering a tuple. *)
94  val print_tuple = enclose "(" ")" o commas
95
96  (*
97   * Generate ML code for a lambda function that replaces variables and types
98   * in a term with parameters.
99   *
100   *   print_lambda ["a", "'b"] "xxx"
101   *     ==> ("(fn (t1, t2) => (xxx))", {"t1 => a", "t2 => 'b"})
102   *)
103  fun print_lambda vars =
104  let
105    val temps = 1 upto (length vars)
106      |> map (fn x => "t" ^ (string_of_int x))
107    val lambda_term = (fn x => atomic ("fn " ^ print_tuple temps ^ " => " ^ (atomic x)))
108    val dict = Symtab.make (vars ~~ temps)
109  in
110    (lambda_term,  dict)
111  end
112
113  (* Generate ML code for constructing the given pattern with the given
114   * template variables. *)
115  fun gen_constructor ((ctxt, pattern), params : string list ) =
116  let
117    (* Parse user term. *)
118    val term = Proof_Context.read_term_pattern ctxt pattern
119
120    (*
121     * Generate the outer shell of our final result:
122     *
123     *    (fn (t1, t2, t3) => ...)
124     *)
125    val (outer_fn, var_dict) = print_lambda params
126
127    (*
128     * For each parameter passed in by the user, generate ML code
129     * to extract relevant parts of its type.
130     *
131     * For example, if the user wants to replace "?X" (having type "?'a =>
132     * ?'b"), then when the user finally fills us in with a concrete term, we
133     * want to substitute "?'a" and "?'b" with their concrete values.
134     *)
135    val schematic_types =
136    let
137      val typ_table =
138        get_schematic_types term
139        |> distinct (op =)
140        |> Symtab.make
141    in
142      (params ~~ map (Symtab.lookup typ_table) params)
143      |> filter (fn (_, b) => b <> NONE)
144      |> map (fn (a, b) => (a, the b))
145    end
146    val (type_patterns, typ_dict) =
147        fold_map (fn (v, T) => capture_type ("T__" ^ v) T) schematic_types Symtab.empty
148
149    (* Merge the dictionary generated above (designed to capture types from
150     * the input term) with the user-provided definitions (which may also
151     * attempt to define types). *)
152    val replacement_dict = Symtab.join
153        (fn k => fn _ => error ("Key " ^ k ^ " used twice. Did you specify a type "
154            ^ "twice in the parameter list (possibly implicity)?"))
155        (Symtab.map (K fst) typ_dict, var_dict)
156
157    (* Generate code to determine types of variables. *)
158    val typ_match =
159      outer_fn (
160        "(let "
161        ^ (cat_lines (map (fn (pattern, param) =>
162            "val " ^ pattern ^ " = fastype_of " ^ param ^ "; ")
163            (type_patterns ~~ map (the o Symtab.lookup var_dict o fst) schematic_types)))
164        ^ " in "
165          ^ (write_term_constructor replacement_dict term)
166        ^ " end)"
167      )
168  in
169    typ_match
170  end
171in
172  val _ = Context.>> (Context.map_theory (
173    ML_Antiquotation.inline @{binding "mk_term"}
174      ((Args.context -- Scan.lift Args.embedded_inner_syntax -- (Scan.optional (Scan.lift ((comma_list Args.name))) []))
175         >>  gen_constructor)))
176end
177
178end
179