1(* ========================================================================= *)
2(* FILE          : mlNeuralNetwork.sml                                       *)
3(* DESCRIPTION   : Feed forward neural network                               *)
4(* AUTHOR        : (c) Thibault Gauthier, Czech Technical University         *)
5(* DATE          : 2018                                                      *)
6(* ========================================================================= *)
7
8structure mlNeuralNetwork :> mlNeuralNetwork =
9struct
10
11open HolKernel Abbrev boolLib aiLib mlMatrix smlParallel
12
13val ERR = mk_HOL_ERR "mlNeuralNetwork"
14
15(* -------------------------------------------------------------------------
16   Activation and derivatives (with an optimization)
17   ------------------------------------------------------------------------- *)
18
19fun idactiv (x:real) = x:real
20fun didactiv (x:real) = 1.0
21fun tanh x = Math.tanh x
22fun dtanh fx = 1.0 - fx * fx
23
24(* -------------------------------------------------------------------------
25   Types
26   ------------------------------------------------------------------------- *)
27
28type layer = {a : real -> real, da : real -> real, w : real vector vector}
29type nn = layer list
30type trainparam =
31  {ncore: int, verbose: bool,
32   learning_rate: real, batch_size: int, nepoch: int}
33
34fun string_of_trainparam {ncore,verbose,learning_rate,batch_size,nepoch} =
35  its ncore ^ " " ^ bts verbose ^ " " ^ rts learning_rate ^ " " ^
36  its batch_size ^ " " ^ its nepoch
37
38fun trainparam_of_string s =
39  let val (a,b,c,d,e) = quintuple_of_list (String.tokens Char.isSpace s) in
40    {
41    ncore = string_to_int a,
42    verbose = string_to_bool b,
43    learning_rate = (valOf o Real.fromString) c,
44    batch_size = string_to_int d,
45    nepoch = string_to_int e
46    }
47  end
48
49type schedule = trainparam list
50
51fun write_schedule file schedule =
52  writel file (map string_of_trainparam schedule)
53fun read_schedule file =
54  map trainparam_of_string (readl file)
55
56(* inv includes biais *)
57type fpdata = {layer : layer, inv : vect, outv : vect, outnv : vect}
58type bpdata = {doutnv : vect, doutv : vect, dinv : vect, dw : mat}
59
60(*---------------------------------------------------------------------------
61  Initialization
62  ---------------------------------------------------------------------------*)
63
64fun diml_aux insize sizel = case sizel of
65    [] => []
66  | outsize :: m => (outsize, insize) :: diml_aux outsize m
67
68fun diml_of sizel = case sizel of
69    [] => []
70  | a :: m => diml_aux a m
71
72fun dimin_nn nn = ((snd o mat_dim o #w o hd) nn) - 1
73fun dimout_nn nn = (fst o mat_dim o #w o last) nn
74
75fun random_nn (a,da) sizel =
76  let
77    val l = diml_of sizel
78    fun biais_dim (i,j) = (i,j+1)
79    fun f x = {a = a, da = da, w = mat_random (biais_dim x)}
80  in
81    map f l
82  end
83
84(* -------------------------------------------------------------------------
85   I/O (assumes tanh activation functions)
86   ------------------------------------------------------------------------- *)
87
88local open HOLsexp in
89fun enc_nn nn = list_encode enc_mat (map #w nn)
90fun dec_nn sexp =
91  let
92    val matl = valOf (list_decode dec_mat sexp)
93      handle Option => raise ERR "dec_nn" ""
94    fun f m = {a = tanh, da = dtanh, w = m}
95  in
96    SOME (map f matl)
97  end
98end
99
100fun write_nn file nn = write_data enc_nn file nn
101fun read_nn file = read_data dec_nn file
102
103fun string_of_ex (l1,l2) = reall_to_string l1 ^ "," ^ reall_to_string l2
104fun ex_of_string s =
105  let val (a,b) = pair_of_list (String.tokens (fn x => x = #",") s) in
106    (string_to_reall a, string_to_reall b)
107  end
108
109fun write_exl file exl = writel file (map string_of_ex exl)
110fun read_exl file = map ex_of_string (readl file)
111
112(* -------------------------------------------------------------------------
113   Biais
114   ------------------------------------------------------------------------- *)
115
116val biais = Vector.fromList [1.0]
117fun add_biais v = Vector.concat [biais,v]
118fun rm_biais v = Vector.fromList (tl (vector_to_list v))
119
120(* -------------------------------------------------------------------------
121  Forward propagation (fp) with memory of the steps
122  -------------------------------------------------------------------------- *)
123
124fun mat_dims m =
125  let val (a,b) = mat_dim m in "(" ^ its a ^ "," ^ its b ^ ")" end
126fun vect_dims v = its (Vector.length v + 1)
127
128fun fp_layer (layer : layer) inv =
129  let
130    val new_inv = add_biais inv
131    val outv = mat_mult (#w layer) new_inv
132    val outnv = Vector.map (#a layer) outv
133  in
134    {layer = layer, inv = new_inv, outv = outv, outnv = outnv}
135  end
136  handle Subscript =>
137    raise ERR "fp_layer" ("dimension: mat-" ^
138                          mat_dims (#w layer) ^ " vect-" ^ vect_dims inv)
139
140fun fp_nn nn v = case nn of
141    [] => []
142  | layer :: m =>
143    let val fpdata = fp_layer layer v in fpdata :: fp_nn m (#outnv fpdata) end
144
145(* -------------------------------------------------------------------------
146   Backward propagation (bp)
147   Takes the data from the forward pass, computes the loss and weight updates
148   by gradient descent.
149   Input has size j. Output has size i. Matrix has i lines and j columns.
150   ------------------------------------------------------------------------- *)
151
152fun bp_layer (fpdata:fpdata) doutnv =
153  let
154    val doutv =
155      (* trick: uses (#outnv fpdata) instead of (#outv fpdata) *)
156      let val dav = Vector.map (#da (#layer fpdata)) (#outnv fpdata) in
157        mult_rvect dav doutnv
158      end
159    val w        = #w (#layer fpdata)
160    fun dw_f i j = Vector.sub (#inv fpdata,j) * Vector.sub (doutv,i)
161    val dw       = mat_tabulate dw_f (mat_dim w)
162    val dinv     = rm_biais (mat_mult (mat_transpose w) doutv)
163  in
164   {doutnv = doutnv, doutv = doutv, dinv = dinv, dw = dw}
165  end
166
167fun bp_nn_aux rev_fpdatal doutnv =
168  case rev_fpdatal of
169    [] => []
170  | fpdata :: m =>
171    let val bpdata = bp_layer fpdata doutnv in
172      bpdata :: bp_nn_aux m (#dinv bpdata)
173    end
174
175fun bp_nn_doutnv fpdatal doutnv = rev (bp_nn_aux (rev fpdatal) doutnv)
176
177fun bp_nn fpdatal expectv =
178  let
179    val rev_fpdatal = rev fpdatal
180    val outnv = #outnv (hd rev_fpdatal)
181    val doutnv = diff_rvect expectv outnv
182  in
183    rev (bp_nn_aux rev_fpdatal doutnv)
184  end
185
186(* -------------------------------------------------------------------------
187   Update weights and calculate loss
188   ------------------------------------------------------------------------- *)
189
190fun train_nn_one nn (inputv,expectv) = bp_nn (fp_nn nn inputv) expectv
191
192fun transpose_ll ll = case ll of
193    [] :: _ => []
194  | _ => map hd ll :: transpose_ll (map tl ll)
195
196fun sum_dwll dwll = case dwll of
197     [dwl] => dwl
198   | _ => map matl_add (transpose_ll dwll)
199
200fun smult_dwl k dwl = map (mat_smult k) dwl
201
202fun mean_square_error v =
203  let fun square x = (x:real) * x in
204    Math.sqrt (average_rvect (Vector.map square v))
205  end
206
207fun bp_loss bpdatal = mean_square_error (#doutnv (last bpdatal))
208
209fun average_loss bpdatall = average_real (map bp_loss bpdatall)
210
211fun clip (a,b) m =
212  let fun f x = if x < a then a else (if x > b then b else x) in
213    mat_map f m
214  end
215
216fun update_layer param (layer, layerwu) =
217  let
218    val coeff = #learning_rate param / Real.fromInt (#batch_size param)
219    val w0 = mat_smult coeff layerwu
220    val w1 = mat_add (#w layer) w0
221    val w2 = clip (~4.0,4.0) w1
222  in
223    {a = #a layer, da = #da layer, w = w2}
224  end
225
226fun update_nn param nn wu = map (update_layer param) (combine (nn,wu))
227
228(* -------------------------------------------------------------------------
229   Statistics
230   ------------------------------------------------------------------------- *)
231
232fun sr r = pad 7 "0" (rts_round 5 r)
233
234fun stats_exl exl =
235  let
236    val ll = list_combine (map snd exl)
237    fun f l =
238      print_endline (sr (average_real l ) ^ " " ^ sr (absolute_deviation l))
239  in
240    print_endline "mean deviation"; app f ll; print_endline ""
241  end
242
243(* -------------------------------------------------------------------------
244   Training
245   ------------------------------------------------------------------------- *)
246
247fun train_nn_batch param pf nn batch =
248  let
249    val bpdatall = pf (train_nn_one nn) batch
250    val dwll = map (map #dw) bpdatall
251    val dwl = sum_dwll dwll
252    val newnn = update_nn param nn dwl
253  in
254    (newnn, average_loss bpdatall)
255  end
256
257fun train_nn_epoch param pf lossl nn batchl  = case batchl of
258    [] => (nn, average_real lossl)
259  | batch :: m =>
260    let val (newnn,loss) = train_nn_batch param pf nn batch in
261      train_nn_epoch param pf (loss :: lossl) newnn m
262    end
263
264fun train_nn_nepoch param pf i nn exl =
265  if i >= #nepoch param then nn else
266  let
267    val batchl = mk_batch (#batch_size param) (shuffle exl)
268    val (new_nn,loss) = train_nn_epoch param pf [] nn batchl
269    val _ =
270      if #verbose param then print_endline (its i ^ " " ^ sr loss) else ()
271  in
272    train_nn_nepoch param pf (i+1) new_nn exl
273  end
274
275(* -------------------------------------------------------------------------
276   Interface:
277   - Scaling from [0,1] to [-1,1] to match activation functions range.
278   - Converting lists to vectors
279   ------------------------------------------------------------------------- *)
280
281fun scale_real x = x * 2.0 - 1.0
282fun descale_real x = (x + 1.0) * 0.5
283fun scale_in l = Vector.fromList (map scale_real l)
284fun scale_out l = Vector.fromList (map scale_real l)
285fun descale_out v = map descale_real (vector_to_list v)
286fun scale_ex (l1,l2) = (scale_in l1, scale_out l2)
287
288fun train_nn param nn exl =
289  let
290    val (pf,close_threadl) = parmap_gen (#ncore param)
291    val _ = if #verbose param then stats_exl exl else ()
292    val newexl = map scale_ex exl
293    val r = train_nn_nepoch param pf 0 nn newexl
294  in
295    close_threadl (); r
296  end
297
298fun infer_nn nn l = (descale_out o #outnv o last o (fp_nn nn) o scale_in) l
299
300
301end (* struct *)
302
303(* -------------------------------------------------------------------------
304   Identity example
305   ------------------------------------------------------------------------- *)
306
307(*
308load "mlNeuralNetwork"; open mlNeuralNetwork;
309load "aiLib"; open aiLib;
310
311(* examples *)
312fun gen_idex dim =
313  let fun f () = List.tabulate (dim, fn _ => random_real ()) in
314    let val x = f () in (x,x) end
315  end
316;
317val dim = 10;
318val exl = List.tabulate (1000, fn _ => gen_idex dim);
319
320(* training *)
321val nn = random_nn (tanh,dtanh) [dim,4*dim,4*dim,dim];
322val param : trainparam =
323  {ncore = 1, verbose = true,
324   learning_rate = 0.02, batch_size = 16, nepoch = 100}
325;
326val (newnn,t) = add_time (train_nn param nn) exl;
327
328(* testing *)
329val inv = fst (gen_idex dim);
330val outv = infer_nn newnn inv;
331*)
332
333
334