1# 2# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230) 3# 4# SPDX-License-Identifier: BSD-2-Clause 5# 6 7import syntax 8from syntax import word32T, word8T, boolT, builtinTs, Expr, Node 9from syntax import true_term, false_term, mk_num 10from syntax import foldr1 11 12(mk_var, mk_plus, mk_uminus, mk_minus, mk_times, mk_modulus, mk_bwand, mk_eq, 13mk_less_eq, mk_less, mk_implies, mk_and, mk_or, mk_not, mk_word32, mk_word8, 14mk_word32_maybe, mk_cast, mk_memacc, mk_memupd, mk_arr_index, mk_arroffs, 15mk_if, mk_meta_typ, mk_pvalid) = syntax.mks 16 17from syntax import structs 18from target_objects import trace, printout 19 20def is_int (n): 21 return hasattr (n, '__int__') 22 23def mk_eq_with_cast (a, c): 24 return mk_eq (a, mk_cast (c, a.typ)) 25 26def mk_rodata (m): 27 assert m.typ == builtinTs['Mem'] 28 return Expr ('Op', boolT, name = 'ROData', vals = [m]) 29 30def cast_pair (((a, a_addr), (c, c_addr))): 31 if a.typ != c.typ and c.typ == boolT: 32 c = mk_if (c, mk_word32 (1), mk_word32 (0)) 33 return ((a, a_addr), (mk_cast (c, a.typ), c_addr)) 34 35ghost_assertion_type = syntax.Type ('WordArray', 50, 32) 36 37def split_scalar_globals (vs): 38 for i in range (len (vs)): 39 if vs[i].typ.kind != 'Word' and vs[i].typ != boolT: 40 break 41 else: 42 i = len (vs) 43 scalars = vs[:i] 44 global_vars = vs[i:] 45 for v in global_vars: 46 if v.typ not in [builtinTs['Mem'], builtinTs['Dom'], 47 builtinTs['HTD'], builtinTs['PMS'], 48 ghost_assertion_type]: 49 assert not "scalar_global split expected", vs 50 memT = builtinTs['Mem'] 51 mems = [v for v in global_vars if v.typ == memT] 52 others = [v for v in global_vars if v.typ != memT] 53 return (scalars, mems, others) 54 55def mk_vars (tups): 56 return [mk_var (nm, typ) for (nm, typ) in tups] 57 58def split_scalar_pairs (var_pairs): 59 return split_scalar_globals (mk_vars (var_pairs)) 60 61def azip (xs, ys): 62 assert len (xs) == len (ys) 63 return zip (xs, ys) 64 65def mk_mem_eqs (a_imem, c_imem, a_omem, c_omem, tags): 66 [a_imem] = a_imem 67 a_tag, c_tag = tags 68 (c_in, c_out) = (c_tag + '_IN', c_tag + '_OUT') 69 (a_in, a_out) = (a_tag + '_IN', a_tag + '_OUT') 70 if c_imem: 71 [c_imem] = c_imem 72 ieqs = [((a_imem, a_in), (c_imem, c_in)), 73 ((mk_rodata (c_imem), c_in), (true_term, c_in))] 74 else: 75 ieqs = [((mk_rodata (a_imem), a_in), (true_term, c_in))] 76 if c_omem: 77 [a_m] = a_omem 78 [c_omem] = c_omem 79 oeqs = [((a_m, a_out), (c_omem, c_out)), 80 ((mk_rodata (c_omem), c_out), (true_term, c_out))] 81 else: 82 oeqs = [((a_m, a_out), (a_imem, a_in)) for a_m in a_omem] 83 84 return (ieqs, oeqs) 85 86def mk_fun_eqs (as_f, c_f, prunes = None): 87 (var_a_args, a_imem, glob_a_args) = split_scalar_pairs (as_f.inputs) 88 (var_c_args, c_imem, glob_c_args) = split_scalar_pairs (c_f.inputs) 89 (var_a_rets, a_omem, glob_a_rets) = split_scalar_pairs (as_f.outputs) 90 (var_c_rets, c_omem, glob_c_rets) = split_scalar_pairs (c_f.outputs) 91 92 (mem_ieqs, mem_oeqs) = mk_mem_eqs (a_imem, c_imem, a_omem, c_omem, 93 ['ASM', 'C']) 94 95 if not prunes: 96 prunes = (var_a_args, var_a_args) 97 assert len (prunes[0]) == len (var_c_args), (params, var_a_args, 98 var_c_args, prunes) 99 a_map = dict (azip (prunes[1], var_a_args)) 100 ivar_pairs = [((a_map[p], 'ASM_IN'), (c, 'C_IN')) for (p, c) 101 in azip (prunes[0], var_c_args) if p in a_map] 102 103 ovar_pairs = [((a_ret, 'ASM_OUT'), (c_ret, 'C_OUT')) for (a_ret, c_ret) 104 in azip (var_a_rets, var_c_rets)] 105 return (map (cast_pair, mem_ieqs + ivar_pairs), 106 map (cast_pair, mem_oeqs + ovar_pairs)) 107 108def mk_var_list (vs, typ): 109 return [syntax.mk_var (v, typ) for v in vs] 110 111def mk_offs_sequence (init, offs, n, do_reverse = False): 112 r = range (n) 113 if do_reverse: 114 r.reverse () 115 def mk_offs (n): 116 return Expr ('Num', init.typ, val = offs * n) 117 return [mk_plus (init, mk_offs (m)) for m in r] 118 119def mk_stack_sequence (sp, offs, stack, typ, n, do_reverse = False): 120 return [(mk_memacc (stack, addr, typ), addr) 121 for addr in mk_offs_sequence (sp, offs, n, do_reverse)] 122 123def mk_aligned (w, n): 124 assert w.typ.kind == 'Word' 125 mask = Expr ('Num', w.typ, val = ((1 << n) - 1)) 126 return mk_eq (mk_bwand (w, mask), mk_num (0, w.typ)) 127 128def mk_eqs_arm_none_eabi_gnu (var_c_args, var_c_rets, c_imem, c_omem, 129 min_stack_size): 130 arg_regs = mk_var_list (['r0', 'r1', 'r2', 'r3'], word32T) 131 r0 = arg_regs[0] 132 sp = mk_var ('r13', word32T) 133 st = mk_var ('stack', builtinTs['Mem']) 134 r0_input = mk_var ('ret_addr_input', word32T) 135 sregs = mk_stack_sequence (sp, 4, st, word32T, len (var_c_args) + 1) 136 137 ret = mk_var ('ret', word32T) 138 preconds = [mk_aligned (sp, 2), mk_eq (ret, mk_var ('r14', word32T)), 139 mk_aligned (ret, 2), mk_eq (r0_input, r0), 140 mk_less_eq (min_stack_size, sp)] 141 post_eqs = [(x, x) for x in mk_var_list (['r4', 'r5', 'r6', 'r7', 'r8', 142 'r9', 'r10', 'r11', 'r13'], word32T)] 143 144 arg_seq = [(r, None) for r in arg_regs] + sregs 145 if len (var_c_rets) > 1: 146 # the 'return-too-much' issue. 147 # instead r0 is a save-returns-here pointer 148 arg_seq.pop (0) 149 preconds += [mk_aligned (r0, 2), mk_less_eq (sp, r0)] 150 save_seq = mk_stack_sequence (r0_input, 4, st, word32T, 151 len (var_c_rets)) 152 save_addrs = [addr for (_, addr) in save_seq] 153 post_eqs += [(r0_input, r0_input)] 154 out_eqs = zip (var_c_rets, [x for (x, _) in save_seq]) 155 out_eqs = [(c, mk_cast (a, c.typ)) for (c, a) in out_eqs] 156 init_save_seq = mk_stack_sequence (r0, 4, st, word32T, 157 len (var_c_rets)) 158 (_, last_arg_addr) = arg_seq[len (var_c_args) - 1] 159 preconds += [mk_less_eq (sp, addr) 160 for (_, addr) in init_save_seq[-1:]] 161 if last_arg_addr: 162 preconds += [mk_less (last_arg_addr, addr) 163 for (_, addr) in init_save_seq[:1]] 164 else: 165 out_eqs = zip (var_c_rets, [r0]) 166 save_addrs = [] 167 arg_seq_addrs = [addr for ((_, addr), _) in zip (arg_seq, var_c_args) 168 if addr != None] 169 swrap = mk_stack_wrapper (sp, st, arg_seq_addrs) 170 swrap2 = mk_stack_wrapper (sp, st, save_addrs) 171 post_eqs += [(swrap, swrap2)] 172 173 mem = mk_var ('mem', builtinTs['Mem']) 174 (mem_ieqs, mem_oeqs) = mk_mem_eqs ([mem], c_imem, [mem], c_omem, 175 ['ASM', 'C']) 176 177 addr = None 178 arg_eqs = [cast_pair (((a_x, 'ASM_IN'), (c_x, 'C_IN'))) 179 for (c_x, (a_x, addr)) in zip (var_c_args, arg_seq)] 180 if addr: 181 preconds += [mk_less_eq (sp, addr)] 182 ret_eqs = [cast_pair (((a_x, 'ASM_OUT'), (c_x, 'C_OUT'))) 183 for (c_x, a_x) in out_eqs] 184 preconds = [((a_x, 'ASM_IN'), (true_term, 'ASM_IN')) for a_x in preconds] 185 asm_invs = [((vin, 'ASM_IN'), (vout, 'ASM_OUT')) for (vin, vout) in post_eqs] 186 187 return (arg_eqs + mem_ieqs + preconds, 188 ret_eqs + mem_oeqs + asm_invs) 189 190known_CPUs = { 191 'arm-none-eabi-gnu': mk_eqs_arm_none_eabi_gnu 192} 193 194def mk_fun_eqs_CPU (cpu_f, c_f, cpu_name, funcall_depth = 1): 195 cpu = known_CPUs[cpu_name] 196 (var_c_args, c_imem, glob_c_args) = split_scalar_pairs (c_f.inputs) 197 (var_c_rets, c_omem, glob_c_rets) = split_scalar_pairs (c_f.outputs) 198 199 return cpu (var_c_args, var_c_rets, c_imem, c_omem, 200 (funcall_depth * 256) + 256) 201 202class Pairing: 203 def __init__ (self, tags, funs, eqs, notes = None): 204 [l_tag, r_tag] = tags 205 self.tags = tags 206 assert set (funs) == set (tags) 207 self.funs = funs 208 self.eqs = eqs 209 210 self.l_f = funs[l_tag] 211 self.r_f = funs[r_tag] 212 self.name = 'Pairing (%s (%s) <= %s (%s))' % (self.l_f, 213 l_tag, self.r_f, r_tag) 214 215 self.notes = {} 216 if notes != None: 217 self.notes.update (notes) 218 219 def __str__ (self): 220 return self.name 221 222 def __hash__ (self): 223 return hash (self.name) 224 225 def __eq__ (self, other): 226 return self.name == other.name and self.eqs == other.eqs 227 228 def __ne__ (self, other): 229 return not other or not self == other 230 231def mk_pairing (functions, c_f, as_f, prunes = None, cpu = None): 232 fs = (functions[as_f], functions[c_f]) 233 if cpu: 234 eqs = mk_fun_eqs_CPU (fs[0], fs[1], cpu, 235 funcall_depth = funcall_depth (functions, c_f)) 236 else: 237 eqs = mk_fun_eqs (fs[0], fs[1], prunes = prunes) 238 return Pairing (['ASM', 'C'], {'C': c_f, 'ASM': as_f}, eqs) 239 240def inst_eqs_pattern (pattern, params): 241 (pat_params, inp_eqs, out_eqs) = pattern 242 substs = [((x.name, x.typ), y) 243 for (pat_vs, vs) in azip (pat_params, params) 244 for (x, y) in azip (pat_vs, vs)] 245 substs = dict (substs) 246 subst = lambda x: var_subst (x, substs) 247 return ([(subst (x), subst (y)) for (x, y) in inp_eqs], 248 [(subst (x), subst (y)) for (x, y) in out_eqs]) 249 250def inst_eqs_pattern_tuples (pattern, params): 251 return inst_eqs_pattern (pattern, map (mk_vars, params)) 252 253def inst_eqs_pattern_exprs (pattern, params): 254 (inp_eqs, out_eqs) = inst_eqs_pattern (pattern, params) 255 return (foldr1 (mk_and, [mk_eq (a, c) for (a, c) in inp_eqs]), 256 foldr1 (mk_and, [mk_eq (a, c) for (a, c) in out_eqs])) 257 258def var_match (var_exp, conc_exp, assigns): 259 if var_exp.typ != conc_exp.typ: 260 return False 261 if var_exp.kind == 'Var': 262 key = (var_exp.name, var_exp.typ) 263 if key in assigns: 264 return conc_exp == assigns[key] 265 else: 266 assigns[key] = conc_exp 267 return True 268 elif var_exp.kind == 'Op': 269 if conc_exp.kind != 'Op' or var_exp.name != conc_exp.name: 270 return False 271 return all ([var_match (a, b, assigns) 272 for (a, b) in azip (var_exp.vals, conc_exp.vals)]) 273 else: 274 return False 275 276def var_subst (var_exp, assigns, must_subst = True): 277 def substor (var_exp): 278 if var_exp.kind == 'Var': 279 k = (var_exp.name, var_exp.typ) 280 if must_subst or k in assigns: 281 return assigns[k] 282 else: 283 return None 284 else: 285 return None 286 return syntax.do_subst (var_exp, substor) 287 288def recursive_term_subst (eqs, expr): 289 if expr in eqs: 290 return eqs[expr] 291 if expr.kind == 'Op': 292 vals = [recursive_term_subst (eqs, x) for x in expr.vals] 293 return syntax.adjust_op_vals (expr, vals) 294 return expr 295 296def mk_accum_rewrites (typ): 297 x = mk_var ('x', typ) 298 y = mk_var ('y', typ) 299 z = mk_var ('z', typ) 300 i = mk_var ('i', typ) 301 return [(x, i, mk_plus (x, y), mk_plus (x, mk_times (i, y)), 302 y), 303 (x, i, mk_plus (y, x), mk_plus (x, mk_times (i, y)), 304 y), 305 (x, i, mk_minus (x, y), mk_minus (x, mk_times (i, y)), 306 mk_uminus (y)), 307 (x, i, mk_plus (mk_plus (x, y), z), 308 mk_plus (x, mk_times (i, mk_plus (y, z))), 309 mk_plus (y, z)), 310 (x, i, mk_plus (mk_plus (y, x), z), 311 mk_plus (x, mk_times (i, mk_plus (y, z))), 312 mk_plus (y, z)), 313 (x, i, mk_plus (y, mk_plus (x, z)), 314 mk_plus (x, mk_times (i, mk_plus (y, z))), 315 mk_plus (y, z)), 316 (x, i, mk_plus (y, mk_plus (z, x)), 317 mk_plus (x, mk_times (i, mk_plus (y, z))), 318 mk_plus (y, z)), 319 (x, i, mk_minus (mk_minus (x, y), z), 320 mk_minus (x, mk_times (i, mk_plus (y, z))), 321 mk_plus (y, z)), 322 ] 323 324def mk_all_accum_rewrites (): 325 return [rew for typ in [word8T, word32T, syntax.word16T, 326 syntax.word64T] 327 for rew in mk_accum_rewrites (typ)] 328 329accum_rewrites = mk_all_accum_rewrites () 330 331def default_val (typ): 332 if typ.kind == 'Word': 333 return Expr ('Num', typ, val = 0) 334 elif typ == boolT: 335 return false_term 336 else: 337 assert not 'default value for type %s created', typ 338 339trace_accumulators = [] 340 341def accumulator_closed_form (expr, (nm, typ), add_mask = None): 342 expr = toplevel_split_out_cast (expr) 343 n = get_bwand_mask (expr) 344 if n and not add_mask: 345 return accumulator_closed_form (expr.vals[0], (nm, typ), 346 add_mask = n) 347 348 for (x, i, pattern, rewrite, offset) in accum_rewrites: 349 var = mk_var (nm, typ) 350 ass = {(x.name, x.typ): var} 351 m = var_match (pattern, expr, ass) 352 if m: 353 x2_def = default_val (typ) 354 i2_def = default_val (word32T) 355 def do_rewrite (x2 = x2_def, i2 = i2_def): 356 ass[(x.name, x.typ)] = x2 357 ass[(i.name, i.typ)] = i2 358 vs = var_subst (rewrite, ass) 359 if add_mask: 360 vs = mk_bwand_mask (vs, add_mask) 361 return vs 362 offs = var_subst (offset, ass) 363 return (do_rewrite, offs) 364 if trace_accumulators: 365 trace ('no accumulator %s' % ((expr, nm, typ), )) 366 return (None, None) 367 368def split_out_cast (expr, target_typ, bits): 369 """given a word-type expression expr (of any word length), 370 compute a simplified expression expr' of the target type, which will 371 have the property that expr' && mask bits = cast expr, 372 where && is bitwise-and (BWAnd), mask n is the bitpattern set at the 373 bottom n bits, e.g. (1 << n) - 1, and cast is WordCast.""" 374 if expr.is_op (['WordCast', 'WordCastSigned']): 375 [x] = expr.vals 376 if x.typ.num >= bits and expr.typ.num >= bits: 377 return split_out_cast (x, target_typ, bits) 378 else: 379 return mk_cast (expr, target_typ) 380 elif expr.is_op ('BWAnd'): 381 [x, y] = expr.vals 382 if y.kind == 'Num': 383 val = y.val 384 else: 385 val = 0 386 full_mask = (1 << bits) - 1 387 if val & full_mask == full_mask: 388 return split_out_cast (x, target_typ, bits) 389 else: 390 return mk_cast (expr, target_typ) 391 elif expr.is_op (['Plus', 'Minus']): 392 # rounding issues will appear if this arithmetic is done 393 # at a smaller number of bits than we'll eventually report 394 if expr.typ.num >= bits: 395 vals = [split_out_cast (x, target_typ, bits) 396 for x in expr.vals] 397 if expr.is_op ('Plus'): 398 return mk_plus (vals[0], vals[1]) 399 else: 400 return mk_minus (vals[0], vals[1]) 401 else: 402 return mk_cast (expr, target_typ) 403 else: 404 return mk_cast (expr, target_typ) 405 406def toplevel_split_out_cast (expr): 407 bits = None 408 if expr.is_op (['WordCast', 'WordCastSigned']): 409 bits = min ([expr.typ.num, expr.vals[0].typ.num]) 410 elif expr.is_op ('BWAnd'): 411 bits = get_bwand_mask (expr) 412 413 if bits: 414 expr = split_out_cast (expr, expr.typ, bits) 415 return mk_bwand_mask (expr, bits) 416 else: 417 return expr 418 419two_powers = {} 420 421def get_bwand_mask (expr): 422 """recognise (x && mask) opers, where mask = ((1 << n) - 1) 423 for some n""" 424 if not expr.is_op ('BWAnd'): 425 return 426 [x, y] = expr.vals 427 if not y.kind == 'Num': 428 return 429 val = y.val & ((1 << (y.typ.num)) - 1) 430 if not two_powers: 431 for i in range (129): 432 two_powers[1 << i] = i 433 return two_powers.get (val + 1) 434 435def mk_bwand_mask (expr, n): 436 return mk_bwand (expr, mk_num (((1 << n) - 1), expr.typ)) 437 438def end_addr (p, typ): 439 if typ[0] == 'Array': 440 (_, typ, n) = typ 441 sz = mk_times (mk_word32 (typ.size ()), n) 442 else: 443 assert typ[0] == 'Type', typ 444 (_, typ) = typ 445 sz = mk_word32 (typ.size ()) 446 return mk_plus (p, mk_minus (sz, mk_word32 (1))) 447 448def pvalid_assertion1 ((typ, k, p, pv), (typ2, k2, p2, pv2)): 449 """first pointer validity assertion: incompatibility. 450 pvalid1 & pvalid2 --> non-overlapping OR somehow-contained. 451 typ/typ2 is ('Type', syntax.Type) or ('Array', Type, Expr) for 452 dynamically sized arrays. 453 """ 454 offs1 = mk_minus (p, p2) 455 cond1 = get_styp_condition (offs1, typ, typ2) 456 offs2 = mk_minus (p2, p) 457 cond2 = get_styp_condition (offs2, typ2, typ) 458 459 out1 = mk_less (end_addr (p, typ), p2) 460 out2 = mk_less (end_addr (p2, typ2), p) 461 return mk_implies (mk_and (pv, pv2), foldr1 (mk_or, 462 [cond1, cond2, out1, out2])) 463 464def pvalid_assertion2 ((typ, k, p, pv), (typ2, k2, p2, pv2)): 465 """second pointer validity assertion: implication. 466 pvalid1 & strictly-contained --> pvalid2 467 """ 468 if typ[0] == 'Array' and typ2[0] == 'Array': 469 # this is such a vague notion it's not worth it 470 return true_term 471 offs1 = mk_minus (p, p2) 472 cond1 = get_styp_condition (offs1, typ, typ2) 473 imp1 = mk_implies (mk_and (cond1, pv2), pv) 474 offs2 = mk_minus (p2, p) 475 cond2 = get_styp_condition (offs2, typ2, typ) 476 imp2 = mk_implies (mk_and (cond2, pv), pv2) 477 return mk_and (imp1, imp2) 478 479def sym_distinct_assertion ((typ, p, pv), (start, end)): 480 out1 = mk_less (mk_plus (p, mk_word32 (typ.size () - 1)), mk_word32 (start)) 481 out2 = mk_less (mk_word32 (end), p) 482 return mk_implies (pv, mk_or (out1, out2)) 483 484def norm_array_type (t): 485 if t[0] == 'Type' and t[1].kind == 'Array': 486 (_, atyp) = t 487 return ('Array', atyp.el_typ_symb, mk_word32 (atyp.num), 'Strong') 488 elif t[0] == 'Array' and len (t) == 3: 489 (_, typ, l) = t 490 # these derive from PArrayValid assertions. we know the array is 491 # at least this long, but it might be longer. 492 return ('Array', typ, l, 'Weak') 493 else: 494 return t 495 496stored_styp_conditions = {} 497 498def get_styp_condition (offs, inner_typ, outer_typ): 499 r = get_styp_condition_inner1 (inner_typ, outer_typ) 500 if not r: 501 return false_term 502 else: 503 return r (offs) 504 505def get_styp_condition_inner1 (inner_typ, outer_typ): 506 inner_typ = norm_array_type (inner_typ) 507 outer_typ = norm_array_type (outer_typ) 508 k = (inner_typ, outer_typ) 509 if k in stored_styp_conditions: 510 return stored_styp_conditions[k] 511 r = get_styp_condition_inner2 (inner_typ, outer_typ) 512 stored_styp_conditions[k] = r 513 return r 514 515def array_typ_size ((kind, el_typ, num, _)): 516 el_size = mk_word32 (el_typ.size ()) 517 return mk_times (num, el_size) 518 519def get_styp_condition_inner2 (inner_typ, outer_typ): 520 if inner_typ[0] == 'Array' and outer_typ[0] == 'Array': 521 (_, ityp, inum, _) = inner_typ 522 (_, otyp, onum, outer_bound) = outer_typ 523 # array fits in another array if the starting element is 524 # a sub-element, and if the size of the left array plus 525 # the offset fits in the right array 526 cond = get_styp_condition_inner1 (('Type', ityp), outer_typ) 527 isize = array_typ_size (inner_typ) 528 osize = array_typ_size (outer_typ) 529 if outer_bound == 'Strong' and cond: 530 return lambda offs: mk_and (cond (offs), 531 mk_less_eq (mk_plus (isize, offs), osize)) 532 else: 533 return cond 534 elif inner_typ == outer_typ: 535 return lambda offs: mk_eq (offs, mk_word32 (0)) 536 elif outer_typ[0] == 'Type' and outer_typ[1].kind == 'Struct': 537 conds = [(get_styp_condition_inner1 (inner_typ, 538 ('Type', sf_typ)), mk_word32 (offs2)) 539 for (_, offs2, sf_typ) 540 in structs[outer_typ[1].name].fields.itervalues()] 541 conds = [cond for cond in conds if cond[0]] 542 if conds: 543 return lambda offs: foldr1 (mk_or, 544 [c (mk_minus (offs, offs2)) 545 for (c, offs2) in conds]) 546 else: 547 return None 548 elif outer_typ[0] == 'Array': 549 (_, el_typ, n, bound) = outer_typ 550 cond = get_styp_condition_inner1 (inner_typ, ('Type', el_typ)) 551 el_size = mk_word32 (el_typ.size ()) 552 size = mk_times (n, el_size) 553 if bound == 'Strong' and cond: 554 return lambda offs: mk_and (mk_less (offs, size), 555 cond (mk_modulus (offs, el_size))) 556 elif cond: 557 return lambda offs: cond (mk_modulus (offs, el_size)) 558 else: 559 return None 560 else: 561 return None 562 563def all_vars_have_prop (expr, prop): 564 class Failed (Exception): 565 pass 566 def visit (expr): 567 if expr.kind != 'Var': 568 return 569 v2 = (expr.name, expr.typ) 570 if not prop (v2): 571 raise Failed () 572 try: 573 expr.visit (visit) 574 return True 575 except Failed: 576 return False 577 578def all_vars_in_set (expr, var_set): 579 return all_vars_have_prop (expr, lambda v: v in var_set) 580 581def var_not_in_expr (var, expr): 582 v2 = (var.name, var.typ) 583 return all_vars_have_prop (expr, lambda v: v != v2) 584 585def mk_array_size_ineq (typ, num, p): 586 align = typ.align () 587 size = mk_times (mk_word32 (typ.size ()), num) 588 size_lim = ((2 ** 32) - 4) / typ.size () 589 return mk_less_eq (num, mk_word32 (size_lim)) 590 591def mk_align_valid_ineq (typ, p): 592 if typ[0] == 'Type': 593 (_, typ) = typ 594 align = typ.align () 595 size = mk_word32 (typ.size ()) 596 size_req = [] 597 else: 598 assert typ[0] == 'Array', typ 599 (kind, typ, num) = typ 600 align = typ.align () 601 size = mk_times (mk_word32 (typ.size ()), num) 602 size_req = [mk_array_size_ineq (typ, num, p)] 603 assert align in [1, 4, 8] 604 w0 = mk_word32 (0) 605 if align > 1: 606 align_req = [mk_eq (mk_bwand (p, mk_word32 (align - 1)), w0)] 607 else: 608 align_req = [] 609 return foldr1 (mk_and, align_req + size_req + [mk_not (mk_eq (p, w0)), 610 mk_implies (mk_less (w0, size), 611 mk_less_eq (p, mk_uminus (size)))]) 612 613# generic operations on function/problem graphs 614def dict_list (xys, keys = None): 615 """dict_list ([(1, 2), (1, 3), (2, 4)]) = {1: [2, 3], 2: [4]}""" 616 d = {} 617 for (x, y) in xys: 618 d.setdefault (x, []) 619 d[x].append (y) 620 if keys: 621 for x in keys: 622 d.setdefault (x, []) 623 return d 624 625def compute_preds (nodes): 626 preds = dict_list ([(c, n) for n in nodes 627 for c in nodes[n].get_conts ()], 628 keys = nodes) 629 for n in ['Ret', 'Err']: 630 preds.setdefault (n, []) 631 preds = dict ([(n, sorted (set (ps))) 632 for (n, ps) in preds.iteritems ()]) 633 return preds 634 635def simplify_node_elementary(node): 636 if node.kind == 'Cond' and node.cond == true_term: 637 return Node ('Basic', node.left, []) 638 elif node.kind == 'Cond' and node.cond == false_term: 639 return Node ('Basic', node.right, []) 640 elif node.kind == 'Cond' and node.left == node.right: 641 return Node ('Basic', node.left, []) 642 else: 643 return node 644 645def compute_var_flows (nodes, outputs, preds, override_lvals_rvals = {}): 646 # compute a graph of reverse var flows to pass to tarjan's algorithm 647 graph = {} 648 entries = ['Ret'] 649 for (n, node) in nodes.iteritems (): 650 if node.kind == 'Basic': 651 for (lv, rv) in node.upds: 652 graph[(n, 'Post', lv)] = [(n, 'Pre', v) 653 for v in syntax.get_expr_var_set (rv)] 654 elif node.is_noop (): 655 pass 656 else: 657 if n in override_lvals_rvals: 658 (lvals, rvals) = override_lvals_rvals[n] 659 else: 660 rvals = syntax.get_node_rvals (node) 661 rvals = set (rvals.iteritems ()) 662 lvals = set (node.get_lvals ()) 663 if node.kind != 'Basic': 664 lvals = list (lvals) + ['PC'] 665 entries.append ((n, 'Post', 'PC')) 666 for lv in lvals: 667 graph[(n, 'Post', lv)] = [(n, 'Pre', rv) 668 for rv in rvals] 669 graph['Ret'] = [(n, 'Post', v) 670 for n in preds['Ret'] for v in outputs (n)] 671 vs = set ([v for k in graph for (_, _, v) in graph[k]]) 672 for v in vs: 673 for n in nodes: 674 graph.setdefault ((n, 'Post', v), [(n, 'Pre', v)]) 675 graph[(n, 'Pre', v)] = [(n2, 'Post', v) 676 for n2 in preds[n]] 677 678 comps = tarjan (graph, entries) 679 return (graph, comps) 680 681def mk_not_red (v): 682 if v.is_op ('Not'): 683 [v] = v.vals 684 return v 685 else: 686 return syntax.mk_not (v) 687 688def cont_with_conds (nodes, n, conds): 689 while True: 690 if n not in nodes or nodes[n].kind != 'Cond': 691 return n 692 cond = nodes[n].cond 693 if cond in conds: 694 n = nodes[n].left 695 elif mk_not_red (cond) in conds: 696 n = nodes[n].right 697 else: 698 return n 699 700def contextual_conds (nodes, preds): 701 """computes a collection of conditions that can be assumed true 702 at any point in the node graph.""" 703 pre_conds = {} 704 arc_conds = {} 705 visit = [n for n in nodes if not (preds[n])] 706 while visit: 707 n = visit.pop () 708 if n not in nodes: 709 continue 710 in_arc_conds = [arc_conds.get ((pre, n), set ()) 711 for pre in preds[n]] 712 if not in_arc_conds: 713 conds = set () 714 else: 715 conds = set.intersection (* in_arc_conds) 716 if pre_conds.get (n) == conds: 717 continue 718 pre_conds[n] = conds 719 if n not in nodes: 720 continue 721 if nodes[n].kind == 'Cond' and nodes[n].left == nodes[n].right: 722 c_conds = [conds, conds] 723 elif nodes[n].kind == 'Cond': 724 c_conds = [nodes[n].cond, mk_not_red (nodes[n].cond)] 725 c_conds = [set.union (set ([c]), conds) 726 for c in c_conds] 727 else: 728 upds = set (nodes[n].get_lvals ()) 729 c_conds = [set ([c for c in conds if 730 not set.intersection (upds, 731 syntax.get_expr_var_set (c))])] 732 for (cont, conds) in zip (nodes[n].get_conts (), c_conds): 733 arc_conds[(n, cont)] = conds 734 visit.append (cont) 735 return (arc_conds, pre_conds) 736 737def contextual_cond_simps (nodes, preds): 738 """a common pattern in architectures with conditional operations is 739 a sequence of instructions with the same condition. 740 we can usually then reduce to a single contional block. 741 b e => b-e 742 / \ / \ => / \ 743 a-c-d-f-g => a-c-f-g 744 this is sometimes important if b calculates a register that e uses 745 since variable dependency analysis will see this register escape via 746 the impossible path a-c-d-e 747 """ 748 (arc_conds, pre_conds) = contextual_conds (nodes, preds) 749 nodes = dict (nodes) 750 for n in nodes: 751 if nodes[n].kind == 'Cond': 752 continue 753 cont = nodes[n].cont 754 conds = arc_conds[(n, cont)] 755 cont2 = cont_with_conds (nodes, cont, conds) 756 if cont2 != cont: 757 nodes[n] = syntax.copy_rename (nodes[n], 758 ({}, {cont: cont2})) 759 return nodes 760 761def minimal_loop_node_set (p): 762 """discover a minimal set of loop addresses, excluding some operations 763 using conditional instructions which are syntactically within the 764 loop but semantically must always be followed by an immediate loop 765 exit. 766 767 amounts to rerunning loop detection after contextual_cond_simps.""" 768 769 loop_ns = set (p.loop_data) 770 really_in_loop = {} 771 nodes = contextual_cond_simps (p.nodes, p.preds) 772 def is_really_in_loop (n): 773 if n in really_in_loop: 774 return really_in_loop[n] 775 ns = [] 776 r = None 777 while r == None: 778 ns.append (n) 779 if n not in loop_ns: 780 r = False 781 elif n in p.splittable_points (n): 782 r = True 783 else: 784 conts = [n2 for n2 in nodes[n].get_conts () 785 if n2 != 'Err'] 786 if len (conts) > 1: 787 r = True 788 else: 789 [n] = conts 790 for n in ns: 791 really_in_loop[n] = r 792 return r 793 return set ([n for n in loop_ns if is_really_in_loop (n)]) 794 795def possible_graph_divs (p, min_cost = 20, max_cost = 20, ratio = 0.85, 796 trace = None): 797 es = [e[0] for e in p.entries] 798 divs = [] 799 direct_costs = {} 800 future_costs = {'Ret': set (), 'Err': set ()} 801 prev_costs = {} 802 int_costs = {} 803 fracs = {} 804 for n in p.nodes: 805 node = p.nodes[n] 806 if node.kind == 'Call': 807 cost = set ([(n, 20)]) 808 elif p.loop_id (n): 809 cost = set ([(p.loop_id (n), 50)]) 810 else: 811 cost = set ([(n, len (node.get_mem_accesses ()))]) 812 cost.discard ((n, 0)) 813 direct_costs[n] = cost 814 for n in p.tarjan_order: 815 prev_costs[n] = set.union (* ([direct_costs[n]] 816 + [prev_costs.get (c, set ()) for c in p.preds[n]])) 817 for n in reversed (p.tarjan_order): 818 cont_costs = [future_costs.get (c, set ()) 819 for c in p.nodes[n].get_conts ()] 820 cost = set.union (* ([direct_costs[n]] + cont_costs)) 821 p_ct = sum ([c for (_, c) in prev_costs[n]]) 822 future_costs[n] = cost 823 if p.nodes[n].kind != 'Cond' or p_ct > max_cost: 824 continue 825 ct = sum ([c for (_, c) in set.union (cost, prev_costs[n])]) 826 if ct < min_cost: 827 continue 828 [c1, c2] = [sum ([c for (_, c) 829 in set.union (cs, prev_costs[n])]) 830 for cs in cont_costs] 831 fracs[n] = ((c1 * c1) + (c2 * c2)) / (ct * ct * 1.0) 832 if fracs[n] < ratio: 833 divs.append (n) 834 divs.reverse () 835 if trace != None: 836 trace[0] = (direct_costs, future_costs, prev_costs, 837 int_costs, fracs) 838 return divs 839 840def compute_var_deps (nodes, outputs, preds, override_lvals_rvals = {}, 841 trace = None): 842 # outs = list of (outname, retvars) 843 var_deps = {} 844 visit = set () 845 visit.update (preds['Ret']) 846 visit.update (preds['Err']) 847 848 nodes = contextual_cond_simps (nodes, preds) 849 850 while visit: 851 n = visit.pop () 852 853 node = simplify_node_elementary (nodes[n]) 854 if n in override_lvals_rvals: 855 (lvals, rvals) = override_lvals_rvals[n] 856 lvals = set (lvals) 857 rvals = set (rvals) 858 elif node.is_noop (): 859 lvals = set ([]) 860 rvals = set ([]) 861 else: 862 rvals = syntax.get_node_rvals (node) 863 rvals = set (rvals.iteritems ()) 864 lvals = set (node.get_lvals ()) 865 cont_vs = set () 866 867 for c in node.get_conts (): 868 if c == 'Ret': 869 cont_vs.update (outputs (n)) 870 elif c == 'Err': 871 pass 872 else: 873 cont_vs.update (var_deps.get (c, [])) 874 vs = set.union (rvals, cont_vs - lvals) 875 876 if n in var_deps and vs <= var_deps[n]: 877 continue 878 if trace and n in trace: 879 diff = vs - var_deps.get (n, set()) 880 printout ('add %s at %d' % (diff, n)) 881 printout (' %s, %s, %s, %s' % (len (vs), len (cont_vs), len (lvals), len (rvals))) 882 var_deps[n] = vs 883 visit.update (preds[n]) 884 885 return var_deps 886 887def compute_loop_var_analysis (p, var_deps, n, override_nodes = None): 888 if override_nodes == None: 889 nodes = p.nodes 890 else: 891 nodes = override_nodes 892 893 upd_vs = set ([v for n2 in p.loop_body (n) 894 if not nodes[n2].is_noop () 895 for v in nodes[n2].get_lvals ()]) 896 const_vs = set ([v for n2 in p.loop_body (n) 897 for v in var_deps[n2] if v not in upd_vs]) 898 899 vca = compute_var_cycle_analysis (p, nodes, n, 900 const_vs, set (var_deps[n])) 901 vca = [(syntax.mk_var (nm, typ), data) 902 for ((nm, typ), data) in vca.items ()] 903 return vca 904 905cvca_trace = [] 906cvca_diag = [False] 907no_accum_expressions = set () 908 909def compute_var_cycle_analysis (p, nodes, n, const_vars, vs, diag = None): 910 911 if diag == None: 912 diag = cvca_diag[0] 913 914 cache = {} 915 del cvca_trace[:] 916 impossible_nodes = {} 917 loop = p.loop_body (n) 918 919 def warm_cache_before (n2, v): 920 cvca_trace.append ((n2, v)) 921 cvca_trace.append ('(') 922 arc = [] 923 for i in range (100000): 924 opts = [n3 for n3 in p.preds[n2] if n3 in loop 925 if v not in nodes[n3].get_lvals () 926 if n3 != n 927 if (n3, v) not in cache] 928 if not opts: 929 break 930 n2 = opts[0] 931 arc.append (n2) 932 if not (len (arc) < 100000): 933 trace ('warmup overrun in compute_var_cycle_analysis') 934 trace ('chasing %s in %s' % (v, set (arc))) 935 assert False, (v, arc[-500:]) 936 for n2 in reversed (arc): 937 var_eval_before (n2, v) 938 cvca_trace.append (')') 939 940 def var_eval_before (n2, v, do_cmp = True): 941 if (n2, v) in cache and do_cmp: 942 return cache[(n2, v)] 943 if n2 == n and do_cmp: 944 var_exp = mk_var (v[0], v[1]) 945 vs = set ([v for v in [v] if v not in const_vars]) 946 return (vs, var_exp) 947 warm_cache_before (n2, v) 948 ps = [n3 for n3 in p.preds[n2] if n3 in loop 949 if not node_impossible (n3)] 950 if not ps: 951 return None 952 vs = [var_eval_after (n3, v) for n3 in ps] 953 if not all ([v3 == vs[0] for v3 in vs]): 954 if diag: 955 trace ('vs disagree for %s @ %d: %s' % (v, n2, vs)) 956 r = None 957 else: 958 r = vs[0] 959 if do_cmp: 960 cache[(n2, v)] = r 961 return r 962 def var_eval_after (n2, v): 963 node = nodes[n2] 964 if node.kind == 'Call' and v in node.rets: 965 if diag: 966 trace ('fetched %s from call at %d' % (v, n2)) 967 return None 968 elif node.kind == 'Basic': 969 for (lv, val) in node.upds: 970 if lv == v: 971 return expr_eval_before (n2, val) 972 return var_eval_before (n2, v) 973 else: 974 return var_eval_before (n2, v) 975 def expr_eval_before (n2, expr): 976 if expr.kind == 'Op': 977 if expr.vals == []: 978 return (set(), expr) 979 vals = [expr_eval_before (n2, v) 980 for v in expr.vals] 981 if None in vals: 982 return None 983 s = set.union (* [s for (s, v) in vals]) 984 if len(s) > 1: 985 if diag: 986 trace ('too many vars for %s @ %d: %s' % (expr, n2, s)) 987 return None 988 return (s, Expr ('Op', expr.typ, 989 name = expr.name, 990 vals = [v for (s, v) in vals])) 991 elif expr.kind == 'Num': 992 return (set(), expr) 993 elif expr.kind == 'Var': 994 return var_eval_before (n2, 995 (expr.name, expr.typ)) 996 else: 997 if diag: 998 trace ('Unwalkable expr %s' % expr) 999 return None 1000 def node_impossible (n2): 1001 if n2 in impossible_nodes: 1002 return impossible_nodes[n2] 1003 if n2 == n or n2 in p.get_loop_splittables (n): 1004 imposs = False 1005 else: 1006 pres = [n3 for n3 in p.preds[n2] 1007 if n3 in loop if not node_impossible (n3)] 1008 if n2 in impossible_nodes: 1009 imposs = impossible_nodes[n2] 1010 else: 1011 imposs = not bool (pres) 1012 impossible_nodes[n2] = imposs 1013 node = nodes[n2] 1014 if imposs or node.kind != 'Cond': 1015 return imposs 1016 if 1 >= len ([n3 for n3 in node.get_conts () 1017 if n3 in loop]): 1018 return imposs 1019 c = expr_eval_before (n2, node.cond) 1020 if c != None: 1021 c = try_eval_expr (c[1]) 1022 if c != None: 1023 trace ('determined loop inner cond at %d equals %s' 1024 % (n2, c == syntax.true_term)) 1025 if c == syntax.true_term: 1026 impossible_nodes[node.right] = True 1027 elif c == syntax.false_term: 1028 impossible_nodes[node.left] = True 1029 return imposs 1030 1031 vca = {} 1032 for v in vs: 1033 rv = var_eval_before (n, v, do_cmp = False) 1034 if rv == None: 1035 vca[v] = 'LoopVariable' 1036 continue 1037 (s, expr) = rv 1038 if expr == mk_var (v[0], v[1]): 1039 vca[v] = 'LoopConst' 1040 continue 1041 if all_vars_in_set (expr, const_vars): 1042 # a repeatedly evaluated const expression, is const 1043 vca[v] = 'LoopConst' 1044 continue 1045 if var_not_in_expr (mk_var (v[0], v[1]), expr): 1046 # leaf calculations do not have data flow to 1047 # themselves. the search algorithm doesn't 1048 # have to worry about these. 1049 vca[v] = 'LoopLeaf' 1050 continue 1051 (form, offs) = accumulator_closed_form (expr, v) 1052 if form != None and all_vars_in_set (form (), const_vars): 1053 vca[v] = ('LoopLinearSeries', form, offs) 1054 else: 1055 if diag: 1056 trace ('No accumulator %s => %s' 1057 % (v, expr)) 1058 no_accum_expressions.add ((v, expr)) 1059 vca[v] = 'LoopVariable' 1060 return vca 1061 1062eval_expr_solver = [None] 1063 1064def try_eval_expr (expr): 1065 """attempt to reduce an expression to a single result, vaguely like 1066 what constant propagation would do. it might work!""" 1067 import search 1068 if not eval_expr_solver[0]: 1069 import solver 1070 eval_expr_solver[0] = solver.Solver () 1071 try: 1072 return search.eval_model_expr ({}, eval_expr_solver[0], expr) 1073 except KeyboardInterrupt, e: 1074 raise e 1075 except Exception, e: 1076 return None 1077 1078expr_linear_sum = set (['Plus', 'Minus']) 1079expr_linear_cast = set (['WordCast', 'WordCastSigned']) 1080 1081expr_linear_all = set.union (expr_linear_sum, expr_linear_cast, 1082 ['Times', 'ShiftLeft']) 1083 1084def possibly_linear (expr): 1085 if expr.kind in set (['Var', 'Num', 'Symbol', 'Type', 'Token']): 1086 return True 1087 elif expr.is_op (expr_linear_all): 1088 return all ([possibly_linear (x) for x in expr.vals]) 1089 else: 1090 return False 1091 1092def lv_expr (expr, env): 1093 if expr in env: 1094 return env[expr] 1095 elif expr.kind in set (['Num', 'Symbol', 'Type', 'Token']): 1096 return (expr, 'LoopConst', None, set ()) 1097 elif expr.kind == 'Var': 1098 return (None, None, None, None) 1099 elif expr.kind != 'Op': 1100 assert expr in env, expr 1101 1102 lvs = [lv_expr (v, env) for v in expr.vals] 1103 rs = [lv[1] for lv in lvs] 1104 mk_offs = lambda vals: syntax.adjust_op_vals (expr, vals) 1105 if None in rs: 1106 return (None, None, None, None) 1107 if set (rs) == set (['LoopConst']): 1108 return (expr, 'LoopConst', None, set ()) 1109 offs_set = set.union (* ([lv[3] for lv in lvs] + [set ()])) 1110 arg_offs = [] 1111 for (expr2, k, offs, _) in lvs: 1112 if k == 'LoopConst' and expr2.typ.kind == 'Word': 1113 arg_offs.append (syntax.mk_num (0, expr2.typ)) 1114 else: 1115 arg_offs.append (offs) 1116 if expr.is_op (expr_linear_sum): 1117 if set (rs) == set (['LoopConst', 'LoopLinearSeries']): 1118 return (expr, 'LoopLinearSeries', mk_offs (arg_offs), 1119 offs_set) 1120 elif expr.is_op ('Times'): 1121 if set (rs) == set (['LoopLinearSeries', 'LoopConst']): 1122 # the new offset is the product of the linear offset 1123 # and the constant value 1124 [linear_offs] = [offs for (_, k, offs, _) in lvs 1125 if k == 'LoopLinearSeries'] 1126 [const_value] = [v for (v, k, _, _) in lvs 1127 if k == 'LoopConst'] 1128 return (expr, 'LoopLinearSeries', 1129 mk_offs ([linear_offs, const_value]), offs_set) 1130 if expr.is_op ('ShiftLeft'): 1131 if rs == ['LoopLinearSeries', 'LoopConst']: 1132 return (expr, 'LoopLinearSeries', 1133 mk_offs ([arg_offs[0], lvs[1][0]]), offs_set) 1134 if expr.is_op (expr_linear_cast): 1135 if rs == ['LoopLinearSeries']: 1136 return (expr, 'LoopLinearSeries', mk_offs (arg_offs), 1137 offs_set) 1138 return (None, None, None, None) 1139 1140# FIXME: this should probably be unified with compute_var_cycle_analysis, 1141# but doing so is complicated 1142def linear_series_exprs (p, loop, va): 1143 def lv_init (v, data): 1144 if data[0] == 'LoopLinearSeries': 1145 return (v, 'LoopLinearSeries', data[2], set ([data[2]])) 1146 elif data == 'LoopConst': 1147 return (v, 'LoopConst', None, set ()) 1148 else: 1149 return (None, None, None, None) 1150 cache = {loop: dict ([(v, lv_init (v, data)) for (v, data) in va])} 1151 post_cache = {} 1152 loop_body = p.loop_body (loop) 1153 frontier = [n2 for n2 in p.nodes[loop].get_conts () 1154 if n2 in loop_body] 1155 def lv_merge ((v1, lv1, offs1, oset1), (v2, lv2, offs2, oset2)): 1156 if v1 != v2: 1157 return (None, None, None, None) 1158 assert lv1 == lv2 and offs1 == offs2 1159 return (v1, lv1, offs1, oset1) 1160 def compute_post (n): 1161 if n in post_cache: 1162 return post_cache[n] 1163 pre_env = cache[n] 1164 env = dict (cache[n]) 1165 if p.nodes[n].kind == 'Basic': 1166 for ((v, typ), rexpr) in p.nodes[n].upds: 1167 env[mk_var (v, typ)] = lv_expr (rexpr, pre_env) 1168 elif p.nodes[n].kind == 'Call': 1169 for (v, typ) in p.nodes[n].get_lvals (): 1170 env[mk_var (v, typ)] = (None, None, None, None) 1171 post_cache[n] = env 1172 return env 1173 while frontier: 1174 n = frontier.pop () 1175 if [n2 for n2 in p.preds[n] if n2 in loop_body 1176 if n2 not in cache]: 1177 continue 1178 if n in cache: 1179 continue 1180 envs = [compute_post (n2) for n2 in p.preds[n] 1181 if n2 in loop_body] 1182 all_vs = set.union (* [set (env) for env in envs]) 1183 cache[n] = dict ([(v, foldr1 (lv_merge, 1184 [env.get (v, (None, None, None, None)) 1185 for env in envs])) 1186 for v in all_vs]) 1187 frontier.extend ([n2 for n2 in p.nodes[n].get_conts () 1188 if n2 in loop_body]) 1189 return cache 1190 1191def get_loop_linear_offs (p, loop_head): 1192 import search 1193 va = search.get_loop_var_analysis_at (p, loop_head) 1194 exprs = linear_series_exprs (p, loop_head, va) 1195 def offs_fn (n, expr): 1196 assert p.loop_id (n) == loop_head 1197 env = exprs[n] 1198 rv = lv_expr (expr, env) 1199 if rv[1] == None: 1200 return None 1201 elif rv[1] == 'LoopConst': 1202 return mk_num (0, expr.typ) 1203 elif rv[1] == 'LoopLinearSeries': 1204 return rv[2] 1205 else: 1206 assert not 'lv_expr kind understood', rv 1207 return offs_fn 1208 1209def interesting_node_exprs (p, n, tags = None, use_pairings = True): 1210 if tags == None: 1211 tags = p.pairing.tags 1212 node = p.nodes[n] 1213 memaccs = node.get_mem_accesses () 1214 vs = [(kind, ptr) for (kind, ptr, v, m) in memaccs] 1215 vs += [('MemUpdateArg', v) for (kind, ptr, v, m) in memaccs 1216 if kind == 'MemUpdate'] 1217 1218 if node.kind == 'Call' and use_pairings: 1219 tag = p.node_tags[n][0] 1220 from target_objects import functions, pairings 1221 import solver 1222 fun = functions[node.fname] 1223 arg_input_map = dict (azip (fun.inputs, node.args)) 1224 pairs = [pair for pair in pairings.get (node.fname, []) 1225 if pair.tags == tags] 1226 if not pairs: 1227 return vs 1228 [pair] = pairs 1229 in_eq_vs = [(('Call', pair.name, i), 1230 var_subst (v, arg_input_map)) 1231 for (i, ((lhs, l_s), (rhs, r_s))) 1232 in enumerate (pair.eqs[0]) 1233 if l_s.endswith ('_IN') and r_s.endswith ('_IN') 1234 if l_s != r_s 1235 if solver.typ_representable (lhs.typ) 1236 for (v, site) in [(lhs, l_s), (rhs, r_s)] 1237 if site == '%s_IN' % tag] 1238 vs.extend (in_eq_vs) 1239 return vs 1240 1241def interesting_linear_series_exprs (p, loop, va, tags = None, 1242 use_pairings = True): 1243 if tags == None: 1244 tags = p.pairing.tags 1245 expr_env = linear_series_exprs (p, loop, va) 1246 res_env = {} 1247 for (n, env) in expr_env.iteritems (): 1248 vs = interesting_node_exprs (p, n) 1249 1250 vs = [(kind, v, lv_expr (v, env)) for (kind, v) in vs] 1251 vs = [(kind, v, offs, offs_set) 1252 for (kind, v, (_, lv, offs, offs_set)) in vs 1253 if lv == 'LoopLinearSeries'] 1254 if vs: 1255 res_env[n] = vs 1256 return res_env 1257 1258def mk_var_renames (xs, ys): 1259 renames = {} 1260 for (x, y) in azip (xs, ys): 1261 assert x.kind == 'Var' and y.kind == 'Var' 1262 assert x.name not in renames 1263 renames[x.name] = y.name 1264 return renames 1265 1266def first_aligned_address (nodes, radix): 1267 ks = [k for k in nodes 1268 if k % radix == 0] 1269 if ks: 1270 return min (ks) 1271 else: 1272 return None 1273 1274def entry_aligned_address (fun, radix): 1275 n = fun.entry 1276 while n % radix != 0: 1277 ns = fun.nodes[n].get_conts () 1278 assert len (ns) == 1, (fun.name, n) 1279 [n] = ns 1280 return n 1281 1282def aligned_address_sanity (functions, symbols, radix): 1283 for (f, func) in functions.iteritems (): 1284 if f not in symbols: 1285 # happens for static or invented functions sometimes 1286 continue 1287 if func.entry: 1288 addr = first_aligned_address (func.nodes, radix) 1289 if addr == None: 1290 printout ('Warning: %s: no aligned instructions' % f) 1291 continue 1292 addr2 = symbols[f][0] 1293 if addr != addr2: 1294 printout ('target mismatch on func %s' % f) 1295 printout (' (starts at 0x%x not 0x%x)' % (addr, addr2)) 1296 return False 1297 addr3 = entry_aligned_address (func, radix) 1298 if addr3 != addr2: 1299 printout ('entry mismatch on func %s' % f) 1300 printout (' (enters at 0x%x not 0x%x)' % (addr3, addr2)) 1301 return False 1302 return True 1303 1304# variant of tarjan's strongly connected component algorithm 1305def tarjan (graph, entries): 1306 """tarjan (graph, entries) 1307 variant of tarjan's strongly connected component algorithm 1308 e.g. tarjan ({1: [2, 3], 3: [4, 5]}, [1]) 1309 entries should not be reachable""" 1310 data = {} 1311 comps = [] 1312 for v in entries: 1313 assert v not in data 1314 tarjan1 (graph, v, data, [], set ([]), comps) 1315 return comps 1316 1317def tarjan1 (graph, v, data, stack, stack_set, comps): 1318 vs = [] 1319 while True: 1320 # skip through nodes with single successors 1321 data[v] = [len(data), len(data)] 1322 stack.append(v) 1323 stack_set.add(v) 1324 cs = graph[v] 1325 if len (cs) != 1 or cs[0] in data: 1326 break 1327 vs.append ((v, cs[0])) 1328 [v] = cs 1329 1330 for c in graph[v]: 1331 if c not in data: 1332 tarjan1 (graph, c, data, stack, stack_set, comps) 1333 data[v][1] = min (data[v][1], data[c][1]) 1334 elif c in stack_set: 1335 data[v][1] = min (data[v][1], data[c][0]) 1336 1337 vs.reverse () 1338 for (v2, c) in vs: 1339 data[v2][1] = min (data[v2][1], data[c][1]) 1340 1341 for (v2, _) in [(v, 0)] + vs: 1342 if data[v2][1] == data[v2][0]: 1343 comp = [] 1344 while True: 1345 x = stack.pop () 1346 stack_set.remove (x) 1347 if x == v2: 1348 break 1349 comp.append (x) 1350 comps.append ((v2, comp)) 1351 1352def divides_loop (graph, split_set): 1353 graph2 = dict (graph) 1354 for n in split_set: 1355 graph2[n] = [] 1356 assert 'ENTRY_POINT' not in graph2 1357 graph2['ENTRY_POINT'] = list (graph) 1358 comps = tarjan (graph2, ['ENTRY_POINT']) 1359 return not ([(h, t) for (h, t) in comps if t]) 1360 1361def strongly_connected_split_points1 (graph): 1362 """find the nodes of a strongly connected 1363 component which, when removed, disconnect the component. 1364 complex loops lack such a split point.""" 1365 1366 # find one simple cycle in the graph 1367 walk = [] 1368 walk_set = set () 1369 n = min (graph) 1370 while n not in walk_set: 1371 walk.append (n) 1372 walk_set.add (n) 1373 n = graph[n][0] 1374 i = walk.index (n) 1375 cycle = walk[i:] 1376 1377 def subgraph_test (subgraph): 1378 graph2 = dict ([(n, [n2 for n2 in graph[n] if n2 in subgraph]) 1379 for n in subgraph]) 1380 graph2['HEAD'] = list (subgraph) 1381 comps = tarjan (graph2, ['HEAD']) 1382 return bool ([h for (h, t) in comps if t]) 1383 1384 cycle_set = set (cycle) 1385 cycle = [('Node', set ([n]), False, 1386 [n2 for n2 in graph[n] if n2 != graph[n][0]]) 1387 for n in cycle] 1388 i = 0 1389 while i < len (cycle): 1390 print i, cycle 1391 (kind, ns, test, unvisited) = cycle[i] 1392 if not unvisited: 1393 i += 1 1394 continue 1395 n = unvisited.pop () 1396 arc_set = set () 1397 while n not in cycle_set: 1398 if n in arc_set: 1399 # found two totally disjoint loops, so there 1400 # are no splitting points 1401 return set () 1402 arc_set.add (n) 1403 n = graph[n][0] 1404 if n in ns: 1405 if kind == 'Node': 1406 # only this node can be a splittable now. 1407 if subgraph_test (set (graph) - set ([n])): 1408 return set () 1409 else: 1410 return set ([n]) 1411 else: 1412 cycle[i] = (kind, ns, True, unvisited) 1413 ns.update (arc_set) 1414 continue 1415 j = (i + 1) % len (cycle) 1416 new_ns = set () 1417 new_unvisited = set () 1418 new_test = False 1419 while n not in cycle[j][1]: 1420 new_ns.update (cycle[j][1]) 1421 new_unvisited.update (cycle[j][3]) 1422 new_test = cycle[j][2] or new_test 1423 j = (j + 1) % len (cycle) 1424 new_ns.update (arc_set) 1425 new_unvisited.update ([n3 for n2 in arc_set for n3 in graph[n2]]) 1426 new_v = ('Group', new_ns, new_test, list (new_unvisited - new_ns)) 1427 print i, j, n 1428 if j > i: 1429 cycle[i + 1:j] = [new_v] 1430 else: 1431 cycle = [cycle[i], new_v] + cycle[j:i] 1432 i = 0 1433 cycle_set.update (new_ns) 1434 for (kind, ns, test, unvisited) in cycle: 1435 if test and subgraph_test (ns): 1436 return set () 1437 return set ([n for (kind, ns, _, _) in cycle 1438 if kind == 'Node' for n in ns]) 1439 1440def strongly_connected_split_points (graph): 1441 res = strongly_connected_split_points1 (graph) 1442 res2 = set () 1443 for n in graph: 1444 graph2 = dict (graph) 1445 graph2[n] = [] 1446 graph2['ENTRY'] = list (graph) 1447 comps = tarjan (graph2, ['ENTRY']) 1448 if not [comp for comp in comps if comp[1]]: 1449 res2.add (n) 1450 assert res == res2, (graph, res, res2) 1451 return res 1452 1453def get_one_loop_splittable (p, loop_set): 1454 """discover a component of a strongly connected 1455 component which, when removed, disconnects the component. 1456 complex loops lack such a split point.""" 1457 candidates = set (loop_set) 1458 graph = dict ([(x, [y for y in p.nodes[x].get_conts () 1459 if y in loop_set]) for x in loop_set]) 1460 while candidates: 1461 loop2 = find_loop_avoiding (graph, loop_set, candidates) 1462 candidates = set.intersection (loop2, candidates) 1463 if not candidates: 1464 return None 1465 n = candidates.pop () 1466 graph2 = dict ([(x, [y for y in graph[x] if y != n]) 1467 for x in graph]) 1468 comps = tarjan (graph2, [n]) 1469 comps = [(h, t) for (h, t) in comps if t] 1470 if not comps: 1471 return n 1472 for (h, t) in comps: 1473 s = set ([h] + t) 1474 candidates = set.intersection (s, candidates) 1475 return None 1476 1477def find_loop_avoiding (graph, loop, avoid): 1478 n = (list (loop - avoid) + list (loop))[0] 1479 arc = [n] 1480 visited = set ([n]) 1481 while True: 1482 cs = set (graph[n]) 1483 acs = cs - avoid 1484 vcs = set.intersection (cs, visited) 1485 if vcs: 1486 n = vcs.pop () 1487 break 1488 elif acs: 1489 n = acs.pop () 1490 else: 1491 n = cs.pop () 1492 visited.add (n) 1493 arc.append (n) 1494 [i] = [i for (i, n2) in enumerate (arc) if n2 == n] 1495 return set (arc[i:]) 1496 1497# non-equality relations in proof hypotheses are recorded as a pretend 1498# equality and reverted to their 'real' meaning here. 1499def mk_stack_wrapper (stack_ptr, stack, excepts): 1500 return syntax.mk_rel_wrapper ('StackWrapper', 1501 [stack_ptr, stack] + excepts) 1502 1503def mk_mem_acc_wrapper (addr, v): 1504 return syntax.mk_rel_wrapper ('MemAccWrapper', [addr, v]) 1505 1506def mk_mem_wrapper (m): 1507 return syntax.mk_rel_wrapper ('MemWrapper', [m]) 1508 1509def tm_with_word32_list (xs): 1510 if xs: 1511 return foldr1 (mk_plus, map (mk_word32, xs)) 1512 else: 1513 return mk_uminus (mk_word32 (0)) 1514 1515def word32_list_from_tm (t): 1516 xs = [] 1517 while t.is_op ('Plus'): 1518 [x, t] = t.vals 1519 assert x.kind == 'Num' and x.typ == word32T 1520 xs.append (x.val) 1521 if t.kind == 'Num': 1522 xs.append (t.val) 1523 return xs 1524 1525def mk_eq_selective_wrapper (v, (xs, ys)): 1526 # this is a huge hack, but we need to put these lists somewhere 1527 xs = tm_with_word32_list (xs) 1528 ys = tm_with_word32_list (ys) 1529 return syntax.mk_rel_wrapper ('EqSelectiveWrapper', [v, xs, ys]) 1530 1531def apply_rel_wrapper (lhs, rhs): 1532 assert lhs.typ == syntax.builtinTs['RelWrapper'] 1533 assert rhs.typ == syntax.builtinTs['RelWrapper'] 1534 assert lhs.kind == 'Op' 1535 assert rhs.kind == 'Op' 1536 ops = set ([lhs.name, rhs.name]) 1537 if ops == set (['StackWrapper']): 1538 [sp1, st1] = lhs.vals[:2] 1539 [sp2, st2] = rhs.vals[:2] 1540 excepts = list (set (lhs.vals[2:] + rhs.vals[2:])) 1541 for p in excepts: 1542 st1 = syntax.mk_memupd (st1, p, syntax.mk_word32 (0)) 1543 st2 = syntax.mk_memupd (st2, p, syntax.mk_word32 (0)) 1544 return syntax.Expr ('Op', boolT, name = 'StackEquals', 1545 vals = [sp1, st1, sp2, st2]) 1546 elif ops == set (['MemAccWrapper', 'MemWrapper']): 1547 [acc] = [v for v in [lhs, rhs] if v.is_op ('MemAccWrapper')] 1548 [addr, val] = acc.vals 1549 assert addr.typ == syntax.word32T 1550 [m] = [v for v in [lhs, rhs] if v.is_op ('MemWrapper')] 1551 [m] = m.vals 1552 assert m.typ == builtinTs['Mem'] 1553 expr = mk_eq (mk_memacc (m, addr, val.typ), val) 1554 return expr 1555 elif ops == set (['EqSelectiveWrapper']): 1556 [lhs_v, _, _] = lhs.vals 1557 [rhs_v, _, _] = rhs.vals 1558 if lhs_v.typ == syntax.builtinTs['RelWrapper']: 1559 return apply_rel_wrapper (lhs_v, rhs_v) 1560 else: 1561 return mk_eq (lhs, rhs) 1562 else: 1563 assert not 'rel wrapper opname understood' 1564 1565def inst_eq_at_visit (exp, vis): 1566 if not exp.is_op ('EqSelectiveWrapper'): 1567 return True 1568 [_, xs, ys] = exp.vals 1569 # hacks 1570 xs = word32_list_from_tm (xs) 1571 ys = word32_list_from_tm (ys) 1572 if vis.kind == 'Number': 1573 return vis.n in xs 1574 elif vis.kind == 'Offset': 1575 return vis.n in ys 1576 else: 1577 assert not 'visit kind useable', vis 1578 1579def strengthen_hyp (expr, sign = 1): 1580 if not expr.kind == 'Op': 1581 return expr 1582 if expr.name in ['And', 'Or']: 1583 vals = [strengthen_hyp (v, sign) for v in expr.vals] 1584 return syntax.adjust_op_vals (expr, vals) 1585 elif expr.name == 'Implies': 1586 [l, r] = expr.vals 1587 l = strengthen_hyp (l, - sign) 1588 r = strengthen_hyp (r, sign) 1589 return syntax.mk_implies (l, r) 1590 elif expr.name == 'Not': 1591 [x] = expr.vals 1592 x = strengthen_hyp (x, - sign) 1593 return syntax.mk_not (x) 1594 elif expr.name == 'StackEquals': 1595 if sign == 1: 1596 return syntax.Expr ('Op', boolT, 1597 name = 'ImpliesStackEquals', vals = expr.vals) 1598 else: 1599 return syntax.Expr ('Op', boolT, 1600 name = 'StackEqualsImplies', vals = expr.vals) 1601 elif expr.name == 'ROData': 1602 if sign == 1: 1603 return syntax.Expr ('Op', boolT, 1604 name = 'ImpliesROData', vals = expr.vals) 1605 else: 1606 return expr 1607 elif expr.name == 'Equals' and expr.vals[0].typ == boolT: 1608 vals = expr.vals 1609 if vals[1] in [syntax.true_term, syntax.false_term]: 1610 vals = [vals[1], vals[0]] 1611 if vals[0] == syntax.true_term: 1612 return strengthen_hyp (vals[1], sign) 1613 elif vals[0] == syntax.false_term: 1614 return strengthen_hyp (syntax.mk_not (vals[1]), sign) 1615 else: 1616 return expr 1617 else: 1618 return expr 1619 1620def weaken_assert (expr): 1621 return strengthen_hyp (expr, -1) 1622 1623pred_logic_ops = set (['Not', 'And', 'Or', 'Implies']) 1624 1625def norm_neg (expr): 1626 if not expr.is_op ('Not'): 1627 return expr 1628 [nexpr] = expr.vals 1629 if not nexpr.is_op (pred_logic_ops): 1630 return expr 1631 if nexpr.is_op ('Not'): 1632 [expr] = nexpr.vals 1633 return norm_neg (expr) 1634 [x, y] = nexpr.vals 1635 if nexpr.is_op ('And'): 1636 return mk_or (norm_mk_not (x), norm_mk_not (y)) 1637 elif nexpr.is_op ('Or'): 1638 return mk_and (norm_mk_not (x), norm_mk_not (y)) 1639 elif nexpr.is_op ('Implies'): 1640 return mk_and (x, mk_not (y)) 1641 1642def norm_mk_not (expr): 1643 return norm_neg (mk_not (expr)) 1644 1645def split_conjuncts (expr): 1646 expr = norm_neg (expr) 1647 if expr.is_op ('And'): 1648 [x, y] = expr.vals 1649 return split_conjuncts (x) + split_conjuncts (y) 1650 else: 1651 return [expr] 1652 1653def split_disjuncts (expr): 1654 expr = norm_neg (expr) 1655 if expr.is_op ('Or'): 1656 [x, y] = expr.vals 1657 return split_disjuncts (x) + split_disjuncts (y) 1658 else: 1659 return [expr] 1660 1661def binary_search_least (test, minimum, maximum): 1662 """find least n, minimum <= n <= maximum, for which test (n).""" 1663 assert maximum >= minimum 1664 if test (minimum): 1665 return minimum 1666 if maximum == minimum or not test (maximum): 1667 return None 1668 while maximum > minimum + 1: 1669 cur = (minimum + maximum) / 2 1670 if test (cur): 1671 maximum = cur 1672 else: 1673 minimum = cur + 1 1674 assert minimum + 1 == maximum 1675 return maximum 1676 1677def binary_search_greatest (test, minimum, maximum): 1678 """find greatest n, minimum <= n <= maximum, for which test (n).""" 1679 assert maximum >= minimum 1680 if test (maximum): 1681 return maximum 1682 if maximum == minimum or not test (minimum): 1683 return None 1684 while maximum > minimum + 1: 1685 cur = (minimum + maximum) / 2 1686 if test (cur): 1687 minimum = cur 1688 else: 1689 maximum = cur - 1 1690 assert minimum + 1 == maximum 1691 return minimum 1692 1693