1# 2# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230) 3# 4# SPDX-License-Identifier: BSD-2-Clause 5# 6 7# pseudo-compiler for use of aggregate types in C-derived function code 8 9import syntax 10from syntax import structs, get_vars, get_expr_typ, get_node_vars, Expr, Node 11import logic 12 13 14(mk_var, mk_plus, mk_uminus, mk_minus, mk_times, mk_modulus, mk_bwand, mk_eq, 15mk_less_eq, mk_less, mk_implies, mk_and, mk_or, mk_not, mk_word32, mk_word8, 16mk_word32_maybe, mk_cast, mk_memacc, mk_memupd, mk_arr_index, mk_arroffs, 17mk_if, mk_meta_typ, mk_pvalid) = syntax.mks 18 19from syntax import word32T, word8T 20 21from syntax import fresh_name, foldr1 22 23from target_objects import symbols, trace 24 25def compile_field_acc (name, expr, replaces): 26 '''pseudo-compile access to field (named name) of expr''' 27 if expr.kind == 'StructCons': 28 return expr.vals[name] 29 elif expr.kind == 'FieldUpd': 30 if expr.field[0] == name: 31 return expr.val 32 else: 33 return compile_field_acc (name, expr.struct, replaces) 34 elif expr.kind == 'Var': 35 assert expr.name in replaces 36 [(v_nm, typ)] = [(v_nm, typ) for (f_nm, v_nm, typ) 37 in replaces[expr.name] if f_nm == name] 38 return mk_var (v_nm, typ) 39 elif expr.is_op ('MemAcc'): 40 assert expr.typ.kind == 'Struct' 41 (typ, offs, _) = structs[expr.typ.name].fields[name] 42 [m, p] = expr.vals 43 return mk_memacc (m, mk_plus (p, mk_word32 (offs)), typ) 44 elif expr.kind == 'Field': 45 expr2 = compile_field_acc (expr.field[0], expr.struct, replaces) 46 return compile_field_acc (name, expr2, replaces) 47 elif expr.is_op ('ArrayIndex'): 48 [arr, i] = expr.vals 49 expr2 = compile_array_acc (i, arr, replaces, False) 50 assert expr2, (arr, i) 51 return compile_field_acc (name, expr2, replaces) 52 else: 53 print expr 54 assert not 'field acc compilable' 55 56def compile_array_acc (i, expr, replaces, must = True): 57 '''pseudo-compile access to array element i of expr''' 58 if not logic.is_int (i) and i.kind == 'Num': 59 assert i.typ == word32T 60 i = i.val 61 if expr.kind == 'Array': 62 if logic.is_int (i): 63 return expr.vals[i] 64 else: 65 expr2 = expr.vals[-1] 66 for (j, v) in enumerate (expr.vals[:-1]): 67 expr2 = mk_if (mk_eq (i, mk_word32 (j)), v, expr2) 68 return expr2 69 elif expr.is_op ('ArrayUpdate'): 70 [arr, j, v] = expr.vals 71 if j.kind == 'Num' and logic.is_int (i): 72 if i == j.val: 73 return v 74 else: 75 return compile_array_acc (i, arr, replaces) 76 else: 77 return mk_if (mk_eq (j, mk_word32_maybe (i)), v, 78 compile_array_acc (i, arr, replaces)) 79 elif expr.is_op ('MemAcc'): 80 [m, p] = expr.vals 81 return mk_memacc (m, mk_arroffs (p, expr.typ, i), expr.typ.el_typ) 82 elif expr.is_op ('IfThenElse'): 83 [cond, left, right] = expr.vals 84 return mk_if (cond, compile_array_acc (i, left, replaces), 85 compile_array_acc (i, right, replaces)) 86 elif expr.kind == 'Var': 87 assert expr.name in replaces 88 if logic.is_int (i): 89 (_, v_nm, typ) = replaces[expr.name][i] 90 return mk_var (v_nm, typ) 91 else: 92 vs = [(mk_word32 (j), mk_var (v_nm, typ)) 93 for (j, v_nm, typ) 94 in replaces[expr.name]] 95 expr2 = vs[0][1] 96 for (j, v) in vs[1:]: 97 expr2 = mk_if (mk_eq (i, j), v, expr2) 98 return expr2 99 else: 100 if not must: 101 return None 102 return mk_arr_index (expr, mk_word32_maybe (i)) 103 104def num_fields (container, typ): 105 if container == typ: 106 return 1 107 elif container.kind == 'Array': 108 return container.num * num_fields (container.el_typ, typ) 109 elif container.kind == 'Struct': 110 struct = structs[container.name] 111 return sum ([num_fields (typ2, typ) 112 for (nm, typ2) in struct.field_list]) 113 else: 114 return 0 115 116def get_const_global_acc_offset (expr, offs, typ): 117 if expr.kind == 'ConstGlobal': 118 return (expr, offs) 119 elif expr.is_op ('ArrayIndex'): 120 [expr2, offs2] = expr.vals 121 offs = mk_plus (offs, mk_times (offs2, 122 mk_word32 (num_fields (expr.typ, typ)))) 123 return get_const_global_acc_offset (expr2, offs, typ) 124 elif expr.kind == 'Field': 125 struct = structs[expr.struct.typ.name] 126 offs2 = 0 127 for (nm, typ2) in struct.field_list: 128 if (nm, typ2) == expr.field: 129 offs = mk_plus (offs, mk_word32 (offs2)) 130 return get_const_global_acc_offset ( 131 expr.struct, offs, typ) 132 else: 133 offs2 += num_fields (typ2, typ) 134 else: 135 return None 136 137def compile_const_global_acc (expr): 138 if expr.kind == 'ConstGlobal' or (expr.is_op ('ArrayIndex') 139 and expr.vals[0].kind == 'ConstGlobal'): 140 return None 141 if expr.typ.kind != 'Word': 142 return None 143 r = get_const_global_acc_offset (expr, mk_word32 (0), expr.typ) 144 if r == None: 145 return None 146 (cg, offs) = r 147 return mk_arr_index (cg, offs) 148 149def compile_val_fields (expr, replaces): 150 if expr.typ.kind == 'Array': 151 res = [] 152 for i in range (expr.typ.num): 153 acc = compile_array_acc (i, expr, replaces) 154 res.extend (compile_val_fields (acc, replaces)) 155 return res 156 elif expr.typ.kind == 'Struct': 157 res = [] 158 for (nm, typ2) in structs[expr.typ.name].field_list: 159 acc = compile_field_acc (nm, expr, replaces) 160 res.extend (compile_val_fields (acc, replaces)) 161 return res 162 else: 163 return [compile_accs (replaces, expr)] 164 165def compile_val_fields_of_typ (expr, replaces, typ): 166 return [e for e in compile_val_fields (expr, replaces) 167 if e.typ == typ] 168 169# it helps in this compilation to know that the outermost expression we are 170# trying to fetch is always of basic type, never struct or array. 171 172# sort of fudged in the array indexing case here 173def compile_accs (replaces, expr): 174 r = compile_const_global_acc (expr) 175 if r: 176 return compile_accs (replaces, r) 177 178 if expr.kind == 'Field': 179 expr = compile_field_acc (expr.field[0], expr.struct, replaces) 180 return compile_accs (replaces, expr) 181 elif expr.is_op ('ArrayIndex'): 182 [arr, n] = expr.vals 183 expr2 = compile_array_acc (n, arr, replaces, False) 184 if expr2: 185 return compile_accs (replaces, expr2) 186 arr = compile_accs (replaces, arr) 187 n = compile_accs (replaces, n) 188 expr2 = compile_array_acc (n, arr, replaces, False) 189 if expr2: 190 return compile_accs (replaces, expr2) 191 else: 192 return mk_arr_index (arr, n) 193 elif (expr.is_op ('MemUpdate') 194 and expr.vals[2].is_op ('MemAcc') 195 and expr.vals[2].vals[0] == expr.vals[0] 196 and expr.vals[2].vals[1] == expr.vals[1]): 197 # null memory copy. probably created by ops below 198 return compile_accs (replaces, expr.vals[0]) 199 elif (expr.is_op ('MemUpdate') 200 and expr.vals[2].kind == 'FieldUpd'): 201 [m, p, f_upd] = expr.vals 202 assert f_upd.typ.kind == 'Struct' 203 (typ, offs, _) = structs[f_upd.typ.name].fields[f_upd.field[0]] 204 assert f_upd.val.typ == typ 205 return compile_accs (replaces, 206 mk_memupd (mk_memupd (m, p, f_upd.struct), 207 mk_plus (p, mk_word32 (offs)), f_upd.val)) 208 elif (expr.is_op ('MemUpdate') 209 and expr.vals[2].typ.kind == 'Struct'): 210 [m, p, s_val] = expr.vals 211 struct = structs[s_val.typ.name] 212 for (nm, (typ, offs, _)) in struct.fields.iteritems (): 213 f = compile_field_acc (nm, s_val, replaces) 214 assert f.typ == typ 215 m = mk_memupd (m, mk_plus (p, mk_word32 (offs)), f) 216 return compile_accs (replaces, m) 217 elif (expr.is_op ('MemUpdate') 218 and expr.vals[2].is_op ('ArrayUpdate')): 219 [m, p, arr_upd] = expr.vals 220 [arr, i, v] = arr_upd.vals 221 return compile_accs (replaces, 222 mk_memupd (mk_memupd (m, p, arr), 223 mk_arroffs (p, arr.typ, i), v)) 224 elif (expr.is_op ('MemUpdate') 225 and expr.vals[2].typ.kind == 'Array'): 226 [m, p, arr] = expr.vals 227 n = arr.typ.num 228 typ = arr.typ.el_typ 229 for i in range (n): 230 offs = i * typ.size () 231 assert offs == i or offs % 4 == 0 232 e = compile_array_acc (i, arr, replaces) 233 m = mk_memupd (m, mk_plus (p, mk_word32 (offs)), e) 234 return compile_accs (replaces, m) 235 elif expr.is_op ('Equals') \ 236 and expr.vals[0].typ.kind in ['Struct', 'Array']: 237 [x, y] = expr.vals 238 assert x.typ == y.typ 239 xs = compile_val_fields (x, replaces) 240 ys = compile_val_fields (y, replaces) 241 eq = foldr1 (mk_and, map (mk_eq, xs, ys)) 242 return compile_accs (replaces, eq) 243 elif expr.is_op ('PAlignValid'): 244 [typ, p] = expr.vals 245 p = compile_accs (replaces, p) 246 assert typ.kind == 'Type' 247 return logic.mk_align_valid_ineq (('Type', typ.val), p) 248 elif expr.kind == 'Op': 249 vals = [compile_accs (replaces, v) for v in expr.vals] 250 return syntax.adjust_op_vals (expr, vals) 251 elif expr.kind == 'Symbol': 252 return mk_word32 (symbols[expr.name][0]) 253 else: 254 if expr.kind not in {'Var':True, 'ConstGlobal':True, 255 'Num':True, 'Invent':True, 'Type':True}: 256 print expr 257 assert not 'field acc compiled' 258 return expr 259 260def expand_arg_fields (replaces, args): 261 xs = [] 262 for arg in args: 263 if arg.typ.kind == 'Struct': 264 ys = [compile_field_acc (nm, arg, replaces) 265 for (nm, _) in structs[arg.typ.name].field_list] 266 xs.extend (expand_arg_fields (replaces, ys)) 267 elif arg.typ.kind == 'Array': 268 ys = [compile_array_acc (i, arg, replaces) 269 for i in range (arg.typ.num)] 270 xs.extend (expand_arg_fields (replaces, ys)) 271 else: 272 xs.append (compile_accs (replaces, arg)) 273 return xs 274 275def expand_lval_list (replaces, lvals): 276 xs = [] 277 for (nm, typ) in lvals: 278 if nm in replaces: 279 xs.extend (expand_lval_list (replaces, [(v_nm, typ) 280 for (f_nm, v_nm, typ) in replaces[nm]])) 281 else: 282 assert typ.kind not in ['Struct', 'Array'] 283 xs.append ((nm, typ)) 284 return xs 285 286def mk_acc (idx, expr, replaces): 287 if logic.is_int (idx): 288 assert expr.typ.kind == 'Array' 289 return compile_array_acc (idx, expr, replaces) 290 else: 291 assert expr.typ.kind == 'Struct' 292 return compile_field_acc (idx, expr, replaces) 293 294def compile_upds (replaces, upds): 295 lvs = expand_lval_list (replaces, [lv for (lv, v) in upds]) 296 vs = expand_arg_fields (replaces, [v for (lv, v) in upds]) 297 298 assert [typ for (nm, typ) in lvs] == map (get_expr_typ, vs), (lvs, vs) 299 300 return [(lv, v) for (lv, v) in zip (lvs, vs) 301 if not v.is_var (lv)] 302 303def compile_struct_use (function): 304 trace ('Compiling in %s.' % function.name) 305 vs = get_vars (function) 306 max_node = max (function.nodes.keys () + [2]) 307 308 visit_vs = vs.keys () 309 replaces = {} 310 while visit_vs: 311 v = visit_vs.pop () 312 typ = vs[v] 313 if typ.kind == 'Struct': 314 fields = structs[typ.name].field_list 315 elif typ.kind == 'Array': 316 fields = [(i, typ.el_typ) for i in range (typ.num)] 317 else: 318 continue 319 new_vs = [(nm, fresh_name ('%s.%s' % (v, nm), vs, f_typ), f_typ) 320 for (nm, f_typ) in fields] 321 replaces[v] = new_vs 322 visit_vs.extend ([v_nm for (_, v_nm, _) in new_vs]) 323 324 for n in function.nodes: 325 node = function.nodes[n] 326 if node.kind == 'Basic': 327 node.upds = compile_upds (replaces, node.upds) 328 elif node.kind == 'Basic': 329 assert not node.lval[1].kind in ['Struct', 'Array'] 330 node.val = compile_accs (replaces, node.val) 331 elif node.kind == 'Call': 332 node.args = expand_arg_fields (replaces, node.args) 333 node.rets = expand_lval_list (replaces, node.rets) 334 elif node.kind == 'Cond': 335 node.cond = compile_accs (replaces, node.cond) 336 else: 337 assert not 'node kind understood' 338 339 function.inputs = expand_lval_list (replaces, function.inputs) 340 function.outputs = expand_lval_list (replaces, function.outputs) 341 return len (replaces) == 0 342 343def check_compile (func): 344 for node in func.nodes.itervalues (): 345 vs = {} 346 get_node_vars (node, vs) 347 for (v_nm, typ) in vs.iteritems (): 348 if typ.kind == 'Struct': 349 print 'Failed to compile struct %s in %s' % (v_nm, func) 350 print node 351 assert not 'compiled' 352 if typ.kind == 'Array': 353 print 'Failed to compile array %s in %s' % (v_nm, func) 354 print node 355 assert not 'compiled' 356 357def subst_expr (expr): 358 if expr.kind == 'Symbol': 359 if expr.name in symbols: 360 return mk_word32 (symbols[expr.name][0]) 361 else: 362 return None 363 elif expr.is_op ('PAlignValid'): 364 [typ, p] = expr.vals 365 assert typ.kind == 'Type' 366 return logic.mk_align_valid_ineq (('Type', typ.val), p) 367 elif expr.kind in ['Op', 'Var', 'Num', 'Type']: 368 return None 369 else: 370 assert not 'expression simple-substitutable', expr 371 372def substitute_simple (func): 373 from syntax import Node 374 for (n, node) in func.nodes.items (): 375 func.nodes[n] = node.subst_exprs (subst_expr, 376 ss = set (['Symbol', 'PAlignValid'])) 377 378def nodes_symbols (nodes): 379 symbols_needed = set() 380 def visitor (expr): 381 if expr.kind == 'Symbol': 382 symbols_needed.add(expr.name) 383 for node in nodes: 384 node.visit (lambda l: (), visitor) 385 return symbols_needed 386 387def missing_symbols (functions): 388 symbols_needed = nodes_symbols ([node 389 for func in functions.itervalues () 390 for node in func.nodes.itervalues ()]) 391 trouble = symbols_needed - set (symbols) 392 if trouble: 393 print ('Symbols missing for substitution: %r' % trouble) 394 return trouble 395 396def compile_funcs (functions): 397 missing_symbols (functions) 398 for (f, func) in functions.iteritems (): 399 substitute_simple (func) 400 check_compile (func) 401 402def combine_duplicate_nodes (nodes): 403 orig_size = len (nodes) 404 node_renames = {} 405 progress = True 406 while progress: 407 progress = False 408 node_names = {} 409 for (n, node) in nodes.items (): 410 if node in node_names: 411 node_renames[n] = node_names[node] 412 del nodes[n] 413 progress = True 414 else: 415 node_names[node] = n 416 417 if not progress: 418 break 419 420 for (n, node) in nodes.items (): 421 nodes[n] = rename_node_conts (node, node_renames) 422 423 if len (nodes) < orig_size: 424 print 'Trimmed duplicates %d -> %d' % (orig_size, len (nodes)) 425 return node_renames 426 427def rename_node_conts (node, renames): 428 new_conts = [renames.get (c, c) for c in node.get_conts ()] 429 return Node (node.kind, new_conts, node.get_args ()) 430 431def recommended_rename (s): 432 bits = s.split ('.') 433 if len (bits) != 2: 434 return s 435 if all ([x in '0123456789' for x in bits[1]]): 436 return bits[0] 437 else: 438 return s 439 440def rename_vars (function): 441 preds = logic.compute_preds (function.nodes) 442 var_deps = logic.compute_var_deps (function.nodes, 443 lambda x: function.outputs, preds) 444 445 vs = set () 446 dont_rename_vs = set () 447 for n in var_deps: 448 rev_renames = {} 449 for (v, t) in var_deps[n]: 450 v2 = recommended_rename (v) 451 rev_renames.setdefault (v2, []) 452 rev_renames[v2].append ((v, t)) 453 vs.add ((v, t)) 454 for (v2, vlist) in rev_renames.iteritems (): 455 if len (vlist) > 1: 456 dont_rename_vs.update (vlist) 457 458 renames = dict ([(v, recommended_rename (v)) for (v, t) in vs 459 if (v, t) not in dont_rename_vs]) 460 461 f = function 462 f.inputs = [(renames.get (v, v), t) for (v, t) in f.inputs] 463 f.outputs = [(renames.get (v, v), t) for (v, t) in f.outputs] 464 for n in f.nodes: 465 f.nodes[n] = syntax.copy_rename (f.nodes[n], (renames, {})) 466 467def rename_and_combine_function_duplicates (functions): 468 for (f, fun) in functions.iteritems (): 469 rename_vars (fun) 470 renames = combine_duplicate_nodes (fun.nodes) 471 fun.entry = renames.get (fun.entry, fun.entry) 472 473 474