1# * Copyright 2015, NICTA 2# * 3# * This software may be distributed and modified according to the terms of 4# * the BSD 2-Clause license. Note that NO WARRANTY is provided. 5# * See "LICENSE_BSD2.txt" for details. 6# * 7# * @TAG(NICTA_BSD) 8 9# pseudo-compiler for use of aggregate types in C-derived function code 10 11import syntax 12from syntax import structs, get_vars, get_expr_typ, get_node_vars, Expr, Node 13import logic 14 15 16(mk_var, mk_plus, mk_uminus, mk_minus, mk_times, mk_modulus, mk_bwand, mk_eq, 17mk_less_eq, mk_less, mk_implies, mk_and, mk_or, mk_not, mk_word32, mk_word8, 18mk_word32_maybe, mk_cast, mk_memacc, mk_memupd, mk_arr_index, mk_arroffs, 19mk_if, mk_meta_typ, mk_pvalid) = syntax.mks 20 21from syntax import word32T, word8T 22 23from syntax import fresh_name, foldr1 24 25from target_objects import symbols, trace 26 27def compile_field_acc (name, expr, replaces): 28 '''pseudo-compile access to field (named name) of expr''' 29 if expr.kind == 'StructCons': 30 return expr.vals[name] 31 elif expr.kind == 'FieldUpd': 32 if expr.field[0] == name: 33 return expr.val 34 else: 35 return compile_field_acc (name, expr.struct, replaces) 36 elif expr.kind == 'Var': 37 assert expr.name in replaces 38 [(v_nm, typ)] = [(v_nm, typ) for (f_nm, v_nm, typ) 39 in replaces[expr.name] if f_nm == name] 40 return mk_var (v_nm, typ) 41 elif expr.is_op ('MemAcc'): 42 assert expr.typ.kind == 'Struct' 43 (typ, offs, _) = structs[expr.typ.name].fields[name] 44 [m, p] = expr.vals 45 return mk_memacc (m, mk_plus (p, mk_word32 (offs)), typ) 46 elif expr.kind == 'Field': 47 expr2 = compile_field_acc (expr.field[0], expr.struct, replaces) 48 return compile_field_acc (name, expr2, replaces) 49 elif expr.is_op ('ArrayIndex'): 50 [arr, i] = expr.vals 51 expr2 = compile_array_acc (i, arr, replaces, False) 52 assert expr2, (arr, i) 53 return compile_field_acc (name, expr2, replaces) 54 else: 55 print expr 56 assert not 'field acc compilable' 57 58def compile_array_acc (i, expr, replaces, must = True): 59 '''pseudo-compile access to array element i of expr''' 60 if not logic.is_int (i) and i.kind == 'Num': 61 assert i.typ == word32T 62 i = i.val 63 if expr.kind == 'Array': 64 if logic.is_int (i): 65 return expr.vals[i] 66 else: 67 expr2 = expr.vals[-1] 68 for (j, v) in enumerate (expr.vals[:-1]): 69 expr2 = mk_if (mk_eq (i, mk_word32 (j)), v, expr2) 70 return expr2 71 elif expr.is_op ('ArrayUpdate'): 72 [arr, j, v] = expr.vals 73 if j.kind == 'Num' and logic.is_int (i): 74 if i == j.val: 75 return v 76 else: 77 return compile_array_acc (i, arr, replaces) 78 else: 79 return mk_if (mk_eq (j, mk_word32_maybe (i)), v, 80 compile_array_acc (i, arr, replaces)) 81 elif expr.is_op ('MemAcc'): 82 [m, p] = expr.vals 83 return mk_memacc (m, mk_arroffs (p, expr.typ, i), expr.typ.el_typ) 84 elif expr.is_op ('IfThenElse'): 85 [cond, left, right] = expr.vals 86 return mk_if (cond, compile_array_acc (i, left, replaces), 87 compile_array_acc (i, right, replaces)) 88 elif expr.kind == 'Var': 89 assert expr.name in replaces 90 if logic.is_int (i): 91 (_, v_nm, typ) = replaces[expr.name][i] 92 return mk_var (v_nm, typ) 93 else: 94 vs = [(mk_word32 (j), mk_var (v_nm, typ)) 95 for (j, v_nm, typ) 96 in replaces[expr.name]] 97 expr2 = vs[0][1] 98 for (j, v) in vs[1:]: 99 expr2 = mk_if (mk_eq (i, j), v, expr2) 100 return expr2 101 else: 102 if not must: 103 return None 104 return mk_arr_index (expr, mk_word32_maybe (i)) 105 106def num_fields (container, typ): 107 if container == typ: 108 return 1 109 elif container.kind == 'Array': 110 return container.num * num_fields (container.el_typ, typ) 111 elif container.kind == 'Struct': 112 struct = structs[container.name] 113 return sum ([num_fields (typ2, typ) 114 for (nm, typ2) in struct.field_list]) 115 else: 116 return 0 117 118def get_const_global_acc_offset (expr, offs, typ): 119 if expr.kind == 'ConstGlobal': 120 return (expr, offs) 121 elif expr.is_op ('ArrayIndex'): 122 [expr2, offs2] = expr.vals 123 offs = mk_plus (offs, mk_times (offs2, 124 mk_word32 (num_fields (expr.typ, typ)))) 125 return get_const_global_acc_offset (expr2, offs, typ) 126 elif expr.kind == 'Field': 127 struct = structs[expr.struct.typ.name] 128 offs2 = 0 129 for (nm, typ2) in struct.field_list: 130 if (nm, typ2) == expr.field: 131 offs = mk_plus (offs, mk_word32 (offs2)) 132 return get_const_global_acc_offset ( 133 expr.struct, offs, typ) 134 else: 135 offs2 += num_fields (typ2, typ) 136 else: 137 return None 138 139def compile_const_global_acc (expr): 140 if expr.kind == 'ConstGlobal' or (expr.is_op ('ArrayIndex') 141 and expr.vals[0].kind == 'ConstGlobal'): 142 return None 143 if expr.typ.kind != 'Word': 144 return None 145 r = get_const_global_acc_offset (expr, mk_word32 (0), expr.typ) 146 if r == None: 147 return None 148 (cg, offs) = r 149 return mk_arr_index (cg, offs) 150 151def compile_val_fields (expr, replaces): 152 if expr.typ.kind == 'Array': 153 res = [] 154 for i in range (expr.typ.num): 155 acc = compile_array_acc (i, expr, replaces) 156 res.extend (compile_val_fields (acc, replaces)) 157 return res 158 elif expr.typ.kind == 'Struct': 159 res = [] 160 for (nm, typ2) in structs[expr.typ.name].field_list: 161 acc = compile_field_acc (nm, expr, replaces) 162 res.extend (compile_val_fields (acc, replaces)) 163 return res 164 else: 165 return [compile_accs (replaces, expr)] 166 167def compile_val_fields_of_typ (expr, replaces, typ): 168 return [e for e in compile_val_fields (expr, replaces) 169 if e.typ == typ] 170 171# it helps in this compilation to know that the outermost expression we are 172# trying to fetch is always of basic type, never struct or array. 173 174# sort of fudged in the array indexing case here 175def compile_accs (replaces, expr): 176 r = compile_const_global_acc (expr) 177 if r: 178 return compile_accs (replaces, r) 179 180 if expr.kind == 'Field': 181 expr = compile_field_acc (expr.field[0], expr.struct, replaces) 182 return compile_accs (replaces, expr) 183 elif expr.is_op ('ArrayIndex'): 184 [arr, n] = expr.vals 185 expr2 = compile_array_acc (n, arr, replaces, False) 186 if expr2: 187 return compile_accs (replaces, expr2) 188 arr = compile_accs (replaces, arr) 189 n = compile_accs (replaces, n) 190 expr2 = compile_array_acc (n, arr, replaces, False) 191 if expr2: 192 return compile_accs (replaces, expr2) 193 else: 194 return mk_arr_index (arr, n) 195 elif (expr.is_op ('MemUpdate') 196 and expr.vals[2].is_op ('MemAcc') 197 and expr.vals[2].vals[0] == expr.vals[0] 198 and expr.vals[2].vals[1] == expr.vals[1]): 199 # null memory copy. probably created by ops below 200 return compile_accs (replaces, expr.vals[0]) 201 elif (expr.is_op ('MemUpdate') 202 and expr.vals[2].kind == 'FieldUpd'): 203 [m, p, f_upd] = expr.vals 204 assert f_upd.typ.kind == 'Struct' 205 (typ, offs, _) = structs[f_upd.typ.name].fields[f_upd.field[0]] 206 assert f_upd.val.typ == typ 207 return compile_accs (replaces, 208 mk_memupd (mk_memupd (m, p, f_upd.struct), 209 mk_plus (p, mk_word32 (offs)), f_upd.val)) 210 elif (expr.is_op ('MemUpdate') 211 and expr.vals[2].typ.kind == 'Struct'): 212 [m, p, s_val] = expr.vals 213 struct = structs[s_val.typ.name] 214 for (nm, (typ, offs, _)) in struct.fields.iteritems (): 215 f = compile_field_acc (nm, s_val, replaces) 216 assert f.typ == typ 217 m = mk_memupd (m, mk_plus (p, mk_word32 (offs)), f) 218 return compile_accs (replaces, m) 219 elif (expr.is_op ('MemUpdate') 220 and expr.vals[2].is_op ('ArrayUpdate')): 221 [m, p, arr_upd] = expr.vals 222 [arr, i, v] = arr_upd.vals 223 return compile_accs (replaces, 224 mk_memupd (mk_memupd (m, p, arr), 225 mk_arroffs (p, arr.typ, i), v)) 226 elif (expr.is_op ('MemUpdate') 227 and expr.vals[2].typ.kind == 'Array'): 228 [m, p, arr] = expr.vals 229 n = arr.typ.num 230 typ = arr.typ.el_typ 231 for i in range (n): 232 offs = i * typ.size () 233 assert offs == i or offs % 4 == 0 234 e = compile_array_acc (i, arr, replaces) 235 m = mk_memupd (m, mk_plus (p, mk_word32 (offs)), e) 236 return compile_accs (replaces, m) 237 elif expr.is_op ('Equals') \ 238 and expr.vals[0].typ.kind in ['Struct', 'Array']: 239 [x, y] = expr.vals 240 assert x.typ == y.typ 241 xs = compile_val_fields (x, replaces) 242 ys = compile_val_fields (y, replaces) 243 eq = foldr1 (mk_and, map (mk_eq, xs, ys)) 244 return compile_accs (replaces, eq) 245 elif expr.is_op ('PAlignValid'): 246 [typ, p] = expr.vals 247 p = compile_accs (replaces, p) 248 assert typ.kind == 'Type' 249 return logic.mk_align_valid_ineq (('Type', typ.val), p) 250 elif expr.kind == 'Op': 251 vals = [compile_accs (replaces, v) for v in expr.vals] 252 return syntax.adjust_op_vals (expr, vals) 253 elif expr.kind == 'Symbol': 254 return mk_word32 (symbols[expr.name][0]) 255 else: 256 if expr.kind not in {'Var':True, 'ConstGlobal':True, 257 'Num':True, 'Invent':True, 'Type':True}: 258 print expr 259 assert not 'field acc compiled' 260 return expr 261 262def expand_arg_fields (replaces, args): 263 xs = [] 264 for arg in args: 265 if arg.typ.kind == 'Struct': 266 ys = [compile_field_acc (nm, arg, replaces) 267 for (nm, _) in structs[arg.typ.name].field_list] 268 xs.extend (expand_arg_fields (replaces, ys)) 269 elif arg.typ.kind == 'Array': 270 ys = [compile_array_acc (i, arg, replaces) 271 for i in range (arg.typ.num)] 272 xs.extend (expand_arg_fields (replaces, ys)) 273 else: 274 xs.append (compile_accs (replaces, arg)) 275 return xs 276 277def expand_lval_list (replaces, lvals): 278 xs = [] 279 for (nm, typ) in lvals: 280 if nm in replaces: 281 xs.extend (expand_lval_list (replaces, [(v_nm, typ) 282 for (f_nm, v_nm, typ) in replaces[nm]])) 283 else: 284 assert typ.kind not in ['Struct', 'Array'] 285 xs.append ((nm, typ)) 286 return xs 287 288def mk_acc (idx, expr, replaces): 289 if logic.is_int (idx): 290 assert expr.typ.kind == 'Array' 291 return compile_array_acc (idx, expr, replaces) 292 else: 293 assert expr.typ.kind == 'Struct' 294 return compile_field_acc (idx, expr, replaces) 295 296def compile_upds (replaces, upds): 297 lvs = expand_lval_list (replaces, [lv for (lv, v) in upds]) 298 vs = expand_arg_fields (replaces, [v for (lv, v) in upds]) 299 300 assert [typ for (nm, typ) in lvs] == map (get_expr_typ, vs), (lvs, vs) 301 302 return [(lv, v) for (lv, v) in zip (lvs, vs) 303 if not v.is_var (lv)] 304 305def compile_struct_use (function): 306 trace ('Compiling in %s.' % function.name) 307 vs = get_vars (function) 308 max_node = max (function.nodes.keys () + [2]) 309 310 visit_vs = vs.keys () 311 replaces = {} 312 while visit_vs: 313 v = visit_vs.pop () 314 typ = vs[v] 315 if typ.kind == 'Struct': 316 fields = structs[typ.name].field_list 317 elif typ.kind == 'Array': 318 fields = [(i, typ.el_typ) for i in range (typ.num)] 319 else: 320 continue 321 new_vs = [(nm, fresh_name ('%s.%s' % (v, nm), vs, f_typ), f_typ) 322 for (nm, f_typ) in fields] 323 replaces[v] = new_vs 324 visit_vs.extend ([v_nm for (_, v_nm, _) in new_vs]) 325 326 for n in function.nodes: 327 node = function.nodes[n] 328 if node.kind == 'Basic': 329 node.upds = compile_upds (replaces, node.upds) 330 elif node.kind == 'Basic': 331 assert not node.lval[1].kind in ['Struct', 'Array'] 332 node.val = compile_accs (replaces, node.val) 333 elif node.kind == 'Call': 334 node.args = expand_arg_fields (replaces, node.args) 335 node.rets = expand_lval_list (replaces, node.rets) 336 elif node.kind == 'Cond': 337 node.cond = compile_accs (replaces, node.cond) 338 else: 339 assert not 'node kind understood' 340 341 function.inputs = expand_lval_list (replaces, function.inputs) 342 function.outputs = expand_lval_list (replaces, function.outputs) 343 return len (replaces) == 0 344 345def check_compile (func): 346 for node in func.nodes.itervalues (): 347 vs = {} 348 get_node_vars (node, vs) 349 for (v_nm, typ) in vs.iteritems (): 350 if typ.kind == 'Struct': 351 print 'Failed to compile struct %s in %s' % (v_nm, func) 352 print node 353 assert not 'compiled' 354 if typ.kind == 'Array': 355 print 'Failed to compile array %s in %s' % (v_nm, func) 356 print node 357 assert not 'compiled' 358 359def subst_expr (expr): 360 if expr.kind == 'Symbol': 361 if expr.name in symbols: 362 return mk_word32 (symbols[expr.name][0]) 363 else: 364 return None 365 elif expr.is_op ('PAlignValid'): 366 [typ, p] = expr.vals 367 assert typ.kind == 'Type' 368 return logic.mk_align_valid_ineq (('Type', typ.val), p) 369 elif expr.kind in ['Op', 'Var', 'Num', 'Type']: 370 return None 371 else: 372 assert not 'expression simple-substitutable', expr 373 374def substitute_simple (func): 375 from syntax import Node 376 for (n, node) in func.nodes.items (): 377 func.nodes[n] = node.subst_exprs (subst_expr, 378 ss = set (['Symbol', 'PAlignValid'])) 379 380def nodes_symbols (nodes): 381 symbols_needed = set() 382 def visitor (expr): 383 if expr.kind == 'Symbol': 384 symbols_needed.add(expr.name) 385 for node in nodes: 386 node.visit (lambda l: (), visitor) 387 return symbols_needed 388 389def missing_symbols (functions): 390 symbols_needed = nodes_symbols ([node 391 for func in functions.itervalues () 392 for node in func.nodes.itervalues ()]) 393 trouble = symbols_needed - set (symbols) 394 if trouble: 395 print ('Symbols missing for substitution: %r' % trouble) 396 return trouble 397 398def compile_funcs (functions): 399 missing_symbols (functions) 400 for (f, func) in functions.iteritems (): 401 substitute_simple (func) 402 check_compile (func) 403 404def combine_duplicate_nodes (nodes): 405 orig_size = len (nodes) 406 node_renames = {} 407 progress = True 408 while progress: 409 progress = False 410 node_names = {} 411 for (n, node) in nodes.items (): 412 if node in node_names: 413 node_renames[n] = node_names[node] 414 del nodes[n] 415 progress = True 416 else: 417 node_names[node] = n 418 419 if not progress: 420 break 421 422 for (n, node) in nodes.items (): 423 nodes[n] = rename_node_conts (node, node_renames) 424 425 if len (nodes) < orig_size: 426 print 'Trimmed duplicates %d -> %d' % (orig_size, len (nodes)) 427 return node_renames 428 429def rename_node_conts (node, renames): 430 new_conts = [renames.get (c, c) for c in node.get_conts ()] 431 return Node (node.kind, new_conts, node.get_args ()) 432 433def recommended_rename (s): 434 bits = s.split ('.') 435 if len (bits) != 2: 436 return s 437 if all ([x in '0123456789' for x in bits[1]]): 438 return bits[0] 439 else: 440 return s 441 442def rename_vars (function): 443 preds = logic.compute_preds (function.nodes) 444 var_deps = logic.compute_var_deps (function.nodes, 445 lambda x: function.outputs, preds) 446 447 vs = set () 448 dont_rename_vs = set () 449 for n in var_deps: 450 rev_renames = {} 451 for (v, t) in var_deps[n]: 452 v2 = recommended_rename (v) 453 rev_renames.setdefault (v2, []) 454 rev_renames[v2].append ((v, t)) 455 vs.add ((v, t)) 456 for (v2, vlist) in rev_renames.iteritems (): 457 if len (vlist) > 1: 458 dont_rename_vs.update (vlist) 459 460 renames = dict ([(v, recommended_rename (v)) for (v, t) in vs 461 if (v, t) not in dont_rename_vs]) 462 463 f = function 464 f.inputs = [(renames.get (v, v), t) for (v, t) in f.inputs] 465 f.outputs = [(renames.get (v, v), t) for (v, t) in f.outputs] 466 for n in f.nodes: 467 f.nodes[n] = syntax.copy_rename (f.nodes[n], (renames, {})) 468 469def rename_and_combine_function_duplicates (functions): 470 for (f, fun) in functions.iteritems (): 471 rename_vars (fun) 472 renames = combine_duplicate_nodes (fun.nodes) 473 fun.entry = renames.get (fun.entry, fun.entry) 474 475 476