1# * Copyright 2015, NICTA 2# * 3# * This software may be distributed and modified according to the terms of 4# * the BSD 2-Clause license. Note that NO WARRANTY is provided. 5# * See "LICENSE_BSD2.txt" for details. 6# * 7# * @TAG(NICTA_BSD) 8 9from solver import Solver, merge_envs_pcs, smt_expr, mk_smt_expr, to_smt_expr 10from syntax import (true_term, false_term, boolT, mk_and, mk_not, mk_implies, 11 builtinTs, word32T, word8T, foldr1, mk_eq, mk_plus, mk_word32, mk_var) 12import syntax 13import logic 14import solver 15from logic import azip 16 17from target_objects import functions, pairings, sections, trace, printout 18import target_objects 19import problem 20 21class VisitCount: 22 """Used to represent a target number of visits to a split point. 23 Options include a number (0, 1, 2), a symbolic offset (i + 1, i + 2), 24 or a list of options.""" 25 def __init__ (self, kind, value): 26 self.kind = kind 27 self.is_visit_count = True 28 if kind == 'Number': 29 self.n = value 30 elif kind == 'Offset': 31 self.n = value 32 elif kind == 'Options': 33 self.opts = tuple (value) 34 for opt in self.opts: 35 assert opt.kind in ['Number', 'Offset'] 36 else: 37 assert not 'VisitCount type understood' 38 39 def __hash__ (self): 40 if self.kind == 'Options': 41 return hash (self.opts) 42 else: 43 return hash (self.kind) + self.n 44 45 def __eq__ (self, other): 46 if not other: 47 return False 48 if self.kind == 'Options': 49 return (other.kind == 'Options' 50 and self.opts == other.opts) 51 else: 52 return self.kind == other.kind and self.n == other.n 53 54 def __neq__ (self, other): 55 if not other: 56 return True 57 return not (self == other) 58 59 def __str__ (self): 60 if self.kind == 'Number': 61 return str (self.n) 62 elif self.kind == 'Offset': 63 return 'i+%s' % self.n 64 elif self.kind == 'Options': 65 return '_'.join (map (str, self.opts)) 66 67 def __repr__ (self): 68 (ns, os) = self.get_opts () 69 return 'vc_options (%r, %r)' % (ns, os) 70 71 def get_opts (self): 72 if self.kind == 'Options': 73 opts = self.opts 74 else: 75 opts = [self] 76 ns = [vc.n for vc in opts if vc.kind == 'Number'] 77 os = [vc.n for vc in opts if vc.kind == 'Offset'] 78 return (ns, os) 79 80 def serialise (self, ss): 81 ss.append ('VC') 82 (ns, os) = self.get_opts () 83 ss.append ('%d' % len (ns)) 84 ss.extend (['%d' % n for n in ns]) 85 ss.append ('%d' % len (os)) 86 ss.extend (['%d' % n for n in os]) 87 88 def incr (self, incr): 89 if self.kind in ['Number', 'Offset']: 90 n = self.n + incr 91 if n < 0: 92 return None 93 return VisitCount (self.kind, n) 94 elif self.kind == 'Options': 95 opts = [vc.incr (incr) for vc in self.opts] 96 opts = [opt for opt in opts if opt] 97 if opts == []: 98 return None 99 return mk_vc_opts (opts) 100 else: 101 assert not 'VisitCount type understood' 102 103 def has_zero (self): 104 if self.kind == 'Options': 105 return bool ([vc for vc in self.opts 106 if vc.has_zero ()]) 107 else: 108 return self.kind == 'Number' and self.n == 0 109 110def mk_vc_opts (opts): 111 if len (opts) == 1: 112 return opts[0] 113 else: 114 return VisitCount ('Options', opts) 115 116def vc_options (nums, offsets): 117 return mk_vc_opts (map (vc_num, nums) + map (vc_offs, offsets)) 118 119def vc_num (n): 120 return VisitCount ('Number', n) 121 122def vc_upto (n): 123 return mk_vc_opts (map (vc_num, range (n))) 124 125def vc_offs (n): 126 return VisitCount ('Offset', n) 127 128def vc_offset_upto (n): 129 return mk_vc_opts (map (vc_offs, range (n))) 130 131def vc_double_range (n, m): 132 return mk_vc_opts (map (vc_num, range (n)) + map (vc_offs, range (m))) 133 134class InlineEvent(Exception): 135 pass 136 137class Hyp: 138 """Used to represent a proposition about path conditions or data at 139 various points in execution.""" 140 141 def __init__ (self, kind, arg1, arg2, induct = None): 142 self.kind = kind 143 if kind == 'PCImp': 144 self.pcs = [arg1, arg2] 145 elif kind == 'Eq': 146 self.vals = [arg1, arg2] 147 self.induct = induct 148 elif kind == 'EqIfAt': 149 self.vals = [arg1, arg2] 150 self.induct = induct 151 else: 152 assert not 'hyp kind understood' 153 154 def __repr__ (self): 155 if self.kind == 'PCImp': 156 vals = map (repr, self.pcs) 157 elif self.kind in ['Eq', 'EqIfAt']: 158 vals = map (repr, self.vals) 159 if self.induct: 160 vals += [repr (self.induct)] 161 else: 162 assert not 'hyp kind understood' 163 return 'Hyp (%r, %s)' % (self.kind, ', '.join (vals)) 164 165 def hyp_tuple (self): 166 if self.kind == 'PCImp': 167 return ('PCImp', self.pcs[0], self.pcs[1]) 168 elif self.kind in ['Eq', 'EqIfAt']: 169 return (self.kind, self.vals[0], 170 self.vals[1], self.induct) 171 else: 172 assert not 'hyp kind understood' 173 174 def __hash__ (self): 175 return hash (self.hyp_tuple ()) 176 177 def __ne__ (self, other): 178 return not other or not (self == other) 179 180 def __cmp__ (self, other): 181 return cmp (self.hyp_tuple (), other.hyp_tuple ()) 182 183 def visits (self): 184 if self.kind == 'PCImp': 185 return [vis for vis in self.pcs 186 if vis[0] != 'Bool'] 187 elif self.kind in ['Eq', 'EqIfAt']: 188 return [vis for (_, vis) in self.vals] 189 else: 190 assert not 'hyp kind understood' 191 192 def get_vals (self): 193 if self.kind == 'PCImp': 194 return [] 195 else: 196 return [val for (val, _) in self.vals] 197 198 def serialise_visit (self, (n, restrs), ss): 199 ss.append ('%s' % n) 200 ss.append ('%d' % len (restrs)) 201 for (n2, vc) in restrs: 202 ss.append ('%d' % n2) 203 vc.serialise (ss) 204 205 def serialise_pc (self, pc, ss): 206 if pc[0] == 'Bool' and pc[1] == true_term: 207 ss.append ('True') 208 elif pc[0] == 'Bool' and pc[1] == false_term: 209 ss.append ('False') 210 else: 211 ss.append ('PC') 212 serialise_visit (pc[0], ss) 213 ss.append (pc[1]) 214 215 def serialise_hyp (self, ss): 216 if self.kind == 'PCImp': 217 (visit1, visit2) = self.pcs 218 ss.append ('PCImp') 219 self.serialise_pc (visit1, ss) 220 self.serialise_pc (visit2, ss) 221 elif self.kind in ['Eq', 'EqIfAt']: 222 assert len (self.vals) == 2 223 ss.extend (self.kind) 224 for (exp, visit) in self.vals: 225 exp.serialise (ss) 226 self.serialise_visit (visit, ss) 227 if induct: 228 ss.append ('%d' % induct[0]) 229 ss.append ('%d' % induct[1]) 230 else: 231 ss.extend (['None', 'None']) 232 else: 233 assert not 'hyp kind understood' 234 235 def interpret (self, rep): 236 if self.kind == 'PCImp': 237 ((visit1, tag1), (visit2, tag2)) = self.pcs 238 if visit1 == 'Bool': 239 pc1 = tag1 240 else: 241 pc1 = rep.get_pc (visit1, tag = tag1) 242 if visit2 == 'Bool': 243 pc2 = tag2 244 else: 245 pc2 = rep.get_pc (visit2, tag = tag2) 246 return mk_implies (pc1, pc2) 247 elif self.kind in ['Eq', 'EqIfAt']: 248 [(x, xvis), (y, yvis)] = self.vals 249 if self.induct: 250 v = rep.get_induct_var (self.induct) 251 x = subst_induct (x, v) 252 y = subst_induct (y, v) 253 x_pc_env = rep.get_node_pc_env (xvis[0], tag = xvis[1]) 254 y_pc_env = rep.get_node_pc_env (yvis[0], tag = yvis[1]) 255 if x_pc_env == None or y_pc_env == None: 256 if self.kind == 'EqIfAt': 257 return syntax.true_term 258 else: 259 return syntax.false_term 260 ((_, xenv), (_, yenv)) = (x_pc_env, y_pc_env) 261 eq = inst_eq_with_envs ((x, xenv), (y, yenv), rep.solv) 262 if self.kind == 'EqIfAt': 263 x_pc = rep.get_pc (xvis[0], tag = xvis[1]) 264 y_pc = rep.get_pc (yvis[0], tag = yvis[1]) 265 return syntax.mk_n_implies ([x_pc, y_pc], eq) 266 else: 267 return eq 268 else: 269 assert not 'hypothesis type understood' 270 271def check_vis_is_vis (((n, vc), tag)): 272 assert vc[:0] == (), vc 273 274def eq_hyp (lhs, rhs, induct = None, use_if_at = False): 275 check_vis_is_vis (lhs[1]) 276 check_vis_is_vis (rhs[1]) 277 kind = 'Eq' 278 if use_if_at: 279 kind = 'EqIfAt' 280 return Hyp (kind, lhs, rhs, induct = induct) 281 282def true_if_at_hyp (expr, vis, induct = None): 283 check_vis_is_vis (vis) 284 return Hyp ('EqIfAt', (expr, vis), (true_term, vis), 285 induct = induct) 286 287def pc_true_hyp (vis): 288 check_vis_is_vis (vis) 289 return Hyp ('PCImp', ('Bool', true_term), vis) 290 291def pc_false_hyp (vis): 292 check_vis_is_vis (vis) 293 return Hyp ('PCImp', vis, ('Bool', false_term)) 294 295def pc_triv_hyp (vis): 296 check_vis_is_vis (vis) 297 return Hyp ('PCImp', vis, vis) 298 299class GraphSlice: 300 """Used to represent a slice of potential execution in a graph where 301 looping is limited to certain specific examples. For instance, we 302 might say that execution through node n will be represented only 303 by visits 0, 1, 2, 3, i, and i + 1 (for a symbolic value i). The 304 variable state at visits 4 and i + 2 will be calculated but no 305 further execution will be done.""" 306 307 def __init__ (self, p, solv, inliner = None, fast = False): 308 self.p = p 309 self.solv = solv 310 self.inp_envs = {} 311 self.mem_calls = {} 312 self.add_input_envs () 313 314 self.node_pc_envs = {} 315 self.node_pc_env_order = [] 316 self.arc_pc_envs = {} 317 self.inliner = inliner 318 self.funcs = {} 319 self.pc_env_requests = set () 320 self.fast = fast 321 self.induct_var_env = {} 322 self.contractions = {} 323 324 self.local_defs_unsat = False 325 self.use_known_eqs = True 326 327 self.avail_hyps = set () 328 self.used_hyps = set () 329 330 def add_input_envs (self): 331 for (entry, _, _, args) in self.p.entries: 332 self.inp_envs[entry] = mk_inp_env (entry, args, self) 333 334 def get_reachable (self, split, n): 335 return self.p.is_reachable_from (split, n) 336 337 class TooGeneral (Exception): 338 def __init__ (self, split): 339 self.split = split 340 341 def get_tag_vcount (self, (n, vcount), tag): 342 if tag == None: 343 tag = self.p.node_tags[n][0] 344 vcount_r = [(split, count, self.get_reachable (split, n)) 345 for (split, count) in vcount 346 if self.p.node_tags[split][0] == tag] 347 for (split, count, r) in vcount_r: 348 if not r and not count.has_zero (): 349 return (tag, None) 350 assert count.is_visit_count 351 vcount = [(s, c) for (s, c, r) in vcount_r if r] 352 vcount = tuple (sorted (vcount)) 353 354 loop_id = self.p.loop_id (n) 355 if loop_id != None: 356 for (split, visits) in vcount: 357 if (self.p.loop_id (split) == loop_id 358 and visits.kind == 'Options'): 359 raise self.TooGeneral (split) 360 361 return (tag, vcount) 362 363 def get_node_pc_env (self, (n, vcount), tag = None, request = True): 364 tag, vcount = self.get_tag_vcount ((n, vcount), tag) 365 if vcount == None: 366 return None 367 368 if (tag, n, vcount) in self.node_pc_envs: 369 return self.node_pc_envs[(tag, n, vcount)] 370 371 if request: 372 self.pc_env_requests.add (((n, vcount), tag)) 373 374 self.warm_pc_env_cache ((n, vcount), tag) 375 376 pc_env = self.get_node_pc_env_raw ((n, vcount), tag) 377 if pc_env: 378 pc_env = self.apply_known_eqs_pc_env ((n, vcount), 379 tag, pc_env) 380 381 assert not (tag, n, vcount) in self.node_pc_envs 382 self.node_pc_envs[(tag, n, vcount)] = pc_env 383 if pc_env: 384 self.node_pc_env_order.append ((tag, n, vcount)) 385 386 return pc_env 387 388 def warm_pc_env_cache (self, n_vc, tag): 389 'this is to avoid recursion limits and spot bugs' 390 prev_chain = [] 391 for i in range (5000): 392 prevs = self.prevs (n_vc) 393 try: 394 prevs = [p for p in prevs 395 if (tag, p[0], p[1]) 396 not in self.node_pc_envs 397 if self.get_tag_vcount (p, None) 398 == (tag, n_vc[1])] 399 except self.TooGeneral: 400 break 401 if not prevs: 402 break 403 n_vc = prevs[0] 404 prev_chain.append(n_vc) 405 if not (len (prev_chain) < 5000): 406 printout ([n for (n, vc) in prev_chain]) 407 assert len (prev_chain) < 5000, (prev_chain[:10], 408 prev_chain[-10:]) 409 410 prev_chain.reverse () 411 for n_vc in prev_chain: 412 self.get_node_pc_env (n_vc, tag, request = False) 413 414 def get_loop_pc_env (self, split, vcount): 415 vcount2 = dict (vcount) 416 vcount2[split] = vc_num (0) 417 vcount2 = tuple (sorted (vcount2.items ())) 418 prev_pc_env = self.get_node_pc_env ((split, vcount2)) 419 if prev_pc_env == None: 420 return None 421 (_, prev_env) = prev_pc_env 422 mem_calls = self.scan_mem_calls (prev_env) 423 mem_calls = self.add_loop_mem_calls (split, mem_calls) 424 def av (nm, typ, mem_name = None): 425 nm2 = '%s_loop_at_%s' % (nm, split) 426 return self.add_var (nm2, typ, 427 mem_name = mem_name, mem_calls = mem_calls) 428 env = {} 429 consts = set () 430 for (nm, typ) in prev_env: 431 check_const = self.fast or (typ in 432 [builtinTs['HTD'], builtinTs['Dom']]) 433 if check_const and self.is_synt_const (nm, typ, split): 434 env[(nm, typ)] = prev_env[(nm, typ)] 435 consts.add ((nm, typ)) 436 else: 437 env[(nm, typ)] = av (nm + '_after', typ, 438 ('Loop', prev_env[(nm, typ)])) 439 for (nm, typ) in prev_env: 440 if (nm, typ) in consts: 441 continue 442 z = self.var_rep_request ((nm, typ), 'Loop', 443 (split, vcount), env) 444 if z: 445 env[(nm, typ)] = z 446 447 pc = mk_smt_expr (av ('pc_of', boolT), boolT) 448 if self.fast: 449 imp = syntax.mk_implies (pc, prev_pc_env[0]) 450 self.solv.assert_fact (imp, prev_env, 451 unsat_tag = ('LoopPCImp', split)) 452 453 return (pc, env) 454 455 def is_synt_const (self, nm, typ, split): 456 """check if a variable at a split point is a syntactic constant 457 which is always unmodified by the loop. 458 we allow cases where a variable is renamed and renamed back 459 during the loop (this often happens because of inlining). 460 the check is done by depth-first-search backward through the 461 graph looking for a source of a variant value.""" 462 loop = self.p.loop_id (split) 463 if problem.has_inner_loop (self.p, split): 464 return False 465 loop_set = set (self.p.loop_body (split)) 466 467 orig_nm = nm 468 safe = set ([(orig_nm, split)]) 469 first_step = True 470 visit = [] 471 count = 0 472 while first_step or visit: 473 if first_step: 474 (nm, n) = (orig_nm, split) 475 first_step = False 476 else: 477 (nm, n) = visit.pop () 478 if (nm, n) in safe: 479 continue 480 elif n == split: 481 return False 482 new_nm = nm 483 node = self.p.nodes[n] 484 if node.kind == 'Call': 485 if (nm, typ) not in node.rets: 486 pass 487 elif self.fast_const_ret (n, nm, typ): 488 pass 489 else: 490 return False 491 elif node.kind == 'Basic': 492 upds = [arg for (lv, arg) in node.upds 493 if lv == (nm, typ)] 494 if [v for v in upds if v.kind != 'Var']: 495 return False 496 if upds: 497 new_nm = upds[0].name 498 preds = [(new_nm, n2) for n2 in self.p.preds[n] 499 if n2 in loop_set] 500 unknowns = [p for p in preds if p not in safe] 501 if unknowns: 502 visit.extend ([(nm, n)] + unknowns) 503 else: 504 safe.add ((nm, n)) 505 count += 1 506 if count % 100000 == 0: 507 trace ('is_synt_const: %d iterations' % count) 508 trace ('visit length %d' % len (visit)) 509 trace ('visit tail %s' % visit[-20:]) 510 return True 511 512 def fast_const_ret (self, n, nm, typ): 513 """determine if we can heuristically consider this return 514 value to be the same as an input. this is known for some 515 function returns, e.g. memory. 516 this is important for heuristic "fast" analysis.""" 517 if not self.fast: 518 return False 519 node = self.p.nodes[n] 520 assert node.kind == 'Call' 521 for hook in target_objects.hooks ('rep_unsafe_const_ret'): 522 if hook (node, nm, typ): 523 return True 524 return False 525 526 def get_node_pc_env_raw (self, (n, vcount), tag): 527 if n in self.inp_envs: 528 return (true_term, self.inp_envs[n]) 529 530 for (split, count) in vcount: 531 if split == n and count == vc_offs (0): 532 return self.get_loop_pc_env (split, vcount) 533 534 pc_envs = [pc_env for n_prev in self.p.preds[n] 535 if self.p.node_tags[n_prev][0] == tag 536 for pc_env in self.get_arc_pc_envs (n_prev, 537 (n, vcount))] 538 539 pc_envs = [pc_env for pc_env in pc_envs if pc_env] 540 if pc_envs == []: 541 return None 542 543 if n == 'Err': 544 # we'll never care about variable values here 545 # and there are sometimes a LOT of arcs to Err 546 # so we save a lot of merge effort 547 pc_envs = [(to_smt_expr (pc, env, self.solv), {}) 548 for (pc, env) in pc_envs] 549 550 (pc, env, large) = merge_envs_pcs (pc_envs, self.solv) 551 552 if pc.kind != 'SMTExpr': 553 name = self.path_cond_name ((n, vcount), tag) 554 name = self.solv.add_def (name, pc, env) 555 pc = mk_smt_expr (name, boolT) 556 557 for (nm, typ) in env: 558 if len (env[(nm, typ)]) > 80: 559 env[(nm, typ)] = self.contract (nm, (n, vcount), 560 env[(nm, typ)], typ) 561 562 return (pc, env) 563 564 def contract (self, name, n_vc, val, typ): 565 if val in self.contractions: 566 return self.contractions[val] 567 568 name = self.local_name_before (name, n_vc) 569 name = self.solv.add_def (name, mk_smt_expr (val, typ), {}) 570 571 self.contractions[val] = name 572 return name 573 574 def get_arc_pc_envs (self, n, n_vc2): 575 try: 576 prevs = [n_vc for n_vc in self.prevs (n_vc2) 577 if n_vc[0] == n] 578 assert len (prevs) <= 1 579 return [self.get_arc_pc_env (n_vc, n_vc2) 580 for n_vc in prevs] 581 except self.TooGeneral, e: 582 # consider specialisations of the target 583 specs = self.specialise (n_vc2, e.split) 584 specs = [(n_vc2[0], spec) for spec in specs] 585 return [pc_env for spec in specs 586 for pc_env in self.get_arc_pc_envs (n, spec)] 587 588 def get_arc_pc_env (self, (n, vcount), n2): 589 tag, vcount = self.get_tag_vcount ((n, vcount), None) 590 591 if vcount == None: 592 return None 593 594 assert self.is_cont ((n, vcount), n2), ((n, vcount), 595 n2, self.p.nodes[n].get_conts ()) 596 597 if (n, vcount) in self.arc_pc_envs: 598 return self.arc_pc_envs[(n, vcount)].get (n2[0]) 599 600 if self.get_node_pc_env ((n, vcount), request = False) == None: 601 return None 602 603 arcs = self.emit_node ((n, vcount)) 604 self.post_emit_node_hooks ((n, vcount)) 605 arcs = dict ([(cont, (pc, env)) for (cont, pc, env) in arcs]) 606 607 self.arc_pc_envs[(n, vcount)] = arcs 608 return arcs.get (n2[0]) 609 610 def add_local_def (self, n, vname, name, val, env): 611 if self.local_defs_unsat: 612 smt_name = self.solv.add_var (name, val.typ) 613 eq = mk_eq (mk_smt_expr (smt_name, val.typ), val) 614 self.solv.assert_fact (eq, env, unsat_tag 615 = ('Def', n, vname)) 616 else: 617 smt_name = self.solv.add_def (name, val, env) 618 return smt_name 619 620 def add_var (self, name, typ, mem_name = None, mem_calls = None): 621 r = self.solv.add_var_restr (name, typ, mem_name = mem_name) 622 if typ == syntax.builtinTs['Mem']: 623 r_x = solver.parse_s_expression (r) 624 self.mem_calls[r_x] = mem_calls 625 return r 626 627 def var_rep_request (self, (nm, typ), kind, n_vc, env): 628 assert type (n_vc[0]) != str 629 for hook in target_objects.hooks ('problem_var_rep'): 630 z = hook (self.p, (nm, typ), kind, n_vc[0]) 631 if z == None: 632 continue 633 if z[0] == 'SplitMem': 634 assert typ == builtinTs['Mem'] 635 (_, addr) = z 636 addr = smt_expr (addr, env, self.solv) 637 name = '%s_for_%s' % (nm, 638 self.node_count_name (n_vc)) 639 return self.solv.add_split_mem_var (addr, name, 640 typ, mem_name = 'SplitMemNonsense') 641 else: 642 assert z == None 643 644 def emit_node (self, n): 645 (pc, env) = self.get_node_pc_env (n, request = False) 646 tag = self.p.node_tags[n[0]][0] 647 app_eqs = self.apply_known_eqs_tm (n, tag) 648 # node = logic.simplify_node_elementary (self.p.nodes[n[0]]) 649 # whether to ignore unreachable Cond arcs seems to be a huge 650 # dilemma. if we ignore them, some reachable sites become 651 # unreachable and we can't interpret all hyps 652 # if we don't ignore them, the variable set disagrees with 653 # var_deps and so the abstracted loop pc/env may not be 654 # sufficient and we get EnvMiss again. I don't really know 655 # what to do about this corner case. 656 node = self.p.nodes[n[0]] 657 env = dict (env) 658 659 if node.kind == 'Call': 660 self.try_inline (n[0], pc, env) 661 662 if pc == false_term: 663 return [(c, false_term, {}) for c in node.get_conts()] 664 elif node.kind == 'Cond' and node.left == node.right: 665 return [(node.left, pc, env)] 666 elif node.kind == 'Cond' and node.cond == true_term: 667 return [(node.left, pc, env), 668 (node.right, false_term, env)] 669 elif node.kind == 'Basic': 670 upds = [] 671 for (lv, v) in node.upds: 672 if v.kind == 'Var': 673 upds.append ((lv, env[(v.name, v.typ)])) 674 else: 675 name = self.local_name (lv[0], n) 676 v = app_eqs (v) 677 vname = self.add_local_def (n, 678 ('Var', lv), name, v, env) 679 upds.append ((lv, vname)) 680 for (lv, v) in upds: 681 env[lv] = v 682 return [(node.cont, pc, env)] 683 elif node.kind == 'Cond': 684 name = self.cond_name (n) 685 cond = self.p.fresh_var (name, boolT) 686 env[(cond.name, boolT)] = self.add_local_def (n, 687 'Cond', name, app_eqs (node.cond), env) 688 lpc = mk_and (cond, pc) 689 rpc = mk_and (mk_not (cond), pc) 690 return [(node.left, lpc, env), (node.right, rpc, env)] 691 elif node.kind == 'Call': 692 nm = self.success_name (node.fname, n) 693 success = self.solv.add_var (nm, boolT) 694 success = mk_smt_expr (success, boolT) 695 fun = functions[node.fname] 696 ins = dict ([((x, typ), smt_expr (app_eqs (arg), env, self.solv)) 697 for ((x, typ), arg) in azip (fun.inputs, node.args)]) 698 mem_name = None 699 for (x, typ) in reversed (fun.inputs): 700 if typ == builtinTs['Mem']: 701 inp_mem = ins[(x, typ)] 702 mem_name = (node.fname, inp_mem) 703 mem_calls = self.scan_mem_calls (ins) 704 mem_calls = self.add_mem_call (node.fname, mem_calls) 705 outs = {} 706 for ((x, typ), (y, typ2)) in azip (node.rets, fun.outputs): 707 assert typ2 == typ 708 if self.fast_const_ret (n[0], x, typ): 709 outs[(y, typ2)] = env [(x, typ)] 710 continue 711 name = self.local_name (x, n) 712 env[(x, typ)] = self.add_var (name, typ, 713 mem_name = mem_name, 714 mem_calls = mem_calls) 715 outs[(y, typ2)] = env[(x, typ)] 716 for ((x, typ), (y, _)) in azip (node.rets, fun.outputs): 717 z = self.var_rep_request ((x, typ), 718 'Call', n, env) 719 if z != None: 720 env[(x, typ)] = z 721 outs[(y, typ)] = z 722 self.add_func (node.fname, ins, outs, success, n) 723 return [(node.cont, pc, env)] 724 else: 725 assert not 'node kind understood' 726 727 def post_emit_node_hooks (self, (n, vcount)): 728 for hook in target_objects.hooks ('post_emit_node'): 729 hook (self, (n, vcount)) 730 731 def fetch_known_eqs (self, n_vc, tag): 732 if not self.use_known_eqs: 733 return None 734 eqs = self.p.known_eqs.get ((n_vc, tag)) 735 if eqs == None: 736 return None 737 avail = [] 738 for (x, n_vc_y, tag_y, y, hyps) in eqs: 739 if hyps <= self.avail_hyps: 740 (_, env) = self.get_node_pc_env (n_vc_y, tag_y) 741 avail.append ((x, smt_expr (y, env, self.solv))) 742 self.used_hyps.update (hyps) 743 if avail: 744 return avail 745 return None 746 747 def apply_known_eqs_pc_env (self, n_vc, tag, (pc, env)): 748 eqs = self.fetch_known_eqs (n_vc, tag) 749 if eqs == None: 750 return (pc, env) 751 env = dict (env) 752 for (x, sx) in eqs: 753 if x.kind == 'Var': 754 cur_rhs = env[x.name] 755 for y in env: 756 if env[y] == cur_rhs: 757 trace ('substituted %s at %s.' % (y, n_vc)) 758 env[y] = sx 759 return (pc, env) 760 761 def apply_known_eqs_tm (self, n_vc, tag): 762 eqs = self.fetch_known_eqs (n_vc, tag) 763 if eqs == None: 764 return lambda x: x 765 eqs = dict ([(x, mk_smt_expr (sexpr, x.typ)) 766 for (x, sexpr) in eqs]) 767 return lambda tm: logic.recursive_term_subst (eqs, tm) 768 769 def rebuild (self, solv = None): 770 requests = self.pc_env_requests 771 772 self.node_pc_env_order = [] 773 self.node_pc_envs = {} 774 self.arc_pc_envs = {} 775 self.funcs = {} 776 self.pc_env_requests = set () 777 self.induct_var_env = {} 778 self.contractions = {} 779 780 if not solv: 781 solv = Solver (produce_unsat_cores 782 = self.local_defs_unsat) 783 self.solv = solv 784 785 self.add_input_envs () 786 787 self.used_hyps = set () 788 run_requests (self, requests) 789 790 def add_func (self, name, inputs, outputs, success, n_vc): 791 assert n_vc not in self.funcs 792 self.funcs[n_vc] = (inputs, outputs, success) 793 for pair in pairings.get (name, []): 794 self.funcs.setdefault (pair.name, []) 795 group = self.funcs[pair.name] 796 for n_vc2 in group: 797 if self.get_func_pairing (n_vc, n_vc2): 798 self.add_func_assert (n_vc, n_vc2) 799 group.append (n_vc) 800 801 def get_func (self, n_vc, tag = None): 802 """returns (input_env, output_env, success_var) for 803 function call at given n_vc.""" 804 tag, vc = self.get_tag_vcount (n_vc, tag) 805 n_vc = (n_vc[0], vc) 806 assert self.p.nodes[n_vc[0]].kind == 'Call' 807 808 if n_vc not in self.funcs: 809 # try to ensure n_vc has been emitted 810 cont = self.get_cont (n_vc) 811 self.get_node_pc_env (cont, tag = tag) 812 813 return self.funcs[n_vc] 814 815 def get_func_pairing_nocheck (self, n_vc, n_vc2): 816 fnames = [self.p.nodes[n_vc[0]].fname, 817 self.p.nodes[n_vc2[0]].fname] 818 pairs = [pair for pair in pairings[list (fnames)[0]] 819 if set (pair.funs.values ()) == set (fnames)] 820 if not pairs: 821 return None 822 [pair] = pairs 823 if pair.funs[pair.tags[0]] == fnames[0]: 824 return (pair, n_vc, n_vc2) 825 else: 826 return (pair, n_vc2, n_vc) 827 828 def get_func_pairing (self, n_vc, n_vc2): 829 res = self.get_func_pairing_nocheck (n_vc, n_vc2) 830 if not res: 831 return res 832 (pair, l_n_vc, r_n_vc) = res 833 (lin, _, _) = self.funcs[l_n_vc] 834 (rin, _, _) = self.funcs[r_n_vc] 835 l_mem_calls = self.scan_mem_calls (lin) 836 r_mem_calls = self.scan_mem_calls (rin) 837 tags = pair.tags 838 (c, s) = mem_calls_compatible (tags, l_mem_calls, r_mem_calls) 839 if not c: 840 trace ('skipped emitting func pairing %s -> %s' 841 % (l_n_vc, r_n_vc)) 842 trace (' ' + s) 843 return None 844 return res 845 846 def get_func_assert (self, n_vc, n_vc2): 847 (pair, l_n_vc, r_n_vc) = self.get_func_pairing (n_vc, n_vc2) 848 (ltag, rtag) = pair.tags 849 (inp_eqs, out_eqs) = pair.eqs 850 (lin, lout, lsucc) = self.funcs[l_n_vc] 851 (rin, rout, rsucc) = self.funcs[r_n_vc] 852 lpc = self.get_pc (l_n_vc) 853 rpc = self.get_pc (r_n_vc) 854 envs = {ltag + '_IN': lin, rtag + '_IN': rin, 855 ltag + '_OUT': lout, rtag + '_OUT': rout} 856 inp_eqs = inst_eqs (inp_eqs, envs, self.solv) 857 out_eqs = inst_eqs (out_eqs, envs, self.solv) 858 succ_imp = mk_implies (rsucc, lsucc) 859 860 return mk_implies (foldr1 (mk_and, inp_eqs + [rpc]), 861 foldr1 (mk_and, out_eqs + [succ_imp])) 862 863 def add_func_assert (self, n_vc, n_vc2): 864 imp = self.get_func_assert (n_vc, n_vc2) 865 imp = logic.weaken_assert (imp) 866 if self.local_defs_unsat: 867 self.solv.assert_fact (imp, {}, unsat_tag = ('FunEq', 868 ln, rn)) 869 else: 870 self.solv.assert_fact (imp, {}) 871 872 def node_count_name (self, (n, vcount)): 873 name = str (n) 874 bits = [str (n)] + ['%s=%s' % (split, count) 875 for (split, count) in vcount] 876 return '_'.join (bits) 877 878 def get_mem_calls (self, mem_sexpr): 879 mem_sexpr = solver.parse_s_expression (mem_sexpr) 880 return self.get_mem_calls_sexpr (mem_sexpr) 881 882 def get_mem_calls_sexpr (self, mem_sexpr): 883 stores = set (['store-word32', 'store-word8', 'store-word64']) 884 if mem_sexpr in self.mem_calls: 885 return self.mem_calls[mem_sexpr] 886 elif len (mem_sexpr) == 4 and mem_sexpr[0] in stores: 887 return self.get_mem_calls_sexpr (mem_sexpr[1]) 888 elif mem_sexpr[:1] == ('ite', ): 889 (_, _, x, y) = mem_sexpr 890 x_calls = self.get_mem_calls_sexpr (x) 891 y_calls = self.get_mem_calls_sexpr (y) 892 return merge_mem_calls (x_calls, y_calls) 893 elif mem_sexpr in self.solv.defs: 894 mem_sexpr = self.solv.defs[mem_sexpr] 895 return self.get_mem_calls_sexpr (mem_sexpr) 896 assert not "mem_calls fallthrough", mem_sexpr 897 898 def scan_mem_calls (self, env): 899 mem_vs = [env[(nm, typ)] 900 for (nm, typ) in env 901 if typ == syntax.builtinTs['Mem']] 902 mem_calls = [self.get_mem_calls (v) 903 for v in mem_vs if v[0] != 'SplitMem'] 904 if mem_calls: 905 return foldr1 (merge_mem_calls, mem_calls) 906 else: 907 return None 908 909 def add_mem_call (self, fname, mem_calls): 910 if mem_calls == None: 911 return None 912 mem_calls = dict (mem_calls) 913 (min_calls, max_calls) = mem_calls.get (fname, (0, 0)) 914 if max_calls == None: 915 mem_calls[fname] = (min_calls + 1, None) 916 else: 917 mem_calls[fname] = (min_calls + 1, max_calls + 1) 918 return mem_calls 919 920 def add_loop_mem_calls (self, split, mem_calls): 921 if mem_calls == None: 922 return None 923 fnames = set ([self.p.nodes[n].fname 924 for n in self.p.loop_body (split) 925 if self.p.nodes[n].kind == 'Call']) 926 if not fnames: 927 return mem_calls 928 mem_calls = dict (mem_calls) 929 for fname in fnames: 930 if fname not in mem_calls: 931 mem_calls[fname] = (0, None) 932 else: 933 (min_calls, max_calls) = mem_calls[fname] 934 mem_calls[fname] = (min_calls, None) 935 return mem_calls 936 937 # note these names are designed to be unique by suffix 938 # (so that smt names are independent of order of requests) 939 def local_name (self, s, n_vc): 940 return '%s_after_%s' % (s, self.node_count_name (n_vc)) 941 942 def local_name_before (self, s, n_vc): 943 return '%s_v_at_%s' % (s, self.node_count_name (n_vc)) 944 945 def cond_name (self, n_vc): 946 return 'cond_at_%s' % self.node_count_name (n_vc) 947 948 def path_cond_name (self, n_vc, tag): 949 return 'path_cond_to_%s_%s' % ( 950 self.node_count_name (n_vc), tag) 951 952 def success_name (self, fname, n_vc): 953 bits = fname.split ('.') 954 nms = ['_'.join (bits[i:]) for i in range (len (bits)) 955 if bits[i:][0].isalpha ()] 956 if nms: 957 nm = nms[-1] 958 else: 959 nm = 'fun' 960 return '%s_success_at_%s' % (nm, self.node_count_name (n_vc)) 961 962 def try_inline (self, n, pc, env): 963 if not self.inliner: 964 return False 965 966 inline = self.inliner ((self.p, n)) 967 if not inline: 968 return False 969 970 # make sure this node is reachable before inlining 971 if self.solv.test_hyp (mk_not (pc), env): 972 trace ('Skipped inlining at %d.' % n) 973 return False 974 975 trace ('Inlining at %d.' % n) 976 inline () 977 raise InlineEvent () 978 979 def incr (self, vcount, n, incr): 980 vcount2 = dict (vcount) 981 vcount2[n] = vcount2[n].incr (incr) 982 if vcount2[n] == None: 983 return None 984 return tuple (sorted (vcount2.items ())) 985 986 def get_cont (self, (n, vcount)): 987 [c] = self.p.nodes[n].get_conts () 988 vcount2 = dict (vcount) 989 if n in vcount2: 990 vcount = self.incr (vcount, n, 1) 991 cont = (c, vcount) 992 assert self.is_cont ((n, vcount), cont) 993 return cont 994 995 def is_cont (self, (n, vcount), (n2, vcount2)): 996 if n2 not in self.p.nodes[n].get_conts (): 997 trace ('Not a graph cont.') 998 return False 999 1000 vcount_d = dict (vcount) 1001 vcount_d2 = dict (vcount2) 1002 if n in vcount_d2: 1003 if n in vcount_d: 1004 assert vcount_d[n].kind != 'Options' 1005 vcount_d2[n] = vcount_d2[n].incr (-1) 1006 1007 if not vcount_d <= vcount_d2: 1008 trace ('Restrictions not subset.') 1009 return False 1010 1011 for (split, count) in vcount_d2.iteritems (): 1012 if split in vcount_d: 1013 continue 1014 if self.get_reachable (split, n): 1015 return False 1016 if not count.has_zero (): 1017 return False 1018 1019 return True 1020 1021 def prevs (self, (n, vcount)): 1022 prevs = [] 1023 vcount_d = dict (vcount) 1024 for p in self.p.preds[n]: 1025 if p in vcount_d: 1026 vcount2 = self.incr (vcount, p, -1) 1027 if vcount2 == None: 1028 continue 1029 prevs.append ((p, vcount2)) 1030 else: 1031 prevs.append ((p, vcount)) 1032 return prevs 1033 1034 def specialise (self, (n, vcount), split): 1035 vcount = dict (vcount) 1036 assert vcount[split].kind == 'Options' 1037 specs = [] 1038 for n in vcount[split].opts: 1039 v = dict (vcount) 1040 v[split] = n 1041 specs.append (tuple (sorted (v.items ()))) 1042 return specs 1043 1044 def get_pc (self, (n, vcount), tag = None): 1045 pc_env = self.get_node_pc_env ((n, vcount), tag = tag) 1046 if pc_env == None: 1047 trace ('Warning: unreachable n_vc, tag: %s, %s' % ((n, vcount), tag)) 1048 return false_term 1049 (pc, env) = pc_env 1050 return to_smt_expr (pc, env, self.solv) 1051 1052 def to_smt_expr (self, expr, (n, vcount), tag = None): 1053 pc_env = self.get_node_pc_env ((n, vcount), tag = tag) 1054 (pc, env) = pc_env 1055 return to_smt_expr (expr, env, self.solv) 1056 1057 def get_induct_var (self, (n1, n2)): 1058 if (n1, n2) not in self.induct_var_env: 1059 vname = self.solv.add_var ('induct_i_%d_%d' % (n1, n2), 1060 word32T) 1061 self.induct_var_env[(n1, n2)] = vname 1062 self.pc_env_requests.add (((n1, n2), 'InductVar')) 1063 else: 1064 vname = self.induct_var_env[(n1, n2)] 1065 return mk_smt_expr (vname, word32T) 1066 1067 def interpret_hyp (self, hyp): 1068 return hyp.interpret (self) 1069 1070 def interpret_hyp_imps (self, hyps, concl): 1071 hyps = map (self.interpret_hyp, hyps) 1072 return logic.strengthen_hyp (syntax.mk_n_implies (hyps, concl)) 1073 1074 def test_hyp_whyps (self, hyp, hyps, cache = None, fast = False, 1075 model = None): 1076 self.avail_hyps = set (hyps) 1077 if not self.used_hyps <= self.avail_hyps: 1078 self.rebuild () 1079 1080 last_test[0] = (hyp, hyps, list (self.pc_env_requests)) 1081 1082 expr = self.interpret_hyp_imps (hyps, hyp) 1083 1084 trace ('Testing hyp whyps', push = 1) 1085 trace ('requests = %s' % self.pc_env_requests) 1086 1087 expr_s = smt_expr (expr, {}, self.solv) 1088 if cache and expr_s in cache: 1089 trace ('Cached: %s' % cache[expr_s]) 1090 return cache[expr_s] 1091 if fast: 1092 trace ('(not in cache)') 1093 return None 1094 1095 self.solv.add_pvalid_dom_assertions () 1096 1097 (result, _, _) = self.solv.parallel_test_hyps ([(None, expr)], 1098 {}, model = model) 1099 trace ('Result: %s' % result, push = -1) 1100 if cache != None: 1101 cache[expr_s] = result 1102 if not result: 1103 last_failed_test[0] = last_test[0] 1104 return result 1105 1106 def test_hyp_imp (self, hyps, hyp, model = None): 1107 return self.test_hyp_whyps (self.interpret_hyp (hyp), hyps, 1108 model = model) 1109 1110 def test_hyp_imps (self, imps): 1111 last_hyp_imps[0] = imps 1112 if imps == []: 1113 return (True, None) 1114 interp_imps = list (enumerate ([self.interpret_hyp_imps (hyps, 1115 self.interpret_hyp (hyp)) 1116 for (hyps, hyp) in imps])) 1117 reqs = list (self.pc_env_requests) 1118 last_test[0] = (self.interpret_hyp (hyp), hyps, reqs) 1119 self.solv.add_pvalid_dom_assertions () 1120 result = self.solv.parallel_test_hyps (interp_imps, {}) 1121 assert result[0] in [True, False], result 1122 if result[0] == False: 1123 (hyps, hyp) = imps[result[1]] 1124 last_test[0] = (self.interpret_hyp (hyp), hyps, reqs) 1125 last_failed_test[0] = last_test[0] 1126 return result 1127 1128 def replay_requests (self, reqs): 1129 for ((n, vc), tag) in reqs: 1130 self.get_node_pc_env ((n, vc), tag = tag) 1131 1132last_test = [0] 1133last_failed_test = [0] 1134last_hyp_imps = [0] 1135 1136def to_smt_expr_under_op (expr, env, solv): 1137 if expr.kind == 'Op': 1138 vals = [to_smt_expr (v, env, solv) for v in expr.vals] 1139 return syntax.adjust_op_vals (expr, vals) 1140 else: 1141 return to_smt_expr (expr, env, solv) 1142 1143def inst_eq_with_envs ((x, env1), (y, env2), solv): 1144 x = to_smt_expr_under_op (x, env1, solv) 1145 y = to_smt_expr_under_op (y, env2, solv) 1146 if x.typ == syntax.builtinTs['RelWrapper']: 1147 return logic.apply_rel_wrapper (x, y) 1148 else: 1149 return mk_eq (x, y) 1150 1151def inst_eqs (eqs, envs, solv): 1152 return [inst_eq_with_envs ((x, envs[x_addr]), (y, envs[y_addr]), solv) 1153 for ((x, x_addr), (y, y_addr)) in eqs] 1154 1155def subst_induct (expr, induct_var): 1156 substs = {('%n', word32T): induct_var} 1157 return logic.var_subst (expr, substs, must_subst = False) 1158 1159printed_hyps = {} 1160def print_hyps (hyps): 1161 hyps = tuple (hyps) 1162 if hyps in printed_hyps: 1163 trace ('hyps = %s' % printed_hyps[hyps]) 1164 else: 1165 hname = 'hyp_set_%d' % (len (printed_hyps) + 1) 1166 trace ('%s = %s' % (hname, list (hyps))) 1167 printed_hyps[hname] = hyps 1168 trace ('hyps = %s' % hname) 1169 1170def merge_mem_calls (mem_calls_x, mem_calls_y): 1171 if mem_calls_x == mem_calls_y: 1172 return mem_calls_x 1173 mem_calls = {} 1174 for fname in set (mem_calls_x.keys () + mem_calls_y.keys ()): 1175 (min_x, max_x) = mem_calls_x.get (fname, (0, 0)) 1176 (min_y, max_y) = mem_calls_y.get (fname, (0, 0)) 1177 if None in [max_x, max_y]: 1178 max_v = None 1179 else: 1180 max_v = max (max_x, max_y) 1181 mem_calls[fname] = (min (min_x, min_y), max_v) 1182 return mem_calls 1183 1184def mem_calls_compatible (tags, l_mem_calls, r_mem_calls): 1185 if l_mem_calls == None or r_mem_calls == None: 1186 return (True, None) 1187 r_cast_calls = {} 1188 for (fname, calls) in l_mem_calls.iteritems (): 1189 pairs = [pair for pair in pairings[fname] 1190 if pair.tags == tags] 1191 if not pairs: 1192 return (None, 'no pairing for %s' % fname) 1193 assert len (pairs) <= 1, pairs 1194 [pair] = pairs 1195 r_fun = pair.funs[tags[1]] 1196 if not [nm for (nm, typ) in functions[r_fun].outputs 1197 if typ == syntax.builtinTs['Mem']]: 1198 continue 1199 r_cast_calls[pair.funs[tags[1]]] = calls 1200 for fname in set (r_cast_calls.keys () + r_mem_calls.keys ()): 1201 r_cast = r_cast_calls.get (fname, (0, 0)) 1202 r_actual = r_mem_calls.get (fname, (0, 0)) 1203 s = 'mismatch in calls to %s and pairs, %s / %s' % (fname, 1204 r_cast, r_actual) 1205 if r_cast[1] != None and r_cast[1] < r_actual[0]: 1206 return (None, s) 1207 if r_actual[1] != None and r_actual[1] < r_cast[0]: 1208 return (None, s) 1209 return (True, None) 1210 1211def mk_inp_env (n, args, rep): 1212 trace ('rep_graph setting up input env at %d' % n, push = 1) 1213 inp_env = {} 1214 1215 for (v_nm, typ) in args: 1216 inp_env[(v_nm, typ)] = rep.add_var (v_nm + '_init', typ, 1217 mem_name = 'Init', mem_calls = {}) 1218 for (v_nm, typ) in args: 1219 z = rep.var_rep_request ((v_nm, typ), 'Init', (n, ()), inp_env) 1220 if z: 1221 inp_env[(v_nm, typ)] = z 1222 1223 trace ('done setting up input env at %d' % n, push = -1) 1224 return inp_env 1225 1226def mk_graph_slice (p, inliner = None, fast = False, mk_solver = Solver): 1227 trace ('rep_graph setting up solver', push = 1) 1228 solv = mk_solver () 1229 trace ('rep_graph setting up solver', push = -1) 1230 return GraphSlice (p, solv, inliner, fast = fast) 1231 1232def run_requests (rep, requests): 1233 for (n_vc, tag) in requests: 1234 if tag == 'InductVar': 1235 rep.get_induct_var (n_vc) 1236 else: 1237 rep.get_pc (n_vc, tag = tag) 1238 rep.solv.add_pvalid_dom_assertions () 1239 1240import re 1241paren_w_re = re.compile (r"(\(|\)|\w+)") 1242 1243def mk_function_link_hyps (p, call_vis, tag, adjust_eq_seq = None): 1244 (entry, _, args) = p.get_entry_details (tag) 1245 ((call_site, restrs), call_tag) = call_vis 1246 assert p.nodes[call_site].kind == 'Call' 1247 entry_vis = ((entry, ()), p.node_tags[entry][0]) 1248 1249 args = [syntax.mk_var (nm, typ) for (nm, typ) in args] 1250 1251 pc = pc_true_hyp (call_vis) 1252 eq_seq = logic.azip (p.nodes[call_site].args, args) 1253 if adjust_eq_seq: 1254 eq_seq = adjust_eq_seq (eq_seq) 1255 hyps = [pc] + [eq_hyp ((x, call_vis), (y, entry_vis)) 1256 for (x, y) in eq_seq 1257 if x.typ.kind == 'Word' or x.typ == syntax.builtinTs['Mem'] 1258 or x.typ.kind == 'WordArray'] 1259 1260 return hyps 1261 1262