1(* ========================================================================= *)
2(* FILE          : mlTreeNeuralNetwork.sml                                   *)
3(* DESCRIPTION   : Tree neural network                                       *)
4(* AUTHOR        : (c) Thibault Gauthier, Czech Technical University         *)
5(* DATE          : 2018                                                      *)
6(* ========================================================================= *)
7
8structure mlTreeNeuralNetwork :> mlTreeNeuralNetwork =
9struct
10
11open HolKernel boolLib Abbrev aiLib mlMatrix mlNeuralNetwork smlParallel
12smlParallel smlParser mlTacticData
13
14val ERR = mk_HOL_ERR "mlTreeNeuralNetwork"
15fun msg param s = if #verbose param then print_endline s else ()
16fun msg_err fs es = (print_endline (fs ^ ": " ^ es); raise ERR fs es)
17
18(* -------------------------------------------------------------------------
19   Tools for computing the dimensions of neural network operators
20   ------------------------------------------------------------------------- *)
21
22type tnndim = (term * int list) list
23
24fun operl_of_term tm =
25  let
26    val (oper,argl) = strip_comb tm
27    val arity = length argl
28  in
29    (oper,arity) :: List.concat (map operl_of_term argl)
30  end
31
32val oper_compare = cpl_compare Term.compare Int.compare;
33
34fun dim_std_arity (nlayer,dim) (oper,a) =
35  let
36    val dim_alt =
37      if is_var oper andalso String.isPrefix "head_" (fst (dest_var oper))
38      then 1
39      else dim
40  in
41    (if a = 0 then [0] else List.tabulate (nlayer, fn _ => a * dim)) @
42    [dim_alt]
43  end
44
45fun dim_std (nlayer,dim) oper =
46  dim_std_arity (nlayer,dim) (oper,arity_of oper)
47
48(* -------------------------------------------------------------------------
49   Random TNN
50   ------------------------------------------------------------------------- *)
51
52type tnn = (term,nn) Redblackmap.dict
53
54fun oper_nn diml = case diml of
55    [] => raise ERR "oper_nn" ""
56  | a :: m =>
57    if a = 0
58    then random_nn (idactiv,didactiv) [0,last m]
59    else random_nn (tanh,dtanh) diml
60
61fun random_tnn tnndim = dnew Term.compare (map_snd oper_nn tnndim)
62
63fun random_tnn_std (nlayer,dim) operl =
64  random_tnn (map_assoc (dim_std (nlayer,dim)) operl)
65
66(* -------------------------------------------------------------------------
67   TNN I/O
68   ------------------------------------------------------------------------- *)
69
70local open SharingTables HOLsexp in
71fun enc_opernn enc_tm = pair_encode (enc_tm, enc_nn)
72fun enc_tnn enc_tm = list_encode (enc_opernn enc_tm)
73fun dec_opernn dec_tm = pair_decode (dec_tm, dec_nn)
74fun dec_tnn dec_tm = list_decode (dec_opernn dec_tm)
75end
76
77fun write_tnn file tnn = write_tmdata (enc_tnn, map fst) file (dlist tnn)
78fun read_tnn file = dnew Term.compare (read_tmdata dec_tnn file)
79
80local open SharingTables HOLsexp in
81fun enc_tnndime enc_tm = pair_encode (enc_tm, list_encode Integer)
82fun enc_tnndim enc_tm = list_encode (enc_tnndime enc_tm)
83fun dec_tnndime dec_tm = pair_decode (dec_tm, list_decode int_decode)
84fun dec_tnndim dec_tm = list_decode (dec_tnndime dec_tm)
85end
86
87fun write_tnndim file tnndim = write_tmdata (enc_tnndim, map fst) file tnndim
88fun read_tnndim file = read_tmdata dec_tnndim file
89
90(* -------------------------------------------------------------------------
91   TNN Examples: I/O
92   ------------------------------------------------------------------------- *)
93
94type tnnex = ((term * real list) list) list
95type tnnbatch = (term list * (term * mlMatrix.vect) list) list
96
97fun basicex_to_tnnex ex = map (fn (tm,r) => [(tm,[r])]) ex
98
99local open SharingTables HOLsexp in
100val enc_real = String o Real.toString
101val dec_real = Option.mapPartial Real.fromString o string_decode
102fun enc_sample enc_tm = pair_encode (enc_tm, list_encode enc_real)
103fun dec_sample dec_tm = pair_decode (dec_tm, list_decode dec_real)
104fun enc_tnnex enc_tm = list_encode (list_encode (enc_sample enc_tm))
105fun dec_tnnex dec_tm = list_decode (list_decode (dec_sample dec_tm))
106fun tml_of_tnnex l = map fst (List.concat l)
107end
108
109fun write_tnnex file ex =
110  write_tmdata (enc_tnnex, tml_of_tnnex) file ex
111fun read_tnnex file =
112  read_tmdata dec_tnnex file
113
114(* -------------------------------------------------------------------------
115   TNN Examples: ordering subterms and scaling output values
116   ------------------------------------------------------------------------- *)
117
118fun order_subtm tml =
119  let
120    val d = ref (dempty (cpl_compare Int.compare Term.compare))
121    fun traverse tm =
122      let
123        val (oper,argl) = strip_comb tm
124        val nl = map traverse argl
125        val n = 1 + sum_int nl
126      in
127        d := dadd (n, tm) () (!d); n
128      end
129    val subtml = (app (ignore o traverse) tml; dkeys (!d))
130  in
131    map snd subtml
132  end
133
134fun prepare_tnnex tnnex =
135  let fun f x = (order_subtm (map fst x), map_snd scale_out x) in
136    map f tnnex
137  end
138
139(* -------------------------------------------------------------------------
140   Fixed embedding
141   ------------------------------------------------------------------------- *)
142
143val embedding_prefix = "embedding_"
144
145fun is_embedding v =
146  is_var v andalso String.isPrefix embedding_prefix (fst (dest_var v))
147
148fun embed_nn v =
149  if is_embedding v then
150    let
151      val vs = fst (dest_var v)
152      val n1 = String.size embedding_prefix
153      val ntot = String.size vs
154      val es = String.substring (vs,n1,ntot-n1)
155      val e1 = string_to_reall es
156      val e2 = map (fn x => Vector.fromList [x]) e1
157    in
158      [{a = idactiv, da = didactiv, w = Vector.fromList e2}]
159    end
160  else msg_err "embed_nn" (tts v)
161
162fun mk_embedding_var (rv,ty) =
163  mk_var (embedding_prefix ^ reall_to_string (vector_to_list rv), ty)
164
165(* -------------------------------------------------------------------------
166   Forward propagation
167   ------------------------------------------------------------------------- *)
168
169fun fp_oper tnn fpdict tm =
170  let
171    val (f,argl) = strip_comb tm
172    val nn = (dfind f) tnn handle NotFound => embed_nn f
173    val invl = (map (fn x => #outnv (last (dfind x fpdict)))) argl
174    val inv = Vector.concat invl
175  in
176    fp_nn nn inv
177  end
178  handle Subscript => msg_err "fp_oper" (tts tm)
179
180fun fp_tnn_aux tnn fpdict tml = case tml of
181    []      => fpdict
182  | tm :: m =>
183    let val fpdatal = fp_oper tnn fpdict tm in
184      fp_tnn_aux tnn (dadd tm fpdatal fpdict) m
185    end
186
187fun fp_tnn tnn tml = fp_tnn_aux tnn (dempty Term.compare) tml
188
189(* -------------------------------------------------------------------------
190   Backward propagation
191   ------------------------------------------------------------------------- *)
192
193fun sum_operdwll (oper,dwll) = [sum_dwll dwll]
194
195fun dimout_fpdatal fpdatal = Vector.length (#outnv (last fpdatal))
196fun dimout_tm fpdict tm = dimout_fpdatal (dfind tm fpdict)
197
198fun bp_tnn_aux doutnvdict fpdict bpdict revtml = case revtml of
199    []      => dmap sum_operdwll bpdict
200  | tm :: m =>
201    let
202      val (oper,argl) = strip_comb tm
203      val diml = map (dimout_tm fpdict) argl
204      val doutnvl = dfind tm doutnvdict
205      val doutnvsum = add_vectl doutnvl
206      fun f doutnv =
207        let
208          val fpdatal = dfind tm fpdict
209          val bpdatal = bp_nn_doutnv fpdatal doutnv
210          val dinv = vector_to_list (#dinv (hd bpdatal))
211          val dinvl = map Vector.fromList (part_group diml dinv)
212        in
213          (map #dw bpdatal, combine (argl,dinvl))
214        end
215      val (operdwl,tmdinvl) = f (add_vectl doutnvl)
216      val newdoutnvdict = dappendl tmdinvl doutnvdict
217      val newbpdict = dappend (oper,operdwl) bpdict
218    in
219      bp_tnn_aux newdoutnvdict fpdict newbpdict m
220    end
221
222fun bp_tnn fpdict (tml,tmevl) =
223  let
224    fun f (tm,ev) =
225      let
226        val fpdatal = dfind tm fpdict
227        val doutnv = diff_rvect ev (#outnv (last fpdatal))
228      in
229        (tm,[doutnv])
230      end
231    val doutnvdict = dnew Term.compare (map f tmevl)
232  in
233    bp_tnn_aux doutnvdict fpdict (dempty Term.compare) (rev tml)
234  end
235
236(* -------------------------------------------------------------------------
237   Inference
238   ------------------------------------------------------------------------- *)
239
240fun infer_tnn tnn tml =
241  let
242    val fpdict = fp_tnn tnn (order_subtm tml)
243    fun f x = descale_out (#outnv (last (dfind x fpdict)))
244  in
245    map_assoc f tml
246  end
247
248fun infer_tnn_basic tnn tm =
249  singleton_of_list (snd (singleton_of_list (infer_tnn tnn [tm])))
250
251fun precomp_embed tnn tm =
252  let
253    val fpdict = fp_tnn tnn (order_subtm [tm])
254    val embedv = #outnv (last (dfind tm fpdict))
255  in
256    mk_embedding_var (embedv, type_of tm)
257  end
258
259(* -------------------------------------------------------------------------
260   Training
261   ------------------------------------------------------------------------- *)
262
263fun se_of fpdict (tm,ev) =
264  let
265    val fpdatal = dfind tm fpdict
266    val doutnv = diff_rvect ev (#outnv (last fpdatal))
267    val r1 = vector_to_list doutnv
268    val r2 = map (fn x => x * x) r1
269  in
270    Math.sqrt (average_real r2)
271  end
272
273fun mse_of fpdict tmevl = average_real (map (se_of fpdict) tmevl)
274
275fun fp_loss tnn (tml,tmevl) = mse_of (fp_tnn tnn tml) tmevl
276
277fun train_tnn_one tnn (tml,tmevl) =
278  let
279    val fpdict = fp_tnn tnn tml
280    val bpdict = bp_tnn fpdict (tml,tmevl)
281  in
282    (bpdict, mse_of fpdict tmevl)
283  end
284
285fun train_tnn_subbatch tnn subbatch =
286  let val (bpdictl,lossl) = split (map (train_tnn_one tnn) subbatch) in
287    (dmap sum_operdwll (dconcat Term.compare bpdictl), lossl)
288  end
289
290fun update_oper param ((oper,dwll),tnn) =
291  if is_embedding oper then tnn else
292  let
293    val nn = dfind oper tnn
294    val dwl = sum_dwll dwll
295    val newnn = update_nn param nn dwl
296  in
297    dadd oper newnn tnn
298  end
299
300fun train_tnn_batch param pf tnn batch =
301  let
302    val subbatchl = cut_modulo (#ncore param) batch
303    val (bpdictl,lossll) = split (pf (train_tnn_subbatch tnn) subbatchl)
304    val bpdict = dconcat Term.compare bpdictl
305  in
306    (foldl (update_oper param) tnn (dlist bpdict),
307     average_real (List.concat lossll))
308  end
309
310fun train_tnn_epoch param pf lossl tnn batchl = case batchl of
311    [] => (tnn, average_real lossl)
312  | batch :: m =>
313    let val (newtnn,loss) = train_tnn_batch param pf tnn batch in
314      train_tnn_epoch param pf (loss :: lossl) newtnn m
315    end
316
317fun train_tnn_epoch_nopar param lossl tnn batchl = case batchl of
318    [] => (tnn, average_real lossl)
319  | batch :: m =>
320    let val (newtnn,loss) = train_tnn_batch param map tnn batch in
321      train_tnn_epoch_nopar param (loss :: lossl) newtnn m
322    end
323
324fun train_tnn_nepoch param pf i tnn (train,test) =
325  if i >= #nepoch param then tnn else
326  let
327    val batchl = mk_batch (#batch_size param) (shuffle train)
328    val _ = if null batchl then msg_err "train_tnn_nepoch" "empty" else ()
329    val (newtnn,loss) = train_tnn_epoch param pf [] tnn batchl
330    val testloss = if null test then "" else
331      (" test: " ^ pretty_real (average_real (map (fp_loss newtnn) test)))
332     val _ = msg param (its i ^ " train: " ^ pretty_real loss ^ testloss)
333  in
334    train_tnn_nepoch param pf (i+1) newtnn (train,test)
335  end
336
337fun train_tnn_schedule schedule tnn (train,test) =
338  case schedule of
339    [] => tnn
340  | param :: m =>
341    let
342      val _ = msg param ("learning rate: " ^ rts (#learning_rate param))
343      val _ = msg param ("ncore: " ^ its (#ncore param))
344      val (pf,close_threadl) = parmap_gen (#ncore param)
345      val newtnn = train_tnn_nepoch param pf 0 tnn (train,test)
346      val r = train_tnn_schedule m newtnn (train,test)
347    in
348      (close_threadl (); r)
349    end
350
351fun stats_head (oper,rll) =
352  let
353    val s0 = "  objective: " ^ tts oper
354    val s1 = "length: " ^ its (length rll)
355    val rll' = list_combine rll
356    val s2 = "means: " ^
357      String.concatWith " " (map (pretty_real o average_real) rll')
358    val s3 = "standard deviations: " ^
359      String.concatWith " " (map (pretty_real o standard_deviation) rll')
360  in
361    String.concatWith "\n  " [s0,s1,s2,s3]
362  end
363
364fun stats_tnnex ex =
365  if null ex then " empty" else
366  let
367    fun head_of tm = fst (strip_comb tm)
368    val d = dregroup Term.compare (map_fst head_of (List.concat ex))
369  in
370    its (length ex) ^ " examples\n" ^
371    String.concatWith "\n" (map stats_head (dlist d))
372  end
373
374fun train_tnn schedule randtnn (trainex,testex) =
375  let
376    val _ = print_endline ("\ntraining set: " ^ stats_tnnex trainex)
377    val _ = print_endline ("testing set: " ^ stats_tnnex testex)
378    val _ = print_endline ""
379    val (tnn,t) = add_time (train_tnn_schedule schedule randtnn)
380      (prepare_tnnex trainex, prepare_tnnex testex)
381  in
382    print_endline ("Tree neural network training time: " ^ rts t); tnn
383  end
384
385(* -------------------------------------------------------------------------
386   Accuracy
387   ------------------------------------------------------------------------- *)
388
389fun is_accurate_one (rl1,rl2) =
390  let
391    val rl3 = combine (rl1,rl2)
392    fun test (x,y) = Real.abs (x - y) < 0.5
393  in
394    if all test rl3 then true else false
395  end
396
397fun is_accurate tnn e =
398  let
399    val rll1 = map snd (infer_tnn tnn (map fst e))
400    val rll2 = map snd e
401  in
402    all is_accurate_one (combine (rll1,rll2))
403  end
404
405fun tnn_accuracy tnn set =
406  let val correct = filter (is_accurate tnn) set in
407    Real.fromInt (length correct) / Real.fromInt (length set)
408  end
409
410(* -------------------------------------------------------------------------
411   Object for training different TNN in parallel
412   ------------------------------------------------------------------------- *)
413
414fun train_tnn_fun () (ex,schedule,tnndim) =
415  let
416    val randtnn = random_tnn tnndim
417    val (tnn,t) = add_time (train_tnn schedule randtnn) (ex,[])
418  in
419    print_endline ("Training time : " ^ rts t); tnn
420  end
421
422fun write_noparam file (_:unit) = ()
423fun read_noparam file = ()
424
425fun write_tnnarg file (ex,schedule,tnndim) =
426  (
427  write_tnnex (file ^ "_tnnex") ex;
428  write_schedule (file ^ "_schedule") schedule;
429  write_tnndim (file ^ "_tnndim") tnndim
430  )
431fun read_tnnarg file =
432  let
433    val ex = read_tnnex (file ^ "_tnnex")
434    val schedule = read_schedule (file ^ "_schedule")
435    val tnndim = read_tnndim (file ^ "_tnndim")
436  in
437    (ex,schedule,tnndim)
438  end
439
440val traintnn_extspec =
441  {
442  self = "mlTreeNeuralNetwork.traintnn_extspec",
443  parallel_dir = default_parallel_dir ^ "_train",
444  reflect_globals = fn () => "()",
445  function = train_tnn_fun,
446  write_param = write_noparam,
447  read_param = read_noparam,
448  write_arg = write_tnnarg,
449  read_arg = read_tnnarg,
450  write_result = write_tnn,
451  read_result = read_tnn
452  }
453
454
455
456(* -------------------------------------------------------------------------
457   Toy example: learning to guess if a term contains the variable "x"
458   ------------------------------------------------------------------------- *)
459
460(*
461load "aiLib"; open aiLib;
462load "psTermGen"; open psTermGen;
463load "mlTreeNeuralNetwork"; open mlTreeNeuralNetwork;
464
465(* terms *)
466val vx = mk_var ("x",alpha);
467val vy = mk_var ("y",alpha);
468val vz = mk_var ("z",alpha);
469val vf = ``f:'a->'a->'a``;
470val vg = ``g:'a -> 'a``;
471val vhead = mk_var ("head_", ``:'a -> 'a``);
472val varl = [vx,vy,vz,vf,vg];
473
474(* examples *)
475fun contain_x tm = can (find_term (fn x => term_eq x vx)) tm;
476fun mk_dataset n =
477  let
478    val pxl = mk_term_set (random_terml varl (n,alpha) 1000);
479    val (px,pnotx) = partition contain_x pxl
480  in
481    (first_n 100 (shuffle px), first_n 100 (shuffle pnotx))
482  end
483val (l1,l2) = split (List.tabulate (20, fn n => mk_dataset (n + 1)));
484val (l1',l2') = (List.concat l1, List.concat l2);
485val (pos,neg) = (map_assoc (fn x => [1.0]) l1', map_assoc (fn x => [0.0]) l2');
486val ex0 = shuffle (pos @ neg);
487val ex1 = map (fn (a,b) => [(mk_comb (vhead,a),b)]) ex0;
488val (trainex,testex) = part_pct 0.9 ex1;
489
490(* TNN *)
491val nlayer = 1;
492val dim = 16;
493val randtnn = random_tnn_std (nlayer,dim) (vhead :: varl);
494
495(* training *)
496val trainparam =
497  {ncore = 1, verbose = true,
498   learning_rate = 0.02, batch_size = 16, nepoch = 20};
499val schedule = [trainparam];
500val tnn = train_tnn schedule randtnn (trainex,testex);
501
502(* testing *)
503val acc = tnn_accuracy tnn testex;
504*)
505
506end (* struct *)
507