1(* ========================================================================= *) 2(* FILE : mlNeuralNetwork.sml *) 3(* DESCRIPTION : Feed forward neural network *) 4(* AUTHOR : (c) Thibault Gauthier, Czech Technical University *) 5(* DATE : 2018 *) 6(* ========================================================================= *) 7 8structure mlNeuralNetwork :> mlNeuralNetwork = 9struct 10 11open HolKernel Abbrev boolLib aiLib mlMatrix smlParallel 12 13val ERR = mk_HOL_ERR "mlNeuralNetwork" 14 15(* ------------------------------------------------------------------------- 16 Activation and derivatives (with an optimization) 17 ------------------------------------------------------------------------- *) 18 19fun idactiv (x:real) = x:real 20fun didactiv (x:real) = 1.0 21fun tanh x = Math.tanh x 22fun dtanh fx = 1.0 - fx * fx 23 24(* ------------------------------------------------------------------------- 25 Types 26 ------------------------------------------------------------------------- *) 27 28type layer = {a : real -> real, da : real -> real, w : real vector vector} 29type nn = layer list 30type trainparam = 31 {ncore: int, verbose: bool, 32 learning_rate: real, batch_size: int, nepoch: int} 33 34fun string_of_trainparam {ncore,verbose,learning_rate,batch_size,nepoch} = 35 its ncore ^ " " ^ bts verbose ^ " " ^ rts learning_rate ^ " " ^ 36 its batch_size ^ " " ^ its nepoch 37 38fun trainparam_of_string s = 39 let val (a,b,c,d,e) = quintuple_of_list (String.tokens Char.isSpace s) in 40 { 41 ncore = string_to_int a, 42 verbose = string_to_bool b, 43 learning_rate = (valOf o Real.fromString) c, 44 batch_size = string_to_int d, 45 nepoch = string_to_int e 46 } 47 end 48 49type schedule = trainparam list 50 51fun write_schedule file schedule = 52 writel file (map string_of_trainparam schedule) 53fun read_schedule file = 54 map trainparam_of_string (readl file) 55 56(* inv includes biais *) 57type fpdata = {layer : layer, inv : vect, outv : vect, outnv : vect} 58type bpdata = {doutnv : vect, doutv : vect, dinv : vect, dw : mat} 59 60(*--------------------------------------------------------------------------- 61 Initialization 62 ---------------------------------------------------------------------------*) 63 64fun diml_aux insize sizel = case sizel of 65 [] => [] 66 | outsize :: m => (outsize, insize) :: diml_aux outsize m 67 68fun diml_of sizel = case sizel of 69 [] => [] 70 | a :: m => diml_aux a m 71 72fun dimin_nn nn = ((snd o mat_dim o #w o hd) nn) - 1 73fun dimout_nn nn = (fst o mat_dim o #w o last) nn 74 75fun random_nn (a,da) sizel = 76 let 77 val l = diml_of sizel 78 fun biais_dim (i,j) = (i,j+1) 79 fun f x = {a = a, da = da, w = mat_random (biais_dim x)} 80 in 81 map f l 82 end 83 84(* ------------------------------------------------------------------------- 85 I/O (assumes tanh activation functions) 86 ------------------------------------------------------------------------- *) 87 88local open HOLsexp in 89fun enc_nn nn = list_encode enc_mat (map #w nn) 90fun dec_nn sexp = 91 let 92 val matl = valOf (list_decode dec_mat sexp) 93 handle Option => raise ERR "dec_nn" "" 94 fun f m = {a = tanh, da = dtanh, w = m} 95 in 96 SOME (map f matl) 97 end 98end 99 100fun write_nn file nn = write_data enc_nn file nn 101fun read_nn file = read_data dec_nn file 102 103fun string_of_ex (l1,l2) = reall_to_string l1 ^ "," ^ reall_to_string l2 104fun ex_of_string s = 105 let val (a,b) = pair_of_list (String.tokens (fn x => x = #",") s) in 106 (string_to_reall a, string_to_reall b) 107 end 108 109fun write_exl file exl = writel file (map string_of_ex exl) 110fun read_exl file = map ex_of_string (readl file) 111 112(* ------------------------------------------------------------------------- 113 Biais 114 ------------------------------------------------------------------------- *) 115 116val biais = Vector.fromList [1.0] 117fun add_biais v = Vector.concat [biais,v] 118fun rm_biais v = Vector.fromList (tl (vector_to_list v)) 119 120(* ------------------------------------------------------------------------- 121 Forward propagation (fp) with memory of the steps 122 -------------------------------------------------------------------------- *) 123 124fun mat_dims m = 125 let val (a,b) = mat_dim m in "(" ^ its a ^ "," ^ its b ^ ")" end 126fun vect_dims v = its (Vector.length v + 1) 127 128fun fp_layer (layer : layer) inv = 129 let 130 val new_inv = add_biais inv 131 val outv = mat_mult (#w layer) new_inv 132 val outnv = Vector.map (#a layer) outv 133 in 134 {layer = layer, inv = new_inv, outv = outv, outnv = outnv} 135 end 136 handle Subscript => 137 raise ERR "fp_layer" ("dimension: mat-" ^ 138 mat_dims (#w layer) ^ " vect-" ^ vect_dims inv) 139 140fun fp_nn nn v = case nn of 141 [] => [] 142 | layer :: m => 143 let val fpdata = fp_layer layer v in fpdata :: fp_nn m (#outnv fpdata) end 144 145(* ------------------------------------------------------------------------- 146 Backward propagation (bp) 147 Takes the data from the forward pass, computes the loss and weight updates 148 by gradient descent. 149 Input has size j. Output has size i. Matrix has i lines and j columns. 150 ------------------------------------------------------------------------- *) 151 152fun bp_layer (fpdata:fpdata) doutnv = 153 let 154 val doutv = 155 (* trick: uses (#outnv fpdata) instead of (#outv fpdata) *) 156 let val dav = Vector.map (#da (#layer fpdata)) (#outnv fpdata) in 157 mult_rvect dav doutnv 158 end 159 val w = #w (#layer fpdata) 160 fun dw_f i j = Vector.sub (#inv fpdata,j) * Vector.sub (doutv,i) 161 val dw = mat_tabulate dw_f (mat_dim w) 162 val dinv = rm_biais (mat_mult (mat_transpose w) doutv) 163 in 164 {doutnv = doutnv, doutv = doutv, dinv = dinv, dw = dw} 165 end 166 167fun bp_nn_aux rev_fpdatal doutnv = 168 case rev_fpdatal of 169 [] => [] 170 | fpdata :: m => 171 let val bpdata = bp_layer fpdata doutnv in 172 bpdata :: bp_nn_aux m (#dinv bpdata) 173 end 174 175fun bp_nn_doutnv fpdatal doutnv = rev (bp_nn_aux (rev fpdatal) doutnv) 176 177fun bp_nn fpdatal expectv = 178 let 179 val rev_fpdatal = rev fpdatal 180 val outnv = #outnv (hd rev_fpdatal) 181 val doutnv = diff_rvect expectv outnv 182 in 183 rev (bp_nn_aux rev_fpdatal doutnv) 184 end 185 186(* ------------------------------------------------------------------------- 187 Update weights and calculate loss 188 ------------------------------------------------------------------------- *) 189 190fun train_nn_one nn (inputv,expectv) = bp_nn (fp_nn nn inputv) expectv 191 192fun transpose_ll ll = case ll of 193 [] :: _ => [] 194 | _ => map hd ll :: transpose_ll (map tl ll) 195 196fun sum_dwll dwll = case dwll of 197 [dwl] => dwl 198 | _ => map matl_add (transpose_ll dwll) 199 200fun smult_dwl k dwl = map (mat_smult k) dwl 201 202fun mean_square_error v = 203 let fun square x = (x:real) * x in 204 Math.sqrt (average_rvect (Vector.map square v)) 205 end 206 207fun bp_loss bpdatal = mean_square_error (#doutnv (last bpdatal)) 208 209fun average_loss bpdatall = average_real (map bp_loss bpdatall) 210 211fun clip (a,b) m = 212 let fun f x = if x < a then a else (if x > b then b else x) in 213 mat_map f m 214 end 215 216fun update_layer param (layer, layerwu) = 217 let 218 val coeff = #learning_rate param / Real.fromInt (#batch_size param) 219 val w0 = mat_smult coeff layerwu 220 val w1 = mat_add (#w layer) w0 221 val w2 = clip (~4.0,4.0) w1 222 in 223 {a = #a layer, da = #da layer, w = w2} 224 end 225 226fun update_nn param nn wu = map (update_layer param) (combine (nn,wu)) 227 228(* ------------------------------------------------------------------------- 229 Statistics 230 ------------------------------------------------------------------------- *) 231 232fun sr r = pad 7 "0" (rts_round 5 r) 233 234fun stats_exl exl = 235 let 236 val ll = list_combine (map snd exl) 237 fun f l = 238 print_endline (sr (average_real l ) ^ " " ^ sr (absolute_deviation l)) 239 in 240 print_endline "mean deviation"; app f ll; print_endline "" 241 end 242 243(* ------------------------------------------------------------------------- 244 Training 245 ------------------------------------------------------------------------- *) 246 247fun train_nn_batch param pf nn batch = 248 let 249 val bpdatall = pf (train_nn_one nn) batch 250 val dwll = map (map #dw) bpdatall 251 val dwl = sum_dwll dwll 252 val newnn = update_nn param nn dwl 253 in 254 (newnn, average_loss bpdatall) 255 end 256 257fun train_nn_epoch param pf lossl nn batchl = case batchl of 258 [] => (nn, average_real lossl) 259 | batch :: m => 260 let val (newnn,loss) = train_nn_batch param pf nn batch in 261 train_nn_epoch param pf (loss :: lossl) newnn m 262 end 263 264fun train_nn_nepoch param pf i nn exl = 265 if i >= #nepoch param then nn else 266 let 267 val batchl = mk_batch (#batch_size param) (shuffle exl) 268 val (new_nn,loss) = train_nn_epoch param pf [] nn batchl 269 val _ = 270 if #verbose param then print_endline (its i ^ " " ^ sr loss) else () 271 in 272 train_nn_nepoch param pf (i+1) new_nn exl 273 end 274 275(* ------------------------------------------------------------------------- 276 Interface: 277 - Scaling from [0,1] to [-1,1] to match activation functions range. 278 - Converting lists to vectors 279 ------------------------------------------------------------------------- *) 280 281fun scale_real x = x * 2.0 - 1.0 282fun descale_real x = (x + 1.0) * 0.5 283fun scale_in l = Vector.fromList (map scale_real l) 284fun scale_out l = Vector.fromList (map scale_real l) 285fun descale_out v = map descale_real (vector_to_list v) 286fun scale_ex (l1,l2) = (scale_in l1, scale_out l2) 287 288fun train_nn param nn exl = 289 let 290 val (pf,close_threadl) = parmap_gen (#ncore param) 291 val _ = if #verbose param then stats_exl exl else () 292 val newexl = map scale_ex exl 293 val r = train_nn_nepoch param pf 0 nn newexl 294 in 295 close_threadl (); r 296 end 297 298fun infer_nn nn l = (descale_out o #outnv o last o (fp_nn nn) o scale_in) l 299 300 301end (* struct *) 302 303(* ------------------------------------------------------------------------- 304 Identity example 305 ------------------------------------------------------------------------- *) 306 307(* 308load "mlNeuralNetwork"; open mlNeuralNetwork; 309load "aiLib"; open aiLib; 310 311(* examples *) 312fun gen_idex dim = 313 let fun f () = List.tabulate (dim, fn _ => random_real ()) in 314 let val x = f () in (x,x) end 315 end 316; 317val dim = 10; 318val exl = List.tabulate (1000, fn _ => gen_idex dim); 319 320(* training *) 321val nn = random_nn (tanh,dtanh) [dim,4*dim,4*dim,dim]; 322val param : trainparam = 323 {ncore = 1, verbose = true, 324 learning_rate = 0.02, batch_size = 16, nepoch = 100} 325; 326val (newnn,t) = add_time (train_nn param nn) exl; 327 328(* testing *) 329val inv = fst (gen_idex dim); 330val outv = infer_nn newnn inv; 331*) 332 333 334