1(* ========================================================================= *)
2(* FILE          : psTermGen.sml                                             *)
3(* DESCRIPTION   : Term generation algorithms                                *)
4(* AUTHOR        : (c) Thibault Gauthier, Czech Technical University         *)
5(* DATE          : 2018                                                      *)
6(* ========================================================================= *)
7
8structure psTermGen :> psTermGen =
9struct
10
11open HolKernel Abbrev boolLib aiLib
12
13val ERR = mk_HOL_ERR "psTermGen"
14
15fun product_real rl = case rl of
16    [] => 1.0
17  | a :: m => a * product_real m
18
19fun sum_real rl = case rl of
20    [] => 0.0
21  | a :: m => a + sum_real m
22
23fun im_ty oper = snd (strip_type (type_of oper))
24
25(* -------------------------------------------------------------------------
26   Number of terms of each type and size.
27   ------------------------------------------------------------------------- *)
28
29fun ntermc cache operl (size,ty) =
30  if size <= 0 then 0.0 else
31  dfind (size,ty) (!cache) handle NotFound =>
32  let val n = sum_real (map (ntermc_oper cache operl (size,ty)) operl) in
33    cache := dadd (size,ty) n (!cache); n
34  end
35and ntermc_oper cache operl (size,ty) oper =
36  let val (tyargl,im) = strip_type (type_of oper) in
37    if ty <> im orelse size <= 0 then 0.0 else
38    if null tyargl andalso size <> 1 then 0.0 else (* first-order *)
39    if null tyargl andalso size = 1 then 1.0 else
40    let
41      val nll = number_partition (length tyargl) (size - 1)
42                handle HOL_ERR _ => []
43      fun f nl = product_real (map (ntermc cache operl) (combine (nl,tyargl)))
44    in
45      sum_real (map f nll)
46    end
47  end
48
49(* -------------------------------------------------------------------------
50   Random terms. Generate with respect to uniform probability over
51   all possible terms of certain size and type.
52   ------------------------------------------------------------------------- *)
53
54fun random_termc cache operl (size,ty) =
55  if ntermc cache operl (size,ty) < epsilon
56    then raise ERR "random_term" "" else
57  let
58    val freql1 = map_assoc (ntermc_oper cache operl (size,ty)) operl
59    val freql2 = filter (fn x => snd x > epsilon) freql1
60  in
61    random_termc_oper cache operl (size,ty) (select_in_distrib freql2)
62  end
63and random_termc_oper cache operl (size,ty) oper =
64  let val (tyargl,im) = strip_type (type_of oper) in
65    if ntermc_oper cache operl (size,ty) oper <= epsilon
66      then raise ERR "random_term_oper" "" else
67    if null tyargl then oper else
68    let
69      val nll = number_partition (length tyargl) (size - 1)
70                handle HOL_ERR _ => raise ERR "random_term_oper" ""
71      fun f nl =
72        (product_real (map (ntermc cache operl) (combine (nl,tyargl))))
73      val freql1 = map_assoc f nll
74      val freql2 = filter (fn x => snd x > epsilon) freql1
75      val nl_chosen = select_in_distrib freql2
76      val argl = map (random_termc cache operl) (combine (nl_chosen,tyargl))
77    in
78      list_mk_comb (oper,argl)
79    end
80  end
81
82(* -------------------------------------------------------------------------
83   Functions with no cache
84   ------------------------------------------------------------------------- *)
85
86fun nterm operl (size,ty) =
87  let val cache = ref (dempty (cpl_compare Int.compare Type.compare)) in
88    ntermc cache operl (size,ty)
89  end
90
91fun random_term operl (size,ty) =
92  let val cache = ref (dempty (cpl_compare Int.compare Type.compare)) in
93    random_termc cache operl (size,ty)
94  end
95
96fun random_terml operl (size,ty) n =
97  let val cache = ref (dempty (cpl_compare Int.compare Type.compare)) in
98    List.tabulate (n, fn _ => random_termc cache operl (size,ty))
99  end
100
101(* -------------------------------------------------------------------------
102   All terms up to a fixed size with a certain type
103   ------------------------------------------------------------------------- *)
104
105fun is_applicable (ty1,ty2) =
106  let fun apply ty1 ty2 = mk_comb (mk_var ("x",ty1), mk_var ("y",ty2)) in
107    can (apply ty1) ty2
108  end
109
110fun all_mk_comb d1 d2 (ty1,ty2) =
111  map mk_comb (cartesian_product (dfind ty1 d1) (dfind ty2 d2))
112
113fun gen_size cache n =
114  (if n <= 0 then dempty Type.compare else dfind n (!cache))
115  handle NotFound =>
116  let
117    val l = map pair_of_list (number_partition 2 n)
118    fun all_comb (n1,n2) =
119      let
120        val d1     = gen_size cache n1
121        val d2     = gen_size cache n2
122        val tytyl  = cartesian_product (dkeys d1) (dkeys d2)
123        val tytyl' = filter is_applicable tytyl
124      in
125        List.concat (map (all_mk_comb d1 d2) tytyl')
126      end
127    val tml1 = List.concat (map all_comb l)
128    val tml2  = map (fn x => (type_of x, x)) tml1
129    val d3 = dregroup Type.compare tml2
130  in
131    cache := dadd n d3 (!cache); d3
132  end
133
134fun gen_term operl (size,ty) =
135  let
136    val tycset = map (fn x => (type_of x, x)) operl
137    val d = dregroup Type.compare tycset
138    val cache = ref (dnew Int.compare [(1,d)])
139    fun g n = dfind ty (gen_size cache (n+1)) handle NotFound => []
140  in
141    List.concat (List.tabulate (size, g))
142  end
143
144end (* struct *)
145