1# 2# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230) 3# 4# SPDX-License-Identifier: BSD-2-Clause 5# 6 7import check,search,problem,syntax,solver,logic,rep_graph,re 8from rep_graph import vc_num, vc_offs, pc_true_hyp, Hyp, eq_hyp 9from target_objects import functions, trace 10from check import restr_others,split_visit_one_visit 11from problem import inline_at_point 12from syntax import mk_not, true_term, false_term, mk_implies, Expr, Type, unspecified_precond_term,mk_and 13from rep_graph import mk_graph_slice, VisitCount, to_smt_expr 14from search import eval_model_expr 15import target_objects 16import trace_refute 17import stack_logic 18import time 19 20#tryFun must take exactly 1 argument 21def downBinSearch(minimum, maximum, tryFun): 22 upperBound = maximum 23 lowerBound = minimum 24 while upperBound > lowerBound: 25 print 'searching in %d - %d' % (lowerBound,upperBound) 26 cur = (lowerBound + upperBound) / 2 27 if tryFun(cur): 28 upperBound = cur 29 else: 30 lowerBound = cur + 1 31 assert upperBound == lowerBound 32 ret = lowerBound 33 return ret 34 35def upDownBinSearch (minimum, maximum, tryFun): 36 """performs a binary search between minimum and maximum, but does not start 37 in the middle. instead it does a binary escalation up from the minimum 38 first. this makes sense for ranges e.g. 2 - 1000000 where the bound is 39 likely to be near the bottom of the range. it also avoids testing values 40 more than twice as high as the bound, which may avoid some issues.""" 41 upperBound = 2 * minimum 42 while upperBound < maximum: 43 if tryFun (upperBound): 44 return downBinSearch (minimum, upperBound, tryFun) 45 else: 46 upperBound *= 2 47 if tryFun (maximum): 48 return downBinSearch (minimum, maximum, tryFun) 49 else: 50 return None 51 52def addr_of_node (preds, n): 53 while not trace_refute.is_addr (n): 54 [n] = preds[n] 55 return n 56 57def all_asm_functions (): 58 ss = stack_logic.get_functions_with_tag ('ASM') 59 return [s for s in ss if not stack_logic.is_instruction (s)] 60 61call_site_set = {} 62 63def build_call_site_set (): 64 for f in all_asm_functions (): 65 preds = logic.compute_preds (functions[f].nodes) 66 for (n, node) in functions[f].nodes.iteritems (): 67 if node.kind == 'Call': 68 s = call_site_set.setdefault (node.fname, set ()) 69 s.add (addr_of_node (preds, n)) 70 call_site_set[('IsLoaded', None)] = True 71 72def all_call_sites (f): 73 if not call_site_set: 74 build_call_site_set () 75 return list (call_site_set.get (f, [])) 76 77#naive binary search to find loop bounds 78def findLoopBoundBS(p_n, p, restrs=None, hyps=None, try_seq=None): 79 if hyps == None: 80 hyps = [] 81 #print 'restrs: %s' % str(restrs) 82 if try_seq == None: 83 #bound_try_seq = [1,2,3,4,5,10,50,130,200,260] 84 #bound_try_seq = [0,1,2,3,4,5,10,50,260] 85 calls = [n for n in p.loop_body (p_n) if p.nodes[n].kind == 'Call'] 86 if calls: 87 bound_try_seq = [0,1,20] 88 else: 89 bound_try_seq = [0,1,20,34] 90 else: 91 bound_try_seq = try_seq 92 rep = mk_graph_slice (p, fast = True) 93 #get the head 94 #print 'Binary addr: %s' % toHexs(self.toPhyAddrs(p_loop_heads)) 95 loop_bound = None 96 p_loop_heads = [n for n in p.loop_data if p.loop_data[n][0] == 'Head'] 97 print 'p_loop_heads: %s' % p_loop_heads 98 99 if restrs == None: 100 others = [x for x in p_loop_heads if not x == p_n] 101 #vc_options([concrete numbers], [offsets]) 102 restrs = tuple( [(n2, rep_graph.vc_options([0],[1])) for n2 in others] ) 103 104 print 'getting the initial bound' 105 #try: 106 index = tryLoopBound(p_n,p,bound_try_seq,rep,restrs = restrs, hyps=hyps) 107 if index == -1: 108 return None 109 print 'got the initial bound %d' % bound_try_seq[index] 110 111 #do a downward binary search to find the concrete loop bound 112 if index == 0: 113 loop_bound = bound_try_seq[0] 114 print 'bound = %d' % loop_bound 115 return loop_bound 116 loop_bound = downBinSearch(bound_try_seq[index-1], bound_try_seq[index], lambda x: tryLoopBound(p_n,p,[x],rep,restrs=restrs, hyps=hyps, bin_return=True)) 117 print 'bound = %d' % loop_bound 118 return loop_bound 119 120def default_n_vc_cases (p, n): 121 head = p.loop_id (n) 122 general = [(n2, rep_graph.vc_options ([0], [1])) 123 for n2 in p.loop_heads () 124 if n2 != head] 125 126 if head: 127 return [(n, tuple (general + [(head, rep_graph.vc_num (1))])), 128 (n, tuple (general + [(head, rep_graph.vc_offs (1))]))] 129 specific = [(head, rep_graph.vc_offs (1)) for _ in [1] if head] 130 return [(n, tuple (general + specific))] 131 132def callNodes(p, fs= None): 133 ns = [n for n in p.nodes if p.nodes[n].kind == 'Call'] 134 if fs != None: 135 ns = [n for n in ns if p.nodes[n].fname in fs] 136 return ns 137 138def noHaltHyps(split,p): 139 ret = [] 140 all_halts = callNodes(p,fs=['halt']) 141 for x in all_halts: 142 ret += [rep_graph.pc_false_hyp((n_vc, p.node_tags[x][0])) 143 for n_vc in default_n_vc_cases (p, x)] 144 return ret 145 146def tryLoopBound(p_n, p, bounds,rep,restrs =None, hints =None,kind = 'Number',bin_return = False,hyps = None): 147 if restrs == None: 148 restrs = () 149 if hints == None: 150 hints = [] 151 if hyps == None: 152 hyps = [] 153 tag = p.node_tags[p_n][0] 154 from stack_logic import default_n_vc 155 print 'trying bound: %s' % bounds 156 ret_bounds = [] 157 for (index,i) in enumerate(bounds): 158 print 'testing %d' % i 159 restrs2 = restrs + ((p_n, VisitCount (kind, i)), ) 160 try: 161 pc = rep.get_pc ((p_n, restrs2)) 162 except: 163 print 'get_pc failed' 164 if bin_return: 165 return False 166 else: 167 return -1 168 #print 'got rep_.get_pc' 169 restrs3 = restr_others (p, restrs2, 2) 170 epc = rep.get_pc (('Err', restrs3), tag = tag) 171 hyp = mk_implies (mk_not (epc), mk_not (pc)) 172 hyps = hyps + noHaltHyps(p_n,p) 173 174 #hyps = [] 175 #print 'calling test_hyp_whyps' 176 if rep.test_hyp_whyps (hyp, hyps): 177 print 'p_n %d: split limit found: %d' % (p_n, i) 178 if bin_return: 179 return True 180 return index 181 if bin_return: 182 return False 183 print 'loop bound not found!' 184 return -1 185 assert False, 'failed to find loop bound for p_n %d' % p_n 186 187def get_linear_series_eqs (p, split, restrs, hyps, omit_standard = False): 188 k = ('linear_series_eqs', split, restrs, tuple (hyps)) 189 if k in p.cached_analysis: 190 if omit_standard: 191 standard = set (search.mk_seq_eqs (p, split, 1, with_rodata = False)) 192 return set (p.cached_analysis[k]) - standard 193 return p.cached_analysis[k] 194 195 cands = search.mk_seq_eqs (p, split, 1, with_rodata = False) 196 cands += candidate_additional_eqs (p, split) 197 (tag, _) = p.node_tags[split] 198 199 rep = rep_graph.mk_graph_slice (p, fast = True) 200 201 def do_checks (eqs_assume, eqs): 202 checks = (check.single_loop_induct_step_checks (p, restrs, hyps, tag, 203 split, 1, eqs, eqs_assume = eqs_assume) 204 + check.single_loop_induct_base_checks (p, restrs, hyps, tag, 205 split, 1, eqs)) 206 207 groups = check.proof_check_groups (checks) 208 for group in groups: 209 (res, _) = check.test_hyp_group (rep, group) 210 if not res: 211 return False 212 return True 213 214 eqs = [] 215 failed = [] 216 while cands: 217 cand = cands.pop () 218 if do_checks (eqs, [cand]): 219 eqs.append (cand) 220 failed.reverse () 221 cands = failed + cands 222 failed = [] 223 else: 224 failed.append (cand) 225 226 assert do_checks ([], eqs) 227 p.cached_analysis[k] = eqs 228 if omit_standard: 229 standard = set (search.mk_seq_eqs (p, split, 1, with_rodata = False)) 230 return set (eqs) - standard 231 return eqs 232 233def get_linear_series_hyps (p, split, restrs, hyps): 234 eqs = get_linear_series_eqs (p, split, restrs, hyps) 235 (tag, _) = p.node_tags[split] 236 hyps = [h for (h, _) in linear_eq_hyps_at_visit (tag, split, eqs, 237 restrs, vc_offs (0))] 238 return hyps 239 240def is_zero (expr): 241 return expr.kind == 'Num' and expr.val & ((1 << expr.typ.num) - 1) == 0 242 243def candidate_additional_eqs (p, split): 244 eq_vals = set () 245 def visitor (expr): 246 if expr.is_op ('Equals') and expr.vals[0].typ.kind == 'Word': 247 [x, y] = expr.vals 248 eq_vals.update ([(x, y), (y, x)]) 249 for n in p.loop_body (split): 250 p.nodes[n].visit (lambda x: (), visitor) 251 for (x, y) in list (eq_vals): 252 if is_zero (x) and y.is_op ('Plus'): 253 [x, y] = y.vals 254 eq_vals.add ((x, syntax.mk_uminus (y))) 255 eq_vals.add ((y, syntax.mk_uminus (x))) 256 elif is_zero (x) and y.is_op ('Minus'): 257 [x, y] = y.vals 258 eq_vals.add ((x, y)) 259 eq_vals.add ((y, x)) 260 261 loop = syntax.mk_var ('%i', syntax.word32T) 262 minus_loop_step = syntax.mk_uminus (loop) 263 264 vas = search.get_loop_var_analysis_at(p, split) 265 ls_vas = dict ([(var, [data]) for (var, data) in vas 266 if data[0] == 'LoopLinearSeries']) 267 cmp_series = [(x, y, rew, offs) for (x, y) in eq_vals 268 for (_, rew, offs) in ls_vas.get (x, [])] 269 odd_eqs = [] 270 for (x, y, rew, offs) in cmp_series: 271 x_init_cmp1 = syntax.mk_less_eq (x, rew (x, minus_loop_step)) 272 x_init_cmp2 = syntax.mk_less_eq (rew (x, minus_loop_step), x) 273 fin_cmp1 = syntax.mk_less (x, y) 274 fin_cmp2 = syntax.mk_less (y, x) 275 odd_eqs.append (syntax.mk_eq (x_init_cmp1, fin_cmp1)) 276 odd_eqs.append (syntax.mk_eq (x_init_cmp2, fin_cmp1)) 277 odd_eqs.append (syntax.mk_eq (x_init_cmp1, fin_cmp2)) 278 odd_eqs.append (syntax.mk_eq (x_init_cmp2, fin_cmp2)) 279 280 ass_eqs = [] 281 var_deps = p.compute_var_dependencies () 282 for hook in target_objects.hooks ('extra_wcet_assertions'): 283 for assn in hook (var_deps[split]): 284 ass_eqs.append (assn) 285 286 return odd_eqs + ass_eqs 287 288extra_loop_consts = [2 ** 16] 289 290call_ctxt_problems = [] 291 292avoid_C_information = [False] 293 294def get_call_ctxt_problem (split, call_ctxt, timing = True): 295 # time this for diagnostic reasons 296 start = time.time () 297 from trace_refute import identify_function, build_compound_problem_with_links 298 f = identify_function (call_ctxt, [split]) 299 for (ctxt2, p, hyps, addr_map) in call_ctxt_problems: 300 if ctxt2 == (call_ctxt, f): 301 return (p, hyps, addr_map) 302 303 (p, hyps, addr_map) = build_compound_problem_with_links (call_ctxt, f) 304 if avoid_C_information[0]: 305 hyps = [h for h in hyps if not has_C_information (p, h)] 306 call_ctxt_problems.append(((call_ctxt, f), p, hyps, addr_map)) 307 del call_ctxt_problems[: -20] 308 309 end = time.time () 310 if timing: 311 save_extra_timing ('GetProblem', call_ctxt + [split], end - start) 312 313 return (p, hyps, addr_map) 314 315def has_C_information (p, hyp): 316 for (n_vc, tag) in hyp.visits (): 317 if not p.hook_tag_hints.get (tag, None) == 'ASM': 318 return True 319 320known_bound_restr_hyps = {} 321 322known_bounds = {} 323 324def serialise_bound (addr, bound_info): 325 if bound_info == None: 326 return [hex(addr), "None", "None"] 327 else: 328 (bound, kind) = bound_info 329 assert logic.is_int (bound) 330 assert str (kind) == kind 331 return [hex (addr), str (bound), kind] 332 333def save_bound (glob, split_bin_addr, call_ctxt, prob_hash, prev_bounds, bound, 334 time = None): 335 f_names = [trace_refute.get_body_addrs_fun (x) 336 for x in call_ctxt + [split_bin_addr]] 337 loop_name = '<%s>' % ' -> '.join (f_names) 338 comment = '# bound for loop in %s:' % loop_name 339 ss = ['LoopBound'] + serialise_bound (split_bin_addr, bound) 340 if glob: 341 ss[0] = 'GlobalLoopBound' 342 ss += [str (len (call_ctxt))] + map (hex, call_ctxt) 343 ss += [str (prob_hash)] 344 if glob: 345 assert prev_bounds == None 346 else: 347 ss += [str (len (prev_bounds))] 348 for (split, bound) in prev_bounds: 349 ss += serialise_bound (split, bound) 350 s = ' '.join (ss) 351 f = open ('%s/LoopBounds.txt' % target_objects.target_dir, 'a') 352 f.write (comment + '\n') 353 f.write (s + '\n') 354 if time != None: 355 ctxt2 = call_ctxt + [split_bin_addr] 356 ctxt2 = ' '.join ([str (len (ctxt2))] + map (hex, ctxt2)) 357 f.write ('LoopBoundTiming %s %s\n' % (ctxt2, time)) 358 f.close () 359 trace ('Found bound %s for 0x%x in %s.' % (bound, split_bin_addr, 360 loop_name)) 361 362def save_extra_timing (nm, ctxt, time): 363 ss = ['ExtraTiming', nm, str (len (ctxt))] + map (hex, ctxt) + [str(time)] 364 f = open ('%s/LoopBounds.txt' % target_objects.target_dir, 'a') 365 f.write (' '.join (ss) + '\n') 366 f.close () 367 368def parse_bound (ss, n): 369 addr = syntax.parse_int (ss[n]) 370 bound = ss[n + 1] 371 if bound == 'None': 372 bound = None 373 return (n + 3, (addr, None)) 374 else: 375 bound = syntax.parse_int (bound) 376 kind = ss[n + 2] 377 return (n + 3, (addr, (bound, kind))) 378 379def parse_ctxt_id (bits, n): 380 return (n + 1, syntax.parse_int (bits[n])) 381 382def parse_ctxt (bits, n): 383 return syntax.parse_list (parse_ctxt_id, bits, n) 384 385def load_bounds (): 386 try: 387 f = open ('%s/LoopBounds.txt' % target_objects.target_dir) 388 ls = list (f) 389 f.close () 390 except IOError, e: 391 ls = [] 392 from syntax import parse_int, parse_list 393 for l in ls: 394 bits = l.split () 395 if bits[:1] not in [['LoopBound'], ['GlobalLoopBound']]: 396 continue 397 (n, (addr, bound)) = parse_bound (bits, 1) 398 (n, ctxt) = parse_ctxt (bits, n) 399 prob_hash = parse_int (bits[n]) 400 n += 1 401 if bits[0] == 'LoopBound': 402 (n, prev_bounds) = parse_list (parse_bound, bits, n) 403 assert n == len (bits), bits 404 known = known_bounds.setdefault (addr, []) 405 known.append ((ctxt, prob_hash, prev_bounds, bound)) 406 else: 407 assert n == len (bits), bits 408 known = known_bounds.setdefault ((addr, 'Global'), []) 409 known.append ((ctxt, prob_hash, bound)) 410 known_bounds['Loaded'] = True 411 412def get_bound_ctxt (split, call_ctxt, use_cache = True): 413 trace ('Getting bound for 0x%x in context %s.' % (split, call_ctxt)) 414 (p, hyps, addr_map) = get_call_ctxt_problem (split, call_ctxt) 415 416 orig_split = split 417 split = p.loop_id (addr_map[split]) 418 assert split, (orig_split, call_ctxt) 419 split_bin_addr = min ([addr for addr in addr_map 420 if p.loop_id (addr_map[addr]) == split]) 421 422 prior = get_prior_loop_heads (p, split) 423 restrs = () 424 prev_bounds = [] 425 for split2 in prior: 426 # recursion! 427 split2 = p.loop_id (split2) 428 assert split2 429 addr = min ([addr for addr in addr_map 430 if p.loop_id (addr_map[addr]) == split2]) 431 bound = get_bound_ctxt (addr, call_ctxt) 432 prev_bounds.append ((addr, bound)) 433 k = (p.name, split2, bound, restrs, tuple (hyps)) 434 if k in known_bound_restr_hyps: 435 (restrs, hyps) = known_bound_restr_hyps[k] 436 else: 437 (restrs, hyps) = add_loop_bound_restrs_hyps (p, restrs, hyps, 438 split2, bound, call_ctxt + [orig_split]) 439 known_bound_restr_hyps[k] = (restrs, hyps) 440 441 # start timing now. we miss some setup time, but it avoids double counting 442 # the recursive searches. 443 start = time.time () 444 445 p_h = problem_hash (p) 446 prev_bounds = sorted (prev_bounds) 447 if not known_bounds: 448 load_bounds () 449 known = known_bounds.get (split_bin_addr, []) 450 for (call_ctxt2, h, prev_bounds2, bound) in known: 451 match = (not call_ctxt2 or call_ctxt[- len (call_ctxt2):] == call_ctxt2) 452 if match and use_cache and h == p_h and prev_bounds2 == prev_bounds: 453 return bound 454 bound = search_bin_bound (p, restrs, hyps, split) 455 known = known_bounds.setdefault (split_bin_addr, []) 456 known.append ((call_ctxt, p_h, prev_bounds, bound)) 457 end = time.time () 458 save_bound (False, split_bin_addr, call_ctxt, p_h, prev_bounds, bound, 459 time = end - start) 460 return bound 461 462def problem_hash (p): 463 return syntax.hash_tuplify ([p.name, p.entries, 464 sorted (p.outputs.iteritems ()), sorted (p.nodes.iteritems ())]) 465 466def search_bin_bound (p, restrs, hyps, split): 467 trace ('Searching for bound for 0x%x in %s.', (split, p.name)) 468 bound = search_bound (p, restrs, hyps, split) 469 if bound: 470 return bound 471 472 # try to use a bound inferred from C 473 if avoid_C_information[0]: 474 # OK told not to 475 return None 476 if get_prior_loop_heads (p, split): 477 # too difficult for now 478 return None 479 asm_tag = p.node_tags[split][0] 480 (_, fname, _) = p.get_entry_details (asm_tag) 481 funs = [f for pair in target_objects.pairings[fname] 482 for f in pair.funs.values ()] 483 c_tags = [tag for tag in p.tags () 484 if p.get_entry_details (tag)[1] in funs and tag != asm_tag] 485 if len (c_tags) != 1: 486 print 'Surprised to see multiple matching tags %s' % c_tags 487 return None 488 489 [c_tag] = c_tags 490 491 rep = rep_graph.mk_graph_slice (p) 492 if len (search.get_loop_entry_sites (rep, restrs, hyps, split)) != 1: 493 # technical, but it's not going to work in this case 494 return None 495 496 return getBinaryBoundFromC (p, c_tag, split, restrs, hyps) 497 498def rab_test (): 499 [split_bin_addr] = get_loop_heads (functions['resolveAddressBits']) 500 (p, hyps, addr_map) = get_call_ctxt_problem (split_bin_addr, []) 501 split = p.loop_id (addr_map[split_bin_addr]) 502 [c_tag] = [tag for tag in p.tags () if tag != p.node_tags[split][0]] 503 return getBinaryBoundFromC (p, c_tag, split, (), hyps) 504 505last_search_bound = [0] 506 507def search_bound (p, restrs, hyps, split): 508 last_search_bound[0] = (p, restrs, hyps, split) 509 510 # try a naive bin search first 511 # limit this to a small bound for time purposes 512 # - for larger bounds the less naive approach can be faster 513 bound = findLoopBoundBS(split, p, restrs=restrs, hyps=hyps, 514 try_seq = [0, 1, 6]) 515 if bound != None: 516 return (bound, 'NaiveBinSearch') 517 518 l_hyps = get_linear_series_hyps (p, split, restrs, hyps) 519 520 rep = rep_graph.mk_graph_slice (p, fast = True) 521 522 def test (n): 523 assert n > 10 524 hyp = check.mk_loop_counter_eq_hyp (p, split, restrs, n - 2) 525 visit = ((split, vc_offs (2)), ) + restrs 526 continue_to_split_guess = rep.get_pc ((split, visit)) 527 return rep.test_hyp_whyps (syntax.mk_not (continue_to_split_guess), 528 [hyp] + l_hyps + hyps) 529 530 # findLoopBoundBS always checks to at least 16 531 min_bound = 16 532 max_bound = max_acceptable_bound[0] 533 bound = upDownBinSearch (min_bound, max_bound, test) 534 if bound != None and test (bound): 535 return (bound, 'InductiveBinSearch') 536 537 # let the naive bin search go a bit further 538 bound = findLoopBoundBS(split, p, restrs=restrs, hyps=hyps) 539 if bound != None: 540 return (bound, 'NaiveBinSearch') 541 542 return None 543 544def getBinaryBoundFromC (p, c_tag, asm_split, restrs, hyps): 545 c_heads = [h for h in search.init_loops_to_split (p, restrs) 546 if p.node_tags[h][0] == c_tag] 547 c_bounds = [(p.loop_id (split), search_bound (p, (), hyps, split)) 548 for split in c_heads] 549 if not [b for (n, b) in c_bounds if b]: 550 trace ('no C bounds found (%s).' % c_bounds) 551 return None 552 553 asm_tag = p.node_tags[asm_split][0] 554 555 rep = rep_graph.mk_graph_slice (p) 556 i_seq_opts = [(0, 1), (1, 1), (2, 1)] 557 j_seq_opts = [(0, 1), (0, 2), (1, 1)] 558 tags = [p.node_tags[asm_split][0], c_tag] 559 try: 560 split = search.find_split (rep, asm_split, restrs, hyps, i_seq_opts, 561 j_seq_opts, 5, tags = [asm_tag, c_tag]) 562 except solver.SolverFailure, e: 563 return None 564 if not split or split[0] != 'Split': 565 trace ('no split found (%s).' % repr (split)) 566 return None 567 (_, split) = split 568 rep = rep_graph.mk_graph_slice (p) 569 checks = check.split_checks (p, (), hyps, split, tags = [asm_tag, c_tag]) 570 groups = check.proof_check_groups (checks) 571 try: 572 for group in groups: 573 (res, el) = check.test_hyp_group (rep, group) 574 if not res: 575 trace ('split check failed!') 576 trace ('failed at %s' % el) 577 return None 578 except solver.SolverFailure, e: 579 return None 580 (as_details, c_details, _, n, _) = split 581 (c_split, (seq_start, step), _) = c_details 582 c_bound = dict (c_bounds).get (p.loop_id (c_split)) 583 if not c_bound: 584 trace ('key split was not bounded (%r, %r).' % (c_split, c_bounds)) 585 return None 586 (c_bound, _) = c_bound 587 max_it = (c_bound - seq_start) / step 588 assert max_it > n, (max_it, n) 589 (_, (seq_start, step), _) = as_details 590 as_bound = seq_start + (max_it * step) 591 # increment by 1 as this may be a bound for a different splitting point 592 # which occurs later in the loop 593 as_bound += 1 594 return (as_bound, 'FromC') 595 596def get_prior_loop_heads (p, split, use_rep = None): 597 if use_rep: 598 rep = use_rep 599 else: 600 rep = rep_graph.mk_graph_slice (p) 601 prior = [] 602 split = p.loop_id (split) 603 for h in p.loop_heads (): 604 s = set (prior) 605 if h not in s and rep.get_reachable (h, split) and h != split: 606 # need to recurse to ensure prior are in order 607 prior2 = get_prior_loop_heads (p, h, use_rep = rep) 608 prior.extend ([h2 for h2 in prior2 if h2 not in s]) 609 prior.append (h) 610 return prior 611 612def add_loop_bound_restrs_hyps (p, restrs, hyps, split, bound, ctxt): 613 # time this for diagnostic reasons 614 start = time.time () 615 616 #vc_options([concrete numbers], [offsets]) 617 hyps = hyps + get_linear_series_hyps (p, split, restrs, hyps) 618 hyps = list (set (hyps)) 619 if bound == None or bound >= 10: 620 restrs = restrs + ((split, rep_graph.vc_options([0],[1])),) 621 else: 622 restrs = restrs + ((split, rep_graph.vc_upto (bound+1)),) 623 624 end = time.time () 625 save_extra_timing ('LoopBoundRestrHyps', ctxt, end - start) 626 627 return (restrs, hyps) 628 629max_acceptable_bound = [1000000] 630 631functions_hash = [None] 632 633def get_functions_hash (): 634 if functions_hash[0] != None: 635 return functions_hash[0] 636 h = hash (tuple (sorted ([(f, hash (functions[f])) for f in functions]))) 637 functions_hash[0] = h 638 return h 639 640addr_to_loop_id_cache = {} 641complex_loop_id_cache = {} 642 643def addr_to_loop_id (split): 644 if split not in addr_to_loop_id_cache: 645 add_fun_to_loop_data_cache (trace_refute.get_body_addrs_fun (split)) 646 return addr_to_loop_id_cache[split] 647 648def is_complex_loop (split): 649 split = addr_to_loop_id (split) 650 if split not in complex_loop_id_cache: 651 add_fun_to_loop_data_cache (trace_refute.get_body_addrs_fun (split)) 652 return complex_loop_id_cache[split] 653 654def get_loop_addrs (split): 655 split = addr_to_loop_id (split) 656 f = functions[trace_refute.get_body_addrs_fun (split)] 657 return [addr for addr in f.nodes if trace_refute.is_addr (addr) 658 if addr_to_loop_id_cache.get (addr) == split] 659 660def add_fun_to_loop_data_cache (fname): 661 p = functions[fname].as_problem (problem.Problem) 662 p.do_loop_analysis () 663 for h in p.loop_heads (): 664 addrs = [n for n in p.loop_body (h) 665 if trace_refute.is_addr (n)] 666 min_addr = min (addrs) 667 for addr in addrs: 668 addr_to_loop_id_cache[addr] = min_addr 669 complex_loop_id_cache[min_addr] = problem.has_inner_loop (p, h) 670 return min_addr 671 672def get_bound_super_ctxt (split, call_ctxt, no_splitting=False, 673 known_bound_only=False): 674 if not known_bounds: 675 load_bounds () 676 for (ctxt2, fn_hash, bound) in known_bounds.get ((split, 'Global'), []): 677 if ctxt2 == call_ctxt and fn_hash == get_functions_hash (): 678 return bound 679 min_loop_addr = addr_to_loop_id (split) 680 if min_loop_addr != split: 681 return get_bound_super_ctxt (min_loop_addr, call_ctxt, 682 no_splitting = no_splitting, known_bound_only = known_bound_only) 683 684 if known_bound_only: 685 return None 686 no_splitting_abort = [False] 687 try: 688 bound = get_bound_super_ctxt_inner (split, call_ctxt, 689 no_splitting = (no_splitting, no_splitting_abort)) 690 except problem.Abort, e: 691 bound = None 692 if no_splitting_abort[0]: 693 # don't record this bound, since it might change if splitting was allowed 694 return bound 695 known = known_bounds.setdefault ((split, 'Global'), []) 696 known.append ((call_ctxt, get_functions_hash (), bound)) 697 save_bound (True, split, call_ctxt, get_functions_hash (), None, bound) 698 return bound 699 700from trace_refute import (function_limit, ctxt_within_function_limits) 701 702def call_ctxt_computable (split, call_ctxt): 703 fs = [trace_refute.identify_function ([], [call_site]) 704 for call_site in call_ctxt] 705 non_computable = [f for f in fs if trace_refute.has_complex_loop (f)] 706 if non_computable: 707 trace ('avoiding functions with complex loops: %s' % non_computable) 708 return not non_computable 709 710def get_bound_super_ctxt_inner (split, call_ctxt, 711 no_splitting = (False, None)): 712 first_f = trace_refute.identify_function ([], (call_ctxt + [split])[:1]) 713 call_sites = all_call_sites (first_f) 714 715 if function_limit (first_f) == 0: 716 return (0, 'FunctionLimit') 717 safe_call_sites = [cs for cs in call_sites 718 if ctxt_within_function_limits ([cs] + call_ctxt)] 719 if call_sites and not safe_call_sites: 720 return (0, 'FunctionLimit') 721 722 if len (call_ctxt) < 3 and len (safe_call_sites) == 1: 723 call_ctxt2 = list (safe_call_sites) + call_ctxt 724 if call_ctxt_computable (split, call_ctxt2): 725 trace ('using unique calling context %s' % str ((split, call_ctxt2))) 726 return get_bound_super_ctxt (split, call_ctxt2) 727 728 fname = trace_refute.identify_function (call_ctxt, [split]) 729 bound = function_limit_bound (fname, split) 730 if bound: 731 return bound 732 733 bound = get_bound_ctxt (split, call_ctxt) 734 if bound: 735 return bound 736 737 trace ('no bound found immediately.') 738 739 if no_splitting[0]: 740 assert no_splitting[1], no_splitting 741 no_splitting[1][0] = True 742 trace ('cannot split by context (recursion).') 743 return None 744 745 # try to split over potential call sites 746 if len (call_ctxt) >= 3: 747 trace ('cannot split by context (context depth).') 748 return None 749 750 if len (call_sites) == 0: 751 # either entry point or nonsense 752 trace ('cannot split by context (reached top level).') 753 return None 754 755 problem_sites = [call_site for call_site in safe_call_sites 756 if not call_ctxt_computable (split, [call_site] + call_ctxt)] 757 if problem_sites: 758 trace ('cannot split by context (issues in %s).' % problem_sites) 759 return None 760 761 anc_bounds = [get_bound_super_ctxt (split, [call_site] + call_ctxt, 762 no_splitting = True) 763 for call_site in safe_call_sites] 764 if None in anc_bounds: 765 return None 766 (bound, kind) = max (anc_bounds) 767 return (bound, 'MergedBound') 768 769def function_limit_bound (fname, split): 770 p = functions[fname].as_problem (problem.Problem) 771 p.do_analysis () 772 cuts = [n for n in p.loop_body (split) 773 if p.nodes[n].kind == 'Call' 774 if function_limit (p.nodes[n].fname) != None] 775 if not cuts: 776 return None 777 graph = p.mk_node_graph (p.loop_body (split)) 778 # it is not possible to iterate the loop without visiting a bounded 779 # function. naively, this sets the limit to the sum of all the possible 780 # bounds, plus one because we can enter the loop a final time without 781 # visiting any function call site yet. 782 if logic.divides_loop (graph, set (cuts)): 783 fnames = set ([p.nodes[n].fname for n in cuts]) 784 return (sum ([function_limit (f) for f in fnames]) + 1, 'FunctionLimit') 785 786def loop_bound_difficulty_estimates (split, ctxt): 787 # various guesses at how hard the loop bounding problem is. 788 (p, hyps, addr_map) = get_call_ctxt_problem (split, ctxt, timing = False) 789 790 loop_id = p.loop_id (addr_map[split]) 791 assert loop_id 792 793 # number of instructions in the loop 794 inst_node_ids = set (addr_map.itervalues ()) 795 l_insts = [n for n in p.loop_body (loop_id) if n in inst_node_ids] 796 797 # number of instructions in the function 798 tag = p.node_tags[loop_id][0] 799 f_insts = [n for n in inst_node_ids if p.node_tags[n][0] == tag] 800 801 # number of instructions in the whole calling context 802 ctxt_insts = len (inst_node_ids) 803 804 # well, what else? 805 return (len (l_insts), len (f_insts), ctxt_insts) 806 807def load_timing (): 808 f = open ('%s/LoopBounds.txt' % target_objects.target_dir) 809 timing = {} 810 loop_time = 0.0 811 ext_time = 0.0 812 for line in f: 813 bits = line.split () 814 if not (bits and 'Timing' in bits[0]): 815 continue 816 if bits[0] == 'LoopBoundTiming': 817 (n, ext_ctxt) = parse_ctxt (bits, 1) 818 assert n == len (bits) - 1 819 time = float (bits[n]) 820 ctxt = ext_ctxt[:-1] 821 split = ext_ctxt[-1] 822 timing[(split, tuple(ctxt))] = time 823 loop_time += time 824 elif bits[0] == 'ExtraTiming': 825 time = float (bits[-1]) 826 ext_time += time 827 f.close () 828 f = open ('%s/time' % target_objects.target_dir) 829 [l] = [l for l in f if '(wall clock)' in l] 830 f.close () 831 tot_time_str = l.split ()[-1] 832 tot_time = sum ([float(s) * (60 ** i) 833 for (i, s) in enumerate (reversed (tot_time_str.split(':')))]) 834 835 return (loop_time, ext_time, tot_time, timing) 836 837def mk_timing_metrics (): 838 if not known_bounds: 839 load_bounds () 840 probs = [(split_bin_addr, tuple (call_ctxt), bound) 841 for (split_bin_addr, known) in known_bounds.iteritems () 842 if type (split_bin_addr) == int 843 for (call_ctxt, h, prev_bounds, bound) in known] 844 probs = set (probs) 845 data = [(split, ctxt, bound, 846 loop_bound_difficulty_estimates (split, list (ctxt))) 847 for (split, ctxt, bound) in probs] 848 return data 849 850# sigh, this is so much work. 851bound_kind_nums = { 852 'FunctionLimit': 2, 853 'NaiveBinSearch': 3, 854 'InductiveBinSearch': 4, 855 'FromC': 5, 856 'MergedBound': 6, 857} 858 859gnuplot_colours = [ 860 "dark-red", "dark-blue", "dark-green", "dark-grey", 861 "dark-orange", "dark-magenta", "dark-cyan"] 862 863def save_timing_metrics (num): 864 (loop_time, ext_time, tot_time, timing) = load_timing () 865 866 col = gnuplot_colours[num] 867 from target import short_name 868 869 time_ests = mk_timing_metrics () 870 import os 871 f = open ('%s/LoopTimingMetrics.txt' % target_objects.target_dir, 'w') 872 f.write ('"%s"\n' % short_name) 873 874 for (split, ctxt, bound, ests) in time_ests: 875 time = timing[(split, tuple (ctxt))] 876 if bound == None: 877 bdata = "1000000 7" 878 else: 879 bdata = '%d %d' % (bound[0], bound_kind_nums[bound[1]]) 880 (l_i, f_i, ct_i) = ests 881 f.write ('%s %s %s %s %s %r %s\n' % (short_name, l_i, f_i, ct_i, 882 bdata, col, time)) 883 f.close () 884 885def get_loop_heads (fun): 886 if not fun.entry: 887 return [] 888 p = fun.as_problem (problem.Problem) 889 p.do_loop_analysis () 890 loops = set () 891 for h in p.loop_heads (): 892 # any address in the loop will do. pick the smallest one 893 addr = min ([n for n in p.loop_body (h) if trace_refute.is_addr (n)]) 894 loops.add ((addr, fun.name, problem.has_inner_loop (p, h))) 895 return list (loops) 896 897def get_all_loop_heads (): 898 loops = set () 899 abort_funs = set () 900 for f in all_asm_functions (): 901 try: 902 loops.update (get_loop_heads (functions[f])) 903 except problem.Abort, e: 904 abort_funs.add (f) 905 if abort_funs: 906 trace ('Cannot analyse loops in: %s' % ', '.join (abort_funs)) 907 return loops 908 909def get_complex_loops (): 910 return [(loop, name) for (loop, name, compl) in get_all_loop_heads () 911 if compl] 912 913def search_all_loops (): 914 all_loops = get_all_loop_heads () 915 for (loop, _, _) in all_loops: 916 get_bound_super_ctxt (loop, []) 917 918main = search_all_loops 919 920if __name__ == '__main__': 921 import sys 922 args = target_objects.load_target_args () 923 if args == ['search']: 924 search_all_loops () 925 elif args[:1] == ['metrics']: 926 num = args[1:].index (str (target_objects.target_dir)) 927 save_timing_metrics (num) 928 929 930