1import check,search,problem,syntax,solver,logic,rep_graph,re 2from rep_graph import vc_num, vc_offs, pc_true_hyp, Hyp, eq_hyp 3from target_objects import functions, trace 4from check import restr_others,split_visit_one_visit 5from problem import inline_at_point 6from syntax import mk_not, true_term, false_term, mk_implies, Expr, Type, unspecified_precond_term,mk_and 7from rep_graph import mk_graph_slice, VisitCount, to_smt_expr 8from search import eval_model_expr 9import target_objects 10import trace_refute 11import stack_logic 12import time 13 14#tryFun must take exactly 1 argument 15def downBinSearch(minimum, maximum, tryFun): 16 upperBound = maximum 17 lowerBound = minimum 18 while upperBound > lowerBound: 19 print 'searching in %d - %d' % (lowerBound,upperBound) 20 cur = (lowerBound + upperBound) / 2 21 if tryFun(cur): 22 upperBound = cur 23 else: 24 lowerBound = cur + 1 25 assert upperBound == lowerBound 26 ret = lowerBound 27 return ret 28 29def upDownBinSearch (minimum, maximum, tryFun): 30 """performs a binary search between minimum and maximum, but does not start 31 in the middle. instead it does a binary escalation up from the minimum 32 first. this makes sense for ranges e.g. 2 - 1000000 where the bound is 33 likely to be near the bottom of the range. it also avoids testing values 34 more than twice as high as the bound, which may avoid some issues.""" 35 upperBound = 2 * minimum 36 while upperBound < maximum: 37 if tryFun (upperBound): 38 return downBinSearch (minimum, upperBound, tryFun) 39 else: 40 upperBound *= 2 41 if tryFun (maximum): 42 return downBinSearch (minimum, maximum, tryFun) 43 else: 44 return None 45 46def addr_of_node (preds, n): 47 while not trace_refute.is_addr (n): 48 [n] = preds[n] 49 return n 50 51def all_asm_functions (): 52 ss = stack_logic.get_functions_with_tag ('ASM') 53 return [s for s in ss if not stack_logic.is_instruction (s)] 54 55call_site_set = {} 56 57def build_call_site_set (): 58 for f in all_asm_functions (): 59 preds = logic.compute_preds (functions[f].nodes) 60 for (n, node) in functions[f].nodes.iteritems (): 61 if node.kind == 'Call': 62 s = call_site_set.setdefault (node.fname, set ()) 63 s.add (addr_of_node (preds, n)) 64 call_site_set[('IsLoaded', None)] = True 65 66def all_call_sites (f): 67 if not call_site_set: 68 build_call_site_set () 69 return list (call_site_set.get (f, [])) 70 71#naive binary search to find loop bounds 72def findLoopBoundBS(p_n, p, restrs=None, hyps=None, try_seq=None): 73 if hyps == None: 74 hyps = [] 75 #print 'restrs: %s' % str(restrs) 76 if try_seq == None: 77 #bound_try_seq = [1,2,3,4,5,10,50,130,200,260] 78 #bound_try_seq = [0,1,2,3,4,5,10,50,260] 79 calls = [n for n in p.loop_body (p_n) if p.nodes[n].kind == 'Call'] 80 if calls: 81 bound_try_seq = [0,1,20] 82 else: 83 bound_try_seq = [0,1,20,34] 84 else: 85 bound_try_seq = try_seq 86 rep = mk_graph_slice (p, fast = True) 87 #get the head 88 #print 'Binary addr: %s' % toHexs(self.toPhyAddrs(p_loop_heads)) 89 loop_bound = None 90 p_loop_heads = [n for n in p.loop_data if p.loop_data[n][0] == 'Head'] 91 print 'p_loop_heads: %s' % p_loop_heads 92 93 if restrs == None: 94 others = [x for x in p_loop_heads if not x == p_n] 95 #vc_options([concrete numbers], [offsets]) 96 restrs = tuple( [(n2, rep_graph.vc_options([0],[1])) for n2 in others] ) 97 98 print 'getting the initial bound' 99 #try: 100 index = tryLoopBound(p_n,p,bound_try_seq,rep,restrs = restrs, hyps=hyps) 101 if index == -1: 102 return None 103 print 'got the initial bound %d' % bound_try_seq[index] 104 105 #do a downward binary search to find the concrete loop bound 106 if index == 0: 107 loop_bound = bound_try_seq[0] 108 print 'bound = %d' % loop_bound 109 return loop_bound 110 loop_bound = downBinSearch(bound_try_seq[index-1], bound_try_seq[index], lambda x: tryLoopBound(p_n,p,[x],rep,restrs=restrs, hyps=hyps, bin_return=True)) 111 print 'bound = %d' % loop_bound 112 return loop_bound 113 114def default_n_vc_cases (p, n): 115 head = p.loop_id (n) 116 general = [(n2, rep_graph.vc_options ([0], [1])) 117 for n2 in p.loop_heads () 118 if n2 != head] 119 120 if head: 121 return [(n, tuple (general + [(head, rep_graph.vc_num (1))])), 122 (n, tuple (general + [(head, rep_graph.vc_offs (1))]))] 123 specific = [(head, rep_graph.vc_offs (1)) for _ in [1] if head] 124 return [(n, tuple (general + specific))] 125 126def callNodes(p, fs= None): 127 ns = [n for n in p.nodes if p.nodes[n].kind == 'Call'] 128 if fs != None: 129 ns = [n for n in ns if p.nodes[n].fname in fs] 130 return ns 131 132def noHaltHyps(split,p): 133 ret = [] 134 all_halts = callNodes(p,fs=['halt']) 135 for x in all_halts: 136 ret += [rep_graph.pc_false_hyp((n_vc, p.node_tags[x][0])) 137 for n_vc in default_n_vc_cases (p, x)] 138 return ret 139 140def tryLoopBound(p_n, p, bounds,rep,restrs =None, hints =None,kind = 'Number',bin_return = False,hyps = None): 141 if restrs == None: 142 restrs = () 143 if hints == None: 144 hints = [] 145 if hyps == None: 146 hyps = [] 147 tag = p.node_tags[p_n][0] 148 from stack_logic import default_n_vc 149 print 'trying bound: %s' % bounds 150 ret_bounds = [] 151 for (index,i) in enumerate(bounds): 152 print 'testing %d' % i 153 restrs2 = restrs + ((p_n, VisitCount (kind, i)), ) 154 try: 155 pc = rep.get_pc ((p_n, restrs2)) 156 except: 157 print 'get_pc failed' 158 if bin_return: 159 return False 160 else: 161 return -1 162 #print 'got rep_.get_pc' 163 restrs3 = restr_others (p, restrs2, 2) 164 epc = rep.get_pc (('Err', restrs3), tag = tag) 165 hyp = mk_implies (mk_not (epc), mk_not (pc)) 166 hyps = hyps + noHaltHyps(p_n,p) 167 168 #hyps = [] 169 #print 'calling test_hyp_whyps' 170 if rep.test_hyp_whyps (hyp, hyps): 171 print 'p_n %d: split limit found: %d' % (p_n, i) 172 if bin_return: 173 return True 174 return index 175 if bin_return: 176 return False 177 print 'loop bound not found!' 178 return -1 179 assert False, 'failed to find loop bound for p_n %d' % p_n 180 181def get_linear_series_eqs (p, split, restrs, hyps, omit_standard = False): 182 k = ('linear_series_eqs', split, restrs, tuple (hyps)) 183 if k in p.cached_analysis: 184 if omit_standard: 185 standard = set (search.mk_seq_eqs (p, split, 1, with_rodata = False)) 186 return set (p.cached_analysis[k]) - standard 187 return p.cached_analysis[k] 188 189 cands = search.mk_seq_eqs (p, split, 1, with_rodata = False) 190 cands += candidate_additional_eqs (p, split) 191 (tag, _) = p.node_tags[split] 192 193 rep = rep_graph.mk_graph_slice (p, fast = True) 194 195 def do_checks (eqs_assume, eqs): 196 checks = (check.single_loop_induct_step_checks (p, restrs, hyps, tag, 197 split, 1, eqs, eqs_assume = eqs_assume) 198 + check.single_loop_induct_base_checks (p, restrs, hyps, tag, 199 split, 1, eqs)) 200 201 groups = check.proof_check_groups (checks) 202 for group in groups: 203 (res, _) = check.test_hyp_group (rep, group) 204 if not res: 205 return False 206 return True 207 208 eqs = [] 209 failed = [] 210 while cands: 211 cand = cands.pop () 212 if do_checks (eqs, [cand]): 213 eqs.append (cand) 214 failed.reverse () 215 cands = failed + cands 216 failed = [] 217 else: 218 failed.append (cand) 219 220 assert do_checks ([], eqs) 221 p.cached_analysis[k] = eqs 222 if omit_standard: 223 standard = set (search.mk_seq_eqs (p, split, 1, with_rodata = False)) 224 return set (eqs) - standard 225 return eqs 226 227def get_linear_series_hyps (p, split, restrs, hyps): 228 eqs = get_linear_series_eqs (p, split, restrs, hyps) 229 (tag, _) = p.node_tags[split] 230 hyps = [h for (h, _) in linear_eq_hyps_at_visit (tag, split, eqs, 231 restrs, vc_offs (0))] 232 return hyps 233 234def is_zero (expr): 235 return expr.kind == 'Num' and expr.val & ((1 << expr.typ.num) - 1) == 0 236 237def candidate_additional_eqs (p, split): 238 eq_vals = set () 239 def visitor (expr): 240 if expr.is_op ('Equals') and expr.vals[0].typ.kind == 'Word': 241 [x, y] = expr.vals 242 eq_vals.update ([(x, y), (y, x)]) 243 for n in p.loop_body (split): 244 p.nodes[n].visit (lambda x: (), visitor) 245 for (x, y) in list (eq_vals): 246 if is_zero (x) and y.is_op ('Plus'): 247 [x, y] = y.vals 248 eq_vals.add ((x, syntax.mk_uminus (y))) 249 eq_vals.add ((y, syntax.mk_uminus (x))) 250 elif is_zero (x) and y.is_op ('Minus'): 251 [x, y] = y.vals 252 eq_vals.add ((x, y)) 253 eq_vals.add ((y, x)) 254 255 loop = syntax.mk_var ('%i', syntax.word32T) 256 minus_loop_step = syntax.mk_uminus (loop) 257 258 vas = search.get_loop_var_analysis_at(p, split) 259 ls_vas = dict ([(var, [data]) for (var, data) in vas 260 if data[0] == 'LoopLinearSeries']) 261 cmp_series = [(x, y, rew, offs) for (x, y) in eq_vals 262 for (_, rew, offs) in ls_vas.get (x, [])] 263 odd_eqs = [] 264 for (x, y, rew, offs) in cmp_series: 265 x_init_cmp1 = syntax.mk_less_eq (x, rew (x, minus_loop_step)) 266 x_init_cmp2 = syntax.mk_less_eq (rew (x, minus_loop_step), x) 267 fin_cmp1 = syntax.mk_less (x, y) 268 fin_cmp2 = syntax.mk_less (y, x) 269 odd_eqs.append (syntax.mk_eq (x_init_cmp1, fin_cmp1)) 270 odd_eqs.append (syntax.mk_eq (x_init_cmp2, fin_cmp1)) 271 odd_eqs.append (syntax.mk_eq (x_init_cmp1, fin_cmp2)) 272 odd_eqs.append (syntax.mk_eq (x_init_cmp2, fin_cmp2)) 273 274 ass_eqs = [] 275 var_deps = p.compute_var_dependencies () 276 for hook in target_objects.hooks ('extra_wcet_assertions'): 277 for assn in hook (var_deps[split]): 278 ass_eqs.append (assn) 279 280 return odd_eqs + ass_eqs 281 282extra_loop_consts = [2 ** 16] 283 284call_ctxt_problems = [] 285 286avoid_C_information = [False] 287 288def get_call_ctxt_problem (split, call_ctxt, timing = True): 289 # time this for diagnostic reasons 290 start = time.time () 291 from trace_refute import identify_function, build_compound_problem_with_links 292 f = identify_function (call_ctxt, [split]) 293 for (ctxt2, p, hyps, addr_map) in call_ctxt_problems: 294 if ctxt2 == (call_ctxt, f): 295 return (p, hyps, addr_map) 296 297 (p, hyps, addr_map) = build_compound_problem_with_links (call_ctxt, f) 298 if avoid_C_information[0]: 299 hyps = [h for h in hyps if not has_C_information (p, h)] 300 call_ctxt_problems.append(((call_ctxt, f), p, hyps, addr_map)) 301 del call_ctxt_problems[: -20] 302 303 end = time.time () 304 if timing: 305 save_extra_timing ('GetProblem', call_ctxt + [split], end - start) 306 307 return (p, hyps, addr_map) 308 309def has_C_information (p, hyp): 310 for (n_vc, tag) in hyp.visits (): 311 if not p.hook_tag_hints.get (tag, None) == 'ASM': 312 return True 313 314known_bound_restr_hyps = {} 315 316known_bounds = {} 317 318def serialise_bound (addr, bound_info): 319 if bound_info == None: 320 return [hex(addr), "None", "None"] 321 else: 322 (bound, kind) = bound_info 323 assert logic.is_int (bound) 324 assert str (kind) == kind 325 return [hex (addr), str (bound), kind] 326 327def save_bound (glob, split_bin_addr, call_ctxt, prob_hash, prev_bounds, bound, 328 time = None): 329 f_names = [trace_refute.get_body_addrs_fun (x) 330 for x in call_ctxt + [split_bin_addr]] 331 loop_name = '<%s>' % ' -> '.join (f_names) 332 comment = '# bound for loop in %s:' % loop_name 333 ss = ['LoopBound'] + serialise_bound (split_bin_addr, bound) 334 if glob: 335 ss[0] = 'GlobalLoopBound' 336 ss += [str (len (call_ctxt))] + map (hex, call_ctxt) 337 ss += [str (prob_hash)] 338 if glob: 339 assert prev_bounds == None 340 else: 341 ss += [str (len (prev_bounds))] 342 for (split, bound) in prev_bounds: 343 ss += serialise_bound (split, bound) 344 s = ' '.join (ss) 345 f = open ('%s/LoopBounds.txt' % target_objects.target_dir, 'a') 346 f.write (comment + '\n') 347 f.write (s + '\n') 348 if time != None: 349 ctxt2 = call_ctxt + [split_bin_addr] 350 ctxt2 = ' '.join ([str (len (ctxt2))] + map (hex, ctxt2)) 351 f.write ('LoopBoundTiming %s %s\n' % (ctxt2, time)) 352 f.close () 353 trace ('Found bound %s for 0x%x in %s.' % (bound, split_bin_addr, 354 loop_name)) 355 356def save_extra_timing (nm, ctxt, time): 357 ss = ['ExtraTiming', nm, str (len (ctxt))] + map (hex, ctxt) + [str(time)] 358 f = open ('%s/LoopBounds.txt' % target_objects.target_dir, 'a') 359 f.write (' '.join (ss) + '\n') 360 f.close () 361 362def parse_bound (ss, n): 363 addr = syntax.parse_int (ss[n]) 364 bound = ss[n + 1] 365 if bound == 'None': 366 bound = None 367 return (n + 3, (addr, None)) 368 else: 369 bound = syntax.parse_int (bound) 370 kind = ss[n + 2] 371 return (n + 3, (addr, (bound, kind))) 372 373def parse_ctxt_id (bits, n): 374 return (n + 1, syntax.parse_int (bits[n])) 375 376def parse_ctxt (bits, n): 377 return syntax.parse_list (parse_ctxt_id, bits, n) 378 379def load_bounds (): 380 try: 381 f = open ('%s/LoopBounds.txt' % target_objects.target_dir) 382 ls = list (f) 383 f.close () 384 except IOError, e: 385 ls = [] 386 from syntax import parse_int, parse_list 387 for l in ls: 388 bits = l.split () 389 if bits[:1] not in [['LoopBound'], ['GlobalLoopBound']]: 390 continue 391 (n, (addr, bound)) = parse_bound (bits, 1) 392 (n, ctxt) = parse_ctxt (bits, n) 393 prob_hash = parse_int (bits[n]) 394 n += 1 395 if bits[0] == 'LoopBound': 396 (n, prev_bounds) = parse_list (parse_bound, bits, n) 397 assert n == len (bits), bits 398 known = known_bounds.setdefault (addr, []) 399 known.append ((ctxt, prob_hash, prev_bounds, bound)) 400 else: 401 assert n == len (bits), bits 402 known = known_bounds.setdefault ((addr, 'Global'), []) 403 known.append ((ctxt, prob_hash, bound)) 404 known_bounds['Loaded'] = True 405 406def get_bound_ctxt (split, call_ctxt, use_cache = True): 407 trace ('Getting bound for 0x%x in context %s.' % (split, call_ctxt)) 408 (p, hyps, addr_map) = get_call_ctxt_problem (split, call_ctxt) 409 410 orig_split = split 411 split = p.loop_id (addr_map[split]) 412 assert split, (orig_split, call_ctxt) 413 split_bin_addr = min ([addr for addr in addr_map 414 if p.loop_id (addr_map[addr]) == split]) 415 416 prior = get_prior_loop_heads (p, split) 417 restrs = () 418 prev_bounds = [] 419 for split2 in prior: 420 # recursion! 421 split2 = p.loop_id (split2) 422 assert split2 423 addr = min ([addr for addr in addr_map 424 if p.loop_id (addr_map[addr]) == split2]) 425 bound = get_bound_ctxt (addr, call_ctxt) 426 prev_bounds.append ((addr, bound)) 427 k = (p.name, split2, bound, restrs, tuple (hyps)) 428 if k in known_bound_restr_hyps: 429 (restrs, hyps) = known_bound_restr_hyps[k] 430 else: 431 (restrs, hyps) = add_loop_bound_restrs_hyps (p, restrs, hyps, 432 split2, bound, call_ctxt + [orig_split]) 433 known_bound_restr_hyps[k] = (restrs, hyps) 434 435 # start timing now. we miss some setup time, but it avoids double counting 436 # the recursive searches. 437 start = time.time () 438 439 p_h = problem_hash (p) 440 prev_bounds = sorted (prev_bounds) 441 if not known_bounds: 442 load_bounds () 443 known = known_bounds.get (split_bin_addr, []) 444 for (call_ctxt2, h, prev_bounds2, bound) in known: 445 match = (not call_ctxt2 or call_ctxt[- len (call_ctxt2):] == call_ctxt2) 446 if match and use_cache and h == p_h and prev_bounds2 == prev_bounds: 447 return bound 448 bound = search_bin_bound (p, restrs, hyps, split) 449 known = known_bounds.setdefault (split_bin_addr, []) 450 known.append ((call_ctxt, p_h, prev_bounds, bound)) 451 end = time.time () 452 save_bound (False, split_bin_addr, call_ctxt, p_h, prev_bounds, bound, 453 time = end - start) 454 return bound 455 456def problem_hash (p): 457 return syntax.hash_tuplify ([p.name, p.entries, 458 sorted (p.outputs.iteritems ()), sorted (p.nodes.iteritems ())]) 459 460def search_bin_bound (p, restrs, hyps, split): 461 trace ('Searching for bound for 0x%x in %s.', (split, p.name)) 462 bound = search_bound (p, restrs, hyps, split) 463 if bound: 464 return bound 465 466 # try to use a bound inferred from C 467 if avoid_C_information[0]: 468 # OK told not to 469 return None 470 if get_prior_loop_heads (p, split): 471 # too difficult for now 472 return None 473 asm_tag = p.node_tags[split][0] 474 (_, fname, _) = p.get_entry_details (asm_tag) 475 funs = [f for pair in target_objects.pairings[fname] 476 for f in pair.funs.values ()] 477 c_tags = [tag for tag in p.tags () 478 if p.get_entry_details (tag)[1] in funs and tag != asm_tag] 479 if len (c_tags) != 1: 480 print 'Surprised to see multiple matching tags %s' % c_tags 481 return None 482 483 [c_tag] = c_tags 484 485 rep = rep_graph.mk_graph_slice (p) 486 if len (search.get_loop_entry_sites (rep, restrs, hyps, split)) != 1: 487 # technical, but it's not going to work in this case 488 return None 489 490 return getBinaryBoundFromC (p, c_tag, split, restrs, hyps) 491 492def rab_test (): 493 [split_bin_addr] = get_loop_heads (functions['resolveAddressBits']) 494 (p, hyps, addr_map) = get_call_ctxt_problem (split_bin_addr, []) 495 split = p.loop_id (addr_map[split_bin_addr]) 496 [c_tag] = [tag for tag in p.tags () if tag != p.node_tags[split][0]] 497 return getBinaryBoundFromC (p, c_tag, split, (), hyps) 498 499last_search_bound = [0] 500 501def search_bound (p, restrs, hyps, split): 502 last_search_bound[0] = (p, restrs, hyps, split) 503 504 # try a naive bin search first 505 # limit this to a small bound for time purposes 506 # - for larger bounds the less naive approach can be faster 507 bound = findLoopBoundBS(split, p, restrs=restrs, hyps=hyps, 508 try_seq = [0, 1, 6]) 509 if bound != None: 510 return (bound, 'NaiveBinSearch') 511 512 l_hyps = get_linear_series_hyps (p, split, restrs, hyps) 513 514 rep = rep_graph.mk_graph_slice (p, fast = True) 515 516 def test (n): 517 assert n > 10 518 hyp = check.mk_loop_counter_eq_hyp (p, split, restrs, n - 2) 519 visit = ((split, vc_offs (2)), ) + restrs 520 continue_to_split_guess = rep.get_pc ((split, visit)) 521 return rep.test_hyp_whyps (syntax.mk_not (continue_to_split_guess), 522 [hyp] + l_hyps + hyps) 523 524 # findLoopBoundBS always checks to at least 16 525 min_bound = 16 526 max_bound = max_acceptable_bound[0] 527 bound = upDownBinSearch (min_bound, max_bound, test) 528 if bound != None and test (bound): 529 return (bound, 'InductiveBinSearch') 530 531 # let the naive bin search go a bit further 532 bound = findLoopBoundBS(split, p, restrs=restrs, hyps=hyps) 533 if bound != None: 534 return (bound, 'NaiveBinSearch') 535 536 return None 537 538def getBinaryBoundFromC (p, c_tag, asm_split, restrs, hyps): 539 c_heads = [h for h in search.init_loops_to_split (p, restrs) 540 if p.node_tags[h][0] == c_tag] 541 c_bounds = [(p.loop_id (split), search_bound (p, (), hyps, split)) 542 for split in c_heads] 543 if not [b for (n, b) in c_bounds if b]: 544 trace ('no C bounds found (%s).' % c_bounds) 545 return None 546 547 asm_tag = p.node_tags[asm_split][0] 548 549 rep = rep_graph.mk_graph_slice (p) 550 i_seq_opts = [(0, 1), (1, 1), (2, 1)] 551 j_seq_opts = [(0, 1), (0, 2), (1, 1)] 552 tags = [p.node_tags[asm_split][0], c_tag] 553 try: 554 split = search.find_split (rep, asm_split, restrs, hyps, i_seq_opts, 555 j_seq_opts, 5, tags = [asm_tag, c_tag]) 556 except solver.SolverFailure, e: 557 return None 558 if not split or split[0] != 'Split': 559 trace ('no split found (%s).' % repr (split)) 560 return None 561 (_, split) = split 562 rep = rep_graph.mk_graph_slice (p) 563 checks = check.split_checks (p, (), hyps, split, tags = [asm_tag, c_tag]) 564 groups = check.proof_check_groups (checks) 565 try: 566 for group in groups: 567 (res, el) = check.test_hyp_group (rep, group) 568 if not res: 569 trace ('split check failed!') 570 trace ('failed at %s' % el) 571 return None 572 except solver.SolverFailure, e: 573 return None 574 (as_details, c_details, _, n, _) = split 575 (c_split, (seq_start, step), _) = c_details 576 c_bound = dict (c_bounds).get (p.loop_id (c_split)) 577 if not c_bound: 578 trace ('key split was not bounded (%r, %r).' % (c_split, c_bounds)) 579 return None 580 (c_bound, _) = c_bound 581 max_it = (c_bound - seq_start) / step 582 assert max_it > n, (max_it, n) 583 (_, (seq_start, step), _) = as_details 584 as_bound = seq_start + (max_it * step) 585 # increment by 1 as this may be a bound for a different splitting point 586 # which occurs later in the loop 587 as_bound += 1 588 return (as_bound, 'FromC') 589 590def get_prior_loop_heads (p, split, use_rep = None): 591 if use_rep: 592 rep = use_rep 593 else: 594 rep = rep_graph.mk_graph_slice (p) 595 prior = [] 596 split = p.loop_id (split) 597 for h in p.loop_heads (): 598 s = set (prior) 599 if h not in s and rep.get_reachable (h, split) and h != split: 600 # need to recurse to ensure prior are in order 601 prior2 = get_prior_loop_heads (p, h, use_rep = rep) 602 prior.extend ([h2 for h2 in prior2 if h2 not in s]) 603 prior.append (h) 604 return prior 605 606def add_loop_bound_restrs_hyps (p, restrs, hyps, split, bound, ctxt): 607 # time this for diagnostic reasons 608 start = time.time () 609 610 #vc_options([concrete numbers], [offsets]) 611 hyps = hyps + get_linear_series_hyps (p, split, restrs, hyps) 612 hyps = list (set (hyps)) 613 if bound == None or bound >= 10: 614 restrs = restrs + ((split, rep_graph.vc_options([0],[1])),) 615 else: 616 restrs = restrs + ((split, rep_graph.vc_upto (bound+1)),) 617 618 end = time.time () 619 save_extra_timing ('LoopBoundRestrHyps', ctxt, end - start) 620 621 return (restrs, hyps) 622 623max_acceptable_bound = [1000000] 624 625functions_hash = [None] 626 627def get_functions_hash (): 628 if functions_hash[0] != None: 629 return functions_hash[0] 630 h = hash (tuple (sorted ([(f, hash (functions[f])) for f in functions]))) 631 functions_hash[0] = h 632 return h 633 634addr_to_loop_id_cache = {} 635complex_loop_id_cache = {} 636 637def addr_to_loop_id (split): 638 if split not in addr_to_loop_id_cache: 639 add_fun_to_loop_data_cache (trace_refute.get_body_addrs_fun (split)) 640 return addr_to_loop_id_cache[split] 641 642def is_complex_loop (split): 643 split = addr_to_loop_id (split) 644 if split not in complex_loop_id_cache: 645 add_fun_to_loop_data_cache (trace_refute.get_body_addrs_fun (split)) 646 return complex_loop_id_cache[split] 647 648def get_loop_addrs (split): 649 split = addr_to_loop_id (split) 650 f = functions[trace_refute.get_body_addrs_fun (split)] 651 return [addr for addr in f.nodes if trace_refute.is_addr (addr) 652 if addr_to_loop_id_cache.get (addr) == split] 653 654def add_fun_to_loop_data_cache (fname): 655 p = functions[fname].as_problem (problem.Problem) 656 p.do_loop_analysis () 657 for h in p.loop_heads (): 658 addrs = [n for n in p.loop_body (h) 659 if trace_refute.is_addr (n)] 660 min_addr = min (addrs) 661 for addr in addrs: 662 addr_to_loop_id_cache[addr] = min_addr 663 complex_loop_id_cache[min_addr] = problem.has_inner_loop (p, h) 664 return min_addr 665 666def get_bound_super_ctxt (split, call_ctxt, no_splitting=False, 667 known_bound_only=False): 668 if not known_bounds: 669 load_bounds () 670 for (ctxt2, fn_hash, bound) in known_bounds.get ((split, 'Global'), []): 671 if ctxt2 == call_ctxt and fn_hash == get_functions_hash (): 672 return bound 673 min_loop_addr = addr_to_loop_id (split) 674 if min_loop_addr != split: 675 return get_bound_super_ctxt (min_loop_addr, call_ctxt, 676 no_splitting = no_splitting, known_bound_only = known_bound_only) 677 678 if known_bound_only: 679 return None 680 no_splitting_abort = [False] 681 try: 682 bound = get_bound_super_ctxt_inner (split, call_ctxt, 683 no_splitting = (no_splitting, no_splitting_abort)) 684 except problem.Abort, e: 685 bound = None 686 if no_splitting_abort[0]: 687 # don't record this bound, since it might change if splitting was allowed 688 return bound 689 known = known_bounds.setdefault ((split, 'Global'), []) 690 known.append ((call_ctxt, get_functions_hash (), bound)) 691 save_bound (True, split, call_ctxt, get_functions_hash (), None, bound) 692 return bound 693 694from trace_refute import (function_limit, ctxt_within_function_limits) 695 696def call_ctxt_computable (split, call_ctxt): 697 fs = [trace_refute.identify_function ([], [call_site]) 698 for call_site in call_ctxt] 699 non_computable = [f for f in fs if trace_refute.has_complex_loop (f)] 700 if non_computable: 701 trace ('avoiding functions with complex loops: %s' % non_computable) 702 return not non_computable 703 704def get_bound_super_ctxt_inner (split, call_ctxt, 705 no_splitting = (False, None)): 706 first_f = trace_refute.identify_function ([], (call_ctxt + [split])[:1]) 707 call_sites = all_call_sites (first_f) 708 709 if function_limit (first_f) == 0: 710 return (0, 'FunctionLimit') 711 safe_call_sites = [cs for cs in call_sites 712 if ctxt_within_function_limits ([cs] + call_ctxt)] 713 if call_sites and not safe_call_sites: 714 return (0, 'FunctionLimit') 715 716 if len (call_ctxt) < 3 and len (safe_call_sites) == 1: 717 call_ctxt2 = list (safe_call_sites) + call_ctxt 718 if call_ctxt_computable (split, call_ctxt2): 719 trace ('using unique calling context %s' % str ((split, call_ctxt2))) 720 return get_bound_super_ctxt (split, call_ctxt2) 721 722 fname = trace_refute.identify_function (call_ctxt, [split]) 723 bound = function_limit_bound (fname, split) 724 if bound: 725 return bound 726 727 bound = get_bound_ctxt (split, call_ctxt) 728 if bound: 729 return bound 730 731 trace ('no bound found immediately.') 732 733 if no_splitting[0]: 734 assert no_splitting[1], no_splitting 735 no_splitting[1][0] = True 736 trace ('cannot split by context (recursion).') 737 return None 738 739 # try to split over potential call sites 740 if len (call_ctxt) >= 3: 741 trace ('cannot split by context (context depth).') 742 return None 743 744 if len (call_sites) == 0: 745 # either entry point or nonsense 746 trace ('cannot split by context (reached top level).') 747 return None 748 749 problem_sites = [call_site for call_site in safe_call_sites 750 if not call_ctxt_computable (split, [call_site] + call_ctxt)] 751 if problem_sites: 752 trace ('cannot split by context (issues in %s).' % problem_sites) 753 return None 754 755 anc_bounds = [get_bound_super_ctxt (split, [call_site] + call_ctxt, 756 no_splitting = True) 757 for call_site in safe_call_sites] 758 if None in anc_bounds: 759 return None 760 (bound, kind) = max (anc_bounds) 761 return (bound, 'MergedBound') 762 763def function_limit_bound (fname, split): 764 p = functions[fname].as_problem (problem.Problem) 765 p.do_analysis () 766 cuts = [n for n in p.loop_body (split) 767 if p.nodes[n].kind == 'Call' 768 if function_limit (p.nodes[n].fname) != None] 769 if not cuts: 770 return None 771 graph = p.mk_node_graph (p.loop_body (split)) 772 # it is not possible to iterate the loop without visiting a bounded 773 # function. naively, this sets the limit to the sum of all the possible 774 # bounds, plus one because we can enter the loop a final time without 775 # visiting any function call site yet. 776 if logic.divides_loop (graph, set (cuts)): 777 fnames = set ([p.nodes[n].fname for n in cuts]) 778 return (sum ([function_limit (f) for f in fnames]) + 1, 'FunctionLimit') 779 780def loop_bound_difficulty_estimates (split, ctxt): 781 # various guesses at how hard the loop bounding problem is. 782 (p, hyps, addr_map) = get_call_ctxt_problem (split, ctxt, timing = False) 783 784 loop_id = p.loop_id (addr_map[split]) 785 assert loop_id 786 787 # number of instructions in the loop 788 inst_node_ids = set (addr_map.itervalues ()) 789 l_insts = [n for n in p.loop_body (loop_id) if n in inst_node_ids] 790 791 # number of instructions in the function 792 tag = p.node_tags[loop_id][0] 793 f_insts = [n for n in inst_node_ids if p.node_tags[n][0] == tag] 794 795 # number of instructions in the whole calling context 796 ctxt_insts = len (inst_node_ids) 797 798 # well, what else? 799 return (len (l_insts), len (f_insts), ctxt_insts) 800 801def load_timing (): 802 f = open ('%s/LoopBounds.txt' % target_objects.target_dir) 803 timing = {} 804 loop_time = 0.0 805 ext_time = 0.0 806 for line in f: 807 bits = line.split () 808 if not (bits and 'Timing' in bits[0]): 809 continue 810 if bits[0] == 'LoopBoundTiming': 811 (n, ext_ctxt) = parse_ctxt (bits, 1) 812 assert n == len (bits) - 1 813 time = float (bits[n]) 814 ctxt = ext_ctxt[:-1] 815 split = ext_ctxt[-1] 816 timing[(split, tuple(ctxt))] = time 817 loop_time += time 818 elif bits[0] == 'ExtraTiming': 819 time = float (bits[-1]) 820 ext_time += time 821 f.close () 822 f = open ('%s/time' % target_objects.target_dir) 823 [l] = [l for l in f if '(wall clock)' in l] 824 f.close () 825 tot_time_str = l.split ()[-1] 826 tot_time = sum ([float(s) * (60 ** i) 827 for (i, s) in enumerate (reversed (tot_time_str.split(':')))]) 828 829 return (loop_time, ext_time, tot_time, timing) 830 831def mk_timing_metrics (): 832 if not known_bounds: 833 load_bounds () 834 probs = [(split_bin_addr, tuple (call_ctxt), bound) 835 for (split_bin_addr, known) in known_bounds.iteritems () 836 if type (split_bin_addr) == int 837 for (call_ctxt, h, prev_bounds, bound) in known] 838 probs = set (probs) 839 data = [(split, ctxt, bound, 840 loop_bound_difficulty_estimates (split, list (ctxt))) 841 for (split, ctxt, bound) in probs] 842 return data 843 844# sigh, this is so much work. 845bound_kind_nums = { 846 'FunctionLimit': 2, 847 'NaiveBinSearch': 3, 848 'InductiveBinSearch': 4, 849 'FromC': 5, 850 'MergedBound': 6, 851} 852 853gnuplot_colours = [ 854 "dark-red", "dark-blue", "dark-green", "dark-grey", 855 "dark-orange", "dark-magenta", "dark-cyan"] 856 857def save_timing_metrics (num): 858 (loop_time, ext_time, tot_time, timing) = load_timing () 859 860 col = gnuplot_colours[num] 861 from target import short_name 862 863 time_ests = mk_timing_metrics () 864 import os 865 f = open ('%s/LoopTimingMetrics.txt' % target_objects.target_dir, 'w') 866 f.write ('"%s"\n' % short_name) 867 868 for (split, ctxt, bound, ests) in time_ests: 869 time = timing[(split, tuple (ctxt))] 870 if bound == None: 871 bdata = "1000000 7" 872 else: 873 bdata = '%d %d' % (bound[0], bound_kind_nums[bound[1]]) 874 (l_i, f_i, ct_i) = ests 875 f.write ('%s %s %s %s %s %r %s\n' % (short_name, l_i, f_i, ct_i, 876 bdata, col, time)) 877 f.close () 878 879def get_loop_heads (fun): 880 if not fun.entry: 881 return [] 882 p = fun.as_problem (problem.Problem) 883 p.do_loop_analysis () 884 loops = set () 885 for h in p.loop_heads (): 886 # any address in the loop will do. pick the smallest one 887 addr = min ([n for n in p.loop_body (h) if trace_refute.is_addr (n)]) 888 loops.add ((addr, fun.name, problem.has_inner_loop (p, h))) 889 return list (loops) 890 891def get_all_loop_heads (): 892 loops = set () 893 abort_funs = set () 894 for f in all_asm_functions (): 895 try: 896 loops.update (get_loop_heads (functions[f])) 897 except problem.Abort, e: 898 abort_funs.add (f) 899 if abort_funs: 900 trace ('Cannot analyse loops in: %s' % ', '.join (abort_funs)) 901 return loops 902 903def get_complex_loops (): 904 return [(loop, name) for (loop, name, compl) in get_all_loop_heads () 905 if compl] 906 907def search_all_loops (): 908 all_loops = get_all_loop_heads () 909 for (loop, _, _) in all_loops: 910 get_bound_super_ctxt (loop, []) 911 912main = search_all_loops 913 914if __name__ == '__main__': 915 import sys 916 args = target_objects.load_target_args () 917 if args == ['search']: 918 search_all_loops () 919 elif args[:1] == ['metrics']: 920 num = args[1:].index (str (target_objects.target_dir)) 921 save_timing_metrics (num) 922 923 924