1# 2# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230) 3# 4# SPDX-License-Identifier: BSD-2-Clause 5# 6 7from target_objects import functions, pairings 8import target_objects 9from problem import Problem 10import problem 11import logic 12import syntax 13import solver 14import search 15import rep_graph 16import check 17 18import random 19 20def check_entry_var_deps (f): 21 if not f.entry: 22 return set () 23 p = f.as_problem (Problem) 24 diff = check_problem_entry_var_deps (p) 25 26 return diff 27 28def check_problem_entry_var_deps (p, var_deps = None): 29 if var_deps == None: 30 var_deps = p.compute_var_dependencies () 31 for (entry, tag, _, inputs) in p.entries: 32 if entry not in var_deps: 33 print 'Entry missing from var_deps: %d' % entry 34 continue 35 diff = set (var_deps[entry]) - set (inputs) 36 if diff: 37 print 'Vars deps escaped in %s in %s: %s' % (tag, 38 p.name, diff) 39 return diff 40 return set () 41 42def check_all_var_deps (): 43 return [f for f in functions if check_entry_var_deps(functions[f])] 44 45def walk_var_deps (p, n, v, var_deps = None, 46 interest = set (), symmetric = False): 47 if var_deps == None: 48 var_deps = p.compute_var_dependencies () 49 while True: 50 if n == 'Ret' or n == 'Err': 51 print n 52 return n 53 if symmetric: 54 opts = set ([n2 for n2 in p.preds[n] if n2 in p.nodes]) 55 else: 56 opts = set ([n2 for n2 in p.nodes[n].get_conts () 57 if n2 in p.nodes]) 58 choices = [n2 for n2 in opts if v in var_deps[n2]] 59 if not choices: 60 print 'Walk ends at %d.' % n 61 return 62 if len (choices) > 1: 63 print 'choices %s, gambling' % choices 64 random.shuffle (choices) 65 print ' ... rolled a %s' % choices[0] 66 elif len (opts) > 1: 67 print 'picked %s from %s' % (choices[0], opts) 68 n = choices[0] 69 if n in interest: 70 print '** %d' % n 71 else: 72 print n 73 74def diagram_var_deps (p, fname, v, var_deps = None): 75 if var_deps == None: 76 var_deps = p.compute_var_dependencies () 77 cols = {} 78 for n in p.nodes: 79 if n not in var_deps: 80 cols[n] = 'darkgrey' 81 elif v not in var_deps[n]: 82 cols[n] = 'darkblue' 83 else: 84 cols[n] = 'orange' 85 problem.save_graph (p.nodes, fname, cols = cols) 86 87def trace_model (rep, m, simplify = True): 88 p = rep.p 89 tags = set ([tag for (tag, n, vc) in rep.node_pc_env_order]) 90 if p.pairing and tags == set (p.pairing.tags): 91 tags = reversed (p.pairing.tags) 92 for tag in tags: 93 print "Walking %s in model" % tag 94 n_vcs = walk_model (rep, tag, m) 95 prev_era = None 96 for (i, (n, vc)) in enumerate (n_vcs): 97 era = n_vc_era (p, (n, vc)) 98 if era != prev_era: 99 print 'now in era %s' % era 100 prev_era = era 101 if n in ['Ret', 'Err']: 102 print 'ends at %s' % n 103 break 104 node = logic.simplify_node_elementary (p.nodes[n]) 105 if node.kind != 'Cond': 106 continue 107 name = rep.cond_name ((n, vc)) 108 cond = m[name] == syntax.true_term 109 print '%s: %s (%s, %s)' % (name, cond, 110 node.left, node.right) 111 investigate_cond (rep, m, name, simplify) 112 113def walk_model (rep, tag, m): 114 n_vcs = [(n, vc) for (tag2, n, vc) in rep.node_pc_env_order 115 if tag2 == tag 116 if search.eval_model_expr (m, rep.solv, 117 rep.get_pc ((n, vc), tag)) 118 == syntax.true_term] 119 120 n_vcs = era_sort (rep, n_vcs) 121 122 return n_vcs 123 124def investigate_cond (rep, m, cond, simplify = True, rec = True): 125 cond_def = rep.solv.defs[cond] 126 while rec and type (cond_def) == str and cond_def in rep.solv.defs: 127 cond_def = rep.solv.defs[cond_def] 128 def do_bit (bit): 129 if bit == 'true': 130 return True 131 valid = eval_model_bool (m, bit) 132 if simplify: 133 # looks a bit strange to do this now but some pointer 134 # lookups have to be done with unmodified s-exprs 135 bit = simplify_sexp (bit, rep, m, flatten = False) 136 print ' %s: %s' % (valid, solver.flat_s_expression (bit)) 137 return valid 138 while cond_def[0] == '=>': 139 valid = do_bit (cond_def[1]) 140 if not valid: 141 break 142 cond_def = cond_def[2] 143 bits = solver.split_hyp_sexpr (cond_def, []) 144 for bit in bits: 145 do_bit (bit) 146 147def eval_model_bool (m, x): 148 if hasattr (x, 'typ'): 149 x = solver.smt_expr (x, {}, None) 150 x = solver.parse_s_expression (x) 151 try: 152 r = search.eval_model (m, x) 153 assert r in [syntax.true_term, syntax.false_term], r 154 return r == syntax.true_term 155 except: 156 return 'EXCEPT' 157 158def funcall_name (rep): 159 return lambda n_vc: "%s @%s" % (rep.p.nodes[n_vc[0]].fname, 160 rep.node_count_name (n_vc)) 161 162def n_vc_era (p, (n, vc)): 163 era = 0 164 for (split, vcount) in vc: 165 if not p.loop_id (split): 166 continue 167 (ns, os) = vcount.get_opts () 168 if len (ns + os) > 1: 169 era += 3 170 elif ns: 171 era += 1 172 elif os: 173 era += 2 174 return era 175 176def era_merge (era): 177 # fold onramp to loops into pre-loop era 178 if era % 3 == 1: 179 era -= 1 180 return era 181 182def do_era_merge (do_merge, era): 183 if do_merge: 184 return era_merge (era) 185 else: 186 return era 187 188def era_sort (rep, n_vcs): 189 with_eras = [(n_vc_era (rep.p, n_vc), n_vc) for n_vc in n_vcs] 190 with_eras.sort (key = lambda x: x[0]) 191 for i in range (len (with_eras) - 1): 192 (e1, n_vc1) = with_eras[i] 193 (e2, n_vc2) = with_eras[i + 1] 194 if e1 != e2: 195 continue 196 if n_vc1[0] in ['Ret', 'Err']: 197 assert not 'Era issues', n_vcs 198 assert rep.is_cont (n_vc1, n_vc2), [n_vc1, n_vc2] 199 return [n_vc for (_, n_vc) in with_eras] 200 201def investigate_funcalls (rep, m, verbose = False, verbose_imp = False, 202 simplify = True, pairing = 'Args', era_merge = True): 203 l_tag, r_tag = rep.p.pairing.tags 204 l_ns = walk_model (rep, l_tag, m) 205 r_ns = walk_model (rep, r_tag, m) 206 nodes = rep.p.nodes 207 208 l_calls = [n_vc for n_vc in l_ns if n_vc in rep.funcs] 209 r_calls = [n_vc for n_vc in r_ns if n_vc in rep.funcs] 210 print '%s calls: %s' % (l_tag, map (funcall_name (rep), l_calls)) 211 print '%s calls: %s' % (r_tag, map (funcall_name (rep), r_calls)) 212 213 if pairing == 'Eras': 214 fc_pairs = pair_funcalls_by_era (rep, l_calls, r_calls, 215 era_m = era_merge) 216 elif pairing == 'Seq': 217 fc_pairs = pair_funcalls_sequential (rep, l_calls, r_calls) 218 elif pairing == 'Args': 219 fc_pairs = pair_funcalls_by_match (rep, m, l_calls, r_calls, 220 era_m = era_merge) 221 elif pairing == 'All': 222 fc_pairs = [(lc, rc) for lc in l_calls for rc in r_calls] 223 else: 224 assert pairing in ['Eras', 'Seq', 'Args', 'All'], pairing 225 226 for (l_n_vc, r_n_vc) in fc_pairs: 227 if not rep.get_func_pairing (l_n_vc, r_n_vc): 228 print 'call seq mismatch: (%s, %s)' % (l_n_vc, r_n_vc) 229 continue 230 investigate_funcall_pair (rep, m, l_n_vc, r_n_vc, 231 verbose, verbose_imp, simplify) 232 233def pair_funcalls_by_era (rep, l_calls, r_calls, era_m = True): 234 eras = set ([n_vc_era (rep.p, n_vc) for n_vc in l_calls + r_calls]) 235 eras = sorted (eras + set (map (era_merge, eras))) 236 pairs = [] 237 for era in eras: 238 ls = [n_vc for n_vc in l_calls 239 if do_era_merge (era_m, n_vc_era (rep.p, n_vc)) == era] 240 rs = [n_vc for n_vc in r_calls 241 if do_era_merge (era_m, n_vc_era (rep.p, n_vc)) == era] 242 if len (ls) != len (rs): 243 print 'call seq length mismatch in era %d:' % era 244 print map (funcall_name (rep), ls) 245 print map (funcall_name (rep), rs) 246 pairs.extend (zip (ls, rs)) 247 return pairs 248 249def pair_funcalls_sequential (rep, l_calls, r_calls): 250 if len (l_calls) != len (r_calls): 251 print 'call seq tail mismatch' 252 if len (l_calls) > len (r_calls): 253 print 'dropping lhs: %s' % map (funcall_name (rep), 254 l_calls[len (r_calls):]) 255 else: 256 print 'dropping rhs: %s' % map (funcall_name (rep), 257 r_calls[len (l_calls):]) 258 # really should add some smarts to this to 'recover' from upsets or 259 # reorders, but maybe not worth it. 260 return zip (l_calls, r_calls) 261 262def pair_funcalls_by_match (rep, m, l_calls, r_calls, era_m = True): 263 eras = set ([n_vc_era (rep.p, n_vc) for n_vc in l_calls + r_calls]) 264 eras = sorted (set.union (eras, set (map (era_merge, eras)))) 265 pairs = [] 266 for era in eras: 267 ls = [n_vc for n_vc in l_calls 268 if do_era_merge (era_m, n_vc_era (rep.p, n_vc)) == era] 269 rs = [n_vc for n_vc in r_calls 270 if do_era_merge (era_m, n_vc_era (rep.p, n_vc)) == era] 271 res = None 272 matches = [(1 - func_assert_premise_strength (rep, m, 273 n_vc, n_vc2), i, j) 274 for (i, n_vc) in enumerate (ls) 275 for (j, n_vc2) in enumerate (rs) 276 if rep.get_func_pairing (n_vc, n_vc2)] 277 matches.sort () 278 if not matches: 279 print 'Cannot match any (%d, %d) at era %d' % (len (ls), 280 len (rs), era) 281 continue 282 (_, i, j) = matches[0] 283 if i > j: 284 pairs.extend ((zip (ls[i - j:], rs))) 285 else: 286 pairs.extend ((zip (ls, rs[j - i:]))) 287 return pairs 288 289def func_assert_premise_strength (rep, m, l_n_vc, r_n_vc): 290 imp = rep.get_func_assert (l_n_vc, r_n_vc) 291 assert imp.is_op ('Implies'), imp 292 [pred, concl] = imp.vals 293 pred = solver.smt_expr (pred, {}, rep.solv) 294 pred = solver.parse_s_expression (pred) 295 bits = solver.split_hyp_sexpr (pred, []) 296 assert bits, bits 297 scores = [] 298 for bit in bits: 299 try: 300 res = eval_model_bool (m, bit) 301 if res: 302 scores.append (1.0) 303 else: 304 scores.append (0.0) 305 except solver.EnvMiss, e: 306 scores.append (0.5) 307 except AssertionError, e: 308 scores.append (0.5) 309 return sum (scores) / len (scores) 310 return all ([eval_model_bool (m, v) for v in bits]) 311 312def investigate_funcall_pair (rep, m, l_n_vc, r_n_vc, 313 verbose = False, verbose_imp = False, simplify = True): 314 315 l_nm = "%s @ %s" % (rep.p.nodes[l_n_vc[0]].fname, rep.node_count_name (l_n_vc)) 316 r_nm = "%s @ %s" % (rep.p.nodes[r_n_vc[0]].fname, rep.node_count_name (r_n_vc)) 317 print 'Attempt match %s -> %s' % (l_nm, r_nm) 318 imp = rep.get_func_assert (l_n_vc, r_n_vc) 319 imp = logic.weaken_assert (imp) 320 if verbose_imp: 321 imp2 = solver.smt_expr (imp, {}, rep.solv) 322 if simplify: 323 imp2 = simplify_sexp (imp2, rep, m) 324 print imp2 325 assert imp.is_op ('Implies'), imp 326 [pred, concl] = imp.vals 327 pred = solver.smt_expr (pred, {}, rep.solv) 328 pred = solver.parse_s_expression (pred) 329 bits = solver.split_hyp_sexpr (pred, []) 330 xs = [eval_model_bool (m, v) for v in bits] 331 print ' %s' % xs 332 for (v, bit) in zip (xs, bits): 333 if v != True or verbose: 334 print ' %s: %s' % (v, bit) 335 if bit[0] == 'word32-eq': 336 vs = [model_sx_word (m, x) 337 for x in bit[1:]] 338 print ' (%s = %s)' % tuple (vs) 339 340def model_sx_word (m, sx): 341 v = search.eval_model (m, sx) 342 x = expr_num (v) 343 return solver.smt_num_t (x, v.typ) 344 345def expr_num (expr): 346 assert expr.typ.kind == 'Word' 347 return expr.val & ((1 << expr.typ.num) - 1) 348 349def str_to_num (smt_str): 350 v = solver.smt_to_val(smt_str) 351 return expr_num (v) 352 353def m_var_name (expr): 354 while expr.is_op ('MemUpdate'): 355 [expr, p, v] = expr.vals 356 if expr.kind == 'Var': 357 return expr.name 358 elif expr.kind == 'Op': 359 return '<Op %s>' % op.name 360 else: 361 return '<Expr %s>' % expr.kind 362 363def eval_str (expr, env, solv, m): 364 expr = solver.to_smt_expr (expr, env, solv) 365 v = search.eval_model_expr (m, solv, expr) 366 if v.typ == syntax.boolT: 367 assert v in [syntax.true_term, syntax.false_term] 368 return v.name 369 elif v.typ.kind == 'Word': 370 return solver.smt_num_t (v.val, v.typ) 371 else: 372 assert not 'type printable', v 373 374def trace_mem (rep, tag, m, verbose = False, simplify = True, symbs = True, 375 resolve_addrs = False): 376 p = rep.p 377 ns = walk_model (rep, tag, m) 378 trace = [] 379 for (n, vc) in ns: 380 if (n, vc) not in rep.arc_pc_envs: 381 # this n_vc has a pre-state, but has not been emitted. 382 # no point trying to evaluate its expressions, the 383 # solve won't have seen them yet. 384 continue 385 n_nm = rep.node_count_name ((n, vc)) 386 node = p.nodes[n] 387 if node.kind == 'Call': 388 exprs = list (node.args) 389 elif node.kind == 'Basic': 390 exprs = [expr for (_, expr) in node.upds] 391 elif node.kind == 'Cond': 392 exprs = [node.cond] 393 env = rep.node_pc_envs[(tag, n, vc)][1] 394 accs = list (set ([acc for expr in exprs 395 for acc in expr.get_mem_accesses ()])) 396 for (kind, addr, v, mem) in accs: 397 addr_s = solver.smt_expr (addr, env, rep.solv) 398 v_s = solver.smt_expr (v, env, rep.solv) 399 addr = eval_str (addr, env, rep.solv, m) 400 v = eval_str (v, env, rep.solv, m) 401 m_nm = m_var_name (mem) 402 print '%s: %s @ <%s> -- %s -- %s' % (kind, m_nm, addr, v, n_nm) 403 if simplify: 404 addr_s = simplify_sexp (addr_s, rep, m) 405 v_s = simplify_sexp (v_s, rep, m) 406 if verbose: 407 print '\t %s -- %s' % (addr_s, v_s) 408 if symbs: 409 addr_n = str_to_num (addr) 410 (hit_symbs, secs) = find_symbol (addr_n, output = False) 411 ss = hit_symbs + secs 412 if ss: 413 print '\t [%s]' % ', '.join (ss) 414 if resolve_addrs: 415 accs = [(kind, solver.to_smt_expr (addr, env, rep.solv), 416 solver.to_smt_expr (v, env, rep.solv), mem) 417 for (kind, addr, v, mem) in accs] 418 trace.extend ([(kind, addr, v, mem, n, vc) 419 for (kind, addr, v, mem) in accs]) 420 if node.kind == 'Call': 421 msg = '<function call to %s at %s>' % (node.fname, n_nm) 422 print msg 423 trace.append (msg) 424 return trace 425 426def simplify_sexp (smt_xp, rep, m, flatten = True): 427 if type (smt_xp) == str: 428 smt_xp = solver.parse_s_expression (smt_xp) 429 if smt_xp[0] == 'ite': 430 (_, c, x, y) = smt_xp 431 if eval_model_bool (m, c): 432 return simplify_sexp (x, rep, m, flatten) 433 else: 434 return simplify_sexp (y, rep, m, flatten) 435 if type (smt_xp) == tuple: 436 smt_xp = tuple ([simplify_sexp (x, rep, m, False) 437 for x in smt_xp]) 438 if flatten: 439 return solver.flat_s_expression (smt_xp) 440 else: 441 return smt_xp 442 443def trace_mems (rep, m, verbose = False, symbs = True, tags = None): 444 if tags == None: 445 if rep.p.pairing: 446 tags = reversed (rep.p.pairing.tags) 447 else: 448 tags = rep.p.tags () 449 for tag in tags: 450 print '%s mem trace:' % tag 451 trace_mem (rep, tag, m, verbose = verbose, symbs = symbs) 452 453def trace_mems_diff (rep, m, tags = ['ASM', 'C']): 454 asms = trace_mem (rep, tags[0], m, resolve_addrs = True) 455 cs = trace_mem (rep, tags[1], m, resolve_addrs = True) 456 ev = lambda expr: eval_str (expr, {}, None, m) 457 c_upds = [(ev (addr), ev (v)) for (kind, addr, v, mem, _, _) in cs 458 if kind == 'MemUpdate'] 459 asm_upds = [(ev (addr), ev (v)) for (kind, addr, v, mem, _, _) in asms 460 if kind == 'MemUpdate' and 'mem' in m_var_name (mem)] 461 c_upd_d = dict (c_upds) 462 asm_upd_d = dict (asm_upds) 463 addr_ord = [addr for (addr, _) in asm_upds] + [addr for (addr, _) in c_upds 464 if addr not in asm_upd_d] 465 mism = [addr for addr in addr_ord 466 if c_upd_d.get (addr) != asm_upd_d.get (addr)] 467 return (c_upd_d == asm_upd_d, mism, c_upds, asm_upds) 468 469def get_pv_type (pv): 470 assert pv.is_op (['PValid', 'PArrayValid']) 471 typ_v = pv.vals[1] 472 assert typ_v.kind == 'Type' 473 typ = typ_v.val 474 if pv.is_op ('PArrayValid'): 475 return ('PArrayValid', typ, pv.vals[3]) 476 else: 477 return ('PValid', typ, None) 478 479def guess_pv (p, n, addr_expr): 480 vs = syntax.get_expr_var_set (addr_expr) 481 [pred] = p.preds[n] 482 pvs = [] 483 def vis (expr): 484 if expr.is_op (['PValid', 'PArrayValid']): 485 pvs.append (expr) 486 p.nodes[pred].cond.visit (vis) 487 match_pvs = [pv for pv in pvs 488 if set.union (* [syntax.get_expr_var_set (v) for v in pv.vals[2:]]) 489 == vs] 490 if len (match_pvs) > 1: 491 match_pvs = [pv for pv in match_pvs if pv.is_op ('PArrayValid')] 492 pv = match_pvs[0] 493 return pv 494 495def eval_pv_type (rep, (n, vc), m, data): 496 if data[0] == 'PValid': 497 return data 498 else: 499 (nm, typ, offs) = data 500 offs = rep.to_smt_expr (offs, (n, vc)) 501 offs = search.eval_model_expr (m, rep.solv, offs) 502 return (nm, typ, offs) 503 504def trace_suspicious_mem (rep, m, tag = 'C'): 505 cs = trace_mem (rep, tag, m) 506 data = [(addr, search.eval_model_expr (m, rep.solv, 507 rep.to_smt_expr (addr, (n, vc))), (n, vc)) 508 for (kind, addr, v, mem, n, vc) in cs] 509 addr_sets = {} 510 for (addr, addr_v, _) in data: 511 addr_sets.setdefault (addr_v, set ()) 512 addr_sets[addr_v].add (addr) 513 dup_addrs = set ([addr_v for addr_v in addr_sets 514 if len (addr_sets[addr_v]) > 1]) 515 data = [(addr, addr_v, guess_pv (rep.p, n, addr), (n, vc)) 516 for (addr, addr_v, (n, vc)) in data 517 if addr_v in dup_addrs] 518 data = [(addr, addr_v, eval_pv_type (rep, (n, vc), m, 519 get_pv_type (pv)), rep.to_smt_expr (pv, (n, vc)), n) 520 for (addr, addr_v, pv, (n, vc)) in data] 521 dup_addr_types = set ([addr_v for addr_v in dup_addrs 522 if len (set ([t for (_, addr_v2, t, _, _) in data 523 if addr_v2 == addr_v])) > 1]) 524 res = [(addr_v, [(t, pv, n) for (_, addr_v2, t, pv, n) in data 525 if addr_v2 == addr_v]) 526 for addr_v in dup_addr_types] 527 for (addr_v, insts) in res: 528 print 'Address %s' % addr_v 529 for (t, pv, n) in insts: 530 print ' -- accessed with type %s at %s' % (t, n) 531 print ' (covered by %s)' % pv 532 return res 533 534def trace_var (rep, tag, m, v): 535 p = rep.p 536 ns = walk_model (rep, tag, m) 537 vds = rep.p.compute_var_dependencies () 538 trace = [] 539 vs = syntax.get_expr_var_set (v) 540 def fetch ((n, vc)): 541 if n in vds and [(nm, typ) for (nm, typ) in vs 542 if (nm, typ) not in vds[n]]: 543 return None 544 try: 545 (_, env) = rep.get_node_pc_env ((n, vc), tag) 546 s = solver.smt_expr (v, env, rep.solv) 547 s_x = solver.parse_s_expression (s) 548 ev = search.eval_model (m, s_x) 549 return (s, solver.smt_expr (ev, {}, None)) 550 except solver.EnvMiss, e: 551 return None 552 except AssertionError, e: 553 return None 554 val = None 555 for (n, vc) in ns: 556 n_nm = rep.node_count_name ((n, vc)) 557 val2 = fetch ((n, vc)) 558 if val2 != val: 559 if val2 == None: 560 print 'at %s: undefined' % n_nm 561 else: 562 print 'at %s:\t\t%s:\t\t%s' % (n_nm, 563 val2[0], val2[1]) 564 val = val2 565 trace.append (((n, vc), val)) 566 if n not in p.nodes: 567 break 568 node = p.nodes[n] 569 if node.kind == 'Call': 570 msg = '<function call to %s at %s>' % (node.fname, 571 rep.node_count_name ((n, vc))) 572 print msg 573 trace.append (msg) 574 return trace 575 576def trace_deriv_ops (rep, m, tag): 577 n_vcs = walk_model (rep, tag, m) 578 derivs = set (('CountTrailingZeroes', 'CountLeadingZeroes', 579 'WordReverse')) 580 def get_derivs (node): 581 dvs = set () 582 def visit (expr): 583 if expr.is_op (derivs): 584 dvs.add (expr) 585 node.visit (lambda x: (), visit) 586 return dvs 587 for (n, vc) in n_vcs: 588 if n not in rep.p.nodes: 589 continue 590 dvs = get_derivs (rep.p.nodes[n]) 591 if not dvs: 592 continue 593 print '%s:' % (rep.node_count_name ((n, vc))) 594 for dv in dvs: 595 [x] = dv.vals 596 x = rep.to_smt_expr (x, (n, vc)) 597 x = eval_str (x, {}, rep.solv, m) 598 print '\t%s: %s' % (dv.name, x) 599 600def check_pairings (): 601 for p in pairings.itervalues (): 602 print p['C'], p['ASM'] 603 as_args = functions[p['ASM']].inputs 604 c_args = functions[p['C']].inputs 605 print as_args, c_args 606 logic.mk_fun_inp_eqs (as_args, c_args, True) 607 608def loop_var_deps (p): 609 return [(n, [v for v in p.var_deps[n] 610 if p.var_deps[n][v] == 'LoopVariable']) 611 for n in p.loop_data] 612 613def find_symbol (n, output = True): 614 from target_objects import symbols, sections 615 symbs = [] 616 secs = [] 617 if output: 618 def p (s): 619 print s 620 else: 621 p = lambda s: () 622 for (s, (addr, size, _)) in symbols.iteritems (): 623 if addr <= n and n < addr + size: 624 symbs.append (s) 625 p ('%x in %s (%x - %x)' % (n, s, addr, addr + size - 1)) 626 for (s, (start, end)) in sections.iteritems (): 627 if start <= n and n <= end: 628 secs.append (s) 629 p ('%x in section %s (%x - %x)' % (n, s, start, end)) 630 return (symbs, secs) 631 632def assembly_point (p, n): 633 (_, hints) = p.node_tags[n] 634 if type (hints) != tuple or not logic.is_int (hints[1]): 635 return None 636 while p.node_tags[n][1][1] % 4 != 0: 637 [n] = p.preds[n] 638 return p.node_tags[n][1][1] 639 640def assembly_points (p, ns): 641 ns = [assembly_point (p, n) for n in ns] 642 ns = [n for n in ns if n != None] 643 return ns 644 645def disassembly_lines (addrs): 646 f = open ('%s/kernel.elf.txt' % target_objects.target_dir) 647 addr_set = set (['%x' % addr for addr in addrs]) 648 ss = [l.strip () 649 for l in f if ':' in l and l.split(':', 1)[0] in addr_set] 650 return ss 651 652def disassembly (p, n): 653 if hasattr (n, '__iter__'): 654 ns = set (n) 655 else: 656 ns = [n] 657 addrs = sorted (set ([assembly_point (p, n) for n in ns]) 658 - set ([None])) 659 print 'asm %s' % ', '.join (['0x%x' % addr for addr in addrs]) 660 for s in disassembly_lines (addrs): 661 print s 662 663def disassembly_loop (p, n): 664 head = p.loop_id (n) 665 loop = p.loop_body (n) 666 ns = sorted (set (assembly_points (p, loop))) 667 entries = assembly_points (p, [n for n in p.preds[head] 668 if n not in loop]) 669 print 'Loop: [%s]' % ', '.join (['%x' % addr for addr in ns]) 670 for s in disassembly_lines (ns): 671 print s 672 print 'entry from %s' % ', '.join (['%x' % addr for addr in entries]) 673 for s in disassembly_lines (entries): 674 print s 675 676def try_interpret_hyp (rep, hyp): 677 try: 678 expr = rep.interpret_hyp (hyp) 679 solver.smt_expr (expr, {}, rep.solv) 680 return None 681 except: 682 return ('Broken Hyp', hyp) 683 684def check_checks (): 685 p = problem.last_problem[0] 686 rep = rep_graph.mk_graph_slice (p) 687 proof = search.last_proof[0] 688 checks = check.proof_checks (p, proof) 689 all_hyps = set ([hyp for (_, hyp, _) in checks] 690 + [hyp for (hyps, _, _) in checks for hyp in hyps]) 691 results = [try_interpret_hyp (rep, hyp) for hyp in all_hyps] 692 return [r[1] for r in results if r] 693 694def proof_failed_groups (p = None, proof = None): 695 if p == None: 696 p = problem.last_problem[0] 697 if proof == None: 698 proof = search.last_proof[0] 699 checks = check.proof_checks (p, proof) 700 groups = check.proof_check_groups (checks) 701 failed = [] 702 for group in groups: 703 rep = rep_graph.mk_graph_slice (p) 704 (res, el) = check.test_hyp_group (rep, group) 705 if not res: 706 failed.append (group) 707 print 'Failed element: %s' % el 708 failed_nms = set ([s for group in failed for (_, _, s) in group]) 709 print 'Failed: %s' % failed_nms 710 return failed 711 712def read_summary (f): 713 results = {} 714 times = {} 715 for line in f: 716 if not line.startswith ('Time taken to'): 717 continue 718 bits = line.split () 719 assert bits[:4] == ['Time', 'taken', 'to', 'check'] 720 res = bits[4] 721 [ref] = [i for (i, b) in enumerate (bits) if b == '<='] 722 f = bits[ref + 1] 723 [pair] = [pair for pair in pairings[f] 724 if pair.name in line] 725 time = float (bits[-1]) 726 results[pair] = res 727 times[pair] = time 728 return (results, times) 729 730def unfold_defs_sexpr (defs, sexpr, depthlimit = -1): 731 if type (sexpr) == str: 732 sexpr = defs.get (sexpr, sexpr) 733 print sexpr 734 return sexpr 735 elif depthlimit == 0: 736 return sexpr 737 return tuple ([sexpr[0]] + [unfold_defs_sexpr (defs, s, depthlimit - 1) 738 for s in sexpr[1:]]) 739 740def unfold_defs (defs, hyp, depthlimit = -1): 741 return solver.flat_s_expression (unfold_defs_sexpr (defs, 742 solver.parse_s_expression (hyp), depthlimit)) 743 744def investigate_unsat (solv, hyps = None): 745 if hyps == None: 746 hyps = list (solver.last_hyps[0]) 747 assert solv.hyps_sat_raw (hyps) == 'unsat', hyps 748 kept_hyps = [] 749 while hyps: 750 h = hyps.pop () 751 if solv.hyps_sat_raw (hyps + kept_hyps) != 'unsat': 752 kept_hyps.append (h) 753 assert solv.hyps_sat_raw (kept_hyps) == 'unsat', kept_hyps 754 split_hyps = sorted (set ([(hyp2, tag) for (hyp, tag) in kept_hyps 755 for hyp2 in solver.split_hyp (hyp)])) 756 if len (split_hyps) > len (kept_hyps): 757 return investigate_unsat (solv, split_hyps) 758 def_hyps = [(unfold_defs (solv.defs, h, 2), tag) 759 for (h, tag) in kept_hyps] 760 if def_hyps != kept_hyps: 761 return investigate_unsat (solv, def_hyps) 762 return kept_hyps 763 764def test_interesting_linear_series_exprs (): 765 pairs = set ([pair for f in pairings for pair in pairings[f]]) 766 notes = {} 767 for pair in pairs: 768 p = check.build_problem (pair) 769 for n in search.init_loops_to_split (p, ()): 770 intr = logic.interesting_linear_series_exprs (p, n, 771 search.get_loop_var_analysis_at (p, n)) 772 if intr: 773 notes[pair.name] = True 774 if 'Call' in str (intr): 775 notes[pair.name] = 'Call!' 776 return notes 777 778def var_analysis (p, n): 779 va = search.get_loop_var_analysis_at (p, n) 780 cats = {} 781 for (v, kind) in va: 782 if kind[0] == 'LoopLinearSeries': 783 offs = kind[2] 784 kind = kind[0] 785 else: 786 offs = None 787 cats.setdefault (kind, []) 788 cats[kind].append ((v, offs)) 789 for kind in cats: 790 print '%s:' % kind 791 for (v, offs) in cats[kind]: 792 print ' %s (%s)' % (syntax.pretty_expr (v), 793 syntax.pretty_type (v.typ)) 794 if offs: 795 print ' ++ %s' % syntax.pretty_expr (offs) 796 797def var_value_sites (rep, v): 798 if type (v) == str: 799 matches = lambda (nm, _): v in nm 800 elif type (v) == tuple: 801 matches = lambda (nm, typ): v == (nm, typ) 802 v_ord = [] 803 d = {} 804 for (tag, n, vc) in rep.node_pc_env_order: 805 (pc, env) = rep.get_node_pc_env ((n, vc), tag = tag) 806 for (v2, smt_exp) in env.iteritems (): 807 if matches (v2): 808 if smt_exp not in d: 809 v_ord.append (smt_exp) 810 d[smt_exp] = [] 811 d[smt_exp].append ((n, vc)) 812 for smt_exp in v_ord: 813 print smt_exp 814 if smt_exp in rep.solv.defs: 815 print (' = %s' % repr (rep.solv.defs[smt_exp])) 816 print (' - at: %s' % d[smt_exp]) 817 if v_ord: 818 print ('') 819 return (v_ord, d) 820 821def loop_num_leaves (p, n): 822 for n in p.loop_body (n): 823 va = search.get_loop_var_analysis_at (p, n) 824 n_leaf = len ([1 for (v, kind) in va if kind == 'LoopLeaf']) 825 print (n, n_leaf) 826 827def try_pairing_at_funcall (p, name, head = None, restrs = None, hyps = None, 828 at = 'At'): 829 pairs = set (pairings[name]) 830 addrs = [n for (n, name2) in p.function_call_addrs () 831 if [pair for pair in pairings[name2] if pair in pairs]] 832 assert at in ['At', 'After'] 833 if at == 'After': 834 addrs = [p.nodes[n].cont for n in addrs] 835 if head == None: 836 tags = p.pairing.tags 837 [head] = [n for n in search.init_loops_to_split (p, ()) 838 if p.node_tags[n][0] == tags[0]] 839 if restrs == None: 840 restrs = () 841 if hyps == None: 842 hyps = check.init_point_hyps (p) 843 while True: 844 res = search.find_split_loop (p, head, restrs, hyps, 845 node_restrs = set (addrs)) 846 if res[0] == 'CaseSplit': 847 (_, ((n, tag), _)) = res 848 hyp = rep_graph.pc_true_hyp (((n, restrs), tag)) 849 hyps = hyps + [hyp] 850 else: 851 return res 852 853def init_true_hyp (p, tag, expr): 854 n = p.get_entry (tag) 855 vis = ((n, ()), tag) 856 assert expr.typ == syntax.boolT, expr 857 return rep_graph.eq_hyp ((expr, vis), (syntax.true_term, vis)) 858 859def smt_print (expr): 860 env = {} 861 while True: 862 try: 863 return solver.smt_expr (expr, env, None) 864 except solver.EnvMiss, e: 865 env[(e.name, e.typ)] = e.name 866 867