1# * Copyright 2015, NICTA 2# * 3# * This software may be distributed and modified according to the terms of 4# * the BSD 2-Clause license. Note that NO WARRANTY is provided. 5# * See "LICENSE_BSD2.txt" for details. 6# * 7# * @TAG(NICTA_BSD) 8 9import solver 10from solver import mk_smt_expr, to_smt_expr, smt_expr 11import check 12from check import restr_others, loops_to_split, ProofNode 13from rep_graph import (mk_graph_slice, vc_num, vc_offs, vc_upto, 14 vc_double_range, VisitCount, vc_offset_upto) 15import rep_graph 16from syntax import (mk_and, mk_cast, mk_implies, mk_not, mk_uminus, mk_var, 17 foldr1, boolT, word32T, word8T, builtinTs, true_term, false_term, 18 mk_word32, mk_word8, mk_times, Expr, Type, mk_or, mk_eq, mk_memacc, 19 mk_num, mk_minus, mk_plus, mk_less) 20import syntax 21import logic 22 23from target_objects import trace, printout 24import target_objects 25import itertools 26 27last_knowledge = [1] 28 29class NoSplit(Exception): 30 pass 31 32def get_loop_var_analysis_at (p, n): 33 k = ('search_loop_var_analysis', n) 34 if k in p.cached_analysis: 35 return p.cached_analysis[k] 36 for hook in target_objects.hooks ('loop_var_analysis'): 37 res = hook (p, n) 38 if res != None: 39 p.cached_analysis[k] = res 40 return res 41 var_deps = p.compute_var_dependencies () 42 res = p.get_loop_var_analysis (var_deps, n) 43 p.cached_analysis[k] = res 44 return res 45 46def get_loop_vars_at (p, n): 47 vs = [var for (var, data) in get_loop_var_analysis_at (p, n) 48 if data == 'LoopVariable'] + [mk_word32 (0)] 49 vs.sort () 50 return vs 51 52default_loop_N = 3 53 54last_proof = [None] 55 56def build_proof (p): 57 init_hyps = check.init_point_hyps (p) 58 proof = build_proof_rec (default_searcher, p, (), list (init_hyps)) 59 60 trace ('Built proof for %s' % p.name) 61 printout (repr (proof)) 62 last_proof[0] = proof 63 64 return proof 65 66def split_sample_set (bound): 67 ns = (range (10) + range (10, 20, 2) 68 + range (20, 40, 5) + range (40, 100, 10) 69 + range (100, 1000, 50)) 70 return [n for n in ns if n < bound] 71 72last_find_split_limit = [0] 73 74def find_split_limit (p, n, restrs, hyps, kind, bound = 51, must_find = True, 75 hints = [], use_rep = None): 76 tag = p.node_tags[n][0] 77 trace ('Finding split limit: %d (%s)' % (n, tag)) 78 last_find_split_limit[0] = (p, n, restrs, hyps, kind) 79 if use_rep == None: 80 rep = mk_graph_slice (p, fast = True) 81 else: 82 rep = use_rep 83 check_order = hints + split_sample_set (bound) + [bound] 84 # bounds strictly outside this range won't be considered 85 bound_range = [0, bound] 86 best_bound_found = [None] 87 def check (i): 88 if i < bound_range[0]: 89 return True 90 if i > bound_range[1]: 91 return False 92 restrs2 = restrs + ((n, VisitCount (kind, i)), ) 93 pc = rep.get_pc ((n, restrs2)) 94 restrs3 = restr_others (p, restrs2, 2) 95 epc = rep.get_pc (('Err', restrs3), tag = tag) 96 hyp = mk_implies (mk_not (epc), mk_not (pc)) 97 res = rep.test_hyp_whyps (hyp, hyps) 98 if res: 99 trace ('split limit found: %d' % i) 100 bound_range[1] = i - 1 101 best_bound_found[0] = i 102 else: 103 bound_range[0] = i + 1 104 return res 105 106 map (check, check_order) 107 while bound_range[0] <= bound_range[1]: 108 split = (bound_range[0] + bound_range[1]) / 2 109 check (split) 110 111 bound = best_bound_found[0] 112 if bound == None: 113 trace ('No split limit found for %d (%s).' % (n, tag)) 114 if must_find: 115 assert not 'split limit found' 116 return bound 117 118def get_split_limit (p, n, restrs, hyps, kind, bound = 51, 119 must_find = True, est_bound = 1, hints = None): 120 k = ('SplitLimit', n, restrs, tuple (hyps), kind) 121 if k in p.cached_analysis: 122 (lim, prev_bound) = p.cached_analysis[k] 123 if lim != None or bound <= prev_bound: 124 return lim 125 if hints == None: 126 hints = [est_bound, est_bound + 1, est_bound + 2] 127 res = find_split_limit (p, n, restrs, hyps, kind, 128 hints = hints, must_find = must_find, bound = bound) 129 p.cached_analysis[k] = (res, bound) 130 return res 131 132def init_case_splits (p, hyps, tags = None): 133 if 'init_case_splits' in p.cached_analysis: 134 return p.cached_analysis['init_case_splits'] 135 if tags == None: 136 tags = p.pairing.tags 137 poss = logic.possible_graph_divs (p) 138 if len (set ([p.node_tags[n][0] for n in poss])) < 2: 139 return None 140 rep = rep_graph.mk_graph_slice (p) 141 assert all ([p.nodes[n].kind == 'Cond' for n in poss]) 142 pc_map = logic.dict_list ([(rep.get_pc ((c, ())), c) 143 for n in poss for c in p.nodes[n].get_conts () 144 if c not in p.loop_data]) 145 no_loop_restrs = tuple ([(n, vc_num (0)) for n in p.loop_heads ()]) 146 err_pc_hyps = [rep_graph.pc_false_hyp ((('Err', no_loop_restrs), tag)) 147 for tag in p.pairing.tags] 148 knowledge = EqSearchKnowledge (rep, hyps + err_pc_hyps, list (pc_map)) 149 last_knowledge[0] = knowledge 150 pc_ids = knowledge.classify_vs () 151 id_n_map = logic.dict_list ([(i, n) for (pc, i) in pc_ids.iteritems () 152 for n in pc_map[pc]]) 153 tag_div_ns = [[[n for n in ns if p.node_tags[n][0] == t] for t in tags] 154 for (i, ns) in id_n_map.iteritems ()] 155 split_pairs = [(l_ns[0], r_ns[0]) for (l_ns, r_ns) in tag_div_ns 156 if l_ns and r_ns] 157 p.cached_analysis['init_case_splits'] = split_pairs 158 return split_pairs 159 160case_split_tr = [] 161 162def init_proof_case_split (p, restrs, hyps): 163 ps = init_case_splits (p, hyps) 164 if ps == None: 165 return None 166 p.cached_analysis.setdefault ('finished_init_case_splits', []) 167 fin = p.cached_analysis['finished_init_case_splits'] 168 known_s = set.union (set (restrs), set (hyps)) 169 for rs in fin: 170 if rs <= known_s: 171 return None 172 rep = rep_graph.mk_graph_slice (p) 173 no_loop_restrs = tuple ([(n, vc_num (0)) for n in p.loop_heads ()]) 174 err_pc_hyps = [rep_graph.pc_false_hyp ((('Err', no_loop_restrs), tag)) 175 for tag in p.pairing.tags] 176 for (n1, n2) in ps: 177 pc = rep.get_pc ((n1, ())) 178 if rep.test_hyp_whyps (pc, hyps + err_pc_hyps): 179 continue 180 if rep.test_hyp_whyps (mk_not (pc), hyps + err_pc_hyps): 181 continue 182 case_split_tr.append ((n1, restrs, hyps)) 183 return ('CaseSplit', ((n1, p.node_tags[n1][0]), [n1, n2])) 184 fin.append (known_s) 185 return None 186 187# TODO: deal with all the code duplication between these two searches 188class EqSearchKnowledge: 189 def __init__ (self, rep, hyps, vs): 190 self.rep = rep 191 self.hyps = hyps 192 self.v_ids = dict ([(v, 1) for v in vs]) 193 self.model_trace = [] 194 self.facts = set () 195 self.premise = foldr1 (mk_and, map (rep.interpret_hyp, hyps)) 196 197 def add_model (self, m): 198 self.model_trace.append (m) 199 update_v_ids_for_model2 (self, self.v_ids, m) 200 201 def hyps_add_model (self, hyps): 202 if hyps: 203 test_expr = foldr1 (mk_and, hyps) 204 else: 205 # we want to learn something, either a new model, or 206 # that all hyps are true. if there are no hyps, 207 # learning they're all true is learning nothing. 208 # instead force a model 209 test_expr = false_term 210 test_expr = mk_implies (self.premise, test_expr) 211 m = {} 212 (r, _) = self.rep.solv.parallel_check_hyps ([(1, test_expr)], 213 {}, model = m) 214 if r == 'unsat': 215 if not hyps: 216 trace ('WARNING: EqSearchKnowledge: premise unsat.') 217 trace (" ... learning procedure isn't going to work.") 218 for hyp in hyps: 219 self.facts.add (hyp) 220 else: 221 assert r == 'sat', r 222 self.add_model (m) 223 224 def classify_vs (self): 225 while not self.facts: 226 hyps = v_id_eq_hyps (self.v_ids) 227 if not hyps: 228 break 229 self.hyps_add_model (hyps) 230 return self.v_ids 231 232def update_v_ids_for_model2 (knowledge, v_ids, m): 233 # first update the live variables 234 ev = lambda v: eval_model_expr (m, knowledge.rep.solv, v) 235 groups = logic.dict_list ([((k, ev (v)), v) 236 for (v, k) in v_ids.iteritems ()]) 237 v_ids.clear () 238 for (i, kt) in enumerate (sorted (groups)): 239 for v in groups[kt]: 240 v_ids[v] = i 241 242def v_id_eq_hyps (v_ids): 243 groups = logic.dict_list ([(k, v) for (v, k) in v_ids.iteritems ()]) 244 hyps = [] 245 for vs in groups.itervalues (): 246 for v in vs[1:]: 247 hyps.append (mk_eq (v, vs[0])) 248 return hyps 249 250class SearchKnowledge: 251 def __init__ (self, rep, name, restrs, hyps, tags, cand_elts = None): 252 self.rep = rep 253 self.name = name 254 self.restrs = restrs 255 self.hyps = hyps 256 self.tags = tags 257 if cand_elts != None: 258 (loop_elts, r_elts) = cand_elts 259 else: 260 (loop_elts, r_elts) = ([], []) 261 (pairs, vs) = init_knowledge_pairs (rep, loop_elts, r_elts) 262 self.pairs = pairs 263 self.v_ids = vs 264 self.model_trace = [] 265 self.facts = set () 266 self.weak_splits = set () 267 self.premise = syntax.true_term 268 self.live_pairs_trace = [] 269 270 def add_model (self, m): 271 self.model_trace.append (m) 272 update_v_ids_for_model (self, self.pairs, self.v_ids, m) 273 274 def hyps_add_model (self, hyps, assert_progress = True): 275 if hyps: 276 test_expr = foldr1 (mk_and, hyps) 277 else: 278 # we want to learn something, either a new model, or 279 # that all hyps are true. if there are no hyps, 280 # learning they're all true is learning nothing. 281 # instead force a model 282 test_expr = false_term 283 test_expr = mk_implies (self.premise, test_expr) 284 m = {} 285 (r, _) = self.rep.solv.parallel_check_hyps ([(1, test_expr)], 286 {}, model = m) 287 if r == 'unsat': 288 if not hyps: 289 trace ('WARNING: SearchKnowledge: premise unsat.') 290 trace (" ... learning procedure isn't going to work.") 291 return 292 if assert_progress: 293 assert not (set (hyps) <= self.facts), hyps 294 for hyp in hyps: 295 self.facts.add (hyp) 296 else: 297 assert r == 'sat', r 298 self.add_model (m) 299 if assert_progress: 300 assert self.model_trace[-2:-1] != [m] 301 302 def eqs_add_model (self, eqs, assert_progress = True): 303 preds = [pred for vpair in eqs 304 for pred in expand_var_eqs (self, vpair) 305 if pred not in self.facts] 306 307 self.hyps_add_model (preds, 308 assert_progress = assert_progress) 309 310 def add_weak_split (self, eqs): 311 preds = [pred for vpair in eqs 312 for pred in expand_var_eqs (self, vpair)] 313 self.weak_splits.add (tuple (sorted (preds))) 314 315 def is_weak_split (self, eqs): 316 preds = [pred for vpair in eqs 317 for pred in expand_var_eqs (self, vpair)] 318 return tuple (sorted (preds)) in self.weak_splits 319 320def init_knowledge_pairs (rep, loop_elts, cand_r_loop_elts): 321 trace ('Doing search knowledge setup now.') 322 v_is = [(i, i_offs, i_step, 323 [(v, i, i_offs, i_step) for v in get_loop_vars_at (rep.p, i)]) 324 for (i, i_offs, i_step) in sorted (loop_elts)] 325 l_vtyps = set ([v[0].typ for (_, _, _, vs) in v_is for v in vs]) 326 v_js = [(j, j_offs, j_step, 327 [(v, j, j_offs, j_step) for v in get_loop_vars_at (rep.p, j) 328 if v.typ in l_vtyps]) 329 for (j, j_offs, j_step) in sorted (cand_r_loop_elts)] 330 vs = {} 331 for (_, _, _, var_vs) in v_is + v_js: 332 for v in var_vs: 333 vs[v] = (v[0].typ, True) 334 pairs = {} 335 for (i, i_offs, i_step, i_vs) in v_is: 336 for (j, j_offs, j_step, j_vs) in v_js: 337 pair = ((i, i_offs, i_step), (j, j_offs, j_step)) 338 pairs[pair] = (i_vs, j_vs) 339 trace ('... done.') 340 return (pairs, vs) 341 342def update_v_ids_for_model (knowledge, pairs, vs, m): 343 rep = knowledge.rep 344 # first update the live variables 345 groups = {} 346 for v in vs: 347 (k, const) = vs[v] 348 groups.setdefault (k, []) 349 groups[k].append ((v, const)) 350 k_counter = 1 351 vs.clear () 352 for k in groups: 353 for (const, xs) in split_group (knowledge, m, groups[k]): 354 for x in xs: 355 vs[x] = (k_counter, const) 356 k_counter += 1 357 # then figure out which pairings are still viable 358 needed_ks = set () 359 zero = syntax.mk_word32 (0) 360 for (pair, data) in pairs.items (): 361 if data[0] == 'Failed': 362 continue 363 (lvs, rvs) = data 364 lv_ks = set ([vs[v][0] for v in lvs 365 if v[0] == zero or not vs[v][1]]) 366 rv_ks = set ([vs[v][0] for v in rvs]) 367 miss_vars = lv_ks - rv_ks 368 if miss_vars: 369 lv_miss = [v[0] for v in lvs if vs[v][0] in miss_vars] 370 pairs[pair] = ('Failed', lv_miss.pop ()) 371 else: 372 needed_ks.update ([vs[v][0] for v in lvs + rvs]) 373 # then drop any vars which are no longer relevant 374 for v in vs.keys (): 375 if vs[v][0] not in needed_ks: 376 del vs[v] 377 378def get_entry_visits_up_to (rep, head, restrs, hyps): 379 """get the set of nodes visited on the entry path entry 380 to the loop, up to and including the head point.""" 381 k = ('loop_visits_up_to', head, restrs, tuple (hyps)) 382 if k in rep.p.cached_analysis: 383 return rep.p.cached_analysis[k] 384 385 [entry] = get_loop_entry_sites (rep, restrs, hyps, head) 386 frontier = set ([entry]) 387 up_to = set () 388 loop = rep.p.loop_body (head) 389 while frontier: 390 n = frontier.pop () 391 if n == head: 392 continue 393 new_conts = [n2 for n2 in rep.p.nodes[n].get_conts () 394 if n2 in loop if n2 not in up_to] 395 up_to.update (new_conts) 396 frontier.update (new_conts) 397 rep.p.cached_analysis[k] = up_to 398 return up_to 399 400def get_nth_visit_restrs (rep, restrs, hyps, i, visit_num): 401 """get the nth (visit_num-th) visit to node i, using its loop head 402 as a restriction point. tricky because there may be a loop entry point 403 that brings us in with the loop head before i, or vice-versa.""" 404 head = rep.p.loop_id (i) 405 if i in get_entry_visits_up_to (rep, head, restrs, hyps): 406 # node i is in the set visited on the entry path, so 407 # the head is visited no more often than it 408 offs = 0 409 else: 410 # these are visited after the head point on the entry path, 411 # so the head point is visited 1 more time than it. 412 offs = 1 413 return ((head, vc_num (visit_num + offs)), ) + restrs 414 415def get_var_pc_var_list (knowledge, v_i): 416 rep = knowledge.rep 417 (v_i, i, i_offs, i_step) = v_i 418 def get_var (k): 419 restrs2 = get_nth_visit_restrs (rep, knowledge.restrs, 420 knowledge.hyps, i, k) 421 (pc, env) = rep.get_node_pc_env ((i, restrs2)) 422 return (to_smt_expr (pc, env, rep.solv), 423 to_smt_expr (v_i, env, rep.solv)) 424 return [get_var (i_offs + (k * i_step)) 425 for k in [0, 1, 2]] 426 427def expand_var_eqs (knowledge, (v_i, v_j)): 428 if v_j == 'Const': 429 pc_vs = get_var_pc_var_list (knowledge, v_i) 430 (_, v0) = pc_vs[0] 431 return [mk_implies (pc, mk_eq (v, v0)) 432 for (pc, v) in pc_vs[1:]] 433 # sorting the vars guarantees we generate the same 434 # mem eqs each time which is important for the solver 435 (v_i, v_j) = sorted ([v_i, v_j]) 436 pc_vs = zip (get_var_pc_var_list (knowledge, v_i), 437 get_var_pc_var_list (knowledge, v_j)) 438 return [pred for ((pc_i, v_i), (pc_j, v_j)) in pc_vs 439 for pred in [mk_eq (pc_i, pc_j), 440 mk_implies (pc_i, logic.mk_eq_with_cast (v_i, v_j))]] 441 442word_ops = {'bvadd':lambda x, y: x + y, 'bvsub':lambda x, y: x - y, 443 'bvmul':lambda x, y: x * y, 'bvurem':lambda x, y: x % y, 444 'bvudiv':lambda x, y: x / y, 'bvand':lambda x, y: x & y, 445 'bvor':lambda x, y: x | y, 'bvxor': lambda x, y: x ^ y, 446 'bvnot': lambda x: ~ x, 'bvneg': lambda x: - x, 447 'bvshl': lambda x, y: x << y, 'bvlshr': lambda x, y: x >> y} 448 449bool_ops = {'=>':lambda x, y: (not x) or y, '=': lambda x, y: x == y, 450 'not': lambda x: not x, 'true': lambda: True, 'false': lambda: False} 451 452word_ineq_ops = {'=': (lambda x, y: x == y, 'Unsigned'), 453 'bvult': (lambda x, y: x < y, 'Unsigned'), 454 'word32-eq': (lambda x, y: x == y, 'Unsigned'), 455 'bvule': (lambda x, y: x <= y, 'Unsigned'), 456 'bvsle': (lambda x, y: x <= y, 'Signed'), 457 'bvslt': (lambda x, y: x < y, 'Signed'), 458} 459 460def eval_model (m, s, toplevel = None): 461 if s in m: 462 return m[s] 463 if toplevel == None: 464 toplevel = s 465 if type (s) == str: 466 try: 467 result = solver.smt_to_val (s) 468 except Exception, e: 469 trace ('Error with eval_model') 470 trace (toplevel) 471 raise e 472 return result 473 474 op = s[0] 475 476 if op == 'ite': 477 [_, b, x, y] = s 478 b = eval_model (m, b, toplevel) 479 assert b in [false_term, true_term] 480 if b == true_term: 481 result = eval_model (m, x, toplevel) 482 else: 483 result = eval_model (m, y, toplevel) 484 m[s] = result 485 return result 486 487 xs = [eval_model (m, x, toplevel) for x in s[1:]] 488 489 if op[0] == '_' and op[1] in ['zero_extend', 'sign_extend']: 490 [_, ex_kind, n_extend] = op 491 n_extend = int (n_extend) 492 [x] = xs 493 assert x.typ.kind == 'Word' and x.kind == 'Num' 494 if ex_kind == 'sign_extend': 495 val = get_signed_val (x) 496 else: 497 val = get_unsigned_val (x) 498 result = mk_num (val, x.typ.num + n_extend) 499 elif op[0] == '_' and op[1] == 'extract': 500 [_, _, n_top, n_bot] = op 501 n_top = int (n_top) 502 n_bot = int (n_bot) 503 [x] = xs 504 assert x.typ.kind == 'Word' and x.kind == 'Num' 505 length = (n_top - n_bot) + 1 506 result = mk_num ((x.val >> n_bot) & ((1 << length) - 1), length) 507 elif op[0] == 'store-word32': 508 (m, p, v) = xs 509 (naming, eqs) = m 510 eqs = dict (eqs) 511 eqs[p.val] = v.val 512 eqs = tuple (sorted (eqs.items ())) 513 result = (naming, eqs) 514 elif op[0] == 'store-word8': 515 (m, p, v) = xs 516 p_al = p.val & -4 517 shift = (p.val & 3) * 8 518 (naming, eqs) = m 519 eqs = dict (eqs) 520 prev_v = eqs[p_al] 521 mask_v = prev_v & (((1 << 32) - 1) ^ (255 << shift)) 522 new_v = mask_v | ((v.val & 255) << shift) 523 eqs[p.val] = new_v 524 eqs = tuple (sorted (eqs.items ())) 525 result = (naming, eqs) 526 elif op[0] == 'load-word32': 527 (m, p) = xs 528 (naming, eqs) = m 529 eqs = dict (eqs) 530 result = syntax.mk_word32 (eqs[p.val]) 531 elif op[0] == 'load-word8': 532 (m, p) = xs 533 p_al = p.val & -4 534 shift = (p.val & 3) * 8 535 (naming, eqs) = m 536 eqs = dict (eqs) 537 v = (eqs[p_al] >> shift) & 255 538 result = syntax.mk_word8 (v) 539 elif xs and xs[0].typ.kind == 'Word' and op in word_ops: 540 for x in xs: 541 assert x.kind == 'Num', (s, op, x) 542 result = word_ops[op](* [x.val for x in xs]) 543 result = result & ((1 << xs[0].typ.num) - 1) 544 result = Expr ('Num', xs[0].typ, val = result) 545 elif xs and xs[0].typ.kind == 'Word' and op in word_ineq_ops: 546 (oper, signed) = word_ineq_ops[op] 547 if signed == 'Signed': 548 result = oper (* map (get_signed_val, xs)) 549 else: 550 assert signed == 'Unsigned' 551 result = oper (* [x.val for x in xs]) 552 result = {True: true_term, False: false_term}[result] 553 elif op == 'and': 554 result = all ([x == true_term for x in xs]) 555 result = {True: true_term, False: false_term}[result] 556 elif op == 'or': 557 result = bool ([x for x in xs if x == true_term]) 558 result = {True: true_term, False: false_term}[result] 559 elif op in bool_ops: 560 assert all ([x.typ == boolT for x in xs]) 561 result = bool_ops[op](* [x == true_term for x in xs]) 562 result = {True: true_term, False: false_term}[result] 563 else: 564 assert not 's_expr handled', (s, op) 565 m[s] = result 566 return result 567 568def get_unsigned_val (x): 569 assert x.typ.kind == 'Word' 570 assert x.kind == 'Num' 571 bits = x.typ.num 572 v = x.val & ((1 << bits) - 1) 573 return v 574 575def get_signed_val (x): 576 assert x.typ.kind == 'Word' 577 assert x.kind == 'Num' 578 bits = x.typ.num 579 v = x.val & ((1 << bits) - 1) 580 if v >= (1 << (bits - 1)): 581 v = v - (1 << bits) 582 return v 583 584def short_array_str (arr): 585 items = [('%x: %x' % (p.val * 4, v.val)) 586 for (p, v) in arr.iteritems () 587 if type (p) != str] 588 items.sort () 589 return '{' + ', '.join (items) + '}' 590 591def eval_model_expr (m, solv, v): 592 s = solver.smt_expr (v, {}, solv) 593 s_x = solver.parse_s_expression (s) 594 595 return eval_model (m, s_x) 596 597def model_equal (m, knowledge, vpair): 598 preds = expand_var_eqs (knowledge, vpair) 599 for pred in preds: 600 x = eval_model_expr (m, knowledge.rep.solv, pred) 601 assert x in [syntax.true_term, syntax.false_term] 602 if x == syntax.false_term: 603 return False 604 return True 605 606def get_model_trace (knowledge, m, v): 607 rep = knowledge.rep 608 pc_vs = get_var_pc_var_list (knowledge, v) 609 trace = [] 610 for (pc, v) in pc_vs: 611 x = eval_model_expr (m, rep.solv, pc) 612 assert x in [syntax.true_term, syntax.false_term] 613 if x == syntax.false_term: 614 trace.append (None) 615 else: 616 trace.append (eval_model_expr (m, rep.solv, v)) 617 return tuple (trace) 618 619def split_group (knowledge, m, group): 620 group = list (set (group)) 621 if group[0][0][0].typ == syntax.builtinTs['Mem']: 622 bins = [] 623 for (v, const) in group: 624 for i in range (len (bins)): 625 if model_equal (m, knowledge, 626 (v, bins[i][1][0])): 627 bins[i][1].append (v) 628 break 629 else: 630 if const: 631 const = model_equal (m, knowledge, 632 (v, 'Const')) 633 bins.append ((const, [v])) 634 return bins 635 else: 636 bins = {} 637 for (v, const) in group: 638 trace = get_model_trace (knowledge, m, v) 639 if trace not in bins: 640 tconst = len (set (trace) - set ([None])) <= 1 641 bins[trace] = (const and tconst, []) 642 bins[trace][1].append (v) 643 return bins.values () 644 645def mk_pairing_v_eqs (knowledge, pair, endorsed = True): 646 v_eqs = [] 647 (lvs, rvs) = knowledge.pairs[pair] 648 zero = mk_word32 (0) 649 for v_i in lvs: 650 (k, const) = knowledge.v_ids[v_i] 651 if const and v_i[0] != zero: 652 if not endorsed or eq_known (knowledge, (v_i, 'Const')): 653 v_eqs.append ((v_i, 'Const')) 654 continue 655 vs_j = [v_j for v_j in rvs if knowledge.v_ids[v_j][0] == k] 656 if endorsed: 657 vs_j = [v_j for v_j in vs_j 658 if eq_known (knowledge, (v_i, v_j))] 659 if not vs_j: 660 return None 661 v_j = vs_j[0] 662 v_eqs.append ((v_i, v_j)) 663 return v_eqs 664 665def eq_known (knowledge, vpair): 666 preds = expand_var_eqs (knowledge, vpair) 667 return set (preds) <= knowledge.facts 668 669def find_split_loop (p, head, restrs, hyps, unfold_limit = 9, 670 node_restrs = None, trace_ind_fails = None): 671 assert p.loop_data[head][0] == 'Head' 672 assert p.node_tags[head][0] == p.pairing.tags[0] 673 674 # the idea is to loop through testable hyps, starting with ones that 675 # need smaller models (the most unfolded models will time out for 676 # large problems like finaliseSlot) 677 678 rep = mk_graph_slice (p, fast = True) 679 680 nec = get_necessary_split_opts (p, head, restrs, hyps) 681 if nec and nec[0] in ['CaseSplit', 'LoopUnroll']: 682 return nec 683 elif nec: 684 i_j_opts = nec 685 else: 686 i_j_opts = default_i_j_opts (unfold_limit) 687 688 if trace_ind_fails == None: 689 ind_fails = [] 690 else: 691 ind_fails = trace_ind_fails 692 for (i_opts, j_opts) in i_j_opts: 693 result = find_split (rep, head, restrs, hyps, 694 i_opts, j_opts, node_restrs = node_restrs) 695 if result[0] != None: 696 return result 697 ind_fails.extend (result[1]) 698 699 if ind_fails: 700 trace ('Warning: inductive failures: %s' % ind_fails) 701 raise NoSplit () 702 703def default_i_j_opts (unfold_limit = 9): 704 return mk_i_j_opts (unfold_limit = unfold_limit) 705 706def mk_i_j_opts (i_seq_opts = None, j_seq_opts = None, unfold_limit = 9): 707 if i_seq_opts == None: 708 i_seq_opts = [(0, 1), (1, 1), (2, 1), (3, 1)] 709 if j_seq_opts == None: 710 j_seq_opts = [(0, 1), (0, 2), (1, 1), (1, 2), 711 (2, 1), (2, 2), (3, 1)] 712 all_opts = set (i_seq_opts + j_seq_opts) 713 714 def filt (opts, lim): 715 return [(start, step) for (start, step) in opts 716 if start + (2 * step) + 1 <= lim] 717 718 lims = [(filt (i_seq_opts, lim), filt (j_seq_opts, lim)) 719 for lim in range (unfold_limit) 720 if [1 for (start, step) in all_opts 721 if start + (2 * step) + 1 == lim]] 722 lims = [(i_opts, j_opts) for (i_opts, j_opts) in lims 723 if i_opts and j_opts] 724 return lims 725 726necessary_split_opts_trace = [] 727 728def get_interesting_linear_series_exprs (p, head): 729 k = ('interesting_linear_series', head) 730 if k in p.cached_analysis: 731 return p.cached_analysis[k] 732 res = logic.interesting_linear_series_exprs (p, head, 733 get_loop_var_analysis_at (p, head)) 734 p.cached_analysis[k] = res 735 return res 736 737def split_opt_test (p, tags = None): 738 if not tags: 739 tags = p.pairing.tags 740 heads = [head for head in init_loops_to_split (p, ()) 741 if p.node_tags[head][0] == tags[0]] 742 hyps = check.init_point_hyps (p) 743 return [(head, get_necessary_split_opts (p, head, (), hyps)) 744 for head in heads] 745 746def interesting_linear_test (p): 747 p.do_analysis () 748 for head in p.loop_heads (): 749 inter = get_interesting_linear_series_exprs (p, head) 750 hooks = target_objects.hooks ('loop_var_analysis') 751 n_exprs = [(n, expr, offs) for (n, vs) in inter.iteritems () 752 if not [hook for hook in hooks if hook (p, n) != None] 753 for (kind, expr, offs) in vs] 754 if n_exprs: 755 rep = rep_graph.mk_graph_slice (p) 756 for (n, expr, offs) in n_exprs: 757 restrs = tuple ([(n2, vc) for (n2, vc) 758 in restr_others_both (p, (), 2, 2) 759 if p.loop_id (n2) != p.loop_id (head)]) 760 vis1 = (n, ((head, vc_offs (1)), ) + restrs) 761 vis2 = (n, ((head, vc_offs (2)), ) + restrs) 762 pc = rep.get_pc (vis2) 763 imp = mk_implies (pc, mk_eq (rep.to_smt_expr (expr, vis2), 764 rep.to_smt_expr (mk_plus (expr, offs), vis1))) 765 assert rep.test_hyp_whyps (imp, []) 766 return True 767 768last_necessary_split_opts = [0] 769 770def get_necessary_split_opts (p, head, restrs, hyps, tags = None): 771 if not tags: 772 tags = p.pairing.tags 773 [l_tag, r_tag] = tags 774 last_necessary_split_opts[0] = (p, head, restrs, hyps, tags) 775 776 rep = rep_graph.mk_graph_slice (p, fast = True) 777 entries = get_loop_entry_sites (rep, restrs, hyps, head) 778 if len (entries) > 1: 779 return ('CaseSplit', ((entries[0], tags[0]), [entries[0]])) 780 for n in init_loops_to_split (p, restrs): 781 if p.node_tags[n][0] != r_tag: 782 continue 783 entries = get_loop_entry_sites (rep, restrs, hyps, n) 784 if len (entries) > 1: 785 return ('CaseSplit', ((entries[0], r_tag), 786 [entries[0]])) 787 788 stuff = linear_setup_stuff (rep, head, restrs, hyps, tags) 789 if stuff == None: 790 return None 791 seq_eqs = get_matching_linear_seqs (rep, head, restrs, hyps, tags) 792 793 vis = stuff['vis'] 794 for v in seq_eqs: 795 if v[0] == 'LoopUnroll': 796 (_, n, est_bound) = v 797 lim = get_split_limit (p, n, restrs, hyps, 'Number', 798 est_bound = est_bound, must_find = False) 799 if lim != None: 800 return ('LoopUnroll', n) 801 continue 802 ((n, expr), (n2, expr2), (l_start, l_step), (r_start, r_step), 803 _, _) = v 804 eqs = [rep_graph.eq_hyp ((expr, 805 (vis (n, l_start + (i * l_step)), l_tag)), 806 (expr2, (vis (n2, r_start + (i * r_step)), r_tag))) 807 for i in range (2)] 808 vis_hyp = rep_graph.pc_true_hyp ((vis (n, l_start), l_tag)) 809 vis_hyps = [vis_hyp] + stuff['hyps'] 810 eq = foldr1 (mk_and, map (rep.interpret_hyp, eqs)) 811 m = {} 812 if rep.test_hyp_whyps (eq, vis_hyps, model = m): 813 trace ('found necessary split info: (%s, %s), (%s, %s)' 814 % (l_start, l_step, r_start, r_step)) 815 return mk_i_j_opts ([(l_start + i, l_step) 816 for i in range (r_step + 1)], 817 [(r_start + i, r_step) 818 for i in range (l_step + 1)], 819 unfold_limit = 100) 820 n_vcs = entry_path_no_loops (rep, l_tag, m, head) 821 path_hyps = [rep_graph.pc_true_hyp ((n_vc, l_tag)) for n_vc in n_vcs] 822 if rep.test_hyp_whyps (eq, stuff['hyps'] + path_hyps): 823 # immediate case split on difference between entry paths 824 checks = [(stuff['hyps'], eq_hyp, 'eq') 825 for eq_hyp in eqs] 826 return derive_case_split (rep, n_vcs, checks) 827 necessary_split_opts_trace.append ((n, expr, (l_start, l_step), 828 (r_start, r_step), 'Seq check failed')) 829 return None 830 831def linear_setup_stuff (rep, head, restrs, hyps, tags): 832 [l_tag, r_tag] = tags 833 k = ('linear_seq setup', head, restrs, tuple (hyps), tuple (tags)) 834 p = rep.p 835 if k in p.cached_analysis: 836 return p.cached_analysis[k] 837 838 assert p.node_tags[head][0] == l_tag 839 l_seq_vs = get_interesting_linear_series_exprs (p, head) 840 if not l_seq_vs: 841 return None 842 r_seq_vs = {} 843 restr_env = {p.loop_id (head): restrs} 844 for n in init_loops_to_split (p, restrs): 845 if p.node_tags[n][0] != r_tag: 846 continue 847 vs = get_interesting_linear_series_exprs (p, n) 848 r_seq_vs.update (vs) 849 if not r_seq_vs: 850 return None 851 852 def vis (n, i): 853 restrs2 = get_nth_visit_restrs (rep, restrs, hyps, n, i) 854 return (n, restrs2) 855 smt = lambda expr, n, i: rep.to_smt_expr (expr, vis (n, i)) 856 smt_pc = lambda n, i: rep.get_pc (vis (n, i)) 857 858 # remove duplicates by concretising 859 l_seq_vs = dict ([(smt (expr, n, 2), (kind, n, expr, offs, oset)) 860 for n in l_seq_vs 861 for (kind, expr, offs, oset) in l_seq_vs[n]]).values () 862 r_seq_vs = dict ([(smt (expr, n, 2), (kind, n, expr, offs, oset)) 863 for n in r_seq_vs 864 for (kind, expr, offs, oset) in r_seq_vs[n]]).values () 865 866 hyps = hyps + [rep_graph.pc_triv_hyp ((vis (n, 3), r_tag)) 867 for n in set ([n for (_, n, _, _, _) in r_seq_vs])] 868 hyps = hyps + [rep_graph.pc_triv_hyp ((vis (n, 3), l_tag)) 869 for n in set ([n for (_, n, _, _, _) in l_seq_vs])] 870 hyps = hyps + [check.non_r_err_pc_hyp (tags, 871 restr_others (p, restrs, 2))] 872 873 r = {'l_seq_vs': l_seq_vs, 'r_seq_vs': r_seq_vs, 874 'hyps': hyps, 'vis': vis, 'smt': smt, 'smt_pc': smt_pc} 875 p.cached_analysis[k] = r 876 return r 877 878def get_matching_linear_seqs (rep, head, restrs, hyps, tags): 879 k = ('matching linear seqs', head, restrs, tuple (hyps), tuple (tags)) 880 p = rep.p 881 if k in p.cached_analysis: 882 v = p.cached_analysis[k] 883 (x, y) = itertools.tee (v[0]) 884 v[0] = x 885 return y 886 887 [l_tag, r_tag] = tags 888 stuff = linear_setup_stuff (rep, head, restrs, hyps, tags) 889 if stuff == None: 890 return [] 891 892 hyps = stuff['hyps'] 893 vis = stuff['vis'] 894 895 def get_model (n, offs): 896 m = {} 897 offs_smt = stuff['smt'] (offs, n, 1) 898 eq = mk_eq (mk_times (offs_smt, mk_num (4, offs_smt.typ)), 899 mk_num (0, offs_smt.typ)) 900 ex_hyps = [rep_graph.pc_true_hyp ((vis (n, 1), l_tag)), 901 rep_graph.pc_true_hyp ((vis (n, 2), l_tag))] 902 res = rep.test_hyp_whyps (eq, hyps + ex_hyps, model = m) 903 if not m: 904 necessary_split_opts_trace.append ((n, kind, 'NoModel')) 905 return None 906 return m 907 908 r = (seq_eq 909 for (kind, n, expr, offs, oset) in sorted (stuff['l_seq_vs']) 910 if [v for v in stuff['r_seq_vs'] if v[0] == kind] 911 for m in [get_model (n, offs)] 912 if m 913 for seq_eq in [get_linear_seq_eq (rep, m, stuff, 914 (kind, n, expr, offs, oset)), 915 get_model_r_side_unroll (rep, tags, m, 916 restrs, hyps, stuff)] 917 if seq_eq != None) 918 (x, y) = itertools.tee (r) 919 p.cached_analysis[k] = [y] 920 return x 921 922def get_linear_seq_eq (rep, m, stuff, expr_t1): 923 def get_int_min (expr): 924 v = eval_model_expr (m, rep.solv, expr) 925 assert v.kind == 'Num', v 926 vs = [v.val + (i << v.typ.num) for i in range (-2, 3)] 927 (_, v) = min ([(abs (v), v) for v in vs]) 928 return v 929 (kind, n1, expr1, offs1, oset1) = expr_t1 930 smt = stuff['smt'] 931 expr_init = smt (expr1, n1, 0) 932 expr_v = get_int_min (expr_init) 933 offs_v = get_int_min (smt (offs1, n1, 1)) 934 r_seqs = [(n, expr, offs, oset2, 935 get_int_min (mk_minus (expr_init, smt (expr, n, 0))), 936 get_int_min (smt (offs, n, 0))) 937 for (kind2, n, expr, offs, oset2) in sorted (stuff['r_seq_vs']) 938 if kind2 == kind] 939 940 for (n, expr, offs2, oset2, diff, offs_v2) in sorted (r_seqs): 941 mult = offs_v / offs_v2 942 if offs_v % offs_v2 != 0 or mult > 8: 943 necessary_split_opts_trace.append ((n, expr, 944 'StepWrong', offs_v, offs_v2)) 945 elif diff % offs_v2 != 0 or (diff * offs_v2) < 0 or (diff / offs_v2) > 8: 946 necessary_split_opts_trace.append ((n, expr, 947 'StartWrong', diff, offs_v2)) 948 else: 949 return ((n1, expr1), (n, expr), (0, 1), 950 (diff / offs_v2, mult), (offs1, offs2), 951 (oset1, oset2)) 952 return None 953 954last_r_side_unroll = [None] 955 956def get_model_r_side_unroll (rep, tags, m, restrs, hyps, stuff): 957 p = rep.p 958 [l_tag, r_tag] = tags 959 last_r_side_unroll[0] = (rep, tags, m, restrs, hyps, stuff) 960 961 r_kinds = set ([kind for (kind, n, _, _, _) in stuff['r_seq_vs']]) 962 l_visited_ns_vcs = logic.dict_list ([(n, vc) 963 for (tag, n, vc) in rep.node_pc_env_order 964 if tag == l_tag 965 if eval_pc (rep, m, (n, vc))]) 966 l_arc_interesting = [(n, vc, kind, expr) 967 for (n, vcs) in l_visited_ns_vcs.iteritems () 968 if len (vcs) == 1 969 for vc in vcs 970 for (kind, expr) 971 in logic.interesting_node_exprs (p, n, tags = tags) 972 if kind in r_kinds 973 if expr.typ.kind == 'Word'] 974 l_kinds = set ([kind for (n, vc, kind, _) in l_arc_interesting]) 975 976 # FIXME: cloned 977 def canon_n (n, typ): 978 vs = [n + (i << typ.num) for i in range (-2, 3)] 979 (_, v) = min ([(abs (v), v) for v in vs]) 980 return v 981 def get_int_min (expr): 982 v = eval_model_expr (m, rep.solv, expr) 983 assert v.kind == 'Num', v 984 return canon_n (v.val, v.typ) 985 def eval (expr, n, vc): 986 expr = rep.to_smt_expr (expr, (n, vc)) 987 return get_int_min (expr) 988 989 val_interesting_map = logic.dict_list ([((kind, eval (expr, n, vc)), n) 990 for (n, vc, kind, expr) in l_arc_interesting]) 991 992 smt = stuff['smt'] 993 994 for (kind, n, expr, offs, _) in stuff['r_seq_vs']: 995 if kind not in l_kinds: 996 continue 997 if expr.typ.kind != 'Word': 998 continue 999 expr_n = get_int_min (smt (expr, n, 0)) 1000 offs_n = get_int_min (smt (offs, n, 0)) 1001 hit = ([i for i in range (64) 1002 if (kind, canon_n (expr_n + (offs_n * i), expr.typ)) 1003 in val_interesting_map]) 1004 if [i for i in hit if i > 4]: 1005 return ('LoopUnroll', p.loop_id (n), max (hit)) 1006 return None 1007 1008last_failed_pairings = [] 1009 1010def setup_split_search (rep, head, restrs, hyps, 1011 i_opts, j_opts, unfold_limit = None, tags = None, 1012 node_restrs = None): 1013 p = rep.p 1014 1015 if not tags: 1016 tags = p.pairing.tags 1017 if node_restrs == None: 1018 node_restrs = set (p.nodes) 1019 if unfold_limit == None: 1020 unfold_limit = max ([start + (2 * step) + 1 1021 for (start, step) in i_opts + j_opts]) 1022 1023 trace ('Split search at %d, unfold limit %d.' % (head, unfold_limit)) 1024 1025 l_tag, r_tag = tags 1026 loop_elts = [(n, start, step) for n in p.splittable_points (head) 1027 if n in node_restrs 1028 for (start, step) in i_opts] 1029 init_to_split = init_loops_to_split (p, restrs) 1030 r_to_split = [n for n in init_to_split if p.node_tags[n][0] == r_tag] 1031 cand_r_loop_elts = [(n2, start, step) for n in r_to_split 1032 for n2 in p.splittable_points (n) 1033 if n2 in node_restrs 1034 for (start, step) in j_opts] 1035 1036 err_restrs = restr_others (p, tuple ([(sp, vc_upto (unfold_limit)) 1037 for sp in r_to_split]) + restrs, 1) 1038 nrerr_pc = mk_not (rep.get_pc (('Err', err_restrs), tag = r_tag)) 1039 1040 def get_pc (n, k): 1041 restrs2 = get_nth_visit_restrs (rep, restrs, hyps, n, k) 1042 return rep.get_pc ((n, restrs2)) 1043 1044 for n in r_to_split: 1045 get_pc (n, unfold_limit) 1046 get_pc (head, unfold_limit) 1047 1048 premise = foldr1 (mk_and, [nrerr_pc] + map (rep.interpret_hyp, hyps)) 1049 premise = logic.weaken_assert (premise) 1050 1051 knowledge = SearchKnowledge (rep, 1052 'search at %d (unfold limit %d)' % (head, unfold_limit), 1053 restrs, hyps, tags, (loop_elts, cand_r_loop_elts)) 1054 knowledge.premise = premise 1055 last_knowledge[0] = knowledge 1056 1057 # make sure the representation is in sync 1058 rep.test_hyp_whyps (true_term, hyps) 1059 1060 # make sure all mem eqs are being tracked 1061 mem_vs = [v for v in knowledge.v_ids if v[0].typ == builtinTs['Mem']] 1062 for (i, v) in enumerate (mem_vs): 1063 for v2 in mem_vs[:i]: 1064 for pred in expand_var_eqs (knowledge, (v, v2)): 1065 smt_expr (pred, {}, rep.solv) 1066 for v in knowledge.v_ids: 1067 for pred in expand_var_eqs (knowledge, (v, 'Const')): 1068 smt_expr (pred, {}, rep.solv) 1069 1070 return knowledge 1071 1072def get_loop_entry_sites (rep, restrs, hyps, head): 1073 k = ('loop_entry_sites', restrs, tuple (hyps), rep.p.loop_id (head)) 1074 if k in rep.p.cached_analysis: 1075 return rep.p.cached_analysis[k] 1076 ns = set ([n for n2 in rep.p.loop_body (head) 1077 for n in rep.p.preds[n2] 1078 if rep.p.loop_id (n) == None]) 1079 def npc (n): 1080 return rep_graph.pc_false_hyp (((n, tuple ([(n2, restr) 1081 for (n2, restr) in restrs if n2 != n])), 1082 rep.p.node_tags[n][0])) 1083 res = [n for n in ns if not rep.test_hyp_imp (hyps, npc (n))] 1084 rep.p.cached_analysis[k] = res 1085 return res 1086 1087def rebuild_knowledge (head, knowledge): 1088 i_opts = sorted (set ([(start, step) 1089 for ((_, start, step), _) in knowledge.pairs])) 1090 j_opts = sorted (set ([(start, step) 1091 for (_, (_, start, step)) in knowledge.pairs])) 1092 knowledge2 = setup_split_search (knowledge.rep, head, knowledge.restrs, 1093 knowledge.hyps, i_opts, j_opts) 1094 knowledge2.facts.update (knowledge.facts) 1095 for m in knowledge.model_trace: 1096 knowledge2.add_model (m) 1097 return knowledge2 1098 1099def split_search (head, knowledge): 1100 rep = knowledge.rep 1101 p = rep.p 1102 1103 # test any relevant cached solutions. 1104 p.cached_analysis.setdefault (('v_eqs', head), set ()) 1105 v_eq_cache = p.cached_analysis[('v_eqs', head)] 1106 for (pair, eqs) in v_eq_cache: 1107 if pair in knowledge.pairs: 1108 knowledge.eqs_add_model (list (eqs), 1109 assert_progress = False) 1110 1111 while True: 1112 trace ('In %s' % knowledge.name) 1113 trace ('Computing live pairings') 1114 pair_eqs = [(pair, mk_pairing_v_eqs (knowledge, pair)) 1115 for pair in sorted (knowledge.pairs) 1116 if knowledge.pairs[pair][0] != 'Failed'] 1117 if not pair_eqs: 1118 ind_fails = trace_search_fail (knowledge) 1119 return (None, ind_fails) 1120 1121 endorsed = [(pair, eqs) for (pair, eqs) in pair_eqs 1122 if eqs != None] 1123 trace (' ... %d live pairings, %d endorsed' % 1124 (len (pair_eqs), len (endorsed))) 1125 knowledge.live_pairs_trace.append (len (pair_eqs)) 1126 for (pair, eqs) in endorsed: 1127 if knowledge.is_weak_split (eqs): 1128 trace (' dropping endorsed - probably weak.') 1129 knowledge.pairs[pair] = ('Failed', 1130 'ExpectedSplitWeak', eqs) 1131 continue 1132 split = build_and_check_split (p, pair, eqs, 1133 knowledge.restrs, knowledge.hyps, 1134 knowledge.tags) 1135 if split == None: 1136 knowledge.pairs[pair] = ('Failed', 1137 'SplitWeak', eqs) 1138 knowledge.add_weak_split (eqs) 1139 continue 1140 elif split == 'InductFailed': 1141 knowledge.pairs[pair] = ('Failed', 1142 'InductFailed', eqs) 1143 elif split[0] == 'SingleRevInduct': 1144 return split 1145 else: 1146 v_eq_cache.add ((pair, tuple (eqs))) 1147 trace ('Found split!') 1148 return ('Split', split) 1149 if endorsed: 1150 continue 1151 1152 (pair, _) = pair_eqs[0] 1153 trace ('Testing guess for pair: %s' % str (pair)) 1154 eqs = mk_pairing_v_eqs (knowledge, pair, endorsed = False) 1155 assert eqs, pair 1156 knowledge.eqs_add_model (eqs) 1157 1158def build_and_check_split_inner (p, pair, eqs, restrs, hyps, tags): 1159 split = v_eqs_to_split (p, pair, eqs, restrs, hyps, tags = tags) 1160 if split == None: 1161 return None 1162 res = check_split_induct (p, restrs, hyps, split, tags = tags) 1163 if res: 1164 return split 1165 else: 1166 return 'InductFailed' 1167 1168def build_and_check_split (p, pair, eqs, restrs, hyps, tags): 1169 res = build_and_check_split_inner (p, pair, eqs, restrs, hyps, tags) 1170 if res != 'InductFailed': 1171 return res 1172 1173 # induction has failed at this point, but we might be able to rescue 1174 # it one of two different ways. 1175 ((l_split, _, l_step), _) = pair 1176 extra = get_new_extra_linear_seq_eqs (p, restrs, l_split, l_step) 1177 if extra: 1178 res = build_and_check_split (p, pair, eqs, restrs, hyps, tags) 1179 # the additional linear eqs get built into the result 1180 if res != 'InductFailed': 1181 return res 1182 1183 (_, (r_split, _, _)) = pair 1184 r_loop = p.loop_id (r_split) 1185 spec = get_new_rhs_speculate_ineq (p, restrs, hyps, r_loop) 1186 if spec: 1187 hyp = check.single_induct_resulting_hyp (p, restrs, spec) 1188 hyps2 = hyps + [hyp] 1189 res = build_and_check_split (p, pair, eqs, restrs, hyps2, tags) 1190 if res != 'InductFailed': 1191 return ('SingleRevInduct', spec) 1192 return 'InductFailed' 1193 1194def get_new_extra_linear_seq_eqs (p, restrs, l_split, l_step): 1195 k = ('extra_linear_seq_eqs', l_split, l_step) 1196 if k in p.cached_analysis: 1197 return [] 1198 if not [v for (v, data) in get_loop_var_analysis_at (p, l_split) 1199 if data[0] == 'LoopLinearSeries']: 1200 return [] 1201 import loop_bounds 1202 lin_series_eqs = loop_bounds.get_linear_series_eqs (p, l_split, 1203 restrs, [], omit_standard = True) 1204 p.cached_analysis[k] = lin_series_eqs 1205 return lin_series_eqs 1206 1207def trace_search_fail (knowledge): 1208 trace (('Exhausted split candidates for %s' % knowledge.name)) 1209 fails = [it for it in knowledge.pairs.items () 1210 if it[1][0] == 'Failed'] 1211 last_failed_pairings.append (fails) 1212 del last_failed_pairings[:-10] 1213 fails10 = fails[:10] 1214 trace (' %d of %d failed pairings:' % (len (fails10), 1215 len (fails))) 1216 for f in fails10: 1217 trace (' %s' % (f,)) 1218 ind_fails = [it for it in fails 1219 if str (it[1][1]) == 'InductFailed'] 1220 if ind_fails: 1221 trace ( 'Inductive failures!') 1222 else: 1223 trace ( 'No inductive failures.') 1224 for f in ind_fails: 1225 trace (' %s' % (f,)) 1226 return ind_fails 1227 1228def find_split (rep, head, restrs, hyps, i_opts, j_opts, 1229 unfold_limit = None, tags = None, 1230 node_restrs = None): 1231 knowledge = setup_split_search (rep, head, restrs, hyps, 1232 i_opts, j_opts, unfold_limit = unfold_limit, 1233 tags = tags, node_restrs = node_restrs) 1234 1235 res = split_search (head, knowledge) 1236 1237 if res[0]: 1238 return res 1239 1240 (models, facts, n_vcs) = most_common_path (head, knowledge) 1241 if not n_vcs: 1242 return res 1243 1244 [tag, _] = knowledge.tags 1245 knowledge = setup_split_search (rep, head, restrs, 1246 hyps + [rep_graph.pc_true_hyp ((n_vc, tag)) for n_vc in n_vcs], 1247 i_opts, j_opts, unfold_limit, tags, node_restrs = node_restrs) 1248 knowledge.facts.update (facts) 1249 for m in models: 1250 knowledge.add_model (m) 1251 res = split_search (head, knowledge) 1252 1253 if res[0] == None: 1254 return res 1255 (_, split) = res 1256 checks = check.split_init_step_checks (rep.p, restrs, 1257 hyps, split) 1258 1259 return derive_case_split (rep, n_vcs, checks) 1260 1261def most_common_path (head, knowledge): 1262 rep = knowledge.rep 1263 [tag, _] = knowledge.tags 1264 data = logic.dict_list ([(tuple (entry_path_no_loops (rep, 1265 tag, m, head)), m) 1266 for m in knowledge.model_trace]) 1267 if len (data) < 2: 1268 return (None, None, None) 1269 1270 (_, path) = max ([(len (data[path]), path) for path in data]) 1271 models = data[path] 1272 facts = knowledge.facts 1273 other_n_vcs = set.intersection (* [set (path2) for path2 in data 1274 if path2 != path]) 1275 1276 n_vcs = [] 1277 pcs = set () 1278 for n_vc in path: 1279 if n_vc in other_n_vcs: 1280 continue 1281 if rep.p.loop_id (n_vc[0]): 1282 continue 1283 pc = rep.get_pc (n_vc) 1284 if pc not in pcs: 1285 pcs.add (pc) 1286 n_vcs.append (n_vc) 1287 assert n_vcs 1288 1289 return (models, facts, n_vcs) 1290 1291def eval_pc (rep, m, n_vc, tag = None): 1292 hit = eval_model_expr (m, rep.solv, rep.get_pc (n_vc, tag = tag)) 1293 assert hit in [syntax.true_term, syntax.false_term], (n_vc, hit) 1294 return hit == syntax.true_term 1295 1296def entry_path (rep, tag, m, head): 1297 n_vcs = [] 1298 for (tag2, n, vc) in rep.node_pc_env_order: 1299 if n == head: 1300 break 1301 if tag2 != tag: 1302 continue 1303 if eval_pc (rep, m, (n, vc), tag): 1304 n_vcs.append ((n, vc)) 1305 return n_vcs 1306 1307def entry_path_no_loops (rep, tag, m, head = None): 1308 n_vcs = entry_path (rep, tag, m, head) 1309 return [(n, vc) for (n, vc) in n_vcs 1310 if not rep.p.loop_id (n)] 1311 1312last_derive_case_split = [0] 1313 1314def derive_case_split (rep, n_vcs, checks): 1315 last_derive_case_split[0] = (rep.p, n_vcs, checks) 1316 # remove duplicate pcs 1317 n_vcs_uniq = dict ([(rep.get_pc (n_vc), (i, n_vc)) 1318 for (i, n_vc) in enumerate (n_vcs)]).values () 1319 n_vcs = [n_vc for (i, n_vc) in sorted (n_vcs_uniq)] 1320 assert n_vcs 1321 tag = rep.p.node_tags[n_vcs[0][0]][0] 1322 keep_n_vcs = [] 1323 test_n_vcs = n_vcs 1324 mk_thyps = lambda n_vcs: [rep_graph.pc_true_hyp ((n_vc, tag)) 1325 for n_vc in n_vcs] 1326 while len (test_n_vcs) > 1: 1327 i = len (test_n_vcs) / 2 1328 test_in = test_n_vcs[:i] 1329 test_out = test_n_vcs[i:] 1330 checks2 = [(hyps + mk_thyps (test_in + keep_n_vcs), hyp, nm) 1331 for (hyps, hyp, nm) in checks] 1332 (verdict, _) = check.test_hyp_group (rep, checks2) 1333 if verdict: 1334 # forget n_vcs that were tested out 1335 test_n_vcs = test_in 1336 else: 1337 # focus on n_vcs that were tested out 1338 test_n_vcs = test_out 1339 keep_n_vcs.extend (test_in) 1340 [(n, vc)] = test_n_vcs 1341 return ('CaseSplit', ((n, tag), [n])) 1342 1343def mk_seq_eqs (p, split, step, with_rodata): 1344 # eqs take the form of a number of constant expressions 1345 eqs = [] 1346 1347 # the variable 'loop' will be converted to the point in 1348 # the sequence - note this should be multiplied by the step size 1349 loop = mk_var ('%i', word32T) 1350 if step == 1: 1351 minus_loop_step = mk_uminus (loop) 1352 else: 1353 minus_loop_step = mk_times (loop, mk_word32 (- step)) 1354 1355 for (var, data) in get_loop_var_analysis_at (p, split): 1356 if data == 'LoopVariable': 1357 if with_rodata and var.typ == builtinTs['Mem']: 1358 eqs.append (logic.mk_rodata (var)) 1359 elif data == 'LoopConst': 1360 if var.typ not in syntax.phantom_types: 1361 eqs.append (var) 1362 elif data == 'LoopLeaf': 1363 continue 1364 elif data[0] == 'LoopLinearSeries': 1365 (_, form, _) = data 1366 eqs.append (form (var, 1367 mk_cast (minus_loop_step, var.typ))) 1368 else: 1369 assert not 'var_deps type understood' 1370 1371 k = ('extra_linear_seq_eqs', split, step) 1372 eqs += p.cached_analysis.get (k, []) 1373 1374 return eqs 1375 1376def c_memory_loop_invariant (p, c_sp, a_sp): 1377 def mem_vars (split): 1378 return [v for (v, data) in get_loop_var_analysis_at (p, split) 1379 if v.typ == builtinTs['Mem'] 1380 if data == 'LoopVariable'] 1381 1382 if mem_vars (a_sp): 1383 return [] 1384 # if ASM keeps memory constant through the loop, it is implying this 1385 # is semantically possible in C also, though it may not be 1386 # syntactically the case 1387 # anyway, we have to assert C memory equals *something* inductively 1388 # so we pick C initial memory. 1389 return mem_vars (c_sp) 1390 1391def v_eqs_to_split (p, pair, v_eqs, restrs, hyps, tags = None): 1392 trace ('v_eqs_to_split: (%s, %s)' % pair) 1393 1394 ((l_n, l_init, l_step), (r_n, r_init, r_step)) = pair 1395 l_details = (l_n, (l_init, l_step), mk_seq_eqs (p, l_n, l_step, True) 1396 + [v_i[0] for (v_i, v_j) in v_eqs if v_j == 'Const']) 1397 r_details = (r_n, (r_init, r_step), mk_seq_eqs (p, r_n, r_step, False) 1398 + c_memory_loop_invariant (p, r_n, l_n)) 1399 1400 eqs = [(v_i[0], mk_cast (v_j[0], v_i[0].typ)) 1401 for (v_i, v_j) in v_eqs if v_j != 'Const' 1402 if v_i[0] != syntax.mk_word32 (0)] 1403 1404 n = 2 1405 split = (l_details, r_details, eqs, n, (n * r_step) - 1) 1406 trace ('Split: %s' % (split, )) 1407 if tags == None: 1408 tags = p.pairing.tags 1409 hyps = hyps + check.split_loop_hyps (tags, split, restrs, exit = True) 1410 1411 r_max = get_split_limit (p, r_n, restrs, hyps, 'Offset', 1412 bound = (n + 2) * r_step, must_find = False, 1413 hints = [n * r_step, n * r_step + 1]) 1414 if r_max == None: 1415 trace ('v_eqs_to_split: no RHS limit') 1416 return None 1417 1418 if r_max > n * r_step: 1419 trace ('v_eqs_to_split: RHS limit not %d' % (n * r_step)) 1420 return None 1421 trace ('v_eqs_to_split: split %s' % (split,)) 1422 return split 1423 1424def get_n_offset_successes (rep, sp, step, restrs): 1425 loop = rep.p.loop_body (sp) 1426 ns = [n for n in loop if rep.p.nodes[n].kind == 'Call'] 1427 succs = [] 1428 for i in range (step): 1429 for n in ns: 1430 vc = vc_offs (i + 1) 1431 if n == sp: 1432 vc = vc_offs (i) 1433 n_vc = (n, restrs + tuple ([(sp, vc)])) 1434 (_, _, succ) = rep.get_func (n_vc) 1435 pc = rep.get_pc (n_vc) 1436 succs.append (syntax.mk_implies (pc, succ)) 1437 return succs 1438 1439eq_ineq_ops = set (['Equals', 'Less', 'LessEquals', 1440 'SignedLess', 'SignedLessEquals']) 1441 1442def split_linear_eq (cond): 1443 if cond.is_op ('Not'): 1444 [c] = cond.vals 1445 return split_linear_eq (c) 1446 elif cond.is_op (eq_ineq_ops): 1447 return (cond.vals[0], cond.vals[1]) 1448 elif cond.is_op ('PArrayValid'): 1449 [htd, typ_expr, p, num] = cond.vals 1450 assert typ_expr.kind == 'Type' 1451 typ = typ_expr.val 1452 return split_linear_eq (logic.mk_array_size_ineq (typ, num, p)) 1453 else: 1454 return None 1455 1456def possibly_linear_ineq (cond): 1457 rv = split_linear_eq (cond) 1458 if not rv: 1459 return False 1460 (lhs, rhs) = rv 1461 return logic.possibly_linear (lhs) and logic.possibly_linear (rhs) 1462 1463def linear_const_comparison (p, n, cond): 1464 """examines a condition. if it is a linear (e.g. Less) comparison 1465 between a linear series variable and a loop-constant expression, 1466 return (linear side, const side), or None if not the case.""" 1467 rv = split_linear_eq (cond) 1468 loop_head = p.loop_id (n) 1469 if not rv: 1470 return None 1471 (lhs, rhs) = rv 1472 zero = mk_num (0, lhs.typ) 1473 offs = logic.get_loop_linear_offs (p, loop_head) 1474 (lhs_offs, rhs_offs) = [offs (n, expr) for expr in [lhs, rhs]] 1475 oset = set ([lhs_offs, rhs_offs]) 1476 if zero in oset and None not in oset and len (oset) > 1: 1477 if lhs_offs == zero: 1478 return (rhs, lhs) 1479 else: 1480 return (lhs, rhs) 1481 return None 1482 1483def do_linear_rev_test (rep, restrs, hyps, split, eqs_assume, pred, large): 1484 p = rep.p 1485 (tag, _) = p.node_tags[split] 1486 checks = (check.single_loop_rev_induct_checks (p, restrs, hyps, tag, 1487 split, eqs_assume, pred) 1488 + check.single_loop_rev_induct_base_checks (p, restrs, hyps, 1489 tag, split, large, eqs_assume, pred)) 1490 1491 groups = check.proof_check_groups (checks) 1492 for group in groups: 1493 (res, _) = check.test_hyp_group (rep, group) 1494 if not res: 1495 return False 1496 return True 1497 1498def get_extra_assn_linear_conds (expr): 1499 if expr.is_op ('And'): 1500 return [cond for conj in logic.split_conjuncts (expr) 1501 for cond in get_extra_assn_linear_conds (conj)] 1502 if not expr.is_op ('Or'): 1503 return [expr] 1504 arr_vs = [v for v in expr.vals if v.is_op ('PArrayValid')] 1505 if not arr_vs: 1506 return [expr] 1507 [htd, typ_expr, p, num] = arr_vs[0].vals 1508 assert typ_expr.kind == 'Type' 1509 typ = typ_expr.val 1510 less_eq = logic.mk_array_size_ineq (typ, num, p) 1511 assn = logic.mk_align_valid_ineq (('Array', typ, num), p) 1512 return get_extra_assn_linear_conds (assn) + [less_eq] 1513 1514def get_rhs_speculate_ineq (p, restrs, loop_head): 1515 assert p.loop_id (loop_head), loop_head 1516 loop_head = p.loop_id (loop_head) 1517 restrs = tuple ([(n, vc) for (n, vc) in restrs 1518 if p.node_tags[n][0] == p.node_tags[loop_head][0]]) 1519 key = ('rhs_speculate_ineq', restrs, loop_head) 1520 if key in p.cached_analysis: 1521 return p.cached_analysis[key] 1522 1523 res = rhs_speculate_ineq (p, restrs, loop_head) 1524 p.cached_analysis[key] = res 1525 return res 1526 1527def get_new_rhs_speculate_ineq (p, restrs, hyps, loop_head): 1528 res = get_rhs_speculate_ineq (p, restrs, loop_head) 1529 if res == None: 1530 return None 1531 (point, _, (pred, _)) = res 1532 hs = [h for h in hyps if point in [n for ((n, _), _) in h.visits ()] 1533 if pred in h.get_vals ()] 1534 if hs: 1535 return None 1536 return res 1537 1538def rhs_speculate_ineq (p, restrs, loop_head): 1539 """code for handling an interesting case in which the compiler 1540 knows that the RHS program might fail in the future. for instance, 1541 consider a loop that cannot be exited until iterator i reaches value n. 1542 any error condition which implies i < b must hold of i - 1, thus 1543 n <= b. 1544 1545 detects this case and identifies the inequality n <= b""" 1546 body = p.loop_body (loop_head) 1547 1548 # if the loop contains function calls, skip it, 1549 # otherwise we need to figure out whether they terminate 1550 if [n for n in body if p.nodes[n].kind == 'Call']: 1551 return None 1552 1553 exit_nodes = set ([n for n in body for n2 in p.nodes[n].get_conts () 1554 if n2 != 'Err' if n2 not in body]) 1555 assert set ([p.nodes[n].kind for n in exit_nodes]) <= set (['Cond']) 1556 1557 # if there are multiple exit conditions, too hard for now 1558 if len (exit_nodes) > 1: 1559 return None 1560 1561 [exit_n] = list (exit_nodes) 1562 rv = linear_const_comparison (p, exit_n, p.nodes[exit_n].cond) 1563 if not rv: 1564 return None 1565 (linear, const) = rv 1566 1567 err_cond_sites = [(n, p.nodes[n].err_cond ()) for n in body] 1568 err_conds = set ([(n, cond) for (n, err_cond) in err_cond_sites 1569 if err_cond 1570 for assn in logic.split_conjuncts (mk_not (err_cond)) 1571 for cond in get_extra_assn_linear_conds (assn) 1572 if possibly_linear_ineq (cond)]) 1573 if not err_conds: 1574 return None 1575 1576 assert const.typ.kind == 'Word' 1577 rep = rep_graph.mk_graph_slice (p) 1578 eqs = mk_seq_eqs (p, exit_n, 1, False) 1579 import loop_bounds 1580 eqs += loop_bounds.get_linear_series_eqs (p, exit_n, 1581 restrs, [], omit_standard = True) 1582 1583 large = (2 ** const.typ.num) - 3 1584 const_less = lambda n: mk_less (const, mk_num (n, const.typ)) 1585 less_test = lambda n: do_linear_rev_test (rep, restrs, [], exit_n, 1586 eqs, const_less (n), large) 1587 const_ge = lambda n: mk_less (mk_num (n, const.typ), const) 1588 ge_test = lambda n: do_linear_rev_test (rep, restrs, [], exit_n, 1589 eqs, const_ge (n), large) 1590 1591 res = logic.binary_search_least (less_test, 1, large) 1592 if res: 1593 return (loop_head, (eqs, 1), (const_less (res), large)) 1594 res = logic.binary_search_greatest (ge_test, 0, large) 1595 if res: 1596 return (loop_head, (eqs, 1), (const_ge (res), large)) 1597 return None 1598 1599def check_split_induct (p, restrs, hyps, split, tags = None): 1600 """perform both the induction check and a function-call based check 1601 on successes which can avoid some problematic inductions.""" 1602 ((l_split, (_, l_step), _), (r_split, (_, r_step), _), _, n, _) = split 1603 if tags == None: 1604 tags = p.pairing.tags 1605 1606 err_hyp = check.split_r_err_pc_hyp (p, split, restrs, tags = tags) 1607 hyps = [err_hyp] + hyps + check.split_loop_hyps (tags, split, 1608 restrs, exit = False) 1609 1610 rep = mk_graph_slice (p) 1611 1612 if not check.check_split_induct_step_group (rep, restrs, hyps, split, 1613 tags = tags): 1614 return False 1615 1616 l_succs = get_n_offset_successes (rep, l_split, l_step, restrs) 1617 r_succs = get_n_offset_successes (rep, r_split, r_step, restrs) 1618 1619 if not l_succs: 1620 return True 1621 1622 hyp = syntax.foldr1 (syntax.mk_and, l_succs) 1623 if r_succs: 1624 hyp = syntax.mk_implies (foldr1 (syntax.mk_and, r_succs), hyp) 1625 1626 return rep.test_hyp_whyps (hyp, hyps) 1627 1628def init_loops_to_split (p, restrs): 1629 to_split = loops_to_split (p, restrs) 1630 1631 return [n for n in to_split 1632 if not [n2 for n2 in to_split if n2 != n 1633 and p.is_reachable_from (n2, n)]] 1634 1635def restr_others_both (p, restrs, n, m): 1636 extras = [(sp, vc_double_range (n, m)) 1637 for sp in loops_to_split (p, restrs)] 1638 return restrs + tuple (extras) 1639 1640def restr_others_as_necessary (p, n, restrs, init_bound, offs_bound, 1641 skip_loops = []): 1642 extras = [(sp, vc_double_range (init_bound, offs_bound)) 1643 for sp in loops_to_split (p, restrs) 1644 if sp not in skip_loops 1645 if p.is_reachable_from (sp, n)] 1646 return restrs + tuple (extras) 1647 1648def loop_no_match_unroll (rep, restrs, hyps, split, other_tag, unroll): 1649 p = rep.p 1650 assert p.node_tags[split][0] != other_tag 1651 restr = ((split, vc_num (unroll)), ) 1652 restrs2 = restr_others (p, restr + restrs, 2) 1653 loop_cond = rep.get_pc ((split, restr + restrs)) 1654 ret_cond = rep.get_pc (('Ret', restrs2), tag = other_tag) 1655 # loop should be reachable 1656 if rep.test_hyp_whyps (mk_not (loop_cond), hyps): 1657 trace ('Loop weak at %d (unroll count %d).' % 1658 (split, unroll)) 1659 return True 1660 # reaching the loop should imply reaching a loop on the other side 1661 hyp = mk_not (mk_and (loop_cond, ret_cond)) 1662 if not rep.test_hyp_whyps (hyp, hyps): 1663 trace ('Loop independent at %d (unroll count %d).' % 1664 (split, unroll)) 1665 return True 1666 return False 1667 1668def loop_no_match (rep, restrs, hyps, split, other_tag, 1669 check_speculate_ineq = False): 1670 if not loop_no_match_unroll (rep, restrs, hyps, split, other_tag, 4): 1671 return False 1672 if not loop_no_match_unroll (rep, restrs, hyps, split, other_tag, 8): 1673 return False 1674 if not check_speculate_ineq: 1675 return 'Restr' 1676 spec = get_new_rhs_speculate_ineq (rep.p, restrs, hyps, split) 1677 if not spec: 1678 return 'Restr' 1679 hyp = check.single_induct_resulting_hyp (rep.p, restrs, spec) 1680 hyps2 = hyps + [hyp] 1681 if not loop_no_match_unroll (rep, restrs, hyps2, split, other_tag, 8): 1682 return 'SingleRevInduct' 1683 return 'Restr' 1684 1685last_searcher_results = [] 1686 1687def restr_point_name (p, n): 1688 if p.loop_id (n): 1689 return '%s (loop head)' % n 1690 elif p.loop_id (n): 1691 return '%s (in loop %d)' % (n, p.loop_id (n)) 1692 else: 1693 return str (n) 1694 1695def fail_searcher (p, restrs, hyps): 1696 return ('Fail Searcher', None) 1697 1698def build_proof_rec (searcher, p, restrs, hyps, name = "problem"): 1699 trace ('doing build proof rec with restrs = %r, hyps = %r' % (restrs, hyps)) 1700 if searcher == None: 1701 searcher = default_searcher 1702 1703 (kind, details) = searcher (p, restrs, hyps) 1704 last_searcher_results.append ((p, restrs, hyps, kind, details, name)) 1705 del last_searcher_results[:-10] 1706 if kind == 'Restr': 1707 (restr_kind, restr_points) = details 1708 printout ("Discovered that points [%s] can be bounded" 1709 % ', '.join ([restr_point_name (p, n) 1710 for n in restr_points])) 1711 printout (" (in %s)" % name) 1712 restr_hints = [(n, restr_kind, True) for n in restr_points] 1713 return build_proof_rec_with_restrs (restr_hints, 1714 searcher, p, restrs, hyps, name = name) 1715 elif kind == 'Leaf': 1716 return ProofNode ('Leaf', None, ()) 1717 assert kind in ['CaseSplit', 'Split', 'SingleRevInduct'], kind 1718 if kind == 'CaseSplit': 1719 (details, hints) = details 1720 probs = check.proof_subproblems (p, kind, details, restrs, hyps, name) 1721 if kind == 'CaseSplit': 1722 printout ("Decided to case split at %s" % str (details)) 1723 printout (" (in %s)" % name) 1724 restr_hints = [[(n, 'Number', False) for n in hints] 1725 for cases in [0, 1]] 1726 elif kind == 'SingleRevInduct': 1727 printout ('Found a future induction at %s' % str (details[0])) 1728 restr_hints = [[]] 1729 else: 1730 restr_points = check.split_heads (details) 1731 restr_hints = [[(n, rkind, True) for n in restr_points] 1732 for rkind in ['Number', 'Offset']] 1733 printout ("Discovered a loop relation for split points %s" 1734 % list (restr_points)) 1735 printout (" (in %s)" % name) 1736 subpfs = [] 1737 for ((restrs, hyps, name), hints) in logic.azip (probs, restr_hints): 1738 printout ('Now doing proof search in %s.' % name) 1739 pf = build_proof_rec_with_restrs (hints, searcher, 1740 p, restrs, hyps, name = name) 1741 subpfs.append (pf) 1742 return ProofNode (kind, details, subpfs) 1743 1744def build_proof_rec_with_restrs (split_hints, searcher, p, restrs, 1745 hyps, name = "problem"): 1746 if not split_hints: 1747 return build_proof_rec (searcher, p, restrs, hyps, name = name) 1748 1749 (sp, kind, must_find) = split_hints[0] 1750 use_hyps = list (hyps) 1751 if p.node_tags[sp][0] != p.pairing.tags[1]: 1752 nrerr_hyp = check.non_r_err_pc_hyp (p.pairing.tags, 1753 restr_others (p, restrs, 2)) 1754 use_hyps = use_hyps + [nrerr_hyp] 1755 1756 if p.loop_id (sp): 1757 lim_pair = get_proof_split_limit (p, sp, restrs, use_hyps, 1758 kind, must_find = must_find) 1759 else: 1760 lim_pair = get_proof_visit_restr (p, sp, restrs, use_hyps, 1761 kind, must_find = must_find) 1762 1763 if not lim_pair: 1764 assert not must_find 1765 return build_proof_rec_with_restrs (split_hints[1:], 1766 searcher, p, restrs, hyps, name = name) 1767 1768 (min_v, max_v) = lim_pair 1769 if kind == 'Number': 1770 vc_opts = rep_graph.vc_options (range (min_v, max_v), []) 1771 else: 1772 vc_opts = rep_graph.vc_options ([], range (min_v, max_v)) 1773 1774 restrs = restrs + ((sp, vc_opts), ) 1775 subproof = build_proof_rec_with_restrs (split_hints[1:], 1776 searcher, p, restrs, hyps, name = name) 1777 1778 return ProofNode ('Restr', (sp, (kind, (min_v, max_v))), [subproof]) 1779 1780def get_proof_split_limit (p, sp, restrs, hyps, kind, must_find = False): 1781 limit = get_split_limit (p, sp, restrs, hyps, kind, 1782 must_find = must_find) 1783 if limit == None: 1784 return None 1785 # double-check this limit with a rep constructed without the 'fast' flag 1786 limit = find_split_limit (p, sp, restrs, hyps, kind, 1787 hints = [limit, limit + 1], use_rep = mk_graph_slice (p)) 1788 return (0, limit + 1) 1789 1790def get_proof_visit_restr (p, sp, restrs, hyps, kind, must_find = False): 1791 rep = rep_graph.mk_graph_slice (p) 1792 pc = rep.get_pc ((sp, restrs)) 1793 if rep.test_hyp_whyps (pc, hyps): 1794 return (1, 2) 1795 elif rep.test_hyp_whyps (mk_not (pc), hyps): 1796 return (0, 1) 1797 else: 1798 assert not must_find 1799 return None 1800 1801def default_searcher (p, restrs, hyps): 1802 # use any handy init splits 1803 res = init_proof_case_split (p, restrs, hyps) 1804 if res: 1805 return res 1806 1807 # detect any un-split loops 1808 to_split_init = init_loops_to_split (p, restrs) 1809 rep = mk_graph_slice (p, fast = True) 1810 1811 l_tag, r_tag = p.pairing.tags 1812 l_to_split = [n for n in to_split_init if p.node_tags[n][0] == l_tag] 1813 r_to_split = [n for n in to_split_init if p.node_tags[n][0] == r_tag] 1814 l_ep = p.get_entry (l_tag) 1815 r_ep = p.get_entry (r_tag) 1816 1817 for r_sp in r_to_split: 1818 trace ('checking loop_no_match at %d' % r_sp, push = 1) 1819 res = loop_no_match (rep, restrs, hyps, r_sp, l_tag, 1820 check_speculate_ineq = True) 1821 if res == 'Restr': 1822 return ('Restr', ('Number', [r_sp])) 1823 elif res == 'SingleRevInduct': 1824 spec = get_new_rhs_speculate_ineq (p, restrs, hyps, r_sp) 1825 assert spec 1826 return ('SingleRevInduct', spec) 1827 trace (' .. done checking loop no match', push = -1) 1828 1829 if l_to_split and not r_to_split: 1830 n = l_to_split[0] 1831 trace ('lhs loop alone, limit must be found.') 1832 return ('Restr', ('Number', [n])) 1833 1834 if l_to_split: 1835 n = l_to_split[0] 1836 trace ('checking lhs loop_no_match at %d' % n, push = 1) 1837 if loop_no_match (rep, restrs, hyps, n, r_tag): 1838 trace ('loop does not match!', push = -1) 1839 return ('Restr', ('Number', [n])) 1840 trace (' .. done checking loop no match', push = -1) 1841 1842 (kind, split) = find_split_loop (p, n, restrs, hyps) 1843 if kind == 'LoopUnroll': 1844 return ('Restr', ('Number', [split])) 1845 return (kind, split) 1846 1847 if r_to_split: 1848 n = r_to_split[0] 1849 trace ('rhs loop alone, limit must be found.') 1850 return ('Restr', ('Number', [n])) 1851 1852 return ('Leaf', None) 1853 1854def use_split_searcher (p, split): 1855 xs = set ([p.loop_id (h) for h in check.split_heads (split)]) 1856 def searcher (p, restrs, hyps): 1857 ys = set ([p.loop_id (h) 1858 for h in init_loops_to_split (p, restrs)]) 1859 if xs <= ys: 1860 return ('Split', split) 1861 else: 1862 return default_searcher (p, restrs, hyps) 1863 return searcher 1864 1865