1# * Copyright 2015, NICTA 2# * 3# * This software may be distributed and modified according to the terms of 4# * the BSD 2-Clause license. Note that NO WARRANTY is provided. 5# * See "LICENSE_BSD2.txt" for details. 6# * 7# * @TAG(NICTA_BSD) 8 9import syntax 10import solver 11import problem 12import rep_graph 13import search 14import logic 15import check 16 17from target_objects import functions, trace, pairings, pre_pairings, printout 18import target_objects 19 20from logic import azip 21 22from syntax import mk_var, word32T, builtinTs, mk_eq, mk_less_eq 23 24last_stuff = [0] 25 26def default_n_vc (p, n): 27 head = p.loop_id (n) 28 general = [(n2, rep_graph.vc_options ([0], [1])) 29 for n2 in p.loop_heads () 30 if n2 != head] 31 specific = [(head, rep_graph.vc_offs (1)) for _ in [1] if head] 32 return (n, tuple (general + specific)) 33 34def split_sum_s_expr (expr, solv, extra_defs, typ): 35 """divides up a linear expression 'a - b - 1 + a' 36 into ({'a':2, 'b': -1}, -1) i.e. 'a' times 2 etc and constant 37 value of -1.""" 38 def rec (expr): 39 return split_sum_s_expr (expr, solv, extra_defs, typ) 40 if expr[0] == 'bvadd': 41 var = {} 42 const = 0 43 for x in expr[1:]: 44 (var2, const2) = rec (x) 45 for (v, count) in var2.iteritems (): 46 var.setdefault (v, 0) 47 var[v] += count 48 const += const2 49 return (var, const) 50 elif expr[0] == 'bvsub': 51 (_, lhs, rhs) = expr 52 (lvar, lconst) = rec (lhs) 53 (rvar, rconst) = rec (rhs) 54 const = lconst - rconst 55 var = dict ([(v, lvar.get (v, 0) - rvar.get (v, 0)) 56 for v in set.union (set (lvar), set (rvar))]) 57 return (var, const) 58 elif expr in solv.defs: 59 return rec (solv.defs[expr]) 60 elif expr in extra_defs: 61 return rec (extra_defs[expr]) 62 elif expr[:2] in ['#x', '#b']: 63 val = solver.smt_to_val (expr) 64 assert val.kind == 'Num' 65 return ({}, val.val) 66 else: 67 return ({expr: 1}, 0) 68 69def split_merge_ite_sum_sexpr (foo): 70 (s0, s1) = [solver.smt_num_t (n, typ) for n in [0, 1]] 71 if y != s0: 72 expr = ('bvadd', ('ite', cond, ('bvsub', x, y), s0), y) 73 return rec (expr) 74 (xvar, xconst) = rec (x) 75 var = dict ([(('ite', cond, v, s0), n) 76 for (v, n) in xvar.iteritems ()]) 77 var.setdefault (('ite', cond, s1, s0), 0) 78 var[('ite', cond, s1, s0)] += xconst 79 return (var, 0) 80 81def simplify_expr_whyps (sexpr, rep, hyps, cache = None, extra_defs = {}, 82 bool_hyps = None): 83 if cache == None: 84 cache = {} 85 if bool_hyps == None: 86 bool_hyps = [] 87 if sexpr in extra_defs: 88 sexpr = extra_defs[sexpr] 89 if sexpr in rep.solv.defs: 90 sexpr = rep.solv.defs[sexpr] 91 if sexpr[0] == 'ite': 92 (_, cond, x, y) = sexpr 93 cond_exp = solver.mk_smt_expr (solver.flat_s_expression (cond), 94 syntax.boolT) 95 (mk_nimp, mk_not) = (syntax.mk_n_implies, syntax.mk_not) 96 if rep.test_hyp_whyps (mk_nimp (bool_hyps, cond_exp), 97 hyps, cache = cache): 98 return x 99 elif rep.test_hyp_whyps (mk_nimp (bool_hyps, mk_not (cond_exp)), 100 hyps, cache = cache): 101 return y 102 x = simplify_expr_whyps (x, rep, hyps, cache = cache, 103 extra_defs = extra_defs, 104 bool_hyps = bool_hyps + [cond_exp]) 105 y = simplify_expr_whyps (y, rep, hyps, cache = cache, 106 extra_defs = extra_defs, 107 bool_hyps = bool_hyps + [syntax.mk_not (cond_exp)]) 108 if x == y: 109 return x 110 return ('ite', cond, x, y) 111 return sexpr 112 113last_10_non_const = [] 114 115def offs_expr_const (addr_expr, sp_expr, rep, hyps, extra_defs = {}, 116 cache = None, typ = syntax.word32T): 117 """if the offset between a stack addr and the initial stack pointer 118 is a constant offset, try to compute it.""" 119 addr_x = solver.parse_s_expression (addr_expr) 120 sp_x = solver.parse_s_expression (sp_expr) 121 vs = [(addr_x, 1), (sp_x, -1)] 122 const = 0 123 124 while True: 125 start_vs = list (vs) 126 new_vs = {} 127 for (x, mult) in vs: 128 (var, c) = split_sum_s_expr (x, rep.solv, extra_defs, 129 typ = typ) 130 for v in var: 131 new_vs.setdefault (v, 0) 132 new_vs[v] += var[v] * mult 133 const += c * mult 134 vs = [(x, n) for (x, n) in new_vs.iteritems () 135 if n % (2 ** typ.num) != 0] 136 if not vs: 137 return const 138 vs = [(simplify_expr_whyps (x, rep, hyps, 139 cache = cache, extra_defs = extra_defs), n) 140 for (x, n) in vs] 141 if sorted (vs) == sorted (start_vs): 142 pass # vs = split_merge_ite_sum_sexpr (vs) 143 if sorted (vs) == sorted (start_vs): 144 trace ('offs_expr_const: not const') 145 trace ('%s - %s' % (addr_expr, sp_expr)) 146 trace (str (vs)) 147 trace (str (hyps)) 148 last_10_non_const.append ((addr_expr, sp_expr, vs, hyps)) 149 del last_10_non_const[:-10] 150 return None 151 152def has_stack_var (expr, stack_var): 153 while True: 154 if expr.is_op ('MemUpdate'): 155 [m, p, v] = expr.vals 156 expr = m 157 elif expr.kind == 'Var': 158 return expr == stack_var 159 else: 160 assert not 'has_stack_var: expr kind', expr 161 162def mk_not_callable_hyps (p): 163 hyps = [] 164 for n in p.nodes: 165 if p.nodes[n].kind != 'Call': 166 continue 167 if get_asm_callable (p.nodes[n].fname): 168 continue 169 tag = p.node_tags[n][0] 170 hyp = rep_graph.pc_false_hyp ((default_n_vc (p, n), tag)) 171 hyps.append (hyp) 172 return hyps 173 174last_get_ptr_offsets = [0] 175last_get_ptr_offsets_setup = [0] 176 177def get_ptr_offsets (p, n_ptrs, bases, hyps = [], cache = None, 178 fail_early = False): 179 """detect which ptrs are guaranteed to be at constant offsets 180 from some set of basis ptrs""" 181 rep = rep_graph.mk_graph_slice (p, fast = True) 182 if cache == None: 183 cache = {} 184 last_get_ptr_offsets[0] = (p, n_ptrs, bases, hyps) 185 186 smt_bases = [] 187 for (n, ptr, k) in bases: 188 n_vc = default_n_vc (p, n) 189 (_, env) = rep.get_node_pc_env (n_vc) 190 smt = solver.smt_expr (ptr, env, rep.solv) 191 smt_bases.append ((smt, k)) 192 ptr_typ = ptr.typ 193 194 smt_ptrs = [] 195 for (n, ptr) in n_ptrs: 196 n_vc = default_n_vc (p, n) 197 pc_env = rep.get_node_pc_env (n_vc) 198 if not pc_env: 199 continue 200 smt = solver.smt_expr (ptr, pc_env[1], rep.solv) 201 hyp = rep_graph.pc_true_hyp ((n_vc, p.node_tags[n][0])) 202 smt_ptrs.append (((n, ptr), smt, hyp)) 203 204 hyps = hyps + mk_not_callable_hyps (p) 205 for tag in set ([p.node_tags[n][0] for (n, _) in n_ptrs]): 206 hyps = hyps + init_correctness_hyps (p, tag) 207 tags = set ([p.node_tags[n][0] for (n, ptr) in n_ptrs]) 208 ex_defs = {} 209 for t in tags: 210 ex_defs.update (get_extra_sp_defs (rep, t)) 211 212 offs = [] 213 for (v, ptr, hyp) in smt_ptrs: 214 off = None 215 for (ptr2, k) in smt_bases: 216 off = offs_expr_const (ptr, ptr2, rep, [hyp] + hyps, 217 cache = cache, extra_defs = ex_defs, 218 typ = ptr_typ) 219 if off != None: 220 offs.append ((v, off, k)) 221 break 222 if off == None: 223 trace ('get_ptr_offs fallthrough at %d: %s' % v) 224 trace (str ([hyp] + hyps)) 225 assert not fail_early, (v, ptr) 226 return offs 227 228def init_correctness_hyps (p, tag): 229 (_, fname, _) = p.get_entry_details (tag) 230 if fname not in pairings: 231 # conveniently handles bootstrap case 232 return [] 233 # revise if multi-pairings for ASM an option 234 [pair] = pairings[fname] 235 true_tag = None 236 if tag in pair.funs: 237 true_tag = tag 238 elif p.hook_tag_hints.get (tag, tag) in pair.funs: 239 true_tag = p.hook_tag_hints.get (tag, tag) 240 if true_tag == None: 241 return [] 242 (inp_eqs, _) = pair.eqs 243 in_tag = "%s_IN" % true_tag 244 eqs = [eq for eq in inp_eqs if eq[0][1] == in_tag 245 and eq[1][1] == in_tag] 246 return check.inst_eqs (p, (), eqs, {true_tag: tag}) 247 248extra_symbols = set () 249 250def preserves_sp (fname): 251 """all functions will keep the stack pointer equal, whether they have 252 pairing partners or not.""" 253 assume_sp_equal = bool (target_objects.hooks ('assume_sp_equal')) 254 if not extra_symbols: 255 for fname2 in target_objects.symbols: 256 extra_symbols.add(fname2) 257 extra_symbols.add('_'.join (fname2.split ('.'))) 258 return (get_asm_calling_convention (fname) 259 or assume_sp_equal 260 or fname in extra_symbols) 261 262def get_extra_sp_defs (rep, tag): 263 """add extra defs/equalities about stack pointer for the 264 purposes of stack depth analysis.""" 265 # FIXME how to parametrise this? 266 sp = mk_var ('r13', syntax.word32T) 267 defs = {} 268 269 fcalls = [n_vc for n_vc in rep.funcs 270 if logic.is_int (n_vc[0]) 271 if rep.p.node_tags[n_vc[0]][0] == tag 272 if preserves_sp (rep.p.nodes[n_vc[0]].fname)] 273 for (n, vc) in fcalls: 274 (inputs, outputs, _) = rep.funcs[(n, vc)] 275 if (sp.name, sp.typ) not in outputs: 276 continue 277 inp_sp = solver.smt_expr (sp, inputs, rep.solv) 278 inp_sp = solver.parse_s_expression (inp_sp) 279 out_sp = solver.smt_expr (sp, outputs, rep.solv) 280 out_sp = solver.parse_s_expression (out_sp) 281 if inp_sp != out_sp: 282 defs[out_sp] = inp_sp 283 return defs 284 285def get_stack_sp (p, tag): 286 """get stack and stack-pointer variables""" 287 entry = p.get_entry (tag) 288 renames = p.entry_exit_renames (tags = [tag]) 289 r = renames[tag + '_IN'] 290 291 sp = syntax.rename_expr (mk_var ('r13', syntax.word32T), r) 292 stack = syntax.rename_expr (mk_var ('stack', 293 syntax.builtinTs['Mem']), r) 294 return (stack, sp) 295 296def pseudo_node_lvals_rvals (node): 297 assert node.kind == 'Call' 298 cc = get_asm_calling_convention_at_node (node) 299 if not cc: 300 return None 301 302 arg_vars = set ([var for arg in cc['args'] 303 for var in syntax.get_expr_var_set (arg)]) 304 305 callee_saved_set = set (cc['callee_saved']) 306 rets = [(nm, typ) for (nm, typ) in node.rets 307 if mk_var (nm, typ) not in callee_saved_set] 308 309 return (rets, arg_vars) 310 311def is_asm_node (p, n): 312 tag = p.node_tags[n][0] 313 return tag == 'ASM' or p.hook_tag_hints.get (tag, None) == 'ASM' 314 315def all_pseudo_node_lvals_rvals (p): 316 pseudo = {} 317 for n in p.nodes: 318 if not is_asm_node (p, n): 319 continue 320 elif p.nodes[n].kind != 'Call': 321 continue 322 ps = pseudo_node_lvals_rvals (p.nodes[n]) 323 if ps != None: 324 pseudo[n] = ps 325 return pseudo 326 327def adjusted_var_dep_outputs_for_tag (p, tag): 328 (ent, fname, _) = p.get_entry_details (tag) 329 fun = functions[fname] 330 cc = get_asm_calling_convention (fname) 331 callee_saved_set = set (cc['callee_saved']) 332 ret_set = set ([(nm, typ) for ret in cc['rets'] 333 for (nm, typ) in syntax.get_expr_var_set (ret)]) 334 rets = [(nm2, typ) for ((nm, typ), (nm2, _)) 335 in azip (fun.outputs, p.outputs[tag]) 336 if (nm, typ) in ret_set 337 or mk_var (nm, typ) in callee_saved_set] 338 return rets 339 340def adjusted_var_dep_outputs (p): 341 outputs = {} 342 for tag in p.outputs: 343 ent = p.get_entry (tag) 344 if is_asm_node (p, ent): 345 outputs[tag] = adjusted_var_dep_outputs_for_tag (p, tag) 346 else: 347 outputs[tag] = p.outputs[tag] 348 def output (n): 349 tag = p.node_tags[n][0] 350 return outputs[tag] 351 return output 352 353def is_stack (expr): 354 return expr.kind == 'Var' and 'stack' in expr.name 355 356class StackOffsMissing (Exception): 357 pass 358 359def stack_virtualise_expr (expr, sp_offs): 360 if expr.is_op ('MemAcc') and is_stack (expr.vals[0]): 361 [m, p] = expr.vals 362 if expr.typ == syntax.word8T: 363 ps = [(syntax.mk_minus (p, syntax.mk_word32 (n)), n) 364 for n in [0, 1, 2, 3]] 365 elif expr.typ == syntax.word32T: 366 ps = [(p, 0)] 367 else: 368 assert expr.typ == syntax.word32T, expr 369 ptrs = [(p, 'MemAcc') for (p, _) in ps] 370 if sp_offs == None: 371 return (ptrs, None) 372 # FIXME: very 32-bit specific 373 ps = [(p, n) for (p, n) in ps if p in sp_offs 374 if sp_offs[p][1] % 4 == 0] 375 if not ps: 376 return (ptrs, expr) 377 [(p, n)] = ps 378 if p not in sp_offs: 379 raise StackOffsMissing () 380 (k, offs) = sp_offs[p] 381 v = mk_var (('Fake', k, offs), syntax.word32T) 382 if n != 0: 383 v = syntax.mk_shiftr (v, n * 8) 384 v = syntax.mk_cast (v, expr.typ) 385 return (ptrs, v) 386 elif expr.kind == 'Op': 387 vs = [stack_virtualise_expr (v, sp_offs) for v in expr.vals] 388 return ([p for (ptrs, _) in vs for p in ptrs], 389 syntax.adjust_op_vals (expr, [v for (_, v) in vs])) 390 else: 391 return ([], expr) 392 393def stack_virtualise_upd (((nm, typ), expr), sp_offs): 394 if 'stack' in nm: 395 upds = [] 396 ptrs = [] 397 while expr.is_op ('MemUpdate'): 398 [m, p, v] = expr.vals 399 ptrs.append ((p, 'MemUpdate')) 400 (ptrs2, v2) = stack_virtualise_expr (v, sp_offs) 401 ptrs.extend (ptrs2) 402 if sp_offs != None: 403 if p not in sp_offs: 404 raise StackOffsMissing () 405 (k, offs) = sp_offs[p] 406 upds.append (((('Fake', k, offs), 407 syntax.word32T), v2)) 408 expr = m 409 assert is_stack (expr), expr 410 return (ptrs, upds) 411 else: 412 (ptrs, expr2) = stack_virtualise_expr (expr, sp_offs) 413 return (ptrs, [((nm, typ), expr2)]) 414 415def stack_virtualise_ret (expr, sp_offs): 416 if expr.kind == 'Var': 417 return ([], (expr.name, expr.typ)) 418 elif expr.is_op ('MemAcc'): 419 [m, p] = expr.vals 420 assert expr.typ == syntax.word32T, expr 421 assert is_stack (m), expr 422 if sp_offs != None: 423 (k, offs) = sp_offs[p] 424 r = (('Fake', k, offs), syntax.word32T) 425 else: 426 r = None 427 return ([(p, 'MemUpdate')], r) 428 else: 429 assert not 'ret expr understood', expr 430 431def stack_virtualise_node (node, sp_offs): 432 if node.kind == 'Cond': 433 (ptrs, cond) = stack_virtualise_expr (node.cond, sp_offs) 434 if sp_offs == None: 435 return (ptrs, None) 436 else: 437 return (ptrs, syntax.Node ('Cond', 438 node.get_conts (), cond)) 439 elif node.kind == 'Call': 440 if is_instruction (node.fname): 441 return ([], node) 442 cc = get_asm_calling_convention_at_node (node) 443 assert cc != None, node.fname 444 args = [arg for arg in cc['args'] if not is_stack (arg)] 445 args = [stack_virtualise_expr (arg, sp_offs) for arg in args] 446 rets = [ret for ret in cc['rets_inp'] if not is_stack (ret)] 447 rets = [stack_virtualise_ret (ret, sp_offs) for ret in rets] 448 ptrs = list (set ([p for (ps, _) in args for p in ps] 449 + [p for (ps, _) in rets for p in ps])) 450 if sp_offs == None: 451 return (ptrs, None) 452 else: 453 return (ptrs, syntax.Node ('Call', node.cont, 454 (None, [v for (_, v) in args] 455 + [p for (p, _) in ptrs], 456 [r for (_, r) in rets]))) 457 elif node.kind == 'Basic': 458 upds = [stack_virtualise_upd (upd, sp_offs) for upd in node.upds] 459 ptrs = list (set ([p for (ps, _) in upds for p in ps])) 460 if sp_offs == None: 461 return (ptrs, None) 462 else: 463 ptr_upds = [(('unused#ptr#name%d' % i, syntax.word32T), 464 ptr) for (i, (ptr, _)) in enumerate (ptrs)] 465 return (ptrs, syntax.Node ('Basic', node.cont, 466 [upd for (_, us) in upds for upd in us] 467 + ptr_upds)) 468 else: 469 assert not "node kind understood", node.kind 470 471def mk_get_local_offs (p, tag, sp_reps): 472 (stack, _) = get_stack_sp (p, tag) 473 def mk_local (n, kind, off, k): 474 (v, off2) = sp_reps[n][k] 475 ptr = syntax.mk_plus (v, syntax.mk_word32 (off + off2)) 476 if kind == 'Ptr': 477 return ptr 478 elif kind == 'MemAcc': 479 return syntax.mk_memacc (stack, ptr, syntax.word32T) 480 return mk_local 481 482def adjust_ret_ptr (ptr): 483 """this is a bit of a hack. 484 485 the return slots are named based on r0_input, which will be unchanged, 486 which is handy, but we really want to be talking about r0, which will 487 produce meaningful offsets against the pointers actually used in the 488 program.""" 489 490 return logic.var_subst (ptr, {('r0_input', syntax.word32T): 491 syntax.mk_var ('r0', syntax.word32T)}, must_subst = False) 492 493def get_loop_virtual_stack_analysis (p, tag): 494 """computes variable liveness etc analyses with stack slots treated 495 as virtual variables.""" 496 cache_key = ('loop_stack_analysis', tag) 497 if cache_key in p.cached_analysis: 498 return p.cached_analysis[cache_key] 499 500 (ent, fname, _) = p.get_entry_details (tag) 501 (_, sp) = get_stack_sp (p, tag) 502 cc = get_asm_calling_convention (fname) 503 rets = list (set ([ptr for arg in cc['rets'] 504 for (ptr, _) in stack_virtualise_expr (arg, None)[0]])) 505 rets = [adjust_ret_ptr (ret) for ret in rets] 506 renames = p.entry_exit_renames (tags = [tag]) 507 r = renames[tag + '_OUT'] 508 rets = [syntax.rename_expr (ret, r) for ret in rets] 509 510 ns = [n for n in p.nodes if p.node_tags[n][0] == tag] 511 loop_ns = logic.minimal_loop_node_set (p) 512 513 ptrs = list (set ([(n, ptr) for n in ns 514 for ptr in (stack_virtualise_node (p.nodes[n], None))[0]])) 515 ptrs += [(n, (sp, 'StackPointer')) for n in ns if n in loop_ns] 516 offs = get_ptr_offsets (p, [(n, ptr) for (n, (ptr, _)) in ptrs], 517 [(ent, sp, 'stack')] 518 + [(ent, ptr, 'indirect_ret') for ptr in rets[:1]]) 519 520 ptr_offs = {} 521 rep_offs = {} 522 upd_offsets = {} 523 for ((n, ptr), off, k) in offs: 524 off = norm_int (off, 32) 525 ptr_offs.setdefault (n, {}) 526 rep_offs.setdefault (n, {}) 527 ptr_offs[n][ptr] = (k, off) 528 rep_offs[n][k] = (ptr, - off) 529 530 for (n, (ptr, kind)) in ptrs: 531 if kind == 'MemUpdate' and n in loop_ns: 532 loop = p.loop_id (n) 533 (k, off) = ptr_offs[n][ptr] 534 upd_offsets.setdefault (loop, set ()) 535 upd_offsets[loop].add ((k, off)) 536 loc_offs = mk_get_local_offs (p, tag, rep_offs) 537 538 adj_nodes = {} 539 for n in ns: 540 try: 541 (_, node) = stack_virtualise_node (p.nodes[n], 542 ptr_offs.get (n, {})) 543 except StackOffsMissing, e: 544 printout ("Stack analysis issue at (%d, %s)." 545 % (n, p.node_tags[n])) 546 node = p.nodes[n] 547 adj_nodes[n] = node 548 549 # finally do analysis on this collection of nodes 550 551 preds = dict (p.preds) 552 preds['Ret'] = [n for n in preds['Ret'] if p.node_tags[n][0] == tag] 553 preds['Err'] = [n for n in preds['Err'] if p.node_tags[n][0] == tag] 554 vds = logic.compute_var_deps (adj_nodes, 555 adjusted_var_dep_outputs (p), preds) 556 557 result = (vds, adj_nodes, loc_offs, upd_offsets, (ptrs, offs)) 558 p.cached_analysis[cache_key] = result 559 return result 560 561def norm_int (n, radix): 562 n = n & ((1 << radix) - 1) 563 n2 = n - (1 << radix) 564 if abs (n2) < abs (n): 565 return n2 566 else: 567 return n 568 569def loop_var_analysis (p, split): 570 """computes the same loop dataflow analysis as in the 'logic' module 571 but with stack slots treated as virtual variables.""" 572 if not is_asm_node (p, split): 573 return None 574 head = p.loop_id (split) 575 tag = p.node_tags[split][0] 576 assert head 577 578 key = ('loop_stack_virtual_var_cycle_analysis', split) 579 if key in p.cached_analysis: 580 return p.cached_analysis[key] 581 582 (vds, adj_nodes, loc_offs, 583 upd_offsets, _) = get_loop_virtual_stack_analysis (p, tag) 584 loop = p.loop_body (head) 585 586 va = logic.compute_loop_var_analysis (p, vds, split, 587 override_nodes = adj_nodes) 588 589 (stack, _) = get_stack_sp (p, tag) 590 591 va2 = [] 592 uoffs = upd_offsets.get (head, []) 593 for (v, data) in va: 594 if v.kind == 'Var' and v.name[0] == 'Fake': 595 (_, k, offs) = v.name 596 if (k, offs) not in uoffs: 597 continue 598 v2 = loc_offs (split, 'MemAcc', offs, k) 599 va2.append ((v2, data)) 600 elif v.kind == 'Var' and v.name.startswith ('stack'): 601 assert v.typ == stack.typ 602 continue 603 else: 604 va2.append ((v, data)) 605 stack_const = stack 606 for (k, off) in uoffs: 607 stack_const = syntax.mk_memupd (stack_const, 608 loc_offs (split, 'Ptr', off, k), 609 syntax.mk_word32 (0)) 610 sp = asm_stack_rep_hook (p, (stack.name, stack.typ), 'Loop', split) 611 assert sp and sp[0] == 'SplitMem', (split, sp) 612 (_, st_split) = sp 613 stack_const = logic.mk_stack_wrapper (st_split, stack_const, []) 614 stack_const = logic.mk_eq_selective_wrapper (stack_const, 615 ([], [0])) 616 617 va2.append ((stack_const, 'LoopConst')) 618 619 p.cached_analysis[key] = va2 620 return va2 621 622def inline_no_pre_pairing (p): 623 # FIXME: handle code sharing with check.inline_completely_unmatched 624 while True: 625 ns = [n for n in p.nodes if p.nodes[n].kind == 'Call' 626 if p.nodes[n].fname not in pre_pairings 627 if not is_instruction (p.nodes[n].fname)] 628 for n in ns: 629 trace ('Inlining %s at %d.' % (p.nodes[n].fname, n)) 630 problem.inline_at_point (p, n) 631 if not ns: 632 return 633 634last_asm_stack_depth_fun = [0] 635 636def check_before_guess_asm_stack_depth (fun): 637 from solver import smt_expr 638 if not fun.entry: 639 return None 640 p = fun.as_problem (problem.Problem, name = 'Target') 641 try: 642 p.do_analysis () 643 p.check_no_inner_loops () 644 inline_no_pre_pairing (p) 645 except problem.Abort, e: 646 return None 647 rep = rep_graph.mk_graph_slice (p, fast = True) 648 try: 649 rep.get_pc (default_n_vc (p, 'Ret'), 'Target') 650 err_pc = rep.get_pc (default_n_vc (p, 'Err'), 'Target') 651 except solver.EnvMiss, e: 652 return None 653 654 inlined_funs = set ([fn for (_, _, fn) in p.inline_scripts['Target']]) 655 if inlined_funs: 656 printout (' (stack analysis also involves %s)' 657 % ', '.join(inlined_funs)) 658 659 return p 660 661def guess_asm_stack_depth (fun): 662 p = check_before_guess_asm_stack_depth (fun) 663 if not p: 664 return (0, {}) 665 666 last_asm_stack_depth_fun[0] = fun.name 667 668 entry = p.get_entry ('Target') 669 (_, sp) = get_stack_sp (p, 'Target') 670 671 nodes = get_asm_reachable_nodes (p, tag_set = ['Target']) 672 673 offs = get_ptr_offsets (p, [(n, sp) for n in nodes], 674 [(entry, sp, 'InitSP')], fail_early = True) 675 676 assert len (offs) == len (nodes), map (hex, set (nodes) 677 - set ([n for ((n, _), _, _) in offs])) 678 679 all_offs = [(n, signed_offset (off, 32, 10 ** 6)) 680 for ((n, ptr), off, _) in offs] 681 min_offs = min ([offs for (n, offs) in all_offs]) 682 max_offs = max ([offs for (n, offs) in all_offs]) 683 684 assert min_offs >= 0 or max_offs <= 0, all_offs 685 multiplier = 1 686 if min_offs < 0: 687 multiplier = -1 688 max_offs = - min_offs 689 690 fcall_offs = [(p.nodes[n].fname, offs * multiplier) 691 for (n, offs) in all_offs if p.nodes[n].kind == 'Call'] 692 fun_offs = {} 693 for f in set ([f for (f, _) in fcall_offs]): 694 fun_offs[f] = max ([offs for (f2, offs) in fcall_offs 695 if f2 == f]) 696 697 return (max_offs, fun_offs) 698 699def signed_offset (n, bits, bound = 0): 700 n = n & ((1 << bits) - 1) 701 if n >= (1 << (bits - 1)): 702 n = n - (1 << bits) 703 if bound: 704 assert n <= bound, (n, bound) 705 assert n >= (- bound), (n, bound) 706 return n 707 708def ident_conds (fname, idents): 709 rolling = syntax.true_term 710 conds = [] 711 for ident in idents.get (fname, [syntax.true_term]): 712 conds.append ((ident, syntax.mk_and (rolling, ident))) 713 rolling = syntax.mk_and (rolling, syntax.mk_not (ident)) 714 return conds 715 716def ident_callables (fname, callees, idents): 717 from solver import to_smt_expr, smt_expr 718 from syntax import mk_not, mk_and, true_term 719 720 auto_callables = dict ([((ident, f, true_term), True) 721 for ident in idents.get (fname, [true_term]) 722 for f in callees if f not in idents]) 723 724 if not [f for f in callees if f in idents]: 725 return auto_callables 726 727 fun = functions[fname] 728 p = fun.as_problem (problem.Problem, name = 'Target') 729 check_ns = [(n, ident, cond) for n in p.nodes 730 if p.nodes[n].kind == 'Call' 731 if p.nodes[n].fname in idents 732 for (ident, cond) in ident_conds (p.nodes[n].fname, idents)] 733 734 p.do_analysis () 735 assert check_ns 736 737 rep = rep_graph.mk_graph_slice (p, fast = True) 738 err_hyp = rep_graph.pc_false_hyp ((default_n_vc (p, 'Err'), 'Target')) 739 740 callables = auto_callables 741 nhyps = mk_not_callable_hyps (p) 742 743 for (ident, cond) in ident_conds (fname, idents): 744 renames = p.entry_exit_renames (tags = ['Target']) 745 cond = syntax.rename_expr (cond, renames['Target_IN']) 746 entry = p.get_entry ('Target') 747 e_vis = ((entry, ()), 'Target') 748 hyps = [err_hyp, rep_graph.eq_hyp ((cond, e_vis), 749 (true_term, e_vis))] 750 751 for (n, ident2, cond2) in check_ns: 752 k = (ident, p.nodes[n].fname, ident2) 753 (inp_env, _, _) = rep.get_func (default_n_vc (p, n)) 754 pc = rep.get_pc (default_n_vc (p, n)) 755 cond2 = to_smt_expr (cond2, inp_env, rep.solv) 756 if rep.test_hyp_whyps (mk_not (mk_and (pc, cond2)), 757 hyps + nhyps): 758 callables[k] = False 759 else: 760 callables[k] = True 761 return callables 762 763def compute_immediate_stack_bounds (idents, names): 764 from syntax import true_term 765 immed = {} 766 names = sorted (names) 767 for (i, fname) in enumerate (names): 768 printout ('Doing stack analysis for %r. (%d of %d)' % (fname, 769 i + 1, len (names))) 770 fun = functions[fname] 771 (offs, fn_offs) = guess_asm_stack_depth (fun) 772 callables = ident_callables (fname, fn_offs.keys (), idents) 773 for ident in idents.get (fname, [true_term]): 774 calls = [((fname2, ident2), fn_offs[fname2]) 775 for fname2 in fn_offs 776 for ident2 in idents.get (fname2, [true_term]) 777 if callables[(ident, fname2, ident2)]] 778 immed[(fname, ident)] = (offs, dict (calls)) 779 last_immediate_stack_bounds[0] = immed 780 return immed 781 782last_immediate_stack_bounds = [0] 783 784def immediate_stack_bounds_loop (immed): 785 graph = dict ([(k, immed[k][1].keys ()) for k in immed]) 786 graph['ENTRY'] = list (immed) 787 comps = logic.tarjan (graph, ['ENTRY']) 788 rec_comps = [[x] + y for (x, y) in comps if y] 789 return rec_comps 790 791def compute_recursive_stack_bounds (immed): 792 assert not immediate_stack_bounds_loop (immed) 793 bounds = {} 794 todo = immed.keys () 795 report = 1000 796 while todo: 797 if len (todo) >= report: 798 trace ('todo length %d' % len (todo)) 799 trace ('tail: %s' % todo[-20:]) 800 report += 1000 801 (fname, ident) = todo.pop () 802 if (fname, ident) in bounds: 803 continue 804 (static, calls) = immed[(fname, ident)] 805 if [1 for k in calls if k not in bounds]: 806 todo.append ((fname, ident)) 807 todo.extend (calls.keys ()) 808 continue 809 else: 810 bounds[(fname, ident)] = max ([static] 811 + [bounds[k] + calls[k] for k in calls]) 812 return bounds 813 814def stack_bounds_to_closed_form (bounds, names, idents): 815 closed = {} 816 for fname in names: 817 res = syntax.mk_word32 (bounds[(fname, syntax.true_term)]) 818 extras = [] 819 if fname in idents: 820 assert idents[fname][-1] == syntax.true_term 821 extras = reversed (idents[fname][:-1]) 822 for ident in extras: 823 alt = syntax.mk_word32 (bounds[(fname, ident)]) 824 res = syntax.mk_if (ident, alt, res) 825 closed[fname] = res 826 return closed 827 828def compute_asm_stack_bounds (idents, names): 829 immed = compute_immediate_stack_bounds (idents, names) 830 bounds = compute_recursive_stack_bounds (immed) 831 closed = stack_bounds_to_closed_form (bounds, names, idents) 832 return closed 833 834recursion_trace = [] 835recursion_last_assns = [[]] 836 837def get_recursion_identifiers (funs, extra_unfolds = []): 838 idents = {} 839 del recursion_trace[:] 840 graph = dict ([(f, list (functions[f].function_calls ())) 841 for f in functions]) 842 fs = funs 843 fs2 = set () 844 while fs2 != fs: 845 fs2 = fs 846 fs = set.union (set ([f for f in graph if [f2 for f2 in graph[f] 847 if f2 in fs2]]), 848 set ([f2 for f in fs2 for f2 in graph[f]]), fs2) 849 graph = dict ([(f, graph[f]) for f in fs]) 850 entries = list (fs - set ([f2 for f in graph for f2 in graph[f]])) 851 comps = logic.tarjan (graph, entries) 852 for (head, tail) in comps: 853 if tail or head in graph[head]: 854 group = [head] + list (tail) 855 idents2 = compute_recursion_idents (group, 856 extra_unfolds) 857 idents.update (idents2) 858 return idents 859 860def compute_recursion_idents (group, extra_unfolds): 861 idents = {} 862 group = set (group) 863 recursion_trace.append ('Computing for group %s' % group) 864 printout ('Doing recursion analysis for function group:') 865 printout (' %s' % list(group)) 866 prevs = set ([f for f in functions 867 if [f2 for f2 in functions[f].function_calls () if f2 in group]]) 868 for f in prevs - group: 869 recursion_trace.append (' checking for %s' % f) 870 trace ('Checking idents for %s' % f) 871 while add_recursion_ident (f, group, idents, extra_unfolds): 872 pass 873 return idents 874 875def function_link_assns (p, call_site, tag): 876 call_vis = (default_n_vc (p, call_site), p.node_tags[call_site][0]) 877 return rep_graph.mk_function_link_hyps (p, call_vis, tag) 878 879def add_recursion_ident (f, group, idents, extra_unfolds): 880 from syntax import mk_eq, mk_implies, mk_var 881 p = problem.Problem (None, name = 'Recursion Test') 882 chain = [] 883 tag = 'fun0' 884 p.add_entry_function (functions[f], tag) 885 p.do_analysis () 886 assns = [] 887 recursion_last_assns[0] = assns 888 889 while True: 890 res = find_unknown_recursion (p, group, idents, tag, assns, 891 extra_unfolds) 892 if res == None: 893 break 894 if p.nodes[res].fname not in group: 895 problem.inline_at_point (p, res) 896 continue 897 fname = p.nodes[res].fname 898 chain.append (fname) 899 tag = 'fun%d' % len (chain) 900 (args, _, entry) = p.add_entry_function (functions[fname], tag) 901 p.do_analysis () 902 assns += function_link_assns (p, res, tag) 903 if chain == []: 904 return None 905 recursion_trace.append (' created fun chain %s' % chain) 906 word_args = [(i, mk_var (s, typ)) 907 for (i, (s, typ)) in enumerate (args) 908 if typ.kind == 'Word'] 909 rep = rep_graph.mk_graph_slice (p, fast = True) 910 (_, env) = rep.get_node_pc_env ((entry, ())) 911 912 m = {} 913 res = rep.test_hyp_whyps (syntax.false_term, assns, model = m) 914 assert m 915 916 if find_unknown_recursion (p, group, idents, tag, [], []) == None: 917 idents.setdefault (fname, []) 918 idents[fname].append (syntax.true_term) 919 recursion_trace.append (' found final ident for %s' % fname) 920 return syntax.true_term 921 assert word_args 922 recursion_trace.append (' scanning for ident for %s' % fname) 923 for (i, arg) in word_args: 924 (nm, typ) = functions[fname].inputs[i] 925 arg_smt = solver.to_smt_expr (arg, env, rep.solv) 926 val = search.eval_model_expr (m, rep.solv, arg_smt) 927 if not rep.test_hyp_whyps (mk_eq (arg_smt, val), assns): 928 recursion_trace.append (' discarded %s = 0x%x, not stable' % (nm, val.val)) 929 continue 930 entry_vis = ((entry, ()), tag) 931 ass = rep_graph.eq_hyp ((arg, entry_vis), (val, entry_vis)) 932 res = find_unknown_recursion (p, group, idents, tag, 933 assns + [ass], []) 934 if res: 935 fname2 = p.nodes[res].fname 936 recursion_trace.append (' discarded %s, allows recursion to %s' % (nm, fname2)) 937 continue 938 eq = syntax.mk_eq (mk_var (nm, typ), val) 939 idents.setdefault (fname, []) 940 idents[fname].append (eq) 941 recursion_trace.append (' found ident for %s: %s' % (fname, eq)) 942 return eq 943 assert not "identifying assertion found" 944 945def find_unknown_recursion (p, group, idents, tag, assns, extra_unfolds): 946 from syntax import mk_not, mk_and, foldr1 947 rep = rep_graph.mk_graph_slice (p, fast = True) 948 for n in p.nodes: 949 if p.nodes[n].kind != 'Call': 950 continue 951 if p.node_tags[n][0] != tag: 952 continue 953 fname = p.nodes[n].fname 954 if fname in extra_unfolds: 955 return n 956 if fname not in group: 957 continue 958 (inp_env, _, _) = rep.get_func (default_n_vc (p, n)) 959 pc = rep.get_pc (default_n_vc (p, n)) 960 new = foldr1 (mk_and, [pc] + [syntax.mk_not ( 961 solver.to_smt_expr (ident, inp_env, rep.solv)) 962 for ident in idents.get (fname, [])]) 963 if rep.test_hyp_whyps (mk_not (new), assns): 964 continue 965 return n 966 return None 967 968asm_cc_cache = {} 969 970def is_instruction (fname): 971 bits = fname.split ("'") 972 return bits[1:] and bits[:1] in [["l_impl"], ["instruction"]] 973 974def get_asm_calling_convention (fname): 975 if fname in asm_cc_cache: 976 return asm_cc_cache[fname] 977 if fname not in pre_pairings: 978 bits = fname.split ("'") 979 if not is_instruction (fname): 980 trace ("Warning: unusual unmatched function (%s, %s)." 981 % (fname, bits)) 982 return None 983 pair = pre_pairings[fname] 984 assert pair['ASM'] == fname 985 c_fun = functions[pair['C']] 986 from logic import split_scalar_pairs 987 (var_c_args, c_imem, glob_c_args) = split_scalar_pairs (c_fun.inputs) 988 (var_c_rets, c_omem, glob_c_rets) = split_scalar_pairs (c_fun.outputs) 989 990 num_args = len (var_c_args) 991 num_rets = len (var_c_rets) 992 const_mem = not (c_omem) 993 994 cc = get_asm_calling_convention_inner (num_args, num_rets, const_mem) 995 asm_cc_cache[fname] = cc 996 return cc 997 998def get_asm_calling_convention_inner (num_c_args, num_c_rets, const_mem): 999 key = ('Inner', num_c_args, num_c_rets, const_mem) 1000 if key in asm_cc_cache: 1001 return asm_cc_cache[key] 1002 1003 from logic import mk_var_list, mk_stack_sequence 1004 from syntax import mk_var, word32T, builtinTs 1005 1006 arg_regs = mk_var_list (['r0', 'r1', 'r2', 'r3'], word32T) 1007 r0 = arg_regs[0] 1008 sp = mk_var ('r13', word32T) 1009 st = mk_var ('stack', builtinTs['Mem']) 1010 r0_input = mk_var ('r0_input', word32T) 1011 1012 mem = mk_var ('mem', builtinTs['Mem']) 1013 dom = mk_var ('dom', builtinTs['Dom']) 1014 dom_stack = mk_var ('dom_stack', builtinTs['Dom']) 1015 1016 global_args = [mem, dom, st, dom_stack, sp, mk_var ('ret', word32T)] 1017 1018 sregs = mk_stack_sequence (sp, 4, st, word32T, num_c_args + 1) 1019 1020 arg_seq = [r for r in arg_regs] + [s for (s, _) in sregs] 1021 if num_c_rets > 1: 1022 # the 'return-too-much' issue. 1023 # instead r0 is a save-returns-here pointer 1024 arg_seq.pop (0) 1025 rets = mk_stack_sequence (r0_input, 4, st, word32T, num_c_rets) 1026 rets = [r for (r, _) in rets] 1027 else: 1028 rets = [r0] 1029 1030 callee_saved_vars = ([mk_var (v, word32T) 1031 for v in 'r4 r5 r6 r7 r8 r9 r10 r11 r13'.split ()] 1032 + [dom, dom_stack]) 1033 1034 if const_mem: 1035 callee_saved_vars += [mem] 1036 else: 1037 rets += [mem] 1038 rets += [st] 1039 1040 cc = {'args': arg_seq[: num_c_args] + global_args, 1041 'rets': rets, 'callee_saved': callee_saved_vars} 1042 1043 asm_cc_cache[key] = cc 1044 return cc 1045 1046def get_asm_calling_convention_at_node (node): 1047 cc = get_asm_calling_convention (node.fname) 1048 if not cc: 1049 return None 1050 1051 fun = functions[node.fname] 1052 arg_input_map = dict (azip (fun.inputs, node.args)) 1053 ret_output_map = dict (azip (fun.outputs, 1054 [mk_var (nm, typ) for (nm, typ) in node.rets])) 1055 1056 args = [logic.var_subst (arg, arg_input_map) for arg in cc['args']] 1057 rets = [logic.var_subst (ret, ret_output_map) for ret in cc['rets']] 1058 # these are useful because they happen to map ret r0_input back to 1059 # the previous value r0, rather than the useless value r0_input_ignore. 1060 rets_inp = [logic.var_subst (ret, arg_input_map) for ret in cc['rets']] 1061 saved = [logic.var_subst (v, ret_output_map) 1062 for v in cc['callee_saved']] 1063 return {'args': args, 'rets': rets, 1064 'rets_inp': rets_inp, 'callee_saved': saved} 1065 1066call_cache = {} 1067 1068def get_asm_callable (fname): 1069 if fname not in pre_pairings: 1070 return True 1071 c_fun = pre_pairings[fname]['C'] 1072 1073 if not call_cache: 1074 for f in functions: 1075 call_cache[f] = False 1076 for f in functions: 1077 fun = functions[f] 1078 for n in fun.reachable_nodes (simplify = True): 1079 if fun.nodes[n].kind == 'Call': 1080 call_cache[fun.nodes[n].fname] = True 1081 return call_cache[c_fun] 1082 1083def get_asm_reachable_nodes (p, tag_set = None): 1084 if tag_set == None: 1085 tag_set = [tag for tag in p.tags () 1086 if is_asm_node (p, p.get_entry (tag))] 1087 frontier = [p.get_entry (tag) for tag in tag_set] 1088 nodes = set () 1089 while frontier: 1090 n = frontier.pop () 1091 if n in nodes or n not in p.nodes: 1092 continue 1093 nodes.add (n) 1094 node = p.nodes[n] 1095 if node.kind == 'Call' and not get_asm_callable (node.fname): 1096 continue 1097 node = logic.simplify_node_elementary (node) 1098 frontier.extend (node.get_conts ()) 1099 return nodes 1100 1101def convert_recursion_idents (idents): 1102 asm_idents = {} 1103 for f in idents: 1104 if f not in pre_pairings: 1105 continue 1106 f2 = pre_pairings[f]['ASM'] 1107 assert f2 != f 1108 asm_idents[f2] = [] 1109 for ident in idents[f]: 1110 if ident.is_op ('True'): 1111 asm_idents[f2].append (ident) 1112 elif ident.is_op ('Equals'): 1113 [x, y] = ident.vals 1114 # this is a bit hacky 1115 [i] = [i for (i, (nm, typ)) 1116 in enumerate (functions[f].inputs) 1117 if x.is_var ((nm, typ))] 1118 cc = get_asm_calling_convention (f2) 1119 x = cc['args'][i] 1120 asm_idents[f2].append (syntax.mk_eq (x, y)) 1121 else: 1122 assert not 'ident kind convertible' 1123 return asm_idents 1124 1125def mk_pairing (pre_pair, stack_bounds): 1126 asm_f = pre_pair['ASM'] 1127 sz = stack_bounds[asm_f] 1128 c_fun = functions[pre_pair['C']] 1129 1130 from logic import split_scalar_pairs 1131 (var_c_args, c_imem, glob_c_args) = split_scalar_pairs (c_fun.inputs) 1132 (var_c_rets, c_omem, glob_c_rets) = split_scalar_pairs (c_fun.outputs) 1133 1134 eqs = logic.mk_eqs_arm_none_eabi_gnu (var_c_args, var_c_rets, 1135 c_imem, c_omem, sz) 1136 1137 return logic.Pairing (['ASM', 'C'], 1138 {'ASM': asm_f, 'C': c_fun.name}, eqs) 1139 1140def mk_pairings (stack_bounds): 1141 new_pairings = {} 1142 for f in pre_pairings: 1143 if f in new_pairings: 1144 continue 1145 pair = mk_pairing (pre_pairings[f], stack_bounds) 1146 for fun in pair.funs.itervalues (): 1147 new_pairings[fun] = [pair] 1148 return new_pairings 1149 1150def serialise_stack_bounds (stack_bounds): 1151 lines = [] 1152 for fname in stack_bounds: 1153 ss = ['StackBound', fname] 1154 stack_bounds[fname].serialise (ss) 1155 lines.append (' '.join (ss) + '\n') 1156 return lines 1157 1158def deserialise_stack_bounds (lines): 1159 bounds = {} 1160 for line in lines: 1161 bits = line.split () 1162 if not bits: 1163 continue 1164 assert bits[0] == 'StackBound' 1165 fname = bits[1] 1166 (_, bound) = syntax.parse_expr (bits, 2) 1167 bounds[fname] = bound 1168 return bounds 1169 1170funs_with_tag = {} 1171 1172def get_functions_with_tag (tag): 1173 if tag in funs_with_tag: 1174 return funs_with_tag[tag] 1175 visit = set ([pre_pairings[f][tag] for f in pre_pairings 1176 if tag in pre_pairings[f]]) 1177 visit.update ([pair.funs[tag] for f in pairings 1178 for pair in pairings[f] if tag in pair.funs]) 1179 funs = set (visit) 1180 while visit: 1181 f = visit.pop () 1182 funs.add (f) 1183 visit.update (set (functions[f].function_calls ()) - funs) 1184 funs_with_tag[tag] = funs 1185 return funs 1186 1187def compute_stack_bounds (quiet = False): 1188 prev_tracer = target_objects.tracer[0] 1189 if quiet: 1190 target_objects.tracer[0] = lambda s, n: () 1191 1192 try: 1193 c_fs = get_functions_with_tag ('C') 1194 idents = get_recursion_identifiers (c_fs) 1195 asm_idents = convert_recursion_idents (idents) 1196 asm_fs = get_functions_with_tag ('ASM') 1197 printout ('Computed recursion limits.') 1198 1199 bounds = compute_asm_stack_bounds (asm_idents, asm_fs) 1200 printout ('Computed stack bounds.') 1201 except Exception, e: 1202 if quiet: 1203 target_objects.tracer[0] = prev_tracer 1204 raise 1205 1206 if quiet: 1207 target_objects.tracer[0] = prev_tracer 1208 return bounds 1209 1210def read_fn_hash (fname): 1211 try: 1212 f = open (fname) 1213 s = f.readline () 1214 bits = s.split () 1215 if bits[0] != 'FunctionHash' or len (bits) != 2: 1216 return None 1217 return int (bits[1]) 1218 except ValueError, e: 1219 return None 1220 except IndexError, e: 1221 return None 1222 except IOError, e: 1223 return None 1224 1225def mk_stack_pairings (pairing_tups, stack_bounds_fname = None, 1226 quiet = True): 1227 """build the stack-aware calling-convention-aware logical pairings 1228 once a collection of function pairs have been read.""" 1229 1230 # simplifies interactive testing of this function 1231 pre_pairings.clear () 1232 1233 for (asm_f, c_f) in pairing_tups: 1234 pair = {'ASM': asm_f, 'C': c_f} 1235 assert c_f not in pre_pairings 1236 assert asm_f not in pre_pairings 1237 pre_pairings[c_f] = pair 1238 pre_pairings[asm_f] = pair 1239 1240 fn_hash = hash (tuple (sorted ([(f, hash (functions[f])) 1241 for f in functions]))) 1242 prev_hash = read_fn_hash (stack_bounds_fname) 1243 if prev_hash == fn_hash: 1244 f = open (stack_bounds_fname) 1245 f.readline () 1246 stack_bounds = deserialise_stack_bounds (f) 1247 f.close () 1248 else: 1249 printout ('Computing stack bounds.') 1250 stack_bounds = compute_stack_bounds (quiet = quiet) 1251 f = open (stack_bounds_fname, 'w') 1252 f.write ('FunctionHash %s\n' % fn_hash) 1253 for line in serialise_stack_bounds (stack_bounds): 1254 f.write(line) 1255 f.close () 1256 1257 problematic_synthetic () 1258 1259 return mk_pairings (stack_bounds) 1260 1261def asm_stack_rep_hook (p, (nm, typ), kind, n): 1262 if not is_asm_node (p, n): 1263 return None 1264 1265 if not (nm.startswith ('stack') and typ == syntax.builtinTs['Mem']): 1266 return None 1267 1268 assert kind in ['Call', 'Init', 'Loop'], kind 1269 if kind == 'Init': 1270 return None 1271 1272 tag = p.node_tags[n][0] 1273 (_, sp) = get_stack_sp (p, tag) 1274 1275 return ('SplitMem', sp) 1276 1277reg_aliases = {'r11': ['fp'], 'r14': ['lr'], 'r13': ['sp']} 1278 1279def inst_const_rets (node): 1280 assert "instruction'" in node.fname 1281 bits = set ([s.lower () for s in node.fname.split ('_')]) 1282 fun = functions[node.fname] 1283 def is_const (nm, typ): 1284 if typ in [builtinTs['Mem'], builtinTs['Dom']]: 1285 return True 1286 if typ != word32T: 1287 return False 1288 return not (nm in bits or [al for al in reg_aliases.get (nm, []) 1289 if al in bits]) 1290 is_consts = [is_const (nm, typ) for (nm, typ) in fun.outputs] 1291 input_set = set ([v for arg in node.args 1292 for v in syntax.get_expr_var_set (arg)]) 1293 return [mk_var (nm, typ) 1294 for ((nm, typ), const) in azip (node.rets, is_consts) 1295 if const and (nm, typ) in input_set] 1296 1297def node_const_rets (node): 1298 if "instruction'" in node.fname: 1299 return inst_const_rets (node) 1300 if node.fname in pre_pairings: 1301 if pre_pairings[node.fname]['ASM'] != node.fname: 1302 return None 1303 cc = get_asm_calling_convention_at_node (node) 1304 input_set = set ([v for arg in node.args 1305 for v in syntax.get_expr_var_set (arg)]) 1306 callee_saved_set = set (cc['callee_saved']) 1307 return [mk_var (nm, typ) for (nm, typ) in node.rets 1308 if mk_var (nm, typ) in callee_saved_set 1309 if (nm, typ) in input_set] 1310 elif preserves_sp (node.fname): 1311 if node.fname not in get_functions_with_tag ('ASM'): 1312 return None 1313 f_outs = functions[node.fname].outputs 1314 return [mk_var (nm, typ) 1315 for ((nm, typ), (nm2, _)) in azip (node.rets, f_outs) 1316 if nm2 == 'r13'] 1317 else: 1318 return None 1319 1320def const_ret_hook (node, nm, typ): 1321 consts = node_const_rets (node) 1322 return consts and mk_var (nm, typ) in consts 1323 1324def get_const_rets (p, node_set = None): 1325 if node_set == None: 1326 node_set = p.nodes 1327 const_rets = {} 1328 for n in node_set: 1329 if p.nodes[n].kind != 'Call': 1330 continue 1331 consts = node_const_rets (node) 1332 const_rets[n] = [(v.name, v.typ) for v in consts] 1333 return const_rets 1334 1335def problematic_synthetic (): 1336 synth = [s for s in target_objects.symbols 1337 if '.clone.' in s or '.part.' in s or '.constprop.' in s] 1338 synth = ['_'.join (s.split ('.')) for s in synth] 1339 if not synth: 1340 return 1341 printout ('Synthetic symbols: %s' % synth) 1342 synth_calls = set ([f for f in synth 1343 if f in functions 1344 if functions[f].function_calls ()]) 1345 printout ('Synthetic symbols which make function calls: %s' 1346 % synth_calls) 1347 if not synth_calls: 1348 return 1349 synth_stack = set ([f for f in synth_calls 1350 if [node for node in functions[f].nodes.itervalues () 1351 if node.kind == 'Basic' 1352 if ('r13', word32T) in node.get_lvals ()]]) 1353 printout ('Synthetic symbols which call and move sp: %s' 1354 % synth_stack) 1355 synth_problems = set ([f for f in synth_stack 1356 if [f2 for f2 in functions 1357 if f in functions[f2].function_calls () 1358 if len (set (functions[f2].function_calls ())) > 1] 1359 ]) 1360 printout ('Problematic synthetics: %s' % synth_problems) 1361 return synth_problems 1362 1363def add_hooks (): 1364 k = 'stack_logic' 1365 add = target_objects.add_hook 1366 add ('problem_var_rep', k, asm_stack_rep_hook) 1367 add ('loop_var_analysis', k, loop_var_analysis) 1368 add ('rep_unsafe_const_ret', k, const_ret_hook) 1369 1370add_hooks () 1371 1372