1# 2# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230) 3# 4# SPDX-License-Identifier: BSD-2-Clause 5# 6 7import syntax 8import solver 9import problem 10import rep_graph 11import search 12import logic 13import check 14 15from target_objects import functions, trace, pairings, pre_pairings, printout 16import target_objects 17 18from logic import azip 19 20from syntax import mk_var, word32T, builtinTs, mk_eq, mk_less_eq 21 22last_stuff = [0] 23 24def default_n_vc (p, n): 25 head = p.loop_id (n) 26 general = [(n2, rep_graph.vc_options ([0], [1])) 27 for n2 in p.loop_heads () 28 if n2 != head] 29 specific = [(head, rep_graph.vc_offs (1)) for _ in [1] if head] 30 return (n, tuple (general + specific)) 31 32def split_sum_s_expr (expr, solv, extra_defs, typ): 33 """divides up a linear expression 'a - b - 1 + a' 34 into ({'a':2, 'b': -1}, -1) i.e. 'a' times 2 etc and constant 35 value of -1.""" 36 def rec (expr): 37 return split_sum_s_expr (expr, solv, extra_defs, typ) 38 if expr[0] == 'bvadd': 39 var = {} 40 const = 0 41 for x in expr[1:]: 42 (var2, const2) = rec (x) 43 for (v, count) in var2.iteritems (): 44 var.setdefault (v, 0) 45 var[v] += count 46 const += const2 47 return (var, const) 48 elif expr[0] == 'bvsub': 49 (_, lhs, rhs) = expr 50 (lvar, lconst) = rec (lhs) 51 (rvar, rconst) = rec (rhs) 52 const = lconst - rconst 53 var = dict ([(v, lvar.get (v, 0) - rvar.get (v, 0)) 54 for v in set.union (set (lvar), set (rvar))]) 55 return (var, const) 56 elif expr in solv.defs: 57 return rec (solv.defs[expr]) 58 elif expr in extra_defs: 59 return rec (extra_defs[expr]) 60 elif expr[:2] in ['#x', '#b']: 61 val = solver.smt_to_val (expr) 62 assert val.kind == 'Num' 63 return ({}, val.val) 64 else: 65 return ({expr: 1}, 0) 66 67def split_merge_ite_sum_sexpr (foo): 68 (s0, s1) = [solver.smt_num_t (n, typ) for n in [0, 1]] 69 if y != s0: 70 expr = ('bvadd', ('ite', cond, ('bvsub', x, y), s0), y) 71 return rec (expr) 72 (xvar, xconst) = rec (x) 73 var = dict ([(('ite', cond, v, s0), n) 74 for (v, n) in xvar.iteritems ()]) 75 var.setdefault (('ite', cond, s1, s0), 0) 76 var[('ite', cond, s1, s0)] += xconst 77 return (var, 0) 78 79def simplify_expr_whyps (sexpr, rep, hyps, cache = None, extra_defs = {}, 80 bool_hyps = None): 81 if cache == None: 82 cache = {} 83 if bool_hyps == None: 84 bool_hyps = [] 85 if sexpr in extra_defs: 86 sexpr = extra_defs[sexpr] 87 if sexpr in rep.solv.defs: 88 sexpr = rep.solv.defs[sexpr] 89 if sexpr[0] == 'ite': 90 (_, cond, x, y) = sexpr 91 cond_exp = solver.mk_smt_expr (solver.flat_s_expression (cond), 92 syntax.boolT) 93 (mk_nimp, mk_not) = (syntax.mk_n_implies, syntax.mk_not) 94 if rep.test_hyp_whyps (mk_nimp (bool_hyps, cond_exp), 95 hyps, cache = cache): 96 return x 97 elif rep.test_hyp_whyps (mk_nimp (bool_hyps, mk_not (cond_exp)), 98 hyps, cache = cache): 99 return y 100 x = simplify_expr_whyps (x, rep, hyps, cache = cache, 101 extra_defs = extra_defs, 102 bool_hyps = bool_hyps + [cond_exp]) 103 y = simplify_expr_whyps (y, rep, hyps, cache = cache, 104 extra_defs = extra_defs, 105 bool_hyps = bool_hyps + [syntax.mk_not (cond_exp)]) 106 if x == y: 107 return x 108 return ('ite', cond, x, y) 109 return sexpr 110 111last_10_non_const = [] 112 113def offs_expr_const (addr_expr, sp_expr, rep, hyps, extra_defs = {}, 114 cache = None, typ = syntax.word32T): 115 """if the offset between a stack addr and the initial stack pointer 116 is a constant offset, try to compute it.""" 117 addr_x = solver.parse_s_expression (addr_expr) 118 sp_x = solver.parse_s_expression (sp_expr) 119 vs = [(addr_x, 1), (sp_x, -1)] 120 const = 0 121 122 while True: 123 start_vs = list (vs) 124 new_vs = {} 125 for (x, mult) in vs: 126 (var, c) = split_sum_s_expr (x, rep.solv, extra_defs, 127 typ = typ) 128 for v in var: 129 new_vs.setdefault (v, 0) 130 new_vs[v] += var[v] * mult 131 const += c * mult 132 vs = [(x, n) for (x, n) in new_vs.iteritems () 133 if n % (2 ** typ.num) != 0] 134 if not vs: 135 return const 136 vs = [(simplify_expr_whyps (x, rep, hyps, 137 cache = cache, extra_defs = extra_defs), n) 138 for (x, n) in vs] 139 if sorted (vs) == sorted (start_vs): 140 pass # vs = split_merge_ite_sum_sexpr (vs) 141 if sorted (vs) == sorted (start_vs): 142 trace ('offs_expr_const: not const') 143 trace ('%s - %s' % (addr_expr, sp_expr)) 144 trace (str (vs)) 145 trace (str (hyps)) 146 last_10_non_const.append ((addr_expr, sp_expr, vs, hyps)) 147 del last_10_non_const[:-10] 148 return None 149 150def has_stack_var (expr, stack_var): 151 while True: 152 if expr.is_op ('MemUpdate'): 153 [m, p, v] = expr.vals 154 expr = m 155 elif expr.kind == 'Var': 156 return expr == stack_var 157 else: 158 assert not 'has_stack_var: expr kind', expr 159 160def mk_not_callable_hyps (p): 161 hyps = [] 162 for n in p.nodes: 163 if p.nodes[n].kind != 'Call': 164 continue 165 if get_asm_callable (p.nodes[n].fname): 166 continue 167 tag = p.node_tags[n][0] 168 hyp = rep_graph.pc_false_hyp ((default_n_vc (p, n), tag)) 169 hyps.append (hyp) 170 return hyps 171 172last_get_ptr_offsets = [0] 173last_get_ptr_offsets_setup = [0] 174 175def get_ptr_offsets (p, n_ptrs, bases, hyps = [], cache = None, 176 fail_early = False): 177 """detect which ptrs are guaranteed to be at constant offsets 178 from some set of basis ptrs""" 179 rep = rep_graph.mk_graph_slice (p, fast = True) 180 if cache == None: 181 cache = {} 182 last_get_ptr_offsets[0] = (p, n_ptrs, bases, hyps) 183 184 smt_bases = [] 185 for (n, ptr, k) in bases: 186 n_vc = default_n_vc (p, n) 187 (_, env) = rep.get_node_pc_env (n_vc) 188 smt = solver.smt_expr (ptr, env, rep.solv) 189 smt_bases.append ((smt, k)) 190 ptr_typ = ptr.typ 191 192 smt_ptrs = [] 193 for (n, ptr) in n_ptrs: 194 n_vc = default_n_vc (p, n) 195 pc_env = rep.get_node_pc_env (n_vc) 196 if not pc_env: 197 continue 198 smt = solver.smt_expr (ptr, pc_env[1], rep.solv) 199 hyp = rep_graph.pc_true_hyp ((n_vc, p.node_tags[n][0])) 200 smt_ptrs.append (((n, ptr), smt, hyp)) 201 202 hyps = hyps + mk_not_callable_hyps (p) 203 for tag in set ([p.node_tags[n][0] for (n, _) in n_ptrs]): 204 hyps = hyps + init_correctness_hyps (p, tag) 205 tags = set ([p.node_tags[n][0] for (n, ptr) in n_ptrs]) 206 ex_defs = {} 207 for t in tags: 208 ex_defs.update (get_extra_sp_defs (rep, t)) 209 210 offs = [] 211 for (v, ptr, hyp) in smt_ptrs: 212 off = None 213 for (ptr2, k) in smt_bases: 214 off = offs_expr_const (ptr, ptr2, rep, [hyp] + hyps, 215 cache = cache, extra_defs = ex_defs, 216 typ = ptr_typ) 217 if off != None: 218 offs.append ((v, off, k)) 219 break 220 if off == None: 221 trace ('get_ptr_offs fallthrough at %d: %s' % v) 222 trace (str ([hyp] + hyps)) 223 assert not fail_early, (v, ptr) 224 return offs 225 226def init_correctness_hyps (p, tag): 227 (_, fname, _) = p.get_entry_details (tag) 228 if fname not in pairings: 229 # conveniently handles bootstrap case 230 return [] 231 # revise if multi-pairings for ASM an option 232 [pair] = pairings[fname] 233 true_tag = None 234 if tag in pair.funs: 235 true_tag = tag 236 elif p.hook_tag_hints.get (tag, tag) in pair.funs: 237 true_tag = p.hook_tag_hints.get (tag, tag) 238 if true_tag == None: 239 return [] 240 (inp_eqs, _) = pair.eqs 241 in_tag = "%s_IN" % true_tag 242 eqs = [eq for eq in inp_eqs if eq[0][1] == in_tag 243 and eq[1][1] == in_tag] 244 return check.inst_eqs (p, (), eqs, {true_tag: tag}) 245 246extra_symbols = set () 247 248def preserves_sp (fname): 249 """all functions will keep the stack pointer equal, whether they have 250 pairing partners or not.""" 251 assume_sp_equal = bool (target_objects.hooks ('assume_sp_equal')) 252 if not extra_symbols: 253 for fname2 in target_objects.symbols: 254 extra_symbols.add(fname2) 255 extra_symbols.add('_'.join (fname2.split ('.'))) 256 return (get_asm_calling_convention (fname) 257 or assume_sp_equal 258 or fname in extra_symbols) 259 260def get_extra_sp_defs (rep, tag): 261 """add extra defs/equalities about stack pointer for the 262 purposes of stack depth analysis.""" 263 # FIXME how to parametrise this? 264 sp = mk_var ('r13', syntax.word32T) 265 defs = {} 266 267 fcalls = [n_vc for n_vc in rep.funcs 268 if logic.is_int (n_vc[0]) 269 if rep.p.node_tags[n_vc[0]][0] == tag 270 if preserves_sp (rep.p.nodes[n_vc[0]].fname)] 271 for (n, vc) in fcalls: 272 (inputs, outputs, _) = rep.funcs[(n, vc)] 273 if (sp.name, sp.typ) not in outputs: 274 continue 275 inp_sp = solver.smt_expr (sp, inputs, rep.solv) 276 inp_sp = solver.parse_s_expression (inp_sp) 277 out_sp = solver.smt_expr (sp, outputs, rep.solv) 278 out_sp = solver.parse_s_expression (out_sp) 279 if inp_sp != out_sp: 280 defs[out_sp] = inp_sp 281 return defs 282 283def get_stack_sp (p, tag): 284 """get stack and stack-pointer variables""" 285 entry = p.get_entry (tag) 286 renames = p.entry_exit_renames (tags = [tag]) 287 r = renames[tag + '_IN'] 288 289 sp = syntax.rename_expr (mk_var ('r13', syntax.word32T), r) 290 stack = syntax.rename_expr (mk_var ('stack', 291 syntax.builtinTs['Mem']), r) 292 return (stack, sp) 293 294def pseudo_node_lvals_rvals (node): 295 assert node.kind == 'Call' 296 cc = get_asm_calling_convention_at_node (node) 297 if not cc: 298 return None 299 300 arg_vars = set ([var for arg in cc['args'] 301 for var in syntax.get_expr_var_set (arg)]) 302 303 callee_saved_set = set (cc['callee_saved']) 304 rets = [(nm, typ) for (nm, typ) in node.rets 305 if mk_var (nm, typ) not in callee_saved_set] 306 307 return (rets, arg_vars) 308 309def is_asm_node (p, n): 310 tag = p.node_tags[n][0] 311 return tag == 'ASM' or p.hook_tag_hints.get (tag, None) == 'ASM' 312 313def all_pseudo_node_lvals_rvals (p): 314 pseudo = {} 315 for n in p.nodes: 316 if not is_asm_node (p, n): 317 continue 318 elif p.nodes[n].kind != 'Call': 319 continue 320 ps = pseudo_node_lvals_rvals (p.nodes[n]) 321 if ps != None: 322 pseudo[n] = ps 323 return pseudo 324 325def adjusted_var_dep_outputs_for_tag (p, tag): 326 (ent, fname, _) = p.get_entry_details (tag) 327 fun = functions[fname] 328 cc = get_asm_calling_convention (fname) 329 callee_saved_set = set (cc['callee_saved']) 330 ret_set = set ([(nm, typ) for ret in cc['rets'] 331 for (nm, typ) in syntax.get_expr_var_set (ret)]) 332 rets = [(nm2, typ) for ((nm, typ), (nm2, _)) 333 in azip (fun.outputs, p.outputs[tag]) 334 if (nm, typ) in ret_set 335 or mk_var (nm, typ) in callee_saved_set] 336 return rets 337 338def adjusted_var_dep_outputs (p): 339 outputs = {} 340 for tag in p.outputs: 341 ent = p.get_entry (tag) 342 if is_asm_node (p, ent): 343 outputs[tag] = adjusted_var_dep_outputs_for_tag (p, tag) 344 else: 345 outputs[tag] = p.outputs[tag] 346 def output (n): 347 tag = p.node_tags[n][0] 348 return outputs[tag] 349 return output 350 351def is_stack (expr): 352 return expr.kind == 'Var' and 'stack' in expr.name 353 354class StackOffsMissing (Exception): 355 pass 356 357def stack_virtualise_expr (expr, sp_offs): 358 if expr.is_op ('MemAcc') and is_stack (expr.vals[0]): 359 [m, p] = expr.vals 360 if expr.typ == syntax.word8T: 361 ps = [(syntax.mk_minus (p, syntax.mk_word32 (n)), n) 362 for n in [0, 1, 2, 3]] 363 elif expr.typ == syntax.word32T: 364 ps = [(p, 0)] 365 else: 366 assert expr.typ == syntax.word32T, expr 367 ptrs = [(p, 'MemAcc') for (p, _) in ps] 368 if sp_offs == None: 369 return (ptrs, None) 370 # FIXME: very 32-bit specific 371 ps = [(p, n) for (p, n) in ps if p in sp_offs 372 if sp_offs[p][1] % 4 == 0] 373 if not ps: 374 return (ptrs, expr) 375 [(p, n)] = ps 376 if p not in sp_offs: 377 raise StackOffsMissing () 378 (k, offs) = sp_offs[p] 379 v = mk_var (('Fake', k, offs), syntax.word32T) 380 if n != 0: 381 v = syntax.mk_shiftr (v, n * 8) 382 v = syntax.mk_cast (v, expr.typ) 383 return (ptrs, v) 384 elif expr.kind == 'Op': 385 vs = [stack_virtualise_expr (v, sp_offs) for v in expr.vals] 386 return ([p for (ptrs, _) in vs for p in ptrs], 387 syntax.adjust_op_vals (expr, [v for (_, v) in vs])) 388 else: 389 return ([], expr) 390 391def stack_virtualise_upd (((nm, typ), expr), sp_offs): 392 if 'stack' in nm: 393 upds = [] 394 ptrs = [] 395 while expr.is_op ('MemUpdate'): 396 [m, p, v] = expr.vals 397 ptrs.append ((p, 'MemUpdate')) 398 (ptrs2, v2) = stack_virtualise_expr (v, sp_offs) 399 ptrs.extend (ptrs2) 400 if sp_offs != None: 401 if p not in sp_offs: 402 raise StackOffsMissing () 403 (k, offs) = sp_offs[p] 404 upds.append (((('Fake', k, offs), 405 syntax.word32T), v2)) 406 expr = m 407 assert is_stack (expr), expr 408 return (ptrs, upds) 409 else: 410 (ptrs, expr2) = stack_virtualise_expr (expr, sp_offs) 411 return (ptrs, [((nm, typ), expr2)]) 412 413def stack_virtualise_ret (expr, sp_offs): 414 if expr.kind == 'Var': 415 return ([], (expr.name, expr.typ)) 416 elif expr.is_op ('MemAcc'): 417 [m, p] = expr.vals 418 assert expr.typ == syntax.word32T, expr 419 assert is_stack (m), expr 420 if sp_offs != None: 421 (k, offs) = sp_offs[p] 422 r = (('Fake', k, offs), syntax.word32T) 423 else: 424 r = None 425 return ([(p, 'MemUpdate')], r) 426 else: 427 assert not 'ret expr understood', expr 428 429def stack_virtualise_node (node, sp_offs): 430 if node.kind == 'Cond': 431 (ptrs, cond) = stack_virtualise_expr (node.cond, sp_offs) 432 if sp_offs == None: 433 return (ptrs, None) 434 else: 435 return (ptrs, syntax.Node ('Cond', 436 node.get_conts (), cond)) 437 elif node.kind == 'Call': 438 if is_instruction (node.fname): 439 return ([], node) 440 cc = get_asm_calling_convention_at_node (node) 441 assert cc != None, node.fname 442 args = [arg for arg in cc['args'] if not is_stack (arg)] 443 args = [stack_virtualise_expr (arg, sp_offs) for arg in args] 444 rets = [ret for ret in cc['rets_inp'] if not is_stack (ret)] 445 rets = [stack_virtualise_ret (ret, sp_offs) for ret in rets] 446 ptrs = list (set ([p for (ps, _) in args for p in ps] 447 + [p for (ps, _) in rets for p in ps])) 448 if sp_offs == None: 449 return (ptrs, None) 450 else: 451 return (ptrs, syntax.Node ('Call', node.cont, 452 (None, [v for (_, v) in args] 453 + [p for (p, _) in ptrs], 454 [r for (_, r) in rets]))) 455 elif node.kind == 'Basic': 456 upds = [stack_virtualise_upd (upd, sp_offs) for upd in node.upds] 457 ptrs = list (set ([p for (ps, _) in upds for p in ps])) 458 if sp_offs == None: 459 return (ptrs, None) 460 else: 461 ptr_upds = [(('unused#ptr#name%d' % i, syntax.word32T), 462 ptr) for (i, (ptr, _)) in enumerate (ptrs)] 463 return (ptrs, syntax.Node ('Basic', node.cont, 464 [upd for (_, us) in upds for upd in us] 465 + ptr_upds)) 466 else: 467 assert not "node kind understood", node.kind 468 469def mk_get_local_offs (p, tag, sp_reps): 470 (stack, _) = get_stack_sp (p, tag) 471 def mk_local (n, kind, off, k): 472 (v, off2) = sp_reps[n][k] 473 ptr = syntax.mk_plus (v, syntax.mk_word32 (off + off2)) 474 if kind == 'Ptr': 475 return ptr 476 elif kind == 'MemAcc': 477 return syntax.mk_memacc (stack, ptr, syntax.word32T) 478 return mk_local 479 480def adjust_ret_ptr (ptr): 481 """this is a bit of a hack. 482 483 the return slots are named based on r0_input, which will be unchanged, 484 which is handy, but we really want to be talking about r0, which will 485 produce meaningful offsets against the pointers actually used in the 486 program.""" 487 488 return logic.var_subst (ptr, {('ret_addr_input', syntax.word32T): 489 syntax.mk_var ('r0', syntax.word32T)}, must_subst = False) 490 491def get_loop_virtual_stack_analysis (p, tag): 492 """computes variable liveness etc analyses with stack slots treated 493 as virtual variables.""" 494 cache_key = ('loop_stack_analysis', tag) 495 if cache_key in p.cached_analysis: 496 return p.cached_analysis[cache_key] 497 498 (ent, fname, _) = p.get_entry_details (tag) 499 (_, sp) = get_stack_sp (p, tag) 500 cc = get_asm_calling_convention (fname) 501 rets = list (set ([ptr for arg in cc['rets'] 502 for (ptr, _) in stack_virtualise_expr (arg, None)[0]])) 503 rets = [adjust_ret_ptr (ret) for ret in rets] 504 renames = p.entry_exit_renames (tags = [tag]) 505 r = renames[tag + '_OUT'] 506 rets = [syntax.rename_expr (ret, r) for ret in rets] 507 508 ns = [n for n in p.nodes if p.node_tags[n][0] == tag] 509 loop_ns = logic.minimal_loop_node_set (p) 510 511 ptrs = list (set ([(n, ptr) for n in ns 512 for ptr in (stack_virtualise_node (p.nodes[n], None))[0]])) 513 ptrs += [(n, (sp, 'StackPointer')) for n in ns if n in loop_ns] 514 offs = get_ptr_offsets (p, [(n, ptr) for (n, (ptr, _)) in ptrs], 515 [(ent, sp, 'stack')] 516 + [(ent, ptr, 'indirect_ret') for ptr in rets[:1]]) 517 518 ptr_offs = {} 519 rep_offs = {} 520 upd_offsets = {} 521 for ((n, ptr), off, k) in offs: 522 off = norm_int (off, 32) 523 ptr_offs.setdefault (n, {}) 524 rep_offs.setdefault (n, {}) 525 ptr_offs[n][ptr] = (k, off) 526 rep_offs[n][k] = (ptr, - off) 527 528 for (n, (ptr, kind)) in ptrs: 529 if kind == 'MemUpdate' and n in loop_ns: 530 loop = p.loop_id (n) 531 (k, off) = ptr_offs[n][ptr] 532 upd_offsets.setdefault (loop, set ()) 533 upd_offsets[loop].add ((k, off)) 534 loc_offs = mk_get_local_offs (p, tag, rep_offs) 535 536 adj_nodes = {} 537 for n in ns: 538 try: 539 (_, node) = stack_virtualise_node (p.nodes[n], 540 ptr_offs.get (n, {})) 541 except StackOffsMissing, e: 542 printout ("Stack analysis issue at (%d, %s)." 543 % (n, p.node_tags[n])) 544 node = p.nodes[n] 545 adj_nodes[n] = node 546 547 # finally do analysis on this collection of nodes 548 549 preds = dict (p.preds) 550 preds['Ret'] = [n for n in preds['Ret'] if p.node_tags[n][0] == tag] 551 preds['Err'] = [n for n in preds['Err'] if p.node_tags[n][0] == tag] 552 vds = logic.compute_var_deps (adj_nodes, 553 adjusted_var_dep_outputs (p), preds) 554 555 result = (vds, adj_nodes, loc_offs, upd_offsets, (ptrs, offs)) 556 p.cached_analysis[cache_key] = result 557 return result 558 559def norm_int (n, radix): 560 n = n & ((1 << radix) - 1) 561 n2 = n - (1 << radix) 562 if abs (n2) < abs (n): 563 return n2 564 else: 565 return n 566 567def loop_var_analysis (p, split): 568 """computes the same loop dataflow analysis as in the 'logic' module 569 but with stack slots treated as virtual variables.""" 570 if not is_asm_node (p, split): 571 return None 572 head = p.loop_id (split) 573 tag = p.node_tags[split][0] 574 assert head 575 576 key = ('loop_stack_virtual_var_cycle_analysis', split) 577 if key in p.cached_analysis: 578 return p.cached_analysis[key] 579 580 (vds, adj_nodes, loc_offs, 581 upd_offsets, _) = get_loop_virtual_stack_analysis (p, tag) 582 loop = p.loop_body (head) 583 584 va = logic.compute_loop_var_analysis (p, vds, split, 585 override_nodes = adj_nodes) 586 587 (stack, _) = get_stack_sp (p, tag) 588 589 va2 = [] 590 uoffs = upd_offsets.get (head, []) 591 for (v, data) in va: 592 if v.kind == 'Var' and v.name[0] == 'Fake': 593 (_, k, offs) = v.name 594 if (k, offs) not in uoffs: 595 continue 596 v2 = loc_offs (split, 'MemAcc', offs, k) 597 va2.append ((v2, data)) 598 elif v.kind == 'Var' and v.name.startswith ('stack'): 599 assert v.typ == stack.typ 600 continue 601 else: 602 va2.append ((v, data)) 603 stack_const = stack 604 for (k, off) in uoffs: 605 stack_const = syntax.mk_memupd (stack_const, 606 loc_offs (split, 'Ptr', off, k), 607 syntax.mk_word32 (0)) 608 sp = asm_stack_rep_hook (p, (stack.name, stack.typ), 'Loop', split) 609 assert sp and sp[0] == 'SplitMem', (split, sp) 610 (_, st_split) = sp 611 stack_const = logic.mk_stack_wrapper (st_split, stack_const, []) 612 stack_const = logic.mk_eq_selective_wrapper (stack_const, 613 ([], [0])) 614 615 va2.append ((stack_const, 'LoopConst')) 616 617 p.cached_analysis[key] = va2 618 return va2 619 620def inline_no_pre_pairing (p): 621 # FIXME: handle code sharing with check.inline_completely_unmatched 622 while True: 623 ns = [n for n in p.nodes if p.nodes[n].kind == 'Call' 624 if p.nodes[n].fname not in pre_pairings 625 if not is_instruction (p.nodes[n].fname)] 626 for n in ns: 627 trace ('Inlining %s at %d.' % (p.nodes[n].fname, n)) 628 problem.inline_at_point (p, n) 629 if not ns: 630 return 631 632last_asm_stack_depth_fun = [0] 633 634def check_before_guess_asm_stack_depth (fun): 635 from solver import smt_expr 636 if not fun.entry: 637 return None 638 p = fun.as_problem (problem.Problem, name = 'Target') 639 try: 640 p.do_analysis () 641 p.check_no_inner_loops () 642 inline_no_pre_pairing (p) 643 except problem.Abort, e: 644 return None 645 rep = rep_graph.mk_graph_slice (p, fast = True) 646 try: 647 rep.get_pc (default_n_vc (p, 'Ret'), 'Target') 648 err_pc = rep.get_pc (default_n_vc (p, 'Err'), 'Target') 649 except solver.EnvMiss, e: 650 return None 651 652 inlined_funs = set ([fn for (_, _, fn) in p.inline_scripts['Target']]) 653 if inlined_funs: 654 printout (' (stack analysis also involves %s)' 655 % ', '.join(inlined_funs)) 656 657 return p 658 659def guess_asm_stack_depth (fun): 660 p = check_before_guess_asm_stack_depth (fun) 661 if not p: 662 return (0, {}) 663 664 last_asm_stack_depth_fun[0] = fun.name 665 666 entry = p.get_entry ('Target') 667 (_, sp) = get_stack_sp (p, 'Target') 668 669 nodes = get_asm_reachable_nodes (p, tag_set = ['Target']) 670 671 offs = get_ptr_offsets (p, [(n, sp) for n in nodes], 672 [(entry, sp, 'InitSP')], fail_early = True) 673 674 assert len (offs) == len (nodes), map (hex, set (nodes) 675 - set ([n for ((n, _), _, _) in offs])) 676 677 all_offs = [(n, signed_offset (off, 32, 10 ** 6)) 678 for ((n, ptr), off, _) in offs] 679 min_offs = min ([offs for (n, offs) in all_offs]) 680 max_offs = max ([offs for (n, offs) in all_offs]) 681 682 assert min_offs >= 0 or max_offs <= 0, all_offs 683 multiplier = 1 684 if min_offs < 0: 685 multiplier = -1 686 max_offs = - min_offs 687 688 fcall_offs = [(p.nodes[n].fname, offs * multiplier) 689 for (n, offs) in all_offs if p.nodes[n].kind == 'Call'] 690 fun_offs = {} 691 for f in set ([f for (f, _) in fcall_offs]): 692 fun_offs[f] = max ([offs for (f2, offs) in fcall_offs 693 if f2 == f]) 694 695 return (max_offs, fun_offs) 696 697def signed_offset (n, bits, bound = 0): 698 n = n & ((1 << bits) - 1) 699 if n >= (1 << (bits - 1)): 700 n = n - (1 << bits) 701 if bound: 702 assert n <= bound, (n, bound) 703 assert n >= (- bound), (n, bound) 704 return n 705 706def ident_conds (fname, idents): 707 rolling = syntax.true_term 708 conds = [] 709 for ident in idents.get (fname, [syntax.true_term]): 710 conds.append ((ident, syntax.mk_and (rolling, ident))) 711 rolling = syntax.mk_and (rolling, syntax.mk_not (ident)) 712 return conds 713 714def ident_callables (fname, callees, idents): 715 from solver import to_smt_expr, smt_expr 716 from syntax import mk_not, mk_and, true_term 717 718 auto_callables = dict ([((ident, f, true_term), True) 719 for ident in idents.get (fname, [true_term]) 720 for f in callees if f not in idents]) 721 722 if not [f for f in callees if f in idents]: 723 return auto_callables 724 725 fun = functions[fname] 726 p = fun.as_problem (problem.Problem, name = 'Target') 727 check_ns = [(n, ident, cond) for n in p.nodes 728 if p.nodes[n].kind == 'Call' 729 if p.nodes[n].fname in idents 730 for (ident, cond) in ident_conds (p.nodes[n].fname, idents)] 731 732 p.do_analysis () 733 assert check_ns 734 735 rep = rep_graph.mk_graph_slice (p, fast = True) 736 err_hyp = rep_graph.pc_false_hyp ((default_n_vc (p, 'Err'), 'Target')) 737 738 callables = auto_callables 739 nhyps = mk_not_callable_hyps (p) 740 741 for (ident, cond) in ident_conds (fname, idents): 742 renames = p.entry_exit_renames (tags = ['Target']) 743 cond = syntax.rename_expr (cond, renames['Target_IN']) 744 entry = p.get_entry ('Target') 745 e_vis = ((entry, ()), 'Target') 746 hyps = [err_hyp, rep_graph.eq_hyp ((cond, e_vis), 747 (true_term, e_vis))] 748 749 for (n, ident2, cond2) in check_ns: 750 k = (ident, p.nodes[n].fname, ident2) 751 (inp_env, _, _) = rep.get_func (default_n_vc (p, n)) 752 pc = rep.get_pc (default_n_vc (p, n)) 753 cond2 = to_smt_expr (cond2, inp_env, rep.solv) 754 if rep.test_hyp_whyps (mk_not (mk_and (pc, cond2)), 755 hyps + nhyps): 756 callables[k] = False 757 else: 758 callables[k] = True 759 return callables 760 761def compute_immediate_stack_bounds (idents, names): 762 from syntax import true_term 763 immed = {} 764 names = sorted (names) 765 for (i, fname) in enumerate (names): 766 printout ('Doing stack analysis for %r. (%d of %d)' % (fname, 767 i + 1, len (names))) 768 fun = functions[fname] 769 (offs, fn_offs) = guess_asm_stack_depth (fun) 770 callables = ident_callables (fname, fn_offs.keys (), idents) 771 for ident in idents.get (fname, [true_term]): 772 calls = [((fname2, ident2), fn_offs[fname2]) 773 for fname2 in fn_offs 774 for ident2 in idents.get (fname2, [true_term]) 775 if callables[(ident, fname2, ident2)]] 776 immed[(fname, ident)] = (offs, dict (calls)) 777 last_immediate_stack_bounds[0] = immed 778 return immed 779 780last_immediate_stack_bounds = [0] 781 782def immediate_stack_bounds_loop (immed): 783 graph = dict ([(k, immed[k][1].keys ()) for k in immed]) 784 graph['ENTRY'] = list (immed) 785 comps = logic.tarjan (graph, ['ENTRY']) 786 rec_comps = [[x] + y for (x, y) in comps if y] 787 return rec_comps 788 789def compute_recursive_stack_bounds (immed): 790 assert not immediate_stack_bounds_loop (immed) 791 bounds = {} 792 todo = immed.keys () 793 report = 1000 794 while todo: 795 if len (todo) >= report: 796 trace ('todo length %d' % len (todo)) 797 trace ('tail: %s' % todo[-20:]) 798 report += 1000 799 (fname, ident) = todo.pop () 800 if (fname, ident) in bounds: 801 continue 802 (static, calls) = immed[(fname, ident)] 803 if [1 for k in calls if k not in bounds]: 804 todo.append ((fname, ident)) 805 todo.extend (calls.keys ()) 806 continue 807 else: 808 bounds[(fname, ident)] = max ([static] 809 + [bounds[k] + calls[k] for k in calls]) 810 return bounds 811 812def stack_bounds_to_closed_form (bounds, names, idents): 813 closed = {} 814 for fname in names: 815 res = syntax.mk_word32 (bounds[(fname, syntax.true_term)]) 816 extras = [] 817 if fname in idents: 818 assert idents[fname][-1] == syntax.true_term 819 extras = reversed (idents[fname][:-1]) 820 for ident in extras: 821 alt = syntax.mk_word32 (bounds[(fname, ident)]) 822 res = syntax.mk_if (ident, alt, res) 823 closed[fname] = res 824 return closed 825 826def compute_asm_stack_bounds (idents, names): 827 immed = compute_immediate_stack_bounds (idents, names) 828 bounds = compute_recursive_stack_bounds (immed) 829 closed = stack_bounds_to_closed_form (bounds, names, idents) 830 return closed 831 832recursion_trace = [] 833recursion_last_assns = [[]] 834 835def get_recursion_identifiers (funs, extra_unfolds = []): 836 idents = {} 837 del recursion_trace[:] 838 graph = dict ([(f, list (functions[f].function_calls ())) 839 for f in functions]) 840 fs = funs 841 fs2 = set () 842 while fs2 != fs: 843 fs2 = fs 844 fs = set.union (set ([f for f in graph if [f2 for f2 in graph[f] 845 if f2 in fs2]]), 846 set ([f2 for f in fs2 for f2 in graph[f]]), fs2) 847 graph = dict ([(f, graph[f]) for f in fs]) 848 entries = list (fs - set ([f2 for f in graph for f2 in graph[f]])) 849 comps = logic.tarjan (graph, entries) 850 for (head, tail) in comps: 851 if tail or head in graph[head]: 852 group = [head] + list (tail) 853 idents2 = compute_recursion_idents (group, 854 extra_unfolds) 855 idents.update (idents2) 856 return idents 857 858def compute_recursion_idents (group, extra_unfolds): 859 idents = {} 860 group = set (group) 861 recursion_trace.append ('Computing for group %s' % group) 862 printout ('Doing recursion analysis for function group:') 863 printout (' %s' % list(group)) 864 prevs = set ([f for f in functions 865 if [f2 for f2 in functions[f].function_calls () if f2 in group]]) 866 for f in prevs - group: 867 recursion_trace.append (' checking for %s' % f) 868 trace ('Checking idents for %s' % f) 869 while add_recursion_ident (f, group, idents, extra_unfolds): 870 pass 871 return idents 872 873def function_link_assns (p, call_site, tag): 874 call_vis = (default_n_vc (p, call_site), p.node_tags[call_site][0]) 875 return rep_graph.mk_function_link_hyps (p, call_vis, tag) 876 877def add_recursion_ident (f, group, idents, extra_unfolds): 878 from syntax import mk_eq, mk_implies, mk_var 879 p = problem.Problem (None, name = 'Recursion Test') 880 chain = [] 881 tag = 'fun0' 882 p.add_entry_function (functions[f], tag) 883 p.do_analysis () 884 assns = [] 885 recursion_last_assns[0] = assns 886 887 while True: 888 res = find_unknown_recursion (p, group, idents, tag, assns, 889 extra_unfolds) 890 if res == None: 891 break 892 if p.nodes[res].fname not in group: 893 problem.inline_at_point (p, res) 894 continue 895 fname = p.nodes[res].fname 896 chain.append (fname) 897 tag = 'fun%d' % len (chain) 898 (args, _, entry) = p.add_entry_function (functions[fname], tag) 899 p.do_analysis () 900 assns += function_link_assns (p, res, tag) 901 if chain == []: 902 return None 903 recursion_trace.append (' created fun chain %s' % chain) 904 word_args = [(i, mk_var (s, typ)) 905 for (i, (s, typ)) in enumerate (args) 906 if typ.kind == 'Word'] 907 rep = rep_graph.mk_graph_slice (p, fast = True) 908 (_, env) = rep.get_node_pc_env ((entry, ())) 909 910 m = {} 911 res = rep.test_hyp_whyps (syntax.false_term, assns, model = m) 912 assert m 913 914 if find_unknown_recursion (p, group, idents, tag, [], []) == None: 915 idents.setdefault (fname, []) 916 idents[fname].append (syntax.true_term) 917 recursion_trace.append (' found final ident for %s' % fname) 918 return syntax.true_term 919 assert word_args 920 recursion_trace.append (' scanning for ident for %s' % fname) 921 for (i, arg) in word_args: 922 (nm, typ) = functions[fname].inputs[i] 923 arg_smt = solver.to_smt_expr (arg, env, rep.solv) 924 val = search.eval_model_expr (m, rep.solv, arg_smt) 925 if not rep.test_hyp_whyps (mk_eq (arg_smt, val), assns): 926 recursion_trace.append (' discarded %s = 0x%x, not stable' % (nm, val.val)) 927 continue 928 entry_vis = ((entry, ()), tag) 929 ass = rep_graph.eq_hyp ((arg, entry_vis), (val, entry_vis)) 930 res = find_unknown_recursion (p, group, idents, tag, 931 assns + [ass], []) 932 if res: 933 fname2 = p.nodes[res].fname 934 recursion_trace.append (' discarded %s, allows recursion to %s' % (nm, fname2)) 935 continue 936 eq = syntax.mk_eq (mk_var (nm, typ), val) 937 idents.setdefault (fname, []) 938 idents[fname].append (eq) 939 recursion_trace.append (' found ident for %s: %s' % (fname, eq)) 940 return eq 941 assert not "identifying assertion found" 942 943def find_unknown_recursion (p, group, idents, tag, assns, extra_unfolds): 944 from syntax import mk_not, mk_and, foldr1 945 rep = rep_graph.mk_graph_slice (p, fast = True) 946 for n in p.nodes: 947 if p.nodes[n].kind != 'Call': 948 continue 949 if p.node_tags[n][0] != tag: 950 continue 951 fname = p.nodes[n].fname 952 if fname in extra_unfolds: 953 return n 954 if fname not in group: 955 continue 956 (inp_env, _, _) = rep.get_func (default_n_vc (p, n)) 957 pc = rep.get_pc (default_n_vc (p, n)) 958 new = foldr1 (mk_and, [pc] + [syntax.mk_not ( 959 solver.to_smt_expr (ident, inp_env, rep.solv)) 960 for ident in idents.get (fname, [])]) 961 if rep.test_hyp_whyps (mk_not (new), assns): 962 continue 963 return n 964 return None 965 966asm_cc_cache = {} 967 968def is_instruction (fname): 969 bits = fname.split ("'") 970 return bits[1:] and bits[:1] in [["l_impl"], ["instruction"]] 971 972def get_asm_calling_convention (fname): 973 if fname in asm_cc_cache: 974 return asm_cc_cache[fname] 975 if fname not in pre_pairings: 976 bits = fname.split ("'") 977 if not is_instruction (fname): 978 trace ("Warning: unusual unmatched function (%s, %s)." 979 % (fname, bits)) 980 return None 981 pair = pre_pairings[fname] 982 assert pair['ASM'] == fname 983 c_fun = functions[pair['C']] 984 from logic import split_scalar_pairs 985 (var_c_args, c_imem, glob_c_args) = split_scalar_pairs (c_fun.inputs) 986 (var_c_rets, c_omem, glob_c_rets) = split_scalar_pairs (c_fun.outputs) 987 988 num_args = len (var_c_args) 989 num_rets = len (var_c_rets) 990 const_mem = not (c_omem) 991 992 cc = get_asm_calling_convention_inner (num_args, num_rets, const_mem) 993 asm_cc_cache[fname] = cc 994 return cc 995 996def get_asm_calling_convention_inner (num_c_args, num_c_rets, const_mem): 997 key = ('Inner', num_c_args, num_c_rets, const_mem) 998 if key in asm_cc_cache: 999 return asm_cc_cache[key] 1000 1001 from logic import mk_var_list, mk_stack_sequence 1002 from syntax import mk_var, word32T, builtinTs 1003 1004 arg_regs = mk_var_list (['r0', 'r1', 'r2', 'r3'], word32T) 1005 r0 = arg_regs[0] 1006 sp = mk_var ('r13', word32T) 1007 st = mk_var ('stack', builtinTs['Mem']) 1008 r0_input = mk_var ('ret_addr_input', word32T) 1009 1010 mem = mk_var ('mem', builtinTs['Mem']) 1011 dom = mk_var ('dom', builtinTs['Dom']) 1012 dom_stack = mk_var ('dom_stack', builtinTs['Dom']) 1013 1014 global_args = [mem, dom, st, dom_stack, sp, mk_var ('ret', word32T)] 1015 1016 sregs = mk_stack_sequence (sp, 4, st, word32T, num_c_args + 1) 1017 1018 arg_seq = [r for r in arg_regs] + [s for (s, _) in sregs] 1019 if num_c_rets > 1: 1020 # the 'return-too-much' issue. 1021 # instead r0 is a save-returns-here pointer 1022 arg_seq.pop (0) 1023 rets = mk_stack_sequence (r0_input, 4, st, word32T, num_c_rets) 1024 rets = [r for (r, _) in rets] 1025 else: 1026 rets = [r0] 1027 1028 callee_saved_vars = ([mk_var (v, word32T) 1029 for v in 'r4 r5 r6 r7 r8 r9 r10 r11 r13'.split ()] 1030 + [dom, dom_stack]) 1031 1032 if const_mem: 1033 callee_saved_vars += [mem] 1034 else: 1035 rets += [mem] 1036 rets += [st] 1037 1038 cc = {'args': arg_seq[: num_c_args] + global_args, 1039 'rets': rets, 'callee_saved': callee_saved_vars} 1040 1041 asm_cc_cache[key] = cc 1042 return cc 1043 1044def get_asm_calling_convention_at_node (node): 1045 cc = get_asm_calling_convention (node.fname) 1046 if not cc: 1047 return None 1048 1049 fun = functions[node.fname] 1050 arg_input_map = dict (azip (fun.inputs, node.args)) 1051 ret_output_map = dict (azip (fun.outputs, 1052 [mk_var (nm, typ) for (nm, typ) in node.rets])) 1053 1054 args = [logic.var_subst (arg, arg_input_map) for arg in cc['args']] 1055 rets = [logic.var_subst (ret, ret_output_map) for ret in cc['rets']] 1056 # these are useful because they happen to map ret r0_input back to 1057 # the previous value r0, rather than the useless value r0_input_ignore. 1058 rets_inp = [logic.var_subst (ret, arg_input_map) for ret in cc['rets']] 1059 saved = [logic.var_subst (v, ret_output_map) 1060 for v in cc['callee_saved']] 1061 return {'args': args, 'rets': rets, 1062 'rets_inp': rets_inp, 'callee_saved': saved} 1063 1064call_cache = {} 1065 1066def get_asm_callable (fname): 1067 if fname not in pre_pairings: 1068 return True 1069 c_fun = pre_pairings[fname]['C'] 1070 1071 if not call_cache: 1072 for f in functions: 1073 call_cache[f] = False 1074 for f in functions: 1075 fun = functions[f] 1076 for n in fun.reachable_nodes (simplify = True): 1077 if fun.nodes[n].kind == 'Call': 1078 call_cache[fun.nodes[n].fname] = True 1079 return call_cache[c_fun] 1080 1081def get_asm_reachable_nodes (p, tag_set = None): 1082 if tag_set == None: 1083 tag_set = [tag for tag in p.tags () 1084 if is_asm_node (p, p.get_entry (tag))] 1085 frontier = [p.get_entry (tag) for tag in tag_set] 1086 nodes = set () 1087 while frontier: 1088 n = frontier.pop () 1089 if n in nodes or n not in p.nodes: 1090 continue 1091 nodes.add (n) 1092 node = p.nodes[n] 1093 if node.kind == 'Call' and not get_asm_callable (node.fname): 1094 continue 1095 node = logic.simplify_node_elementary (node) 1096 frontier.extend (node.get_conts ()) 1097 return nodes 1098 1099def convert_recursion_idents (idents): 1100 asm_idents = {} 1101 for f in idents: 1102 if f not in pre_pairings: 1103 continue 1104 f2 = pre_pairings[f]['ASM'] 1105 assert f2 != f 1106 asm_idents[f2] = [] 1107 for ident in idents[f]: 1108 if ident.is_op ('True'): 1109 asm_idents[f2].append (ident) 1110 elif ident.is_op ('Equals'): 1111 [x, y] = ident.vals 1112 # this is a bit hacky 1113 [i] = [i for (i, (nm, typ)) 1114 in enumerate (functions[f].inputs) 1115 if x.is_var ((nm, typ))] 1116 cc = get_asm_calling_convention (f2) 1117 x = cc['args'][i] 1118 asm_idents[f2].append (syntax.mk_eq (x, y)) 1119 else: 1120 assert not 'ident kind convertible' 1121 return asm_idents 1122 1123def mk_pairing (pre_pair, stack_bounds): 1124 asm_f = pre_pair['ASM'] 1125 sz = stack_bounds[asm_f] 1126 c_fun = functions[pre_pair['C']] 1127 1128 from logic import split_scalar_pairs 1129 (var_c_args, c_imem, glob_c_args) = split_scalar_pairs (c_fun.inputs) 1130 (var_c_rets, c_omem, glob_c_rets) = split_scalar_pairs (c_fun.outputs) 1131 1132 eqs = logic.mk_eqs_arm_none_eabi_gnu (var_c_args, var_c_rets, 1133 c_imem, c_omem, sz) 1134 1135 return logic.Pairing (['ASM', 'C'], 1136 {'ASM': asm_f, 'C': c_fun.name}, eqs) 1137 1138def mk_pairings (stack_bounds): 1139 new_pairings = {} 1140 for f in pre_pairings: 1141 if f in new_pairings: 1142 continue 1143 pair = mk_pairing (pre_pairings[f], stack_bounds) 1144 for fun in pair.funs.itervalues (): 1145 new_pairings[fun] = [pair] 1146 return new_pairings 1147 1148def serialise_stack_bounds (stack_bounds): 1149 lines = [] 1150 for fname in stack_bounds: 1151 ss = ['StackBound', fname] 1152 stack_bounds[fname].serialise (ss) 1153 lines.append (' '.join (ss) + '\n') 1154 return lines 1155 1156def deserialise_stack_bounds (lines): 1157 bounds = {} 1158 for line in lines: 1159 bits = line.split () 1160 if not bits: 1161 continue 1162 assert bits[0] == 'StackBound' 1163 fname = bits[1] 1164 (_, bound) = syntax.parse_expr (bits, 2) 1165 bounds[fname] = bound 1166 return bounds 1167 1168funs_with_tag = {} 1169 1170def get_functions_with_tag (tag): 1171 if tag in funs_with_tag: 1172 return funs_with_tag[tag] 1173 visit = set ([pre_pairings[f][tag] for f in pre_pairings 1174 if tag in pre_pairings[f]]) 1175 visit.update ([pair.funs[tag] for f in pairings 1176 for pair in pairings[f] if tag in pair.funs]) 1177 funs = set (visit) 1178 while visit: 1179 f = visit.pop () 1180 funs.add (f) 1181 visit.update (set (functions[f].function_calls ()) - funs) 1182 funs_with_tag[tag] = funs 1183 return funs 1184 1185def compute_stack_bounds (quiet = False): 1186 prev_tracer = target_objects.tracer[0] 1187 if quiet: 1188 target_objects.tracer[0] = lambda s, n: () 1189 1190 try: 1191 c_fs = get_functions_with_tag ('C') 1192 idents = get_recursion_identifiers (c_fs) 1193 asm_idents = convert_recursion_idents (idents) 1194 asm_fs = get_functions_with_tag ('ASM') 1195 printout ('Computed recursion limits.') 1196 1197 bounds = compute_asm_stack_bounds (asm_idents, asm_fs) 1198 printout ('Computed stack bounds.') 1199 except Exception, e: 1200 if quiet: 1201 target_objects.tracer[0] = prev_tracer 1202 raise 1203 1204 if quiet: 1205 target_objects.tracer[0] = prev_tracer 1206 return bounds 1207 1208def read_fn_hash (fname): 1209 try: 1210 f = open (fname) 1211 s = f.readline () 1212 bits = s.split () 1213 if bits[0] != 'FunctionHash' or len (bits) != 2: 1214 return None 1215 return int (bits[1]) 1216 except ValueError, e: 1217 return None 1218 except IndexError, e: 1219 return None 1220 except IOError, e: 1221 return None 1222 1223def mk_stack_pairings (pairing_tups, stack_bounds_fname = None, 1224 quiet = True): 1225 """build the stack-aware calling-convention-aware logical pairings 1226 once a collection of function pairs have been read.""" 1227 1228 # simplifies interactive testing of this function 1229 pre_pairings.clear () 1230 1231 for (asm_f, c_f) in pairing_tups: 1232 pair = {'ASM': asm_f, 'C': c_f} 1233 assert c_f not in pre_pairings 1234 assert asm_f not in pre_pairings 1235 pre_pairings[c_f] = pair 1236 pre_pairings[asm_f] = pair 1237 1238 fn_hash = hash (tuple (sorted ([(f, hash (functions[f])) 1239 for f in functions]))) 1240 prev_hash = read_fn_hash (stack_bounds_fname) 1241 if prev_hash == fn_hash: 1242 f = open (stack_bounds_fname) 1243 f.readline () 1244 stack_bounds = deserialise_stack_bounds (f) 1245 f.close () 1246 else: 1247 printout ('Computing stack bounds.') 1248 stack_bounds = compute_stack_bounds (quiet = quiet) 1249 f = open (stack_bounds_fname, 'w') 1250 f.write ('FunctionHash %s\n' % fn_hash) 1251 for line in serialise_stack_bounds (stack_bounds): 1252 f.write(line) 1253 f.close () 1254 1255 problematic_synthetic () 1256 1257 return mk_pairings (stack_bounds) 1258 1259def asm_stack_rep_hook (p, (nm, typ), kind, n): 1260 if not is_asm_node (p, n): 1261 return None 1262 1263 if not (nm.startswith ('stack') and typ == syntax.builtinTs['Mem']): 1264 return None 1265 1266 assert kind in ['Call', 'Init', 'Loop'], kind 1267 if kind == 'Init': 1268 return None 1269 1270 tag = p.node_tags[n][0] 1271 (_, sp) = get_stack_sp (p, tag) 1272 1273 return ('SplitMem', sp) 1274 1275reg_aliases = {'r11': ['fp'], 'r14': ['lr'], 'r13': ['sp']} 1276 1277def inst_const_rets (node): 1278 assert "instruction'" in node.fname 1279 bits = set ([s.lower () for s in node.fname.split ('_')]) 1280 fun = functions[node.fname] 1281 def is_const (nm, typ): 1282 if typ in [builtinTs['Mem'], builtinTs['Dom']]: 1283 return True 1284 if typ != word32T: 1285 return False 1286 return not (nm in bits or [al for al in reg_aliases.get (nm, []) 1287 if al in bits]) 1288 is_consts = [is_const (nm, typ) for (nm, typ) in fun.outputs] 1289 input_set = set ([v for arg in node.args 1290 for v in syntax.get_expr_var_set (arg)]) 1291 return [mk_var (nm, typ) 1292 for ((nm, typ), const) in azip (node.rets, is_consts) 1293 if const and (nm, typ) in input_set] 1294 1295def node_const_rets (node): 1296 if "instruction'" in node.fname: 1297 return inst_const_rets (node) 1298 if node.fname in pre_pairings: 1299 if pre_pairings[node.fname]['ASM'] != node.fname: 1300 return None 1301 cc = get_asm_calling_convention_at_node (node) 1302 input_set = set ([v for arg in node.args 1303 for v in syntax.get_expr_var_set (arg)]) 1304 callee_saved_set = set (cc['callee_saved']) 1305 return [mk_var (nm, typ) for (nm, typ) in node.rets 1306 if mk_var (nm, typ) in callee_saved_set 1307 if (nm, typ) in input_set] 1308 elif preserves_sp (node.fname): 1309 if node.fname not in get_functions_with_tag ('ASM'): 1310 return None 1311 f_outs = functions[node.fname].outputs 1312 return [mk_var (nm, typ) 1313 for ((nm, typ), (nm2, _)) in azip (node.rets, f_outs) 1314 if nm2 == 'r13'] 1315 else: 1316 return None 1317 1318def const_ret_hook (node, nm, typ): 1319 consts = node_const_rets (node) 1320 return consts and mk_var (nm, typ) in consts 1321 1322def get_const_rets (p, node_set = None): 1323 if node_set == None: 1324 node_set = p.nodes 1325 const_rets = {} 1326 for n in node_set: 1327 if p.nodes[n].kind != 'Call': 1328 continue 1329 consts = node_const_rets (node) 1330 const_rets[n] = [(v.name, v.typ) for v in consts] 1331 return const_rets 1332 1333def problematic_synthetic (): 1334 synth = [s for s in target_objects.symbols 1335 if '.clone.' in s or '.part.' in s or '.constprop.' in s] 1336 synth = ['_'.join (s.split ('.')) for s in synth] 1337 if not synth: 1338 return 1339 printout ('Synthetic symbols: %s' % synth) 1340 synth_calls = set ([f for f in synth 1341 if f in functions 1342 if functions[f].function_calls ()]) 1343 printout ('Synthetic symbols which make function calls: %s' 1344 % synth_calls) 1345 if not synth_calls: 1346 return 1347 synth_stack = set ([f for f in synth_calls 1348 if [node for node in functions[f].nodes.itervalues () 1349 if node.kind == 'Basic' 1350 if ('r13', word32T) in node.get_lvals ()]]) 1351 printout ('Synthetic symbols which call and move sp: %s' 1352 % synth_stack) 1353 synth_problems = set ([f for f in synth_stack 1354 if [f2 for f2 in functions 1355 if f in functions[f2].function_calls () 1356 if len (set (functions[f2].function_calls ())) > 1] 1357 ]) 1358 printout ('Problematic synthetics: %s' % synth_problems) 1359 return synth_problems 1360 1361def add_hooks (): 1362 k = 'stack_logic' 1363 add = target_objects.add_hook 1364 add ('problem_var_rep', k, asm_stack_rep_hook) 1365 add ('loop_var_analysis', k, loop_var_analysis) 1366 add ('rep_unsafe_const_ret', k, const_ret_hook) 1367 1368add_hooks () 1369 1370