1(* ========================================================================= *) 2(* FILE : mlTreeNeuralNetwork.sml *) 3(* DESCRIPTION : Tree neural network *) 4(* AUTHOR : (c) Thibault Gauthier, Czech Technical University *) 5(* DATE : 2018 *) 6(* ========================================================================= *) 7 8structure mlTreeNeuralNetwork :> mlTreeNeuralNetwork = 9struct 10 11open HolKernel boolLib Abbrev aiLib mlMatrix mlNeuralNetwork smlParallel 12smlParallel smlParser mlTacticData 13 14val ERR = mk_HOL_ERR "mlTreeNeuralNetwork" 15fun msg param s = if #verbose param then print_endline s else () 16fun msg_err fs es = (print_endline (fs ^ ": " ^ es); raise ERR fs es) 17 18(* ------------------------------------------------------------------------- 19 Tools for computing the dimensions of neural network operators 20 ------------------------------------------------------------------------- *) 21 22type tnndim = (term * int list) list 23 24fun operl_of_term tm = 25 let 26 val (oper,argl) = strip_comb tm 27 val arity = length argl 28 in 29 (oper,arity) :: List.concat (map operl_of_term argl) 30 end 31 32val oper_compare = cpl_compare Term.compare Int.compare; 33 34fun dim_std_arity (nlayer,dim) (oper,a) = 35 let 36 val dim_alt = 37 if is_var oper andalso String.isPrefix "head_" (fst (dest_var oper)) 38 then 1 39 else dim 40 in 41 (if a = 0 then [0] else List.tabulate (nlayer, fn _ => a * dim)) @ 42 [dim_alt] 43 end 44 45fun dim_std (nlayer,dim) oper = 46 dim_std_arity (nlayer,dim) (oper,arity_of oper) 47 48(* ------------------------------------------------------------------------- 49 Random TNN 50 ------------------------------------------------------------------------- *) 51 52type tnn = (term,nn) Redblackmap.dict 53 54fun oper_nn diml = case diml of 55 [] => raise ERR "oper_nn" "" 56 | a :: m => 57 if a = 0 58 then random_nn (idactiv,didactiv) [0,last m] 59 else random_nn (tanh,dtanh) diml 60 61fun random_tnn tnndim = dnew Term.compare (map_snd oper_nn tnndim) 62 63fun random_tnn_std (nlayer,dim) operl = 64 random_tnn (map_assoc (dim_std (nlayer,dim)) operl) 65 66(* ------------------------------------------------------------------------- 67 TNN I/O 68 ------------------------------------------------------------------------- *) 69 70local open SharingTables HOLsexp in 71fun enc_opernn enc_tm = pair_encode (enc_tm, enc_nn) 72fun enc_tnn enc_tm = list_encode (enc_opernn enc_tm) 73fun dec_opernn dec_tm = pair_decode (dec_tm, dec_nn) 74fun dec_tnn dec_tm = list_decode (dec_opernn dec_tm) 75end 76 77fun write_tnn file tnn = write_tmdata (enc_tnn, map fst) file (dlist tnn) 78fun read_tnn file = dnew Term.compare (read_tmdata dec_tnn file) 79 80local open SharingTables HOLsexp in 81fun enc_tnndime enc_tm = pair_encode (enc_tm, list_encode Integer) 82fun enc_tnndim enc_tm = list_encode (enc_tnndime enc_tm) 83fun dec_tnndime dec_tm = pair_decode (dec_tm, list_decode int_decode) 84fun dec_tnndim dec_tm = list_decode (dec_tnndime dec_tm) 85end 86 87fun write_tnndim file tnndim = write_tmdata (enc_tnndim, map fst) file tnndim 88fun read_tnndim file = read_tmdata dec_tnndim file 89 90(* ------------------------------------------------------------------------- 91 TNN Examples: I/O 92 ------------------------------------------------------------------------- *) 93 94type tnnex = ((term * real list) list) list 95type tnnbatch = (term list * (term * mlMatrix.vect) list) list 96 97fun basicex_to_tnnex ex = map (fn (tm,r) => [(tm,[r])]) ex 98 99local open SharingTables HOLsexp in 100val enc_real = String o Real.toString 101val dec_real = Option.mapPartial Real.fromString o string_decode 102fun enc_sample enc_tm = pair_encode (enc_tm, list_encode enc_real) 103fun dec_sample dec_tm = pair_decode (dec_tm, list_decode dec_real) 104fun enc_tnnex enc_tm = list_encode (list_encode (enc_sample enc_tm)) 105fun dec_tnnex dec_tm = list_decode (list_decode (dec_sample dec_tm)) 106fun tml_of_tnnex l = map fst (List.concat l) 107end 108 109fun write_tnnex file ex = 110 write_tmdata (enc_tnnex, tml_of_tnnex) file ex 111fun read_tnnex file = 112 read_tmdata dec_tnnex file 113 114(* ------------------------------------------------------------------------- 115 TNN Examples: ordering subterms and scaling output values 116 ------------------------------------------------------------------------- *) 117 118fun order_subtm tml = 119 let 120 val d = ref (dempty (cpl_compare Int.compare Term.compare)) 121 fun traverse tm = 122 let 123 val (oper,argl) = strip_comb tm 124 val nl = map traverse argl 125 val n = 1 + sum_int nl 126 in 127 d := dadd (n, tm) () (!d); n 128 end 129 val subtml = (app (ignore o traverse) tml; dkeys (!d)) 130 in 131 map snd subtml 132 end 133 134fun prepare_tnnex tnnex = 135 let fun f x = (order_subtm (map fst x), map_snd scale_out x) in 136 map f tnnex 137 end 138 139(* ------------------------------------------------------------------------- 140 Fixed embedding 141 ------------------------------------------------------------------------- *) 142 143val embedding_prefix = "embedding_" 144 145fun is_embedding v = 146 is_var v andalso String.isPrefix embedding_prefix (fst (dest_var v)) 147 148fun embed_nn v = 149 if is_embedding v then 150 let 151 val vs = fst (dest_var v) 152 val n1 = String.size embedding_prefix 153 val ntot = String.size vs 154 val es = String.substring (vs,n1,ntot-n1) 155 val e1 = string_to_reall es 156 val e2 = map (fn x => Vector.fromList [x]) e1 157 in 158 [{a = idactiv, da = didactiv, w = Vector.fromList e2}] 159 end 160 else msg_err "embed_nn" (tts v) 161 162fun mk_embedding_var (rv,ty) = 163 mk_var (embedding_prefix ^ reall_to_string (vector_to_list rv), ty) 164 165(* ------------------------------------------------------------------------- 166 Forward propagation 167 ------------------------------------------------------------------------- *) 168 169fun fp_oper tnn fpdict tm = 170 let 171 val (f,argl) = strip_comb tm 172 val nn = (dfind f) tnn handle NotFound => embed_nn f 173 val invl = (map (fn x => #outnv (last (dfind x fpdict)))) argl 174 val inv = Vector.concat invl 175 in 176 fp_nn nn inv 177 end 178 handle Subscript => msg_err "fp_oper" (tts tm) 179 180fun fp_tnn_aux tnn fpdict tml = case tml of 181 [] => fpdict 182 | tm :: m => 183 let val fpdatal = fp_oper tnn fpdict tm in 184 fp_tnn_aux tnn (dadd tm fpdatal fpdict) m 185 end 186 187fun fp_tnn tnn tml = fp_tnn_aux tnn (dempty Term.compare) tml 188 189(* ------------------------------------------------------------------------- 190 Backward propagation 191 ------------------------------------------------------------------------- *) 192 193fun sum_operdwll (oper,dwll) = [sum_dwll dwll] 194 195fun dimout_fpdatal fpdatal = Vector.length (#outnv (last fpdatal)) 196fun dimout_tm fpdict tm = dimout_fpdatal (dfind tm fpdict) 197 198fun bp_tnn_aux doutnvdict fpdict bpdict revtml = case revtml of 199 [] => dmap sum_operdwll bpdict 200 | tm :: m => 201 let 202 val (oper,argl) = strip_comb tm 203 val diml = map (dimout_tm fpdict) argl 204 val doutnvl = dfind tm doutnvdict 205 val doutnvsum = add_vectl doutnvl 206 fun f doutnv = 207 let 208 val fpdatal = dfind tm fpdict 209 val bpdatal = bp_nn_doutnv fpdatal doutnv 210 val dinv = vector_to_list (#dinv (hd bpdatal)) 211 val dinvl = map Vector.fromList (part_group diml dinv) 212 in 213 (map #dw bpdatal, combine (argl,dinvl)) 214 end 215 val (operdwl,tmdinvl) = f (add_vectl doutnvl) 216 val newdoutnvdict = dappendl tmdinvl doutnvdict 217 val newbpdict = dappend (oper,operdwl) bpdict 218 in 219 bp_tnn_aux newdoutnvdict fpdict newbpdict m 220 end 221 222fun bp_tnn fpdict (tml,tmevl) = 223 let 224 fun f (tm,ev) = 225 let 226 val fpdatal = dfind tm fpdict 227 val doutnv = diff_rvect ev (#outnv (last fpdatal)) 228 in 229 (tm,[doutnv]) 230 end 231 val doutnvdict = dnew Term.compare (map f tmevl) 232 in 233 bp_tnn_aux doutnvdict fpdict (dempty Term.compare) (rev tml) 234 end 235 236(* ------------------------------------------------------------------------- 237 Inference 238 ------------------------------------------------------------------------- *) 239 240fun infer_tnn tnn tml = 241 let 242 val fpdict = fp_tnn tnn (order_subtm tml) 243 fun f x = descale_out (#outnv (last (dfind x fpdict))) 244 in 245 map_assoc f tml 246 end 247 248fun infer_tnn_basic tnn tm = 249 singleton_of_list (snd (singleton_of_list (infer_tnn tnn [tm]))) 250 251fun precomp_embed tnn tm = 252 let 253 val fpdict = fp_tnn tnn (order_subtm [tm]) 254 val embedv = #outnv (last (dfind tm fpdict)) 255 in 256 mk_embedding_var (embedv, type_of tm) 257 end 258 259(* ------------------------------------------------------------------------- 260 Training 261 ------------------------------------------------------------------------- *) 262 263fun se_of fpdict (tm,ev) = 264 let 265 val fpdatal = dfind tm fpdict 266 val doutnv = diff_rvect ev (#outnv (last fpdatal)) 267 val r1 = vector_to_list doutnv 268 val r2 = map (fn x => x * x) r1 269 in 270 Math.sqrt (average_real r2) 271 end 272 273fun mse_of fpdict tmevl = average_real (map (se_of fpdict) tmevl) 274 275fun fp_loss tnn (tml,tmevl) = mse_of (fp_tnn tnn tml) tmevl 276 277fun train_tnn_one tnn (tml,tmevl) = 278 let 279 val fpdict = fp_tnn tnn tml 280 val bpdict = bp_tnn fpdict (tml,tmevl) 281 in 282 (bpdict, mse_of fpdict tmevl) 283 end 284 285fun train_tnn_subbatch tnn subbatch = 286 let val (bpdictl,lossl) = split (map (train_tnn_one tnn) subbatch) in 287 (dmap sum_operdwll (dconcat Term.compare bpdictl), lossl) 288 end 289 290fun update_oper param ((oper,dwll),tnn) = 291 if is_embedding oper then tnn else 292 let 293 val nn = dfind oper tnn 294 val dwl = sum_dwll dwll 295 val newnn = update_nn param nn dwl 296 in 297 dadd oper newnn tnn 298 end 299 300fun train_tnn_batch param pf tnn batch = 301 let 302 val subbatchl = cut_modulo (#ncore param) batch 303 val (bpdictl,lossll) = split (pf (train_tnn_subbatch tnn) subbatchl) 304 val bpdict = dconcat Term.compare bpdictl 305 in 306 (foldl (update_oper param) tnn (dlist bpdict), 307 average_real (List.concat lossll)) 308 end 309 310fun train_tnn_epoch param pf lossl tnn batchl = case batchl of 311 [] => (tnn, average_real lossl) 312 | batch :: m => 313 let val (newtnn,loss) = train_tnn_batch param pf tnn batch in 314 train_tnn_epoch param pf (loss :: lossl) newtnn m 315 end 316 317fun train_tnn_epoch_nopar param lossl tnn batchl = case batchl of 318 [] => (tnn, average_real lossl) 319 | batch :: m => 320 let val (newtnn,loss) = train_tnn_batch param map tnn batch in 321 train_tnn_epoch_nopar param (loss :: lossl) newtnn m 322 end 323 324fun train_tnn_nepoch param pf i tnn (train,test) = 325 if i >= #nepoch param then tnn else 326 let 327 val batchl = mk_batch (#batch_size param) (shuffle train) 328 val _ = if null batchl then msg_err "train_tnn_nepoch" "empty" else () 329 val (newtnn,loss) = train_tnn_epoch param pf [] tnn batchl 330 val testloss = if null test then "" else 331 (" test: " ^ pretty_real (average_real (map (fp_loss newtnn) test))) 332 val _ = msg param (its i ^ " train: " ^ pretty_real loss ^ testloss) 333 in 334 train_tnn_nepoch param pf (i+1) newtnn (train,test) 335 end 336 337fun train_tnn_schedule schedule tnn (train,test) = 338 case schedule of 339 [] => tnn 340 | param :: m => 341 let 342 val _ = msg param ("learning rate: " ^ rts (#learning_rate param)) 343 val _ = msg param ("ncore: " ^ its (#ncore param)) 344 val (pf,close_threadl) = parmap_gen (#ncore param) 345 val newtnn = train_tnn_nepoch param pf 0 tnn (train,test) 346 val r = train_tnn_schedule m newtnn (train,test) 347 in 348 (close_threadl (); r) 349 end 350 351fun stats_head (oper,rll) = 352 let 353 val s0 = " objective: " ^ tts oper 354 val s1 = "length: " ^ its (length rll) 355 val rll' = list_combine rll 356 val s2 = "means: " ^ 357 String.concatWith " " (map (pretty_real o average_real) rll') 358 val s3 = "standard deviations: " ^ 359 String.concatWith " " (map (pretty_real o standard_deviation) rll') 360 in 361 String.concatWith "\n " [s0,s1,s2,s3] 362 end 363 364fun stats_tnnex ex = 365 if null ex then " empty" else 366 let 367 fun head_of tm = fst (strip_comb tm) 368 val d = dregroup Term.compare (map_fst head_of (List.concat ex)) 369 in 370 its (length ex) ^ " examples\n" ^ 371 String.concatWith "\n" (map stats_head (dlist d)) 372 end 373 374fun train_tnn schedule randtnn (trainex,testex) = 375 let 376 val _ = print_endline ("\ntraining set: " ^ stats_tnnex trainex) 377 val _ = print_endline ("testing set: " ^ stats_tnnex testex) 378 val _ = print_endline "" 379 val (tnn,t) = add_time (train_tnn_schedule schedule randtnn) 380 (prepare_tnnex trainex, prepare_tnnex testex) 381 in 382 print_endline ("Tree neural network training time: " ^ rts t); tnn 383 end 384 385(* ------------------------------------------------------------------------- 386 Accuracy 387 ------------------------------------------------------------------------- *) 388 389fun is_accurate_one (rl1,rl2) = 390 let 391 val rl3 = combine (rl1,rl2) 392 fun test (x,y) = Real.abs (x - y) < 0.5 393 in 394 if all test rl3 then true else false 395 end 396 397fun is_accurate tnn e = 398 let 399 val rll1 = map snd (infer_tnn tnn (map fst e)) 400 val rll2 = map snd e 401 in 402 all is_accurate_one (combine (rll1,rll2)) 403 end 404 405fun tnn_accuracy tnn set = 406 let val correct = filter (is_accurate tnn) set in 407 Real.fromInt (length correct) / Real.fromInt (length set) 408 end 409 410(* ------------------------------------------------------------------------- 411 Object for training different TNN in parallel 412 ------------------------------------------------------------------------- *) 413 414fun train_tnn_fun () (ex,schedule,tnndim) = 415 let 416 val randtnn = random_tnn tnndim 417 val (tnn,t) = add_time (train_tnn schedule randtnn) (ex,[]) 418 in 419 print_endline ("Training time : " ^ rts t); tnn 420 end 421 422fun write_noparam file (_:unit) = () 423fun read_noparam file = () 424 425fun write_tnnarg file (ex,schedule,tnndim) = 426 ( 427 write_tnnex (file ^ "_tnnex") ex; 428 write_schedule (file ^ "_schedule") schedule; 429 write_tnndim (file ^ "_tnndim") tnndim 430 ) 431fun read_tnnarg file = 432 let 433 val ex = read_tnnex (file ^ "_tnnex") 434 val schedule = read_schedule (file ^ "_schedule") 435 val tnndim = read_tnndim (file ^ "_tnndim") 436 in 437 (ex,schedule,tnndim) 438 end 439 440val traintnn_extspec = 441 { 442 self = "mlTreeNeuralNetwork.traintnn_extspec", 443 parallel_dir = default_parallel_dir ^ "_train", 444 reflect_globals = fn () => "()", 445 function = train_tnn_fun, 446 write_param = write_noparam, 447 read_param = read_noparam, 448 write_arg = write_tnnarg, 449 read_arg = read_tnnarg, 450 write_result = write_tnn, 451 read_result = read_tnn 452 } 453 454 455 456(* ------------------------------------------------------------------------- 457 Toy example: learning to guess if a term contains the variable "x" 458 ------------------------------------------------------------------------- *) 459 460(* 461load "aiLib"; open aiLib; 462load "psTermGen"; open psTermGen; 463load "mlTreeNeuralNetwork"; open mlTreeNeuralNetwork; 464 465(* terms *) 466val vx = mk_var ("x",alpha); 467val vy = mk_var ("y",alpha); 468val vz = mk_var ("z",alpha); 469val vf = ``f:'a->'a->'a``; 470val vg = ``g:'a -> 'a``; 471val vhead = mk_var ("head_", ``:'a -> 'a``); 472val varl = [vx,vy,vz,vf,vg]; 473 474(* examples *) 475fun contain_x tm = can (find_term (fn x => term_eq x vx)) tm; 476fun mk_dataset n = 477 let 478 val pxl = mk_term_set (random_terml varl (n,alpha) 1000); 479 val (px,pnotx) = partition contain_x pxl 480 in 481 (first_n 100 (shuffle px), first_n 100 (shuffle pnotx)) 482 end 483val (l1,l2) = split (List.tabulate (20, fn n => mk_dataset (n + 1))); 484val (l1',l2') = (List.concat l1, List.concat l2); 485val (pos,neg) = (map_assoc (fn x => [1.0]) l1', map_assoc (fn x => [0.0]) l2'); 486val ex0 = shuffle (pos @ neg); 487val ex1 = map (fn (a,b) => [(mk_comb (vhead,a),b)]) ex0; 488val (trainex,testex) = part_pct 0.9 ex1; 489 490(* TNN *) 491val nlayer = 1; 492val dim = 16; 493val randtnn = random_tnn_std (nlayer,dim) (vhead :: varl); 494 495(* training *) 496val trainparam = 497 {ncore = 1, verbose = true, 498 learning_rate = 0.02, batch_size = 16, nepoch = 20}; 499val schedule = [trainparam]; 500val tnn = train_tnn schedule randtnn (trainex,testex); 501 502(* testing *) 503val acc = tnn_accuracy tnn testex; 504*) 505 506end (* struct *) 507