1# * Copyright 2015, NICTA 2# * 3# * This software may be distributed and modified according to the terms of 4# * the BSD 2-Clause license. Note that NO WARRANTY is provided. 5# * See "LICENSE_BSD2.txt" for details. 6# * 7# * @TAG(NICTA_BSD) 8 9from target_objects import functions, pairings 10import target_objects 11from problem import Problem 12import problem 13import logic 14import syntax 15import solver 16import search 17import rep_graph 18import check 19 20import random 21 22def check_entry_var_deps (f): 23 if not f.entry: 24 return set () 25 p = f.as_problem (Problem) 26 diff = check_problem_entry_var_deps (p) 27 28 return diff 29 30def check_problem_entry_var_deps (p, var_deps = None): 31 if var_deps == None: 32 var_deps = p.compute_var_dependencies () 33 for (entry, tag, _, inputs) in p.entries: 34 if entry not in var_deps: 35 print 'Entry missing from var_deps: %d' % entry 36 continue 37 diff = set (var_deps[entry]) - set (inputs) 38 if diff: 39 print 'Vars deps escaped in %s in %s: %s' % (tag, 40 p.name, diff) 41 return diff 42 return set () 43 44def check_all_var_deps (): 45 return [f for f in functions if check_entry_var_deps(functions[f])] 46 47def walk_var_deps (p, n, v, var_deps = None, 48 interest = set (), symmetric = False): 49 if var_deps == None: 50 var_deps = p.compute_var_dependencies () 51 while True: 52 if n == 'Ret' or n == 'Err': 53 print n 54 return n 55 if symmetric: 56 opts = set ([n2 for n2 in p.preds[n] if n2 in p.nodes]) 57 else: 58 opts = set ([n2 for n2 in p.nodes[n].get_conts () 59 if n2 in p.nodes]) 60 choices = [n2 for n2 in opts if v in var_deps[n2]] 61 if not choices: 62 print 'Walk ends at %d.' % n 63 return 64 if len (choices) > 1: 65 print 'choices %s, gambling' % choices 66 random.shuffle (choices) 67 print ' ... rolled a %s' % choices[0] 68 elif len (opts) > 1: 69 print 'picked %s from %s' % (choices[0], opts) 70 n = choices[0] 71 if n in interest: 72 print '** %d' % n 73 else: 74 print n 75 76def diagram_var_deps (p, fname, v, var_deps = None): 77 if var_deps == None: 78 var_deps = p.compute_var_dependencies () 79 cols = {} 80 for n in p.nodes: 81 if n not in var_deps: 82 cols[n] = 'darkgrey' 83 elif v not in var_deps[n]: 84 cols[n] = 'darkblue' 85 else: 86 cols[n] = 'orange' 87 problem.save_graph (p.nodes, fname, cols = cols) 88 89def trace_model (rep, m, simplify = True): 90 p = rep.p 91 tags = set ([tag for (tag, n, vc) in rep.node_pc_env_order]) 92 if p.pairing and tags == set (p.pairing.tags): 93 tags = reversed (p.pairing.tags) 94 for tag in tags: 95 print "Walking %s in model" % tag 96 n_vcs = walk_model (rep, tag, m) 97 prev_era = None 98 for (i, (n, vc)) in enumerate (n_vcs): 99 era = n_vc_era (p, (n, vc)) 100 if era != prev_era: 101 print 'now in era %s' % era 102 prev_era = era 103 if n in ['Ret', 'Err']: 104 print 'ends at %s' % n 105 break 106 node = logic.simplify_node_elementary (p.nodes[n]) 107 if node.kind != 'Cond': 108 continue 109 name = rep.cond_name ((n, vc)) 110 cond = m[name] == syntax.true_term 111 print '%s: %s (%s, %s)' % (name, cond, 112 node.left, node.right) 113 investigate_cond (rep, m, name, simplify) 114 115def walk_model (rep, tag, m): 116 n_vcs = [(n, vc) for (tag2, n, vc) in rep.node_pc_env_order 117 if tag2 == tag 118 if search.eval_model_expr (m, rep.solv, 119 rep.get_pc ((n, vc), tag)) 120 == syntax.true_term] 121 122 n_vcs = era_sort (rep, n_vcs) 123 124 return n_vcs 125 126def investigate_cond (rep, m, cond, simplify = True, rec = True): 127 cond_def = rep.solv.defs[cond] 128 while rec and type (cond_def) == str and cond_def in rep.solv.defs: 129 cond_def = rep.solv.defs[cond_def] 130 def do_bit (bit): 131 if bit == 'true': 132 return True 133 valid = eval_model_bool (m, bit) 134 if simplify: 135 # looks a bit strange to do this now but some pointer 136 # lookups have to be done with unmodified s-exprs 137 bit = simplify_sexp (bit, rep, m, flatten = False) 138 print ' %s: %s' % (valid, solver.flat_s_expression (bit)) 139 return valid 140 while cond_def[0] == '=>': 141 valid = do_bit (cond_def[1]) 142 if not valid: 143 break 144 cond_def = cond_def[2] 145 bits = solver.split_hyp_sexpr (cond_def, []) 146 for bit in bits: 147 do_bit (bit) 148 149def eval_model_bool (m, x): 150 if hasattr (x, 'typ'): 151 x = solver.smt_expr (x, {}, None) 152 x = solver.parse_s_expression (x) 153 try: 154 r = search.eval_model (m, x) 155 assert r in [syntax.true_term, syntax.false_term], r 156 return r == syntax.true_term 157 except: 158 return 'EXCEPT' 159 160def funcall_name (rep): 161 return lambda n_vc: "%s @%s" % (rep.p.nodes[n_vc[0]].fname, 162 rep.node_count_name (n_vc)) 163 164def n_vc_era (p, (n, vc)): 165 era = 0 166 for (split, vcount) in vc: 167 if not p.loop_id (split): 168 continue 169 (ns, os) = vcount.get_opts () 170 if len (ns + os) > 1: 171 era += 3 172 elif ns: 173 era += 1 174 elif os: 175 era += 2 176 return era 177 178def era_merge (era): 179 # fold onramp to loops into pre-loop era 180 if era % 3 == 1: 181 era -= 1 182 return era 183 184def do_era_merge (do_merge, era): 185 if do_merge: 186 return era_merge (era) 187 else: 188 return era 189 190def era_sort (rep, n_vcs): 191 with_eras = [(n_vc_era (rep.p, n_vc), n_vc) for n_vc in n_vcs] 192 with_eras.sort (key = lambda x: x[0]) 193 for i in range (len (with_eras) - 1): 194 (e1, n_vc1) = with_eras[i] 195 (e2, n_vc2) = with_eras[i + 1] 196 if e1 != e2: 197 continue 198 if n_vc1[0] in ['Ret', 'Err']: 199 assert not 'Era issues', n_vcs 200 assert rep.is_cont (n_vc1, n_vc2), [n_vc1, n_vc2] 201 return [n_vc for (_, n_vc) in with_eras] 202 203def investigate_funcalls (rep, m, verbose = False, verbose_imp = False, 204 simplify = True, pairing = 'Args', era_merge = True): 205 l_tag, r_tag = rep.p.pairing.tags 206 l_ns = walk_model (rep, l_tag, m) 207 r_ns = walk_model (rep, r_tag, m) 208 nodes = rep.p.nodes 209 210 l_calls = [n_vc for n_vc in l_ns if n_vc in rep.funcs] 211 r_calls = [n_vc for n_vc in r_ns if n_vc in rep.funcs] 212 print '%s calls: %s' % (l_tag, map (funcall_name (rep), l_calls)) 213 print '%s calls: %s' % (r_tag, map (funcall_name (rep), r_calls)) 214 215 if pairing == 'Eras': 216 fc_pairs = pair_funcalls_by_era (rep, l_calls, r_calls, 217 era_m = era_merge) 218 elif pairing == 'Seq': 219 fc_pairs = pair_funcalls_sequential (rep, l_calls, r_calls) 220 elif pairing == 'Args': 221 fc_pairs = pair_funcalls_by_match (rep, m, l_calls, r_calls, 222 era_m = era_merge) 223 elif pairing == 'All': 224 fc_pairs = [(lc, rc) for lc in l_calls for rc in r_calls] 225 else: 226 assert pairing in ['Eras', 'Seq', 'Args', 'All'], pairing 227 228 for (l_n_vc, r_n_vc) in fc_pairs: 229 if not rep.get_func_pairing (l_n_vc, r_n_vc): 230 print 'call seq mismatch: (%s, %s)' % (l_n_vc, r_n_vc) 231 continue 232 investigate_funcall_pair (rep, m, l_n_vc, r_n_vc, 233 verbose, verbose_imp, simplify) 234 235def pair_funcalls_by_era (rep, l_calls, r_calls, era_m = True): 236 eras = set ([n_vc_era (rep.p, n_vc) for n_vc in l_calls + r_calls]) 237 eras = sorted (eras + set (map (era_merge, eras))) 238 pairs = [] 239 for era in eras: 240 ls = [n_vc for n_vc in l_calls 241 if do_era_merge (era_m, n_vc_era (rep.p, n_vc)) == era] 242 rs = [n_vc for n_vc in r_calls 243 if do_era_merge (era_m, n_vc_era (rep.p, n_vc)) == era] 244 if len (ls) != len (rs): 245 print 'call seq length mismatch in era %d:' % era 246 print map (funcall_name (rep), ls) 247 print map (funcall_name (rep), rs) 248 pairs.extend (zip (ls, rs)) 249 return pairs 250 251def pair_funcalls_sequential (rep, l_calls, r_calls): 252 if len (l_calls) != len (r_calls): 253 print 'call seq tail mismatch' 254 if len (l_calls) > len (r_calls): 255 print 'dropping lhs: %s' % map (funcall_name (rep), 256 l_calls[len (r_calls):]) 257 else: 258 print 'dropping rhs: %s' % map (funcall_name (rep), 259 r_calls[len (l_calls):]) 260 # really should add some smarts to this to 'recover' from upsets or 261 # reorders, but maybe not worth it. 262 return zip (l_calls, r_calls) 263 264def pair_funcalls_by_match (rep, m, l_calls, r_calls, era_m = True): 265 eras = set ([n_vc_era (rep.p, n_vc) for n_vc in l_calls + r_calls]) 266 eras = sorted (set.union (eras, set (map (era_merge, eras)))) 267 pairs = [] 268 for era in eras: 269 ls = [n_vc for n_vc in l_calls 270 if do_era_merge (era_m, n_vc_era (rep.p, n_vc)) == era] 271 rs = [n_vc for n_vc in r_calls 272 if do_era_merge (era_m, n_vc_era (rep.p, n_vc)) == era] 273 res = None 274 matches = [(1 - func_assert_premise_strength (rep, m, 275 n_vc, n_vc2), i, j) 276 for (i, n_vc) in enumerate (ls) 277 for (j, n_vc2) in enumerate (rs) 278 if rep.get_func_pairing (n_vc, n_vc2)] 279 matches.sort () 280 if not matches: 281 print 'Cannot match any (%d, %d) at era %d' % (len (ls), 282 len (rs), era) 283 continue 284 (_, i, j) = matches[0] 285 if i > j: 286 pairs.extend ((zip (ls[i - j:], rs))) 287 else: 288 pairs.extend ((zip (ls, rs[j - i:]))) 289 return pairs 290 291def func_assert_premise_strength (rep, m, l_n_vc, r_n_vc): 292 imp = rep.get_func_assert (l_n_vc, r_n_vc) 293 assert imp.is_op ('Implies'), imp 294 [pred, concl] = imp.vals 295 pred = solver.smt_expr (pred, {}, rep.solv) 296 pred = solver.parse_s_expression (pred) 297 bits = solver.split_hyp_sexpr (pred, []) 298 assert bits, bits 299 scores = [] 300 for bit in bits: 301 try: 302 res = eval_model_bool (m, bit) 303 if res: 304 scores.append (1.0) 305 else: 306 scores.append (0.0) 307 except solver.EnvMiss, e: 308 scores.append (0.5) 309 except AssertionError, e: 310 scores.append (0.5) 311 return sum (scores) / len (scores) 312 return all ([eval_model_bool (m, v) for v in bits]) 313 314def investigate_funcall_pair (rep, m, l_n_vc, r_n_vc, 315 verbose = False, verbose_imp = False, simplify = True): 316 317 l_nm = "%s @ %s" % (rep.p.nodes[l_n_vc[0]].fname, rep.node_count_name (l_n_vc)) 318 r_nm = "%s @ %s" % (rep.p.nodes[r_n_vc[0]].fname, rep.node_count_name (r_n_vc)) 319 print 'Attempt match %s -> %s' % (l_nm, r_nm) 320 imp = rep.get_func_assert (l_n_vc, r_n_vc) 321 imp = logic.weaken_assert (imp) 322 if verbose_imp: 323 imp2 = solver.smt_expr (imp, {}, rep.solv) 324 if simplify: 325 imp2 = simplify_sexp (imp2, rep, m) 326 print imp2 327 assert imp.is_op ('Implies'), imp 328 [pred, concl] = imp.vals 329 pred = solver.smt_expr (pred, {}, rep.solv) 330 pred = solver.parse_s_expression (pred) 331 bits = solver.split_hyp_sexpr (pred, []) 332 xs = [eval_model_bool (m, v) for v in bits] 333 print ' %s' % xs 334 for (v, bit) in zip (xs, bits): 335 if v != True or verbose: 336 print ' %s: %s' % (v, bit) 337 if bit[0] == 'word32-eq': 338 vs = [model_sx_word (m, x) 339 for x in bit[1:]] 340 print ' (%s = %s)' % tuple (vs) 341 342def model_sx_word (m, sx): 343 v = search.eval_model (m, sx) 344 x = expr_num (v) 345 return solver.smt_num_t (x, v.typ) 346 347def expr_num (expr): 348 assert expr.typ.kind == 'Word' 349 return expr.val & ((1 << expr.typ.num) - 1) 350 351def str_to_num (smt_str): 352 v = solver.smt_to_val(smt_str) 353 return expr_num (v) 354 355def m_var_name (expr): 356 while expr.is_op ('MemUpdate'): 357 [expr, p, v] = expr.vals 358 if expr.kind == 'Var': 359 return expr.name 360 elif expr.kind == 'Op': 361 return '<Op %s>' % op.name 362 else: 363 return '<Expr %s>' % expr.kind 364 365def eval_str (expr, env, solv, m): 366 expr = solver.to_smt_expr (expr, env, solv) 367 v = search.eval_model_expr (m, solv, expr) 368 if v.typ == syntax.boolT: 369 assert v in [syntax.true_term, syntax.false_term] 370 return v.name 371 elif v.typ.kind == 'Word': 372 return solver.smt_num_t (v.val, v.typ) 373 else: 374 assert not 'type printable', v 375 376def trace_mem (rep, tag, m, verbose = False, simplify = True, symbs = True, 377 resolve_addrs = False): 378 p = rep.p 379 ns = walk_model (rep, tag, m) 380 trace = [] 381 for (n, vc) in ns: 382 if (n, vc) not in rep.arc_pc_envs: 383 # this n_vc has a pre-state, but has not been emitted. 384 # no point trying to evaluate its expressions, the 385 # solve won't have seen them yet. 386 continue 387 n_nm = rep.node_count_name ((n, vc)) 388 node = p.nodes[n] 389 if node.kind == 'Call': 390 exprs = list (node.args) 391 elif node.kind == 'Basic': 392 exprs = [expr for (_, expr) in node.upds] 393 elif node.kind == 'Cond': 394 exprs = [node.cond] 395 env = rep.node_pc_envs[(tag, n, vc)][1] 396 accs = list (set ([acc for expr in exprs 397 for acc in expr.get_mem_accesses ()])) 398 for (kind, addr, v, mem) in accs: 399 addr_s = solver.smt_expr (addr, env, rep.solv) 400 v_s = solver.smt_expr (v, env, rep.solv) 401 addr = eval_str (addr, env, rep.solv, m) 402 v = eval_str (v, env, rep.solv, m) 403 m_nm = m_var_name (mem) 404 print '%s: %s @ <%s> -- %s -- %s' % (kind, m_nm, addr, v, n_nm) 405 if simplify: 406 addr_s = simplify_sexp (addr_s, rep, m) 407 v_s = simplify_sexp (v_s, rep, m) 408 if verbose: 409 print '\t %s -- %s' % (addr_s, v_s) 410 if symbs: 411 addr_n = str_to_num (addr) 412 (hit_symbs, secs) = find_symbol (addr_n, output = False) 413 ss = hit_symbs + secs 414 if ss: 415 print '\t [%s]' % ', '.join (ss) 416 if resolve_addrs: 417 accs = [(kind, solver.to_smt_expr (addr, env, rep.solv), 418 solver.to_smt_expr (v, env, rep.solv), mem) 419 for (kind, addr, v, mem) in accs] 420 trace.extend ([(kind, addr, v, mem, n, vc) 421 for (kind, addr, v, mem) in accs]) 422 if node.kind == 'Call': 423 msg = '<function call to %s at %s>' % (node.fname, n_nm) 424 print msg 425 trace.append (msg) 426 return trace 427 428def simplify_sexp (smt_xp, rep, m, flatten = True): 429 if type (smt_xp) == str: 430 smt_xp = solver.parse_s_expression (smt_xp) 431 if smt_xp[0] == 'ite': 432 (_, c, x, y) = smt_xp 433 if eval_model_bool (m, c): 434 return simplify_sexp (x, rep, m, flatten) 435 else: 436 return simplify_sexp (y, rep, m, flatten) 437 if type (smt_xp) == tuple: 438 smt_xp = tuple ([simplify_sexp (x, rep, m, False) 439 for x in smt_xp]) 440 if flatten: 441 return solver.flat_s_expression (smt_xp) 442 else: 443 return smt_xp 444 445def trace_mems (rep, m, verbose = False, symbs = True, tags = None): 446 if tags == None: 447 if rep.p.pairing: 448 tags = reversed (rep.p.pairing.tags) 449 else: 450 tags = rep.p.tags () 451 for tag in tags: 452 print '%s mem trace:' % tag 453 trace_mem (rep, tag, m, verbose = verbose, symbs = symbs) 454 455def trace_mems_diff (rep, m, tags = ['ASM', 'C']): 456 asms = trace_mem (rep, tags[0], m, resolve_addrs = True) 457 cs = trace_mem (rep, tags[1], m, resolve_addrs = True) 458 ev = lambda expr: eval_str (expr, {}, None, m) 459 c_upds = [(ev (addr), ev (v)) for (kind, addr, v, mem, _, _) in cs 460 if kind == 'MemUpdate'] 461 asm_upds = [(ev (addr), ev (v)) for (kind, addr, v, mem, _, _) in asms 462 if kind == 'MemUpdate' and 'mem' in m_var_name (mem)] 463 c_upd_d = dict (c_upds) 464 asm_upd_d = dict (asm_upds) 465 addr_ord = [addr for (addr, _) in asm_upds] + [addr for (addr, _) in c_upds 466 if addr not in asm_upd_d] 467 mism = [addr for addr in addr_ord 468 if c_upd_d.get (addr) != asm_upd_d.get (addr)] 469 return (c_upd_d == asm_upd_d, mism, c_upds, asm_upds) 470 471def get_pv_type (pv): 472 assert pv.is_op (['PValid', 'PArrayValid']) 473 typ_v = pv.vals[1] 474 assert typ_v.kind == 'Type' 475 typ = typ_v.val 476 if pv.is_op ('PArrayValid'): 477 return ('PArrayValid', typ, pv.vals[3]) 478 else: 479 return ('PValid', typ, None) 480 481def guess_pv (p, n, addr_expr): 482 vs = syntax.get_expr_var_set (addr_expr) 483 [pred] = p.preds[n] 484 pvs = [] 485 def vis (expr): 486 if expr.is_op (['PValid', 'PArrayValid']): 487 pvs.append (expr) 488 p.nodes[pred].cond.visit (vis) 489 match_pvs = [pv for pv in pvs 490 if set.union (* [syntax.get_expr_var_set (v) for v in pv.vals[2:]]) 491 == vs] 492 if len (match_pvs) > 1: 493 match_pvs = [pv for pv in match_pvs if pv.is_op ('PArrayValid')] 494 pv = match_pvs[0] 495 return pv 496 497def eval_pv_type (rep, (n, vc), m, data): 498 if data[0] == 'PValid': 499 return data 500 else: 501 (nm, typ, offs) = data 502 offs = rep.to_smt_expr (offs, (n, vc)) 503 offs = search.eval_model_expr (m, rep.solv, offs) 504 return (nm, typ, offs) 505 506def trace_suspicious_mem (rep, m, tag = 'C'): 507 cs = trace_mem (rep, tag, m) 508 data = [(addr, search.eval_model_expr (m, rep.solv, 509 rep.to_smt_expr (addr, (n, vc))), (n, vc)) 510 for (kind, addr, v, mem, n, vc) in cs] 511 addr_sets = {} 512 for (addr, addr_v, _) in data: 513 addr_sets.setdefault (addr_v, set ()) 514 addr_sets[addr_v].add (addr) 515 dup_addrs = set ([addr_v for addr_v in addr_sets 516 if len (addr_sets[addr_v]) > 1]) 517 data = [(addr, addr_v, guess_pv (rep.p, n, addr), (n, vc)) 518 for (addr, addr_v, (n, vc)) in data 519 if addr_v in dup_addrs] 520 data = [(addr, addr_v, eval_pv_type (rep, (n, vc), m, 521 get_pv_type (pv)), rep.to_smt_expr (pv, (n, vc)), n) 522 for (addr, addr_v, pv, (n, vc)) in data] 523 dup_addr_types = set ([addr_v for addr_v in dup_addrs 524 if len (set ([t for (_, addr_v2, t, _, _) in data 525 if addr_v2 == addr_v])) > 1]) 526 res = [(addr_v, [(t, pv, n) for (_, addr_v2, t, pv, n) in data 527 if addr_v2 == addr_v]) 528 for addr_v in dup_addr_types] 529 for (addr_v, insts) in res: 530 print 'Address %s' % addr_v 531 for (t, pv, n) in insts: 532 print ' -- accessed with type %s at %s' % (t, n) 533 print ' (covered by %s)' % pv 534 return res 535 536def trace_var (rep, tag, m, v): 537 p = rep.p 538 ns = walk_model (rep, tag, m) 539 vds = rep.p.compute_var_dependencies () 540 trace = [] 541 vs = syntax.get_expr_var_set (v) 542 def fetch ((n, vc)): 543 if n in vds and [(nm, typ) for (nm, typ) in vs 544 if (nm, typ) not in vds[n]]: 545 return None 546 try: 547 (_, env) = rep.get_node_pc_env ((n, vc), tag) 548 s = solver.smt_expr (v, env, rep.solv) 549 s_x = solver.parse_s_expression (s) 550 ev = search.eval_model (m, s_x) 551 return (s, solver.smt_expr (ev, {}, None)) 552 except solver.EnvMiss, e: 553 return None 554 except AssertionError, e: 555 return None 556 val = None 557 for (n, vc) in ns: 558 n_nm = rep.node_count_name ((n, vc)) 559 val2 = fetch ((n, vc)) 560 if val2 != val: 561 if val2 == None: 562 print 'at %s: undefined' % n_nm 563 else: 564 print 'at %s:\t\t%s:\t\t%s' % (n_nm, 565 val2[0], val2[1]) 566 val = val2 567 trace.append (((n, vc), val)) 568 if n not in p.nodes: 569 break 570 node = p.nodes[n] 571 if node.kind == 'Call': 572 msg = '<function call to %s at %s>' % (node.fname, 573 rep.node_count_name ((n, vc))) 574 print msg 575 trace.append (msg) 576 return trace 577 578def trace_deriv_ops (rep, m, tag): 579 n_vcs = walk_model (rep, tag, m) 580 derivs = set (('CountTrailingZeroes', 'CountLeadingZeroes', 581 'WordReverse')) 582 def get_derivs (node): 583 dvs = set () 584 def visit (expr): 585 if expr.is_op (derivs): 586 dvs.add (expr) 587 node.visit (lambda x: (), visit) 588 return dvs 589 for (n, vc) in n_vcs: 590 if n not in rep.p.nodes: 591 continue 592 dvs = get_derivs (rep.p.nodes[n]) 593 if not dvs: 594 continue 595 print '%s:' % (rep.node_count_name ((n, vc))) 596 for dv in dvs: 597 [x] = dv.vals 598 x = rep.to_smt_expr (x, (n, vc)) 599 x = eval_str (x, {}, rep.solv, m) 600 print '\t%s: %s' % (dv.name, x) 601 602def check_pairings (): 603 for p in pairings.itervalues (): 604 print p['C'], p['ASM'] 605 as_args = functions[p['ASM']].inputs 606 c_args = functions[p['C']].inputs 607 print as_args, c_args 608 logic.mk_fun_inp_eqs (as_args, c_args, True) 609 610def loop_var_deps (p): 611 return [(n, [v for v in p.var_deps[n] 612 if p.var_deps[n][v] == 'LoopVariable']) 613 for n in p.loop_data] 614 615def find_symbol (n, output = True): 616 from target_objects import symbols, sections 617 symbs = [] 618 secs = [] 619 if output: 620 def p (s): 621 print s 622 else: 623 p = lambda s: () 624 for (s, (addr, size, _)) in symbols.iteritems (): 625 if addr <= n and n < addr + size: 626 symbs.append (s) 627 p ('%x in %s (%x - %x)' % (n, s, addr, addr + size - 1)) 628 for (s, (start, end)) in sections.iteritems (): 629 if start <= n and n <= end: 630 secs.append (s) 631 p ('%x in section %s (%x - %x)' % (n, s, start, end)) 632 return (symbs, secs) 633 634def assembly_point (p, n): 635 (_, hints) = p.node_tags[n] 636 if type (hints) != tuple or not logic.is_int (hints[1]): 637 return None 638 while p.node_tags[n][1][1] % 4 != 0: 639 [n] = p.preds[n] 640 return p.node_tags[n][1][1] 641 642def assembly_points (p, ns): 643 ns = [assembly_point (p, n) for n in ns] 644 ns = [n for n in ns if n != None] 645 return ns 646 647def disassembly_lines (addrs): 648 f = open ('%s/kernel.elf.txt' % target_objects.target_dir) 649 addr_set = set (['%x' % addr for addr in addrs]) 650 ss = [l.strip () 651 for l in f if ':' in l and l.split(':', 1)[0] in addr_set] 652 return ss 653 654def disassembly (p, n): 655 if hasattr (n, '__iter__'): 656 ns = set (n) 657 else: 658 ns = [n] 659 addrs = sorted (set ([assembly_point (p, n) for n in ns]) 660 - set ([None])) 661 print 'asm %s' % ', '.join (['0x%x' % addr for addr in addrs]) 662 for s in disassembly_lines (addrs): 663 print s 664 665def disassembly_loop (p, n): 666 head = p.loop_id (n) 667 loop = p.loop_body (n) 668 ns = sorted (set (assembly_points (p, loop))) 669 entries = assembly_points (p, [n for n in p.preds[head] 670 if n not in loop]) 671 print 'Loop: [%s]' % ', '.join (['%x' % addr for addr in ns]) 672 for s in disassembly_lines (ns): 673 print s 674 print 'entry from %s' % ', '.join (['%x' % addr for addr in entries]) 675 for s in disassembly_lines (entries): 676 print s 677 678def try_interpret_hyp (rep, hyp): 679 try: 680 expr = rep.interpret_hyp (hyp) 681 solver.smt_expr (expr, {}, rep.solv) 682 return None 683 except: 684 return ('Broken Hyp', hyp) 685 686def check_checks (): 687 p = problem.last_problem[0] 688 rep = rep_graph.mk_graph_slice (p) 689 proof = search.last_proof[0] 690 checks = check.proof_checks (p, proof) 691 all_hyps = set ([hyp for (_, hyp, _) in checks] 692 + [hyp for (hyps, _, _) in checks for hyp in hyps]) 693 results = [try_interpret_hyp (rep, hyp) for hyp in all_hyps] 694 return [r[1] for r in results if r] 695 696def proof_failed_groups (p = None, proof = None): 697 if p == None: 698 p = problem.last_problem[0] 699 if proof == None: 700 proof = search.last_proof[0] 701 checks = check.proof_checks (p, proof) 702 groups = check.proof_check_groups (checks) 703 failed = [] 704 for group in groups: 705 rep = rep_graph.mk_graph_slice (p) 706 (res, el) = check.test_hyp_group (rep, group) 707 if not res: 708 failed.append (group) 709 print 'Failed element: %s' % el 710 failed_nms = set ([s for group in failed for (_, _, s) in group]) 711 print 'Failed: %s' % failed_nms 712 return failed 713 714def read_summary (f): 715 results = {} 716 times = {} 717 for line in f: 718 if not line.startswith ('Time taken to'): 719 continue 720 bits = line.split () 721 assert bits[:4] == ['Time', 'taken', 'to', 'check'] 722 res = bits[4] 723 [ref] = [i for (i, b) in enumerate (bits) if b == '<='] 724 f = bits[ref + 1] 725 [pair] = [pair for pair in pairings[f] 726 if pair.name in line] 727 time = float (bits[-1]) 728 results[pair] = res 729 times[pair] = time 730 return (results, times) 731 732def unfold_defs_sexpr (defs, sexpr, depthlimit = -1): 733 if type (sexpr) == str: 734 sexpr = defs.get (sexpr, sexpr) 735 print sexpr 736 return sexpr 737 elif depthlimit == 0: 738 return sexpr 739 return tuple ([sexpr[0]] + [unfold_defs_sexpr (defs, s, depthlimit - 1) 740 for s in sexpr[1:]]) 741 742def unfold_defs (defs, hyp, depthlimit = -1): 743 return solver.flat_s_expression (unfold_defs_sexpr (defs, 744 solver.parse_s_expression (hyp), depthlimit)) 745 746def investigate_unsat (solv, hyps = None): 747 if hyps == None: 748 hyps = list (solver.last_hyps[0]) 749 assert solv.hyps_sat_raw (hyps) == 'unsat', hyps 750 kept_hyps = [] 751 while hyps: 752 h = hyps.pop () 753 if solv.hyps_sat_raw (hyps + kept_hyps) != 'unsat': 754 kept_hyps.append (h) 755 assert solv.hyps_sat_raw (kept_hyps) == 'unsat', kept_hyps 756 split_hyps = sorted (set ([(hyp2, tag) for (hyp, tag) in kept_hyps 757 for hyp2 in solver.split_hyp (hyp)])) 758 if len (split_hyps) > len (kept_hyps): 759 return investigate_unsat (solv, split_hyps) 760 def_hyps = [(unfold_defs (solv.defs, h, 2), tag) 761 for (h, tag) in kept_hyps] 762 if def_hyps != kept_hyps: 763 return investigate_unsat (solv, def_hyps) 764 return kept_hyps 765 766def test_interesting_linear_series_exprs (): 767 pairs = set ([pair for f in pairings for pair in pairings[f]]) 768 notes = {} 769 for pair in pairs: 770 p = check.build_problem (pair) 771 for n in search.init_loops_to_split (p, ()): 772 intr = logic.interesting_linear_series_exprs (p, n, 773 search.get_loop_var_analysis_at (p, n)) 774 if intr: 775 notes[pair.name] = True 776 if 'Call' in str (intr): 777 notes[pair.name] = 'Call!' 778 return notes 779 780def var_analysis (p, n): 781 va = search.get_loop_var_analysis_at (p, n) 782 cats = {} 783 for (v, kind) in va: 784 if kind[0] == 'LoopLinearSeries': 785 offs = kind[2] 786 kind = kind[0] 787 else: 788 offs = None 789 cats.setdefault (kind, []) 790 cats[kind].append ((v, offs)) 791 for kind in cats: 792 print '%s:' % kind 793 for (v, offs) in cats[kind]: 794 print ' %s (%s)' % (syntax.pretty_expr (v), 795 syntax.pretty_type (v.typ)) 796 if offs: 797 print ' ++ %s' % syntax.pretty_expr (offs) 798 799def var_value_sites (rep, v): 800 if type (v) == str: 801 matches = lambda (nm, _): v in nm 802 elif type (v) == tuple: 803 matches = lambda (nm, typ): v == (nm, typ) 804 v_ord = [] 805 d = {} 806 for (tag, n, vc) in rep.node_pc_env_order: 807 (pc, env) = rep.get_node_pc_env ((n, vc), tag = tag) 808 for (v2, smt_exp) in env.iteritems (): 809 if matches (v2): 810 if smt_exp not in d: 811 v_ord.append (smt_exp) 812 d[smt_exp] = [] 813 d[smt_exp].append ((n, vc)) 814 for smt_exp in v_ord: 815 print smt_exp 816 if smt_exp in rep.solv.defs: 817 print (' = %s' % repr (rep.solv.defs[smt_exp])) 818 print (' - at: %s' % d[smt_exp]) 819 if v_ord: 820 print ('') 821 return (v_ord, d) 822 823def loop_num_leaves (p, n): 824 for n in p.loop_body (n): 825 va = search.get_loop_var_analysis_at (p, n) 826 n_leaf = len ([1 for (v, kind) in va if kind == 'LoopLeaf']) 827 print (n, n_leaf) 828 829def try_pairing_at_funcall (p, name, head = None, restrs = None, hyps = None, 830 at = 'At'): 831 pairs = set (pairings[name]) 832 addrs = [n for (n, name2) in p.function_call_addrs () 833 if [pair for pair in pairings[name2] if pair in pairs]] 834 assert at in ['At', 'After'] 835 if at == 'After': 836 addrs = [p.nodes[n].cont for n in addrs] 837 if head == None: 838 tags = p.pairing.tags 839 [head] = [n for n in search.init_loops_to_split (p, ()) 840 if p.node_tags[n][0] == tags[0]] 841 if restrs == None: 842 restrs = () 843 if hyps == None: 844 hyps = check.init_point_hyps (p) 845 while True: 846 res = search.find_split_loop (p, head, restrs, hyps, 847 node_restrs = set (addrs)) 848 if res[0] == 'CaseSplit': 849 (_, ((n, tag), _)) = res 850 hyp = rep_graph.pc_true_hyp (((n, restrs), tag)) 851 hyps = hyps + [hyp] 852 else: 853 return res 854 855def init_true_hyp (p, tag, expr): 856 n = p.get_entry (tag) 857 vis = ((n, ()), tag) 858 assert expr.typ == syntax.boolT, expr 859 return rep_graph.eq_hyp ((expr, vis), (syntax.true_term, vis)) 860 861def smt_print (expr): 862 env = {} 863 while True: 864 try: 865 return solver.smt_expr (expr, env, None) 866 except solver.EnvMiss, e: 867 env[(e.name, e.typ)] = e.name 868 869