1# 2# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230) 3# 4# SPDX-License-Identifier: BSD-2-Clause 5# 6 7from solver import Solver, merge_envs_pcs, smt_expr, mk_smt_expr, to_smt_expr 8from syntax import (true_term, false_term, boolT, mk_and, mk_not, mk_implies, 9 builtinTs, word32T, word8T, foldr1, mk_eq, mk_plus, mk_word32, mk_var) 10import syntax 11import logic 12import solver 13from logic import azip 14 15from target_objects import functions, pairings, sections, trace, printout 16import target_objects 17import problem 18 19class VisitCount: 20 """Used to represent a target number of visits to a split point. 21 Options include a number (0, 1, 2), a symbolic offset (i + 1, i + 2), 22 or a list of options.""" 23 def __init__ (self, kind, value): 24 self.kind = kind 25 self.is_visit_count = True 26 if kind == 'Number': 27 self.n = value 28 elif kind == 'Offset': 29 self.n = value 30 elif kind == 'Options': 31 self.opts = tuple (value) 32 for opt in self.opts: 33 assert opt.kind in ['Number', 'Offset'] 34 else: 35 assert not 'VisitCount type understood' 36 37 def __hash__ (self): 38 if self.kind == 'Options': 39 return hash (self.opts) 40 else: 41 return hash (self.kind) + self.n 42 43 def __eq__ (self, other): 44 if not other: 45 return False 46 if self.kind == 'Options': 47 return (other.kind == 'Options' 48 and self.opts == other.opts) 49 else: 50 return self.kind == other.kind and self.n == other.n 51 52 def __neq__ (self, other): 53 if not other: 54 return True 55 return not (self == other) 56 57 def __str__ (self): 58 if self.kind == 'Number': 59 return str (self.n) 60 elif self.kind == 'Offset': 61 return 'i+%s' % self.n 62 elif self.kind == 'Options': 63 return '_'.join (map (str, self.opts)) 64 65 def __repr__ (self): 66 (ns, os) = self.get_opts () 67 return 'vc_options (%r, %r)' % (ns, os) 68 69 def get_opts (self): 70 if self.kind == 'Options': 71 opts = self.opts 72 else: 73 opts = [self] 74 ns = [vc.n for vc in opts if vc.kind == 'Number'] 75 os = [vc.n for vc in opts if vc.kind == 'Offset'] 76 return (ns, os) 77 78 def serialise (self, ss): 79 ss.append ('VC') 80 (ns, os) = self.get_opts () 81 ss.append ('%d' % len (ns)) 82 ss.extend (['%d' % n for n in ns]) 83 ss.append ('%d' % len (os)) 84 ss.extend (['%d' % n for n in os]) 85 86 def incr (self, incr): 87 if self.kind in ['Number', 'Offset']: 88 n = self.n + incr 89 if n < 0: 90 return None 91 return VisitCount (self.kind, n) 92 elif self.kind == 'Options': 93 opts = [vc.incr (incr) for vc in self.opts] 94 opts = [opt for opt in opts if opt] 95 if opts == []: 96 return None 97 return mk_vc_opts (opts) 98 else: 99 assert not 'VisitCount type understood' 100 101 def has_zero (self): 102 if self.kind == 'Options': 103 return bool ([vc for vc in self.opts 104 if vc.has_zero ()]) 105 else: 106 return self.kind == 'Number' and self.n == 0 107 108def mk_vc_opts (opts): 109 if len (opts) == 1: 110 return opts[0] 111 else: 112 return VisitCount ('Options', opts) 113 114def vc_options (nums, offsets): 115 return mk_vc_opts (map (vc_num, nums) + map (vc_offs, offsets)) 116 117def vc_num (n): 118 return VisitCount ('Number', n) 119 120def vc_upto (n): 121 return mk_vc_opts (map (vc_num, range (n))) 122 123def vc_offs (n): 124 return VisitCount ('Offset', n) 125 126def vc_offset_upto (n): 127 return mk_vc_opts (map (vc_offs, range (n))) 128 129def vc_double_range (n, m): 130 return mk_vc_opts (map (vc_num, range (n)) + map (vc_offs, range (m))) 131 132class InlineEvent(Exception): 133 pass 134 135class Hyp: 136 """Used to represent a proposition about path conditions or data at 137 various points in execution.""" 138 139 def __init__ (self, kind, arg1, arg2, induct = None): 140 self.kind = kind 141 if kind == 'PCImp': 142 self.pcs = [arg1, arg2] 143 elif kind == 'Eq': 144 self.vals = [arg1, arg2] 145 self.induct = induct 146 elif kind == 'EqIfAt': 147 self.vals = [arg1, arg2] 148 self.induct = induct 149 else: 150 assert not 'hyp kind understood' 151 152 def __repr__ (self): 153 if self.kind == 'PCImp': 154 vals = map (repr, self.pcs) 155 elif self.kind in ['Eq', 'EqIfAt']: 156 vals = map (repr, self.vals) 157 if self.induct: 158 vals += [repr (self.induct)] 159 else: 160 assert not 'hyp kind understood' 161 return 'Hyp (%r, %s)' % (self.kind, ', '.join (vals)) 162 163 def hyp_tuple (self): 164 if self.kind == 'PCImp': 165 return ('PCImp', self.pcs[0], self.pcs[1]) 166 elif self.kind in ['Eq', 'EqIfAt']: 167 return (self.kind, self.vals[0], 168 self.vals[1], self.induct) 169 else: 170 assert not 'hyp kind understood' 171 172 def __hash__ (self): 173 return hash (self.hyp_tuple ()) 174 175 def __ne__ (self, other): 176 return not other or not (self == other) 177 178 def __cmp__ (self, other): 179 return cmp (self.hyp_tuple (), other.hyp_tuple ()) 180 181 def visits (self): 182 if self.kind == 'PCImp': 183 return [vis for vis in self.pcs 184 if vis[0] != 'Bool'] 185 elif self.kind in ['Eq', 'EqIfAt']: 186 return [vis for (_, vis) in self.vals] 187 else: 188 assert not 'hyp kind understood' 189 190 def get_vals (self): 191 if self.kind == 'PCImp': 192 return [] 193 else: 194 return [val for (val, _) in self.vals] 195 196 def serialise_visit (self, (n, restrs), ss): 197 ss.append ('%s' % n) 198 ss.append ('%d' % len (restrs)) 199 for (n2, vc) in restrs: 200 ss.append ('%d' % n2) 201 vc.serialise (ss) 202 203 def serialise_pc (self, pc, ss): 204 if pc[0] == 'Bool' and pc[1] == true_term: 205 ss.append ('True') 206 elif pc[0] == 'Bool' and pc[1] == false_term: 207 ss.append ('False') 208 else: 209 ss.append ('PC') 210 serialise_visit (pc[0], ss) 211 ss.append (pc[1]) 212 213 def serialise_hyp (self, ss): 214 if self.kind == 'PCImp': 215 (visit1, visit2) = self.pcs 216 ss.append ('PCImp') 217 self.serialise_pc (visit1, ss) 218 self.serialise_pc (visit2, ss) 219 elif self.kind in ['Eq', 'EqIfAt']: 220 assert len (self.vals) == 2 221 ss.extend (self.kind) 222 for (exp, visit) in self.vals: 223 exp.serialise (ss) 224 self.serialise_visit (visit, ss) 225 if induct: 226 ss.append ('%d' % induct[0]) 227 ss.append ('%d' % induct[1]) 228 else: 229 ss.extend (['None', 'None']) 230 else: 231 assert not 'hyp kind understood' 232 233 def interpret (self, rep): 234 if self.kind == 'PCImp': 235 ((visit1, tag1), (visit2, tag2)) = self.pcs 236 if visit1 == 'Bool': 237 pc1 = tag1 238 else: 239 pc1 = rep.get_pc (visit1, tag = tag1) 240 if visit2 == 'Bool': 241 pc2 = tag2 242 else: 243 pc2 = rep.get_pc (visit2, tag = tag2) 244 return mk_implies (pc1, pc2) 245 elif self.kind in ['Eq', 'EqIfAt']: 246 [(x, xvis), (y, yvis)] = self.vals 247 if self.induct: 248 v = rep.get_induct_var (self.induct) 249 x = subst_induct (x, v) 250 y = subst_induct (y, v) 251 x_pc_env = rep.get_node_pc_env (xvis[0], tag = xvis[1]) 252 y_pc_env = rep.get_node_pc_env (yvis[0], tag = yvis[1]) 253 if x_pc_env == None or y_pc_env == None: 254 if self.kind == 'EqIfAt': 255 return syntax.true_term 256 else: 257 return syntax.false_term 258 ((_, xenv), (_, yenv)) = (x_pc_env, y_pc_env) 259 eq = inst_eq_with_envs ((x, xenv), (y, yenv), rep.solv) 260 if self.kind == 'EqIfAt': 261 x_pc = rep.get_pc (xvis[0], tag = xvis[1]) 262 y_pc = rep.get_pc (yvis[0], tag = yvis[1]) 263 return syntax.mk_n_implies ([x_pc, y_pc], eq) 264 else: 265 return eq 266 else: 267 assert not 'hypothesis type understood' 268 269def check_vis_is_vis (((n, vc), tag)): 270 assert vc[:0] == (), vc 271 272def eq_hyp (lhs, rhs, induct = None, use_if_at = False): 273 check_vis_is_vis (lhs[1]) 274 check_vis_is_vis (rhs[1]) 275 kind = 'Eq' 276 if use_if_at: 277 kind = 'EqIfAt' 278 return Hyp (kind, lhs, rhs, induct = induct) 279 280def true_if_at_hyp (expr, vis, induct = None): 281 check_vis_is_vis (vis) 282 return Hyp ('EqIfAt', (expr, vis), (true_term, vis), 283 induct = induct) 284 285def pc_true_hyp (vis): 286 check_vis_is_vis (vis) 287 return Hyp ('PCImp', ('Bool', true_term), vis) 288 289def pc_false_hyp (vis): 290 check_vis_is_vis (vis) 291 return Hyp ('PCImp', vis, ('Bool', false_term)) 292 293def pc_triv_hyp (vis): 294 check_vis_is_vis (vis) 295 return Hyp ('PCImp', vis, vis) 296 297class GraphSlice: 298 """Used to represent a slice of potential execution in a graph where 299 looping is limited to certain specific examples. For instance, we 300 might say that execution through node n will be represented only 301 by visits 0, 1, 2, 3, i, and i + 1 (for a symbolic value i). The 302 variable state at visits 4 and i + 2 will be calculated but no 303 further execution will be done.""" 304 305 def __init__ (self, p, solv, inliner = None, fast = False): 306 self.p = p 307 self.solv = solv 308 self.inp_envs = {} 309 self.mem_calls = {} 310 self.add_input_envs () 311 312 self.node_pc_envs = {} 313 self.node_pc_env_order = [] 314 self.arc_pc_envs = {} 315 self.inliner = inliner 316 self.funcs = {} 317 self.pc_env_requests = set () 318 self.fast = fast 319 self.induct_var_env = {} 320 self.contractions = {} 321 322 self.local_defs_unsat = False 323 self.use_known_eqs = True 324 325 self.avail_hyps = set () 326 self.used_hyps = set () 327 328 def add_input_envs (self): 329 for (entry, _, _, args) in self.p.entries: 330 self.inp_envs[entry] = mk_inp_env (entry, args, self) 331 332 def get_reachable (self, split, n): 333 return self.p.is_reachable_from (split, n) 334 335 class TooGeneral (Exception): 336 def __init__ (self, split): 337 self.split = split 338 339 def get_tag_vcount (self, (n, vcount), tag): 340 if tag == None: 341 tag = self.p.node_tags[n][0] 342 vcount_r = [(split, count, self.get_reachable (split, n)) 343 for (split, count) in vcount 344 if self.p.node_tags[split][0] == tag] 345 for (split, count, r) in vcount_r: 346 if not r and not count.has_zero (): 347 return (tag, None) 348 assert count.is_visit_count 349 vcount = [(s, c) for (s, c, r) in vcount_r if r] 350 vcount = tuple (sorted (vcount)) 351 352 loop_id = self.p.loop_id (n) 353 if loop_id != None: 354 for (split, visits) in vcount: 355 if (self.p.loop_id (split) == loop_id 356 and visits.kind == 'Options'): 357 raise self.TooGeneral (split) 358 359 return (tag, vcount) 360 361 def get_node_pc_env (self, (n, vcount), tag = None, request = True): 362 tag, vcount = self.get_tag_vcount ((n, vcount), tag) 363 if vcount == None: 364 return None 365 366 if (tag, n, vcount) in self.node_pc_envs: 367 return self.node_pc_envs[(tag, n, vcount)] 368 369 if request: 370 self.pc_env_requests.add (((n, vcount), tag)) 371 372 self.warm_pc_env_cache ((n, vcount), tag) 373 374 pc_env = self.get_node_pc_env_raw ((n, vcount), tag) 375 if pc_env: 376 pc_env = self.apply_known_eqs_pc_env ((n, vcount), 377 tag, pc_env) 378 379 assert not (tag, n, vcount) in self.node_pc_envs 380 self.node_pc_envs[(tag, n, vcount)] = pc_env 381 if pc_env: 382 self.node_pc_env_order.append ((tag, n, vcount)) 383 384 return pc_env 385 386 def warm_pc_env_cache (self, n_vc, tag): 387 'this is to avoid recursion limits and spot bugs' 388 prev_chain = [] 389 for i in range (5000): 390 prevs = self.prevs (n_vc) 391 try: 392 prevs = [p for p in prevs 393 if (tag, p[0], p[1]) 394 not in self.node_pc_envs 395 if self.get_tag_vcount (p, None) 396 == (tag, n_vc[1])] 397 except self.TooGeneral: 398 break 399 if not prevs: 400 break 401 n_vc = prevs[0] 402 prev_chain.append(n_vc) 403 if not (len (prev_chain) < 5000): 404 printout ([n for (n, vc) in prev_chain]) 405 assert len (prev_chain) < 5000, (prev_chain[:10], 406 prev_chain[-10:]) 407 408 prev_chain.reverse () 409 for n_vc in prev_chain: 410 self.get_node_pc_env (n_vc, tag, request = False) 411 412 def get_loop_pc_env (self, split, vcount): 413 vcount2 = dict (vcount) 414 vcount2[split] = vc_num (0) 415 vcount2 = tuple (sorted (vcount2.items ())) 416 prev_pc_env = self.get_node_pc_env ((split, vcount2)) 417 if prev_pc_env == None: 418 return None 419 (_, prev_env) = prev_pc_env 420 mem_calls = self.scan_mem_calls (prev_env) 421 mem_calls = self.add_loop_mem_calls (split, mem_calls) 422 def av (nm, typ, mem_name = None): 423 nm2 = '%s_loop_at_%s' % (nm, split) 424 return self.add_var (nm2, typ, 425 mem_name = mem_name, mem_calls = mem_calls) 426 env = {} 427 consts = set () 428 for (nm, typ) in prev_env: 429 check_const = self.fast or (typ in 430 [builtinTs['HTD'], builtinTs['Dom']]) 431 if check_const and self.is_synt_const (nm, typ, split): 432 env[(nm, typ)] = prev_env[(nm, typ)] 433 consts.add ((nm, typ)) 434 else: 435 env[(nm, typ)] = av (nm + '_after', typ, 436 ('Loop', prev_env[(nm, typ)])) 437 for (nm, typ) in prev_env: 438 if (nm, typ) in consts: 439 continue 440 z = self.var_rep_request ((nm, typ), 'Loop', 441 (split, vcount), env) 442 if z: 443 env[(nm, typ)] = z 444 445 pc = mk_smt_expr (av ('pc_of', boolT), boolT) 446 if self.fast: 447 imp = syntax.mk_implies (pc, prev_pc_env[0]) 448 self.solv.assert_fact (imp, prev_env, 449 unsat_tag = ('LoopPCImp', split)) 450 451 return (pc, env) 452 453 def is_synt_const (self, nm, typ, split): 454 """check if a variable at a split point is a syntactic constant 455 which is always unmodified by the loop. 456 we allow cases where a variable is renamed and renamed back 457 during the loop (this often happens because of inlining). 458 the check is done by depth-first-search backward through the 459 graph looking for a source of a variant value.""" 460 loop = self.p.loop_id (split) 461 if problem.has_inner_loop (self.p, split): 462 return False 463 loop_set = set (self.p.loop_body (split)) 464 465 orig_nm = nm 466 safe = set ([(orig_nm, split)]) 467 first_step = True 468 visit = [] 469 count = 0 470 while first_step or visit: 471 if first_step: 472 (nm, n) = (orig_nm, split) 473 first_step = False 474 else: 475 (nm, n) = visit.pop () 476 if (nm, n) in safe: 477 continue 478 elif n == split: 479 return False 480 new_nm = nm 481 node = self.p.nodes[n] 482 if node.kind == 'Call': 483 if (nm, typ) not in node.rets: 484 pass 485 elif self.fast_const_ret (n, nm, typ): 486 pass 487 else: 488 return False 489 elif node.kind == 'Basic': 490 upds = [arg for (lv, arg) in node.upds 491 if lv == (nm, typ)] 492 if [v for v in upds if v.kind != 'Var']: 493 return False 494 if upds: 495 new_nm = upds[0].name 496 preds = [(new_nm, n2) for n2 in self.p.preds[n] 497 if n2 in loop_set] 498 unknowns = [p for p in preds if p not in safe] 499 if unknowns: 500 visit.extend ([(nm, n)] + unknowns) 501 else: 502 safe.add ((nm, n)) 503 count += 1 504 if count % 100000 == 0: 505 trace ('is_synt_const: %d iterations' % count) 506 trace ('visit length %d' % len (visit)) 507 trace ('visit tail %s' % visit[-20:]) 508 return True 509 510 def fast_const_ret (self, n, nm, typ): 511 """determine if we can heuristically consider this return 512 value to be the same as an input. this is known for some 513 function returns, e.g. memory. 514 this is important for heuristic "fast" analysis.""" 515 if not self.fast: 516 return False 517 node = self.p.nodes[n] 518 assert node.kind == 'Call' 519 for hook in target_objects.hooks ('rep_unsafe_const_ret'): 520 if hook (node, nm, typ): 521 return True 522 return False 523 524 def get_node_pc_env_raw (self, (n, vcount), tag): 525 if n in self.inp_envs: 526 return (true_term, self.inp_envs[n]) 527 528 for (split, count) in vcount: 529 if split == n and count == vc_offs (0): 530 return self.get_loop_pc_env (split, vcount) 531 532 pc_envs = [pc_env for n_prev in self.p.preds[n] 533 if self.p.node_tags[n_prev][0] == tag 534 for pc_env in self.get_arc_pc_envs (n_prev, 535 (n, vcount))] 536 537 pc_envs = [pc_env for pc_env in pc_envs if pc_env] 538 if pc_envs == []: 539 return None 540 541 if n == 'Err': 542 # we'll never care about variable values here 543 # and there are sometimes a LOT of arcs to Err 544 # so we save a lot of merge effort 545 pc_envs = [(to_smt_expr (pc, env, self.solv), {}) 546 for (pc, env) in pc_envs] 547 548 (pc, env, large) = merge_envs_pcs (pc_envs, self.solv) 549 550 if pc.kind != 'SMTExpr': 551 name = self.path_cond_name ((n, vcount), tag) 552 name = self.solv.add_def (name, pc, env) 553 pc = mk_smt_expr (name, boolT) 554 555 for (nm, typ) in env: 556 if len (env[(nm, typ)]) > 80: 557 env[(nm, typ)] = self.contract (nm, (n, vcount), 558 env[(nm, typ)], typ) 559 560 return (pc, env) 561 562 def contract (self, name, n_vc, val, typ): 563 if val in self.contractions: 564 return self.contractions[val] 565 566 name = self.local_name_before (name, n_vc) 567 name = self.solv.add_def (name, mk_smt_expr (val, typ), {}) 568 569 self.contractions[val] = name 570 return name 571 572 def get_arc_pc_envs (self, n, n_vc2): 573 try: 574 prevs = [n_vc for n_vc in self.prevs (n_vc2) 575 if n_vc[0] == n] 576 assert len (prevs) <= 1 577 return [self.get_arc_pc_env (n_vc, n_vc2) 578 for n_vc in prevs] 579 except self.TooGeneral, e: 580 # consider specialisations of the target 581 specs = self.specialise (n_vc2, e.split) 582 specs = [(n_vc2[0], spec) for spec in specs] 583 return [pc_env for spec in specs 584 for pc_env in self.get_arc_pc_envs (n, spec)] 585 586 def get_arc_pc_env (self, (n, vcount), n2): 587 tag, vcount = self.get_tag_vcount ((n, vcount), None) 588 589 if vcount == None: 590 return None 591 592 assert self.is_cont ((n, vcount), n2), ((n, vcount), 593 n2, self.p.nodes[n].get_conts ()) 594 595 if (n, vcount) in self.arc_pc_envs: 596 return self.arc_pc_envs[(n, vcount)].get (n2[0]) 597 598 if self.get_node_pc_env ((n, vcount), request = False) == None: 599 return None 600 601 arcs = self.emit_node ((n, vcount)) 602 self.post_emit_node_hooks ((n, vcount)) 603 arcs = dict ([(cont, (pc, env)) for (cont, pc, env) in arcs]) 604 605 self.arc_pc_envs[(n, vcount)] = arcs 606 return arcs.get (n2[0]) 607 608 def add_local_def (self, n, vname, name, val, env): 609 if self.local_defs_unsat: 610 smt_name = self.solv.add_var (name, val.typ) 611 eq = mk_eq (mk_smt_expr (smt_name, val.typ), val) 612 self.solv.assert_fact (eq, env, unsat_tag 613 = ('Def', n, vname)) 614 else: 615 smt_name = self.solv.add_def (name, val, env) 616 return smt_name 617 618 def add_var (self, name, typ, mem_name = None, mem_calls = None): 619 r = self.solv.add_var_restr (name, typ, mem_name = mem_name) 620 if typ == syntax.builtinTs['Mem']: 621 r_x = solver.parse_s_expression (r) 622 self.mem_calls[r_x] = mem_calls 623 return r 624 625 def var_rep_request (self, (nm, typ), kind, n_vc, env): 626 assert type (n_vc[0]) != str 627 for hook in target_objects.hooks ('problem_var_rep'): 628 z = hook (self.p, (nm, typ), kind, n_vc[0]) 629 if z == None: 630 continue 631 if z[0] == 'SplitMem': 632 assert typ == builtinTs['Mem'] 633 (_, addr) = z 634 addr = smt_expr (addr, env, self.solv) 635 name = '%s_for_%s' % (nm, 636 self.node_count_name (n_vc)) 637 return self.solv.add_split_mem_var (addr, name, 638 typ, mem_name = 'SplitMemNonsense') 639 else: 640 assert z == None 641 642 def emit_node (self, n): 643 (pc, env) = self.get_node_pc_env (n, request = False) 644 tag = self.p.node_tags[n[0]][0] 645 app_eqs = self.apply_known_eqs_tm (n, tag) 646 # node = logic.simplify_node_elementary (self.p.nodes[n[0]]) 647 # whether to ignore unreachable Cond arcs seems to be a huge 648 # dilemma. if we ignore them, some reachable sites become 649 # unreachable and we can't interpret all hyps 650 # if we don't ignore them, the variable set disagrees with 651 # var_deps and so the abstracted loop pc/env may not be 652 # sufficient and we get EnvMiss again. I don't really know 653 # what to do about this corner case. 654 node = self.p.nodes[n[0]] 655 env = dict (env) 656 657 if node.kind == 'Call': 658 self.try_inline (n[0], pc, env) 659 660 if pc == false_term: 661 return [(c, false_term, {}) for c in node.get_conts()] 662 elif node.kind == 'Cond' and node.left == node.right: 663 return [(node.left, pc, env)] 664 elif node.kind == 'Cond' and node.cond == true_term: 665 return [(node.left, pc, env), 666 (node.right, false_term, env)] 667 elif node.kind == 'Basic': 668 upds = [] 669 for (lv, v) in node.upds: 670 if v.kind == 'Var': 671 upds.append ((lv, env[(v.name, v.typ)])) 672 else: 673 name = self.local_name (lv[0], n) 674 v = app_eqs (v) 675 vname = self.add_local_def (n, 676 ('Var', lv), name, v, env) 677 upds.append ((lv, vname)) 678 for (lv, v) in upds: 679 env[lv] = v 680 return [(node.cont, pc, env)] 681 elif node.kind == 'Cond': 682 name = self.cond_name (n) 683 cond = self.p.fresh_var (name, boolT) 684 env[(cond.name, boolT)] = self.add_local_def (n, 685 'Cond', name, app_eqs (node.cond), env) 686 lpc = mk_and (cond, pc) 687 rpc = mk_and (mk_not (cond), pc) 688 return [(node.left, lpc, env), (node.right, rpc, env)] 689 elif node.kind == 'Call': 690 nm = self.success_name (node.fname, n) 691 success = self.solv.add_var (nm, boolT) 692 success = mk_smt_expr (success, boolT) 693 fun = functions[node.fname] 694 ins = dict ([((x, typ), smt_expr (app_eqs (arg), env, self.solv)) 695 for ((x, typ), arg) in azip (fun.inputs, node.args)]) 696 mem_name = None 697 for (x, typ) in reversed (fun.inputs): 698 if typ == builtinTs['Mem']: 699 inp_mem = ins[(x, typ)] 700 mem_name = (node.fname, inp_mem) 701 mem_calls = self.scan_mem_calls (ins) 702 mem_calls = self.add_mem_call (node.fname, mem_calls) 703 outs = {} 704 for ((x, typ), (y, typ2)) in azip (node.rets, fun.outputs): 705 assert typ2 == typ 706 if self.fast_const_ret (n[0], x, typ): 707 outs[(y, typ2)] = env [(x, typ)] 708 continue 709 name = self.local_name (x, n) 710 env[(x, typ)] = self.add_var (name, typ, 711 mem_name = mem_name, 712 mem_calls = mem_calls) 713 outs[(y, typ2)] = env[(x, typ)] 714 for ((x, typ), (y, _)) in azip (node.rets, fun.outputs): 715 z = self.var_rep_request ((x, typ), 716 'Call', n, env) 717 if z != None: 718 env[(x, typ)] = z 719 outs[(y, typ)] = z 720 self.add_func (node.fname, ins, outs, success, n) 721 return [(node.cont, pc, env)] 722 else: 723 assert not 'node kind understood' 724 725 def post_emit_node_hooks (self, (n, vcount)): 726 for hook in target_objects.hooks ('post_emit_node'): 727 hook (self, (n, vcount)) 728 729 def fetch_known_eqs (self, n_vc, tag): 730 if not self.use_known_eqs: 731 return None 732 eqs = self.p.known_eqs.get ((n_vc, tag)) 733 if eqs == None: 734 return None 735 avail = [] 736 for (x, n_vc_y, tag_y, y, hyps) in eqs: 737 if hyps <= self.avail_hyps: 738 (_, env) = self.get_node_pc_env (n_vc_y, tag_y) 739 avail.append ((x, smt_expr (y, env, self.solv))) 740 self.used_hyps.update (hyps) 741 if avail: 742 return avail 743 return None 744 745 def apply_known_eqs_pc_env (self, n_vc, tag, (pc, env)): 746 eqs = self.fetch_known_eqs (n_vc, tag) 747 if eqs == None: 748 return (pc, env) 749 env = dict (env) 750 for (x, sx) in eqs: 751 if x.kind == 'Var': 752 cur_rhs = env[x.name] 753 for y in env: 754 if env[y] == cur_rhs: 755 trace ('substituted %s at %s.' % (y, n_vc)) 756 env[y] = sx 757 return (pc, env) 758 759 def apply_known_eqs_tm (self, n_vc, tag): 760 eqs = self.fetch_known_eqs (n_vc, tag) 761 if eqs == None: 762 return lambda x: x 763 eqs = dict ([(x, mk_smt_expr (sexpr, x.typ)) 764 for (x, sexpr) in eqs]) 765 return lambda tm: logic.recursive_term_subst (eqs, tm) 766 767 def rebuild (self, solv = None): 768 requests = self.pc_env_requests 769 770 self.node_pc_env_order = [] 771 self.node_pc_envs = {} 772 self.arc_pc_envs = {} 773 self.funcs = {} 774 self.pc_env_requests = set () 775 self.induct_var_env = {} 776 self.contractions = {} 777 778 if not solv: 779 solv = Solver (produce_unsat_cores 780 = self.local_defs_unsat) 781 self.solv = solv 782 783 self.add_input_envs () 784 785 self.used_hyps = set () 786 run_requests (self, requests) 787 788 def add_func (self, name, inputs, outputs, success, n_vc): 789 assert n_vc not in self.funcs 790 self.funcs[n_vc] = (inputs, outputs, success) 791 for pair in pairings.get (name, []): 792 self.funcs.setdefault (pair.name, []) 793 group = self.funcs[pair.name] 794 for n_vc2 in group: 795 if self.get_func_pairing (n_vc, n_vc2): 796 self.add_func_assert (n_vc, n_vc2) 797 group.append (n_vc) 798 799 def get_func (self, n_vc, tag = None): 800 """returns (input_env, output_env, success_var) for 801 function call at given n_vc.""" 802 tag, vc = self.get_tag_vcount (n_vc, tag) 803 n_vc = (n_vc[0], vc) 804 assert self.p.nodes[n_vc[0]].kind == 'Call' 805 806 if n_vc not in self.funcs: 807 # try to ensure n_vc has been emitted 808 cont = self.get_cont (n_vc) 809 self.get_node_pc_env (cont, tag = tag) 810 811 return self.funcs[n_vc] 812 813 def get_func_pairing_nocheck (self, n_vc, n_vc2): 814 fnames = [self.p.nodes[n_vc[0]].fname, 815 self.p.nodes[n_vc2[0]].fname] 816 pairs = [pair for pair in pairings[list (fnames)[0]] 817 if set (pair.funs.values ()) == set (fnames)] 818 if not pairs: 819 return None 820 [pair] = pairs 821 if pair.funs[pair.tags[0]] == fnames[0]: 822 return (pair, n_vc, n_vc2) 823 else: 824 return (pair, n_vc2, n_vc) 825 826 def get_func_pairing (self, n_vc, n_vc2): 827 res = self.get_func_pairing_nocheck (n_vc, n_vc2) 828 if not res: 829 return res 830 (pair, l_n_vc, r_n_vc) = res 831 (lin, _, _) = self.funcs[l_n_vc] 832 (rin, _, _) = self.funcs[r_n_vc] 833 l_mem_calls = self.scan_mem_calls (lin) 834 r_mem_calls = self.scan_mem_calls (rin) 835 tags = pair.tags 836 (c, s) = mem_calls_compatible (tags, l_mem_calls, r_mem_calls) 837 if not c: 838 trace ('skipped emitting func pairing %s -> %s' 839 % (l_n_vc, r_n_vc)) 840 trace (' ' + s) 841 return None 842 return res 843 844 def get_func_assert (self, n_vc, n_vc2): 845 (pair, l_n_vc, r_n_vc) = self.get_func_pairing (n_vc, n_vc2) 846 (ltag, rtag) = pair.tags 847 (inp_eqs, out_eqs) = pair.eqs 848 (lin, lout, lsucc) = self.funcs[l_n_vc] 849 (rin, rout, rsucc) = self.funcs[r_n_vc] 850 lpc = self.get_pc (l_n_vc) 851 rpc = self.get_pc (r_n_vc) 852 envs = {ltag + '_IN': lin, rtag + '_IN': rin, 853 ltag + '_OUT': lout, rtag + '_OUT': rout} 854 inp_eqs = inst_eqs (inp_eqs, envs, self.solv) 855 out_eqs = inst_eqs (out_eqs, envs, self.solv) 856 succ_imp = mk_implies (rsucc, lsucc) 857 858 return mk_implies (foldr1 (mk_and, inp_eqs + [rpc]), 859 foldr1 (mk_and, out_eqs + [succ_imp])) 860 861 def add_func_assert (self, n_vc, n_vc2): 862 imp = self.get_func_assert (n_vc, n_vc2) 863 imp = logic.weaken_assert (imp) 864 if self.local_defs_unsat: 865 self.solv.assert_fact (imp, {}, unsat_tag = ('FunEq', 866 ln, rn)) 867 else: 868 self.solv.assert_fact (imp, {}) 869 870 def node_count_name (self, (n, vcount)): 871 name = str (n) 872 bits = [str (n)] + ['%s=%s' % (split, count) 873 for (split, count) in vcount] 874 return '_'.join (bits) 875 876 def get_mem_calls (self, mem_sexpr): 877 mem_sexpr = solver.parse_s_expression (mem_sexpr) 878 return self.get_mem_calls_sexpr (mem_sexpr) 879 880 def get_mem_calls_sexpr (self, mem_sexpr): 881 stores = set (['store-word32', 'store-word8', 'store-word64']) 882 if mem_sexpr in self.mem_calls: 883 return self.mem_calls[mem_sexpr] 884 elif len (mem_sexpr) == 4 and mem_sexpr[0] in stores: 885 return self.get_mem_calls_sexpr (mem_sexpr[1]) 886 elif mem_sexpr[:1] == ('ite', ): 887 (_, _, x, y) = mem_sexpr 888 x_calls = self.get_mem_calls_sexpr (x) 889 y_calls = self.get_mem_calls_sexpr (y) 890 return merge_mem_calls (x_calls, y_calls) 891 elif mem_sexpr in self.solv.defs: 892 mem_sexpr = self.solv.defs[mem_sexpr] 893 return self.get_mem_calls_sexpr (mem_sexpr) 894 assert not "mem_calls fallthrough", mem_sexpr 895 896 def scan_mem_calls (self, env): 897 mem_vs = [env[(nm, typ)] 898 for (nm, typ) in env 899 if typ == syntax.builtinTs['Mem']] 900 mem_calls = [self.get_mem_calls (v) 901 for v in mem_vs if v[0] != 'SplitMem'] 902 if mem_calls: 903 return foldr1 (merge_mem_calls, mem_calls) 904 else: 905 return None 906 907 def add_mem_call (self, fname, mem_calls): 908 if mem_calls == None: 909 return None 910 mem_calls = dict (mem_calls) 911 (min_calls, max_calls) = mem_calls.get (fname, (0, 0)) 912 if max_calls == None: 913 mem_calls[fname] = (min_calls + 1, None) 914 else: 915 mem_calls[fname] = (min_calls + 1, max_calls + 1) 916 return mem_calls 917 918 def add_loop_mem_calls (self, split, mem_calls): 919 if mem_calls == None: 920 return None 921 fnames = set ([self.p.nodes[n].fname 922 for n in self.p.loop_body (split) 923 if self.p.nodes[n].kind == 'Call']) 924 if not fnames: 925 return mem_calls 926 mem_calls = dict (mem_calls) 927 for fname in fnames: 928 if fname not in mem_calls: 929 mem_calls[fname] = (0, None) 930 else: 931 (min_calls, max_calls) = mem_calls[fname] 932 mem_calls[fname] = (min_calls, None) 933 return mem_calls 934 935 # note these names are designed to be unique by suffix 936 # (so that smt names are independent of order of requests) 937 def local_name (self, s, n_vc): 938 return '%s_after_%s' % (s, self.node_count_name (n_vc)) 939 940 def local_name_before (self, s, n_vc): 941 return '%s_v_at_%s' % (s, self.node_count_name (n_vc)) 942 943 def cond_name (self, n_vc): 944 return 'cond_at_%s' % self.node_count_name (n_vc) 945 946 def path_cond_name (self, n_vc, tag): 947 return 'path_cond_to_%s_%s' % ( 948 self.node_count_name (n_vc), tag) 949 950 def success_name (self, fname, n_vc): 951 bits = fname.split ('.') 952 nms = ['_'.join (bits[i:]) for i in range (len (bits)) 953 if bits[i:][0].isalpha ()] 954 if nms: 955 nm = nms[-1] 956 else: 957 nm = 'fun' 958 return '%s_success_at_%s' % (nm, self.node_count_name (n_vc)) 959 960 def try_inline (self, n, pc, env): 961 if not self.inliner: 962 return False 963 964 inline = self.inliner ((self.p, n)) 965 if not inline: 966 return False 967 968 # make sure this node is reachable before inlining 969 if self.solv.test_hyp (mk_not (pc), env): 970 trace ('Skipped inlining at %d.' % n) 971 return False 972 973 trace ('Inlining at %d.' % n) 974 inline () 975 raise InlineEvent () 976 977 def incr (self, vcount, n, incr): 978 vcount2 = dict (vcount) 979 vcount2[n] = vcount2[n].incr (incr) 980 if vcount2[n] == None: 981 return None 982 return tuple (sorted (vcount2.items ())) 983 984 def get_cont (self, (n, vcount)): 985 [c] = self.p.nodes[n].get_conts () 986 vcount2 = dict (vcount) 987 if n in vcount2: 988 vcount = self.incr (vcount, n, 1) 989 cont = (c, vcount) 990 assert self.is_cont ((n, vcount), cont) 991 return cont 992 993 def is_cont (self, (n, vcount), (n2, vcount2)): 994 if n2 not in self.p.nodes[n].get_conts (): 995 trace ('Not a graph cont.') 996 return False 997 998 vcount_d = dict (vcount) 999 vcount_d2 = dict (vcount2) 1000 if n in vcount_d2: 1001 if n in vcount_d: 1002 assert vcount_d[n].kind != 'Options' 1003 vcount_d2[n] = vcount_d2[n].incr (-1) 1004 1005 if not vcount_d <= vcount_d2: 1006 trace ('Restrictions not subset.') 1007 return False 1008 1009 for (split, count) in vcount_d2.iteritems (): 1010 if split in vcount_d: 1011 continue 1012 if self.get_reachable (split, n): 1013 return False 1014 if not count.has_zero (): 1015 return False 1016 1017 return True 1018 1019 def prevs (self, (n, vcount)): 1020 prevs = [] 1021 vcount_d = dict (vcount) 1022 for p in self.p.preds[n]: 1023 if p in vcount_d: 1024 vcount2 = self.incr (vcount, p, -1) 1025 if vcount2 == None: 1026 continue 1027 prevs.append ((p, vcount2)) 1028 else: 1029 prevs.append ((p, vcount)) 1030 return prevs 1031 1032 def specialise (self, (n, vcount), split): 1033 vcount = dict (vcount) 1034 assert vcount[split].kind == 'Options' 1035 specs = [] 1036 for n in vcount[split].opts: 1037 v = dict (vcount) 1038 v[split] = n 1039 specs.append (tuple (sorted (v.items ()))) 1040 return specs 1041 1042 def get_pc (self, (n, vcount), tag = None): 1043 pc_env = self.get_node_pc_env ((n, vcount), tag = tag) 1044 if pc_env == None: 1045 trace ('Warning: unreachable n_vc, tag: %s, %s' % ((n, vcount), tag)) 1046 return false_term 1047 (pc, env) = pc_env 1048 return to_smt_expr (pc, env, self.solv) 1049 1050 def to_smt_expr (self, expr, (n, vcount), tag = None): 1051 pc_env = self.get_node_pc_env ((n, vcount), tag = tag) 1052 (pc, env) = pc_env 1053 return to_smt_expr (expr, env, self.solv) 1054 1055 def get_induct_var (self, (n1, n2)): 1056 if (n1, n2) not in self.induct_var_env: 1057 vname = self.solv.add_var ('induct_i_%d_%d' % (n1, n2), 1058 word32T) 1059 self.induct_var_env[(n1, n2)] = vname 1060 self.pc_env_requests.add (((n1, n2), 'InductVar')) 1061 else: 1062 vname = self.induct_var_env[(n1, n2)] 1063 return mk_smt_expr (vname, word32T) 1064 1065 def interpret_hyp (self, hyp): 1066 return hyp.interpret (self) 1067 1068 def interpret_hyp_imps (self, hyps, concl): 1069 hyps = map (self.interpret_hyp, hyps) 1070 return logic.strengthen_hyp (syntax.mk_n_implies (hyps, concl)) 1071 1072 def test_hyp_whyps (self, hyp, hyps, cache = None, fast = False, 1073 model = None): 1074 self.avail_hyps = set (hyps) 1075 if not self.used_hyps <= self.avail_hyps: 1076 self.rebuild () 1077 1078 last_test[0] = (hyp, hyps, list (self.pc_env_requests)) 1079 1080 expr = self.interpret_hyp_imps (hyps, hyp) 1081 1082 trace ('Testing hyp whyps', push = 1) 1083 trace ('requests = %s' % self.pc_env_requests) 1084 1085 expr_s = smt_expr (expr, {}, self.solv) 1086 if cache and expr_s in cache: 1087 trace ('Cached: %s' % cache[expr_s]) 1088 return cache[expr_s] 1089 if fast: 1090 trace ('(not in cache)') 1091 return None 1092 1093 self.solv.add_pvalid_dom_assertions () 1094 1095 (result, _, _) = self.solv.parallel_test_hyps ([(None, expr)], 1096 {}, model = model) 1097 trace ('Result: %s' % result, push = -1) 1098 if cache != None: 1099 cache[expr_s] = result 1100 if not result: 1101 last_failed_test[0] = last_test[0] 1102 return result 1103 1104 def test_hyp_imp (self, hyps, hyp, model = None): 1105 return self.test_hyp_whyps (self.interpret_hyp (hyp), hyps, 1106 model = model) 1107 1108 def test_hyp_imps (self, imps): 1109 last_hyp_imps[0] = imps 1110 if imps == []: 1111 return (True, None) 1112 interp_imps = list (enumerate ([self.interpret_hyp_imps (hyps, 1113 self.interpret_hyp (hyp)) 1114 for (hyps, hyp) in imps])) 1115 reqs = list (self.pc_env_requests) 1116 last_test[0] = (self.interpret_hyp (hyp), hyps, reqs) 1117 self.solv.add_pvalid_dom_assertions () 1118 result = self.solv.parallel_test_hyps (interp_imps, {}) 1119 assert result[0] in [True, False], result 1120 if result[0] == False: 1121 (hyps, hyp) = imps[result[1]] 1122 last_test[0] = (self.interpret_hyp (hyp), hyps, reqs) 1123 last_failed_test[0] = last_test[0] 1124 return result 1125 1126 def replay_requests (self, reqs): 1127 for ((n, vc), tag) in reqs: 1128 self.get_node_pc_env ((n, vc), tag = tag) 1129 1130last_test = [0] 1131last_failed_test = [0] 1132last_hyp_imps = [0] 1133 1134def to_smt_expr_under_op (expr, env, solv): 1135 if expr.kind == 'Op': 1136 vals = [to_smt_expr (v, env, solv) for v in expr.vals] 1137 return syntax.adjust_op_vals (expr, vals) 1138 else: 1139 return to_smt_expr (expr, env, solv) 1140 1141def inst_eq_with_envs ((x, env1), (y, env2), solv): 1142 x = to_smt_expr_under_op (x, env1, solv) 1143 y = to_smt_expr_under_op (y, env2, solv) 1144 if x.typ == syntax.builtinTs['RelWrapper']: 1145 return logic.apply_rel_wrapper (x, y) 1146 else: 1147 return mk_eq (x, y) 1148 1149def inst_eqs (eqs, envs, solv): 1150 return [inst_eq_with_envs ((x, envs[x_addr]), (y, envs[y_addr]), solv) 1151 for ((x, x_addr), (y, y_addr)) in eqs] 1152 1153def subst_induct (expr, induct_var): 1154 substs = {('%n', word32T): induct_var} 1155 return logic.var_subst (expr, substs, must_subst = False) 1156 1157printed_hyps = {} 1158def print_hyps (hyps): 1159 hyps = tuple (hyps) 1160 if hyps in printed_hyps: 1161 trace ('hyps = %s' % printed_hyps[hyps]) 1162 else: 1163 hname = 'hyp_set_%d' % (len (printed_hyps) + 1) 1164 trace ('%s = %s' % (hname, list (hyps))) 1165 printed_hyps[hname] = hyps 1166 trace ('hyps = %s' % hname) 1167 1168def merge_mem_calls (mem_calls_x, mem_calls_y): 1169 if mem_calls_x == mem_calls_y: 1170 return mem_calls_x 1171 mem_calls = {} 1172 for fname in set (mem_calls_x.keys () + mem_calls_y.keys ()): 1173 (min_x, max_x) = mem_calls_x.get (fname, (0, 0)) 1174 (min_y, max_y) = mem_calls_y.get (fname, (0, 0)) 1175 if None in [max_x, max_y]: 1176 max_v = None 1177 else: 1178 max_v = max (max_x, max_y) 1179 mem_calls[fname] = (min (min_x, min_y), max_v) 1180 return mem_calls 1181 1182def mem_calls_compatible (tags, l_mem_calls, r_mem_calls): 1183 if l_mem_calls == None or r_mem_calls == None: 1184 return (True, None) 1185 r_cast_calls = {} 1186 for (fname, calls) in l_mem_calls.iteritems (): 1187 pairs = [pair for pair in pairings[fname] 1188 if pair.tags == tags] 1189 if not pairs: 1190 return (None, 'no pairing for %s' % fname) 1191 assert len (pairs) <= 1, pairs 1192 [pair] = pairs 1193 r_fun = pair.funs[tags[1]] 1194 if not [nm for (nm, typ) in functions[r_fun].outputs 1195 if typ == syntax.builtinTs['Mem']]: 1196 continue 1197 r_cast_calls[pair.funs[tags[1]]] = calls 1198 for fname in set (r_cast_calls.keys () + r_mem_calls.keys ()): 1199 r_cast = r_cast_calls.get (fname, (0, 0)) 1200 r_actual = r_mem_calls.get (fname, (0, 0)) 1201 s = 'mismatch in calls to %s and pairs, %s / %s' % (fname, 1202 r_cast, r_actual) 1203 if r_cast[1] != None and r_cast[1] < r_actual[0]: 1204 return (None, s) 1205 if r_actual[1] != None and r_actual[1] < r_cast[0]: 1206 return (None, s) 1207 return (True, None) 1208 1209def mk_inp_env (n, args, rep): 1210 trace ('rep_graph setting up input env at %d' % n, push = 1) 1211 inp_env = {} 1212 1213 for (v_nm, typ) in args: 1214 inp_env[(v_nm, typ)] = rep.add_var (v_nm + '_init', typ, 1215 mem_name = 'Init', mem_calls = {}) 1216 for (v_nm, typ) in args: 1217 z = rep.var_rep_request ((v_nm, typ), 'Init', (n, ()), inp_env) 1218 if z: 1219 inp_env[(v_nm, typ)] = z 1220 1221 trace ('done setting up input env at %d' % n, push = -1) 1222 return inp_env 1223 1224def mk_graph_slice (p, inliner = None, fast = False, mk_solver = Solver): 1225 trace ('rep_graph setting up solver', push = 1) 1226 solv = mk_solver () 1227 trace ('rep_graph setting up solver', push = -1) 1228 return GraphSlice (p, solv, inliner, fast = fast) 1229 1230def run_requests (rep, requests): 1231 for (n_vc, tag) in requests: 1232 if tag == 'InductVar': 1233 rep.get_induct_var (n_vc) 1234 else: 1235 rep.get_pc (n_vc, tag = tag) 1236 rep.solv.add_pvalid_dom_assertions () 1237 1238import re 1239paren_w_re = re.compile (r"(\(|\)|\w+)") 1240 1241def mk_function_link_hyps (p, call_vis, tag, adjust_eq_seq = None): 1242 (entry, _, args) = p.get_entry_details (tag) 1243 ((call_site, restrs), call_tag) = call_vis 1244 assert p.nodes[call_site].kind == 'Call' 1245 entry_vis = ((entry, ()), p.node_tags[entry][0]) 1246 1247 args = [syntax.mk_var (nm, typ) for (nm, typ) in args] 1248 1249 pc = pc_true_hyp (call_vis) 1250 eq_seq = logic.azip (p.nodes[call_site].args, args) 1251 if adjust_eq_seq: 1252 eq_seq = adjust_eq_seq (eq_seq) 1253 hyps = [pc] + [eq_hyp ((x, call_vis), (y, entry_vis)) 1254 for (x, y) in eq_seq 1255 if x.typ.kind == 'Word' or x.typ == syntax.builtinTs['Mem'] 1256 or x.typ.kind == 'WordArray'] 1257 1258 return hyps 1259 1260