1# 2# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230) 3# 4# SPDX-License-Identifier: BSD-2-Clause 5# 6 7import solver 8from solver import mk_smt_expr, to_smt_expr, smt_expr 9import check 10from check import restr_others, loops_to_split, ProofNode 11from rep_graph import (mk_graph_slice, vc_num, vc_offs, vc_upto, 12 vc_double_range, VisitCount, vc_offset_upto) 13import rep_graph 14from syntax import (mk_and, mk_cast, mk_implies, mk_not, mk_uminus, mk_var, 15 foldr1, boolT, word32T, word8T, builtinTs, true_term, false_term, 16 mk_word32, mk_word8, mk_times, Expr, Type, mk_or, mk_eq, mk_memacc, 17 mk_num, mk_minus, mk_plus, mk_less) 18import syntax 19import logic 20 21from target_objects import trace, printout 22import target_objects 23import itertools 24 25last_knowledge = [1] 26 27class NoSplit(Exception): 28 pass 29 30def get_loop_var_analysis_at (p, n): 31 k = ('search_loop_var_analysis', n) 32 if k in p.cached_analysis: 33 return p.cached_analysis[k] 34 for hook in target_objects.hooks ('loop_var_analysis'): 35 res = hook (p, n) 36 if res != None: 37 p.cached_analysis[k] = res 38 return res 39 var_deps = p.compute_var_dependencies () 40 res = p.get_loop_var_analysis (var_deps, n) 41 p.cached_analysis[k] = res 42 return res 43 44def get_loop_vars_at (p, n): 45 vs = [var for (var, data) in get_loop_var_analysis_at (p, n) 46 if data == 'LoopVariable'] + [mk_word32 (0)] 47 vs.sort () 48 return vs 49 50default_loop_N = 3 51 52last_proof = [None] 53 54def build_proof (p): 55 init_hyps = check.init_point_hyps (p) 56 proof = build_proof_rec (default_searcher, p, (), list (init_hyps)) 57 58 trace ('Built proof for %s' % p.name) 59 printout (repr (proof)) 60 last_proof[0] = proof 61 62 return proof 63 64def split_sample_set (bound): 65 ns = (range (10) + range (10, 20, 2) 66 + range (20, 40, 5) + range (40, 100, 10) 67 + range (100, 1000, 50)) 68 return [n for n in ns if n < bound] 69 70last_find_split_limit = [0] 71 72def find_split_limit (p, n, restrs, hyps, kind, bound = 51, must_find = True, 73 hints = [], use_rep = None): 74 tag = p.node_tags[n][0] 75 trace ('Finding split limit: %d (%s)' % (n, tag)) 76 last_find_split_limit[0] = (p, n, restrs, hyps, kind) 77 if use_rep == None: 78 rep = mk_graph_slice (p, fast = True) 79 else: 80 rep = use_rep 81 check_order = hints + split_sample_set (bound) + [bound] 82 # bounds strictly outside this range won't be considered 83 bound_range = [0, bound] 84 best_bound_found = [None] 85 def check (i): 86 if i < bound_range[0]: 87 return True 88 if i > bound_range[1]: 89 return False 90 restrs2 = restrs + ((n, VisitCount (kind, i)), ) 91 pc = rep.get_pc ((n, restrs2)) 92 restrs3 = restr_others (p, restrs2, 2) 93 epc = rep.get_pc (('Err', restrs3), tag = tag) 94 hyp = mk_implies (mk_not (epc), mk_not (pc)) 95 res = rep.test_hyp_whyps (hyp, hyps) 96 if res: 97 trace ('split limit found: %d' % i) 98 bound_range[1] = i - 1 99 best_bound_found[0] = i 100 else: 101 bound_range[0] = i + 1 102 return res 103 104 map (check, check_order) 105 while bound_range[0] <= bound_range[1]: 106 split = (bound_range[0] + bound_range[1]) / 2 107 check (split) 108 109 bound = best_bound_found[0] 110 if bound == None: 111 trace ('No split limit found for %d (%s).' % (n, tag)) 112 if must_find: 113 assert not 'split limit found' 114 return bound 115 116def get_split_limit (p, n, restrs, hyps, kind, bound = 51, 117 must_find = True, est_bound = 1, hints = None): 118 k = ('SplitLimit', n, restrs, tuple (hyps), kind) 119 if k in p.cached_analysis: 120 (lim, prev_bound) = p.cached_analysis[k] 121 if lim != None or bound <= prev_bound: 122 return lim 123 if hints == None: 124 hints = [est_bound, est_bound + 1, est_bound + 2] 125 res = find_split_limit (p, n, restrs, hyps, kind, 126 hints = hints, must_find = must_find, bound = bound) 127 p.cached_analysis[k] = (res, bound) 128 return res 129 130def init_case_splits (p, hyps, tags = None): 131 if 'init_case_splits' in p.cached_analysis: 132 return p.cached_analysis['init_case_splits'] 133 if tags == None: 134 tags = p.pairing.tags 135 poss = logic.possible_graph_divs (p) 136 if len (set ([p.node_tags[n][0] for n in poss])) < 2: 137 return None 138 rep = rep_graph.mk_graph_slice (p) 139 assert all ([p.nodes[n].kind == 'Cond' for n in poss]) 140 pc_map = logic.dict_list ([(rep.get_pc ((c, ())), c) 141 for n in poss for c in p.nodes[n].get_conts () 142 if c not in p.loop_data]) 143 no_loop_restrs = tuple ([(n, vc_num (0)) for n in p.loop_heads ()]) 144 err_pc_hyps = [rep_graph.pc_false_hyp ((('Err', no_loop_restrs), tag)) 145 for tag in p.pairing.tags] 146 knowledge = EqSearchKnowledge (rep, hyps + err_pc_hyps, list (pc_map)) 147 last_knowledge[0] = knowledge 148 pc_ids = knowledge.classify_vs () 149 id_n_map = logic.dict_list ([(i, n) for (pc, i) in pc_ids.iteritems () 150 for n in pc_map[pc]]) 151 tag_div_ns = [[[n for n in ns if p.node_tags[n][0] == t] for t in tags] 152 for (i, ns) in id_n_map.iteritems ()] 153 split_pairs = [(l_ns[0], r_ns[0]) for (l_ns, r_ns) in tag_div_ns 154 if l_ns and r_ns] 155 p.cached_analysis['init_case_splits'] = split_pairs 156 return split_pairs 157 158case_split_tr = [] 159 160def init_proof_case_split (p, restrs, hyps): 161 ps = init_case_splits (p, hyps) 162 if ps == None: 163 return None 164 p.cached_analysis.setdefault ('finished_init_case_splits', []) 165 fin = p.cached_analysis['finished_init_case_splits'] 166 known_s = set.union (set (restrs), set (hyps)) 167 for rs in fin: 168 if rs <= known_s: 169 return None 170 rep = rep_graph.mk_graph_slice (p) 171 no_loop_restrs = tuple ([(n, vc_num (0)) for n in p.loop_heads ()]) 172 err_pc_hyps = [rep_graph.pc_false_hyp ((('Err', no_loop_restrs), tag)) 173 for tag in p.pairing.tags] 174 for (n1, n2) in ps: 175 pc = rep.get_pc ((n1, ())) 176 if rep.test_hyp_whyps (pc, hyps + err_pc_hyps): 177 continue 178 if rep.test_hyp_whyps (mk_not (pc), hyps + err_pc_hyps): 179 continue 180 case_split_tr.append ((n1, restrs, hyps)) 181 return ('CaseSplit', ((n1, p.node_tags[n1][0]), [n1, n2])) 182 fin.append (known_s) 183 return None 184 185# TODO: deal with all the code duplication between these two searches 186class EqSearchKnowledge: 187 def __init__ (self, rep, hyps, vs): 188 self.rep = rep 189 self.hyps = hyps 190 self.v_ids = dict ([(v, 1) for v in vs]) 191 self.model_trace = [] 192 self.facts = set () 193 self.premise = foldr1 (mk_and, map (rep.interpret_hyp, hyps)) 194 195 def add_model (self, m): 196 self.model_trace.append (m) 197 update_v_ids_for_model2 (self, self.v_ids, m) 198 199 def hyps_add_model (self, hyps): 200 if hyps: 201 test_expr = foldr1 (mk_and, hyps) 202 else: 203 # we want to learn something, either a new model, or 204 # that all hyps are true. if there are no hyps, 205 # learning they're all true is learning nothing. 206 # instead force a model 207 test_expr = false_term 208 test_expr = mk_implies (self.premise, test_expr) 209 m = {} 210 (r, _) = self.rep.solv.parallel_check_hyps ([(1, test_expr)], 211 {}, model = m) 212 if r == 'unsat': 213 if not hyps: 214 trace ('WARNING: EqSearchKnowledge: premise unsat.') 215 trace (" ... learning procedure isn't going to work.") 216 for hyp in hyps: 217 self.facts.add (hyp) 218 else: 219 assert r == 'sat', r 220 self.add_model (m) 221 222 def classify_vs (self): 223 while not self.facts: 224 hyps = v_id_eq_hyps (self.v_ids) 225 if not hyps: 226 break 227 self.hyps_add_model (hyps) 228 return self.v_ids 229 230def update_v_ids_for_model2 (knowledge, v_ids, m): 231 # first update the live variables 232 ev = lambda v: eval_model_expr (m, knowledge.rep.solv, v) 233 groups = logic.dict_list ([((k, ev (v)), v) 234 for (v, k) in v_ids.iteritems ()]) 235 v_ids.clear () 236 for (i, kt) in enumerate (sorted (groups)): 237 for v in groups[kt]: 238 v_ids[v] = i 239 240def v_id_eq_hyps (v_ids): 241 groups = logic.dict_list ([(k, v) for (v, k) in v_ids.iteritems ()]) 242 hyps = [] 243 for vs in groups.itervalues (): 244 for v in vs[1:]: 245 hyps.append (mk_eq (v, vs[0])) 246 return hyps 247 248class SearchKnowledge: 249 def __init__ (self, rep, name, restrs, hyps, tags, cand_elts = None): 250 self.rep = rep 251 self.name = name 252 self.restrs = restrs 253 self.hyps = hyps 254 self.tags = tags 255 if cand_elts != None: 256 (loop_elts, r_elts) = cand_elts 257 else: 258 (loop_elts, r_elts) = ([], []) 259 (pairs, vs) = init_knowledge_pairs (rep, loop_elts, r_elts) 260 self.pairs = pairs 261 self.v_ids = vs 262 self.model_trace = [] 263 self.facts = set () 264 self.weak_splits = set () 265 self.premise = syntax.true_term 266 self.live_pairs_trace = [] 267 268 def add_model (self, m): 269 self.model_trace.append (m) 270 update_v_ids_for_model (self, self.pairs, self.v_ids, m) 271 272 def hyps_add_model (self, hyps, assert_progress = True): 273 if hyps: 274 test_expr = foldr1 (mk_and, hyps) 275 else: 276 # we want to learn something, either a new model, or 277 # that all hyps are true. if there are no hyps, 278 # learning they're all true is learning nothing. 279 # instead force a model 280 test_expr = false_term 281 test_expr = mk_implies (self.premise, test_expr) 282 m = {} 283 (r, _) = self.rep.solv.parallel_check_hyps ([(1, test_expr)], 284 {}, model = m) 285 if r == 'unsat': 286 if not hyps: 287 trace ('WARNING: SearchKnowledge: premise unsat.') 288 trace (" ... learning procedure isn't going to work.") 289 return 290 if assert_progress: 291 assert not (set (hyps) <= self.facts), hyps 292 for hyp in hyps: 293 self.facts.add (hyp) 294 else: 295 assert r == 'sat', r 296 self.add_model (m) 297 if assert_progress: 298 assert self.model_trace[-2:-1] != [m] 299 300 def eqs_add_model (self, eqs, assert_progress = True): 301 preds = [pred for vpair in eqs 302 for pred in expand_var_eqs (self, vpair) 303 if pred not in self.facts] 304 305 self.hyps_add_model (preds, 306 assert_progress = assert_progress) 307 308 def add_weak_split (self, eqs): 309 preds = [pred for vpair in eqs 310 for pred in expand_var_eqs (self, vpair)] 311 self.weak_splits.add (tuple (sorted (preds))) 312 313 def is_weak_split (self, eqs): 314 preds = [pred for vpair in eqs 315 for pred in expand_var_eqs (self, vpair)] 316 return tuple (sorted (preds)) in self.weak_splits 317 318def init_knowledge_pairs (rep, loop_elts, cand_r_loop_elts): 319 trace ('Doing search knowledge setup now.') 320 v_is = [(i, i_offs, i_step, 321 [(v, i, i_offs, i_step) for v in get_loop_vars_at (rep.p, i)]) 322 for (i, i_offs, i_step) in sorted (loop_elts)] 323 l_vtyps = set ([v[0].typ for (_, _, _, vs) in v_is for v in vs]) 324 v_js = [(j, j_offs, j_step, 325 [(v, j, j_offs, j_step) for v in get_loop_vars_at (rep.p, j) 326 if v.typ in l_vtyps]) 327 for (j, j_offs, j_step) in sorted (cand_r_loop_elts)] 328 vs = {} 329 for (_, _, _, var_vs) in v_is + v_js: 330 for v in var_vs: 331 vs[v] = (v[0].typ, True) 332 pairs = {} 333 for (i, i_offs, i_step, i_vs) in v_is: 334 for (j, j_offs, j_step, j_vs) in v_js: 335 pair = ((i, i_offs, i_step), (j, j_offs, j_step)) 336 pairs[pair] = (i_vs, j_vs) 337 trace ('... done.') 338 return (pairs, vs) 339 340def update_v_ids_for_model (knowledge, pairs, vs, m): 341 rep = knowledge.rep 342 # first update the live variables 343 groups = {} 344 for v in vs: 345 (k, const) = vs[v] 346 groups.setdefault (k, []) 347 groups[k].append ((v, const)) 348 k_counter = 1 349 vs.clear () 350 for k in groups: 351 for (const, xs) in split_group (knowledge, m, groups[k]): 352 for x in xs: 353 vs[x] = (k_counter, const) 354 k_counter += 1 355 # then figure out which pairings are still viable 356 needed_ks = set () 357 zero = syntax.mk_word32 (0) 358 for (pair, data) in pairs.items (): 359 if data[0] == 'Failed': 360 continue 361 (lvs, rvs) = data 362 lv_ks = set ([vs[v][0] for v in lvs 363 if v[0] == zero or not vs[v][1]]) 364 rv_ks = set ([vs[v][0] for v in rvs]) 365 miss_vars = lv_ks - rv_ks 366 if miss_vars: 367 lv_miss = [v[0] for v in lvs if vs[v][0] in miss_vars] 368 pairs[pair] = ('Failed', lv_miss.pop ()) 369 else: 370 needed_ks.update ([vs[v][0] for v in lvs + rvs]) 371 # then drop any vars which are no longer relevant 372 for v in vs.keys (): 373 if vs[v][0] not in needed_ks: 374 del vs[v] 375 376def get_entry_visits_up_to (rep, head, restrs, hyps): 377 """get the set of nodes visited on the entry path entry 378 to the loop, up to and including the head point.""" 379 k = ('loop_visits_up_to', head, restrs, tuple (hyps)) 380 if k in rep.p.cached_analysis: 381 return rep.p.cached_analysis[k] 382 383 [entry] = get_loop_entry_sites (rep, restrs, hyps, head) 384 frontier = set ([entry]) 385 up_to = set () 386 loop = rep.p.loop_body (head) 387 while frontier: 388 n = frontier.pop () 389 if n == head: 390 continue 391 new_conts = [n2 for n2 in rep.p.nodes[n].get_conts () 392 if n2 in loop if n2 not in up_to] 393 up_to.update (new_conts) 394 frontier.update (new_conts) 395 rep.p.cached_analysis[k] = up_to 396 return up_to 397 398def get_nth_visit_restrs (rep, restrs, hyps, i, visit_num): 399 """get the nth (visit_num-th) visit to node i, using its loop head 400 as a restriction point. tricky because there may be a loop entry point 401 that brings us in with the loop head before i, or vice-versa.""" 402 head = rep.p.loop_id (i) 403 if i in get_entry_visits_up_to (rep, head, restrs, hyps): 404 # node i is in the set visited on the entry path, so 405 # the head is visited no more often than it 406 offs = 0 407 else: 408 # these are visited after the head point on the entry path, 409 # so the head point is visited 1 more time than it. 410 offs = 1 411 return ((head, vc_num (visit_num + offs)), ) + restrs 412 413def get_var_pc_var_list (knowledge, v_i): 414 rep = knowledge.rep 415 (v_i, i, i_offs, i_step) = v_i 416 def get_var (k): 417 restrs2 = get_nth_visit_restrs (rep, knowledge.restrs, 418 knowledge.hyps, i, k) 419 (pc, env) = rep.get_node_pc_env ((i, restrs2)) 420 return (to_smt_expr (pc, env, rep.solv), 421 to_smt_expr (v_i, env, rep.solv)) 422 return [get_var (i_offs + (k * i_step)) 423 for k in [0, 1, 2]] 424 425def expand_var_eqs (knowledge, (v_i, v_j)): 426 if v_j == 'Const': 427 pc_vs = get_var_pc_var_list (knowledge, v_i) 428 (_, v0) = pc_vs[0] 429 return [mk_implies (pc, mk_eq (v, v0)) 430 for (pc, v) in pc_vs[1:]] 431 # sorting the vars guarantees we generate the same 432 # mem eqs each time which is important for the solver 433 (v_i, v_j) = sorted ([v_i, v_j]) 434 pc_vs = zip (get_var_pc_var_list (knowledge, v_i), 435 get_var_pc_var_list (knowledge, v_j)) 436 return [pred for ((pc_i, v_i), (pc_j, v_j)) in pc_vs 437 for pred in [mk_eq (pc_i, pc_j), 438 mk_implies (pc_i, logic.mk_eq_with_cast (v_i, v_j))]] 439 440word_ops = {'bvadd':lambda x, y: x + y, 'bvsub':lambda x, y: x - y, 441 'bvmul':lambda x, y: x * y, 'bvurem':lambda x, y: x % y, 442 'bvudiv':lambda x, y: x / y, 'bvand':lambda x, y: x & y, 443 'bvor':lambda x, y: x | y, 'bvxor': lambda x, y: x ^ y, 444 'bvnot': lambda x: ~ x, 'bvneg': lambda x: - x, 445 'bvshl': lambda x, y: x << y, 'bvlshr': lambda x, y: x >> y} 446 447bool_ops = {'=>':lambda x, y: (not x) or y, '=': lambda x, y: x == y, 448 'not': lambda x: not x, 'true': lambda: True, 'false': lambda: False} 449 450word_ineq_ops = {'=': (lambda x, y: x == y, 'Unsigned'), 451 'bvult': (lambda x, y: x < y, 'Unsigned'), 452 'word32-eq': (lambda x, y: x == y, 'Unsigned'), 453 'bvule': (lambda x, y: x <= y, 'Unsigned'), 454 'bvsle': (lambda x, y: x <= y, 'Signed'), 455 'bvslt': (lambda x, y: x < y, 'Signed'), 456} 457 458def eval_model (m, s, toplevel = None): 459 if s in m: 460 return m[s] 461 if toplevel == None: 462 toplevel = s 463 if type (s) == str: 464 try: 465 result = solver.smt_to_val (s) 466 except Exception, e: 467 trace ('Error with eval_model') 468 trace (toplevel) 469 raise e 470 return result 471 472 op = s[0] 473 474 if op == 'ite': 475 [_, b, x, y] = s 476 b = eval_model (m, b, toplevel) 477 assert b in [false_term, true_term] 478 if b == true_term: 479 result = eval_model (m, x, toplevel) 480 else: 481 result = eval_model (m, y, toplevel) 482 m[s] = result 483 return result 484 485 xs = [eval_model (m, x, toplevel) for x in s[1:]] 486 487 if op[0] == '_' and op[1] in ['zero_extend', 'sign_extend']: 488 [_, ex_kind, n_extend] = op 489 n_extend = int (n_extend) 490 [x] = xs 491 assert x.typ.kind == 'Word' and x.kind == 'Num' 492 if ex_kind == 'sign_extend': 493 val = get_signed_val (x) 494 else: 495 val = get_unsigned_val (x) 496 result = mk_num (val, x.typ.num + n_extend) 497 elif op[0] == '_' and op[1] == 'extract': 498 [_, _, n_top, n_bot] = op 499 n_top = int (n_top) 500 n_bot = int (n_bot) 501 [x] = xs 502 assert x.typ.kind == 'Word' and x.kind == 'Num' 503 length = (n_top - n_bot) + 1 504 result = mk_num ((x.val >> n_bot) & ((1 << length) - 1), length) 505 elif op[0] == 'store-word32': 506 (m, p, v) = xs 507 (naming, eqs) = m 508 eqs = dict (eqs) 509 eqs[p.val] = v.val 510 eqs = tuple (sorted (eqs.items ())) 511 result = (naming, eqs) 512 elif op[0] == 'store-word8': 513 (m, p, v) = xs 514 p_al = p.val & -4 515 shift = (p.val & 3) * 8 516 (naming, eqs) = m 517 eqs = dict (eqs) 518 prev_v = eqs[p_al] 519 mask_v = prev_v & (((1 << 32) - 1) ^ (255 << shift)) 520 new_v = mask_v | ((v.val & 255) << shift) 521 eqs[p.val] = new_v 522 eqs = tuple (sorted (eqs.items ())) 523 result = (naming, eqs) 524 elif op[0] == 'load-word32': 525 (m, p) = xs 526 (naming, eqs) = m 527 eqs = dict (eqs) 528 result = syntax.mk_word32 (eqs[p.val]) 529 elif op[0] == 'load-word8': 530 (m, p) = xs 531 p_al = p.val & -4 532 shift = (p.val & 3) * 8 533 (naming, eqs) = m 534 eqs = dict (eqs) 535 v = (eqs[p_al] >> shift) & 255 536 result = syntax.mk_word8 (v) 537 elif xs and xs[0].typ.kind == 'Word' and op in word_ops: 538 for x in xs: 539 assert x.kind == 'Num', (s, op, x) 540 result = word_ops[op](* [x.val for x in xs]) 541 result = result & ((1 << xs[0].typ.num) - 1) 542 result = Expr ('Num', xs[0].typ, val = result) 543 elif xs and xs[0].typ.kind == 'Word' and op in word_ineq_ops: 544 (oper, signed) = word_ineq_ops[op] 545 if signed == 'Signed': 546 result = oper (* map (get_signed_val, xs)) 547 else: 548 assert signed == 'Unsigned' 549 result = oper (* [x.val for x in xs]) 550 result = {True: true_term, False: false_term}[result] 551 elif op == 'and': 552 result = all ([x == true_term for x in xs]) 553 result = {True: true_term, False: false_term}[result] 554 elif op == 'or': 555 result = bool ([x for x in xs if x == true_term]) 556 result = {True: true_term, False: false_term}[result] 557 elif op in bool_ops: 558 assert all ([x.typ == boolT for x in xs]) 559 result = bool_ops[op](* [x == true_term for x in xs]) 560 result = {True: true_term, False: false_term}[result] 561 else: 562 assert not 's_expr handled', (s, op) 563 m[s] = result 564 return result 565 566def get_unsigned_val (x): 567 assert x.typ.kind == 'Word' 568 assert x.kind == 'Num' 569 bits = x.typ.num 570 v = x.val & ((1 << bits) - 1) 571 return v 572 573def get_signed_val (x): 574 assert x.typ.kind == 'Word' 575 assert x.kind == 'Num' 576 bits = x.typ.num 577 v = x.val & ((1 << bits) - 1) 578 if v >= (1 << (bits - 1)): 579 v = v - (1 << bits) 580 return v 581 582def short_array_str (arr): 583 items = [('%x: %x' % (p.val * 4, v.val)) 584 for (p, v) in arr.iteritems () 585 if type (p) != str] 586 items.sort () 587 return '{' + ', '.join (items) + '}' 588 589def eval_model_expr (m, solv, v): 590 s = solver.smt_expr (v, {}, solv) 591 s_x = solver.parse_s_expression (s) 592 593 return eval_model (m, s_x) 594 595def model_equal (m, knowledge, vpair): 596 preds = expand_var_eqs (knowledge, vpair) 597 for pred in preds: 598 x = eval_model_expr (m, knowledge.rep.solv, pred) 599 assert x in [syntax.true_term, syntax.false_term] 600 if x == syntax.false_term: 601 return False 602 return True 603 604def get_model_trace (knowledge, m, v): 605 rep = knowledge.rep 606 pc_vs = get_var_pc_var_list (knowledge, v) 607 trace = [] 608 for (pc, v) in pc_vs: 609 x = eval_model_expr (m, rep.solv, pc) 610 assert x in [syntax.true_term, syntax.false_term] 611 if x == syntax.false_term: 612 trace.append (None) 613 else: 614 trace.append (eval_model_expr (m, rep.solv, v)) 615 return tuple (trace) 616 617def split_group (knowledge, m, group): 618 group = list (set (group)) 619 if group[0][0][0].typ == syntax.builtinTs['Mem']: 620 bins = [] 621 for (v, const) in group: 622 for i in range (len (bins)): 623 if model_equal (m, knowledge, 624 (v, bins[i][1][0])): 625 bins[i][1].append (v) 626 break 627 else: 628 if const: 629 const = model_equal (m, knowledge, 630 (v, 'Const')) 631 bins.append ((const, [v])) 632 return bins 633 else: 634 bins = {} 635 for (v, const) in group: 636 trace = get_model_trace (knowledge, m, v) 637 if trace not in bins: 638 tconst = len (set (trace) - set ([None])) <= 1 639 bins[trace] = (const and tconst, []) 640 bins[trace][1].append (v) 641 return bins.values () 642 643def mk_pairing_v_eqs (knowledge, pair, endorsed = True): 644 v_eqs = [] 645 (lvs, rvs) = knowledge.pairs[pair] 646 zero = mk_word32 (0) 647 for v_i in lvs: 648 (k, const) = knowledge.v_ids[v_i] 649 if const and v_i[0] != zero: 650 if not endorsed or eq_known (knowledge, (v_i, 'Const')): 651 v_eqs.append ((v_i, 'Const')) 652 continue 653 vs_j = [v_j for v_j in rvs if knowledge.v_ids[v_j][0] == k] 654 if endorsed: 655 vs_j = [v_j for v_j in vs_j 656 if eq_known (knowledge, (v_i, v_j))] 657 if not vs_j: 658 return None 659 v_j = vs_j[0] 660 v_eqs.append ((v_i, v_j)) 661 return v_eqs 662 663def eq_known (knowledge, vpair): 664 preds = expand_var_eqs (knowledge, vpair) 665 return set (preds) <= knowledge.facts 666 667def find_split_loop (p, head, restrs, hyps, unfold_limit = 9, 668 node_restrs = None, trace_ind_fails = None): 669 assert p.loop_data[head][0] == 'Head' 670 assert p.node_tags[head][0] == p.pairing.tags[0] 671 672 # the idea is to loop through testable hyps, starting with ones that 673 # need smaller models (the most unfolded models will time out for 674 # large problems like finaliseSlot) 675 676 rep = mk_graph_slice (p, fast = True) 677 678 nec = get_necessary_split_opts (p, head, restrs, hyps) 679 if nec and nec[0] in ['CaseSplit', 'LoopUnroll']: 680 return nec 681 elif nec: 682 i_j_opts = nec 683 else: 684 i_j_opts = default_i_j_opts (unfold_limit) 685 686 if trace_ind_fails == None: 687 ind_fails = [] 688 else: 689 ind_fails = trace_ind_fails 690 for (i_opts, j_opts) in i_j_opts: 691 result = find_split (rep, head, restrs, hyps, 692 i_opts, j_opts, node_restrs = node_restrs) 693 if result[0] != None: 694 return result 695 ind_fails.extend (result[1]) 696 697 if ind_fails: 698 trace ('Warning: inductive failures: %s' % ind_fails) 699 raise NoSplit () 700 701def default_i_j_opts (unfold_limit = 9): 702 return mk_i_j_opts (unfold_limit = unfold_limit) 703 704def mk_i_j_opts (i_seq_opts = None, j_seq_opts = None, unfold_limit = 9): 705 if i_seq_opts == None: 706 i_seq_opts = [(0, 1), (1, 1), (2, 1), (3, 1)] 707 if j_seq_opts == None: 708 j_seq_opts = [(0, 1), (0, 2), (1, 1), (1, 2), 709 (2, 1), (2, 2), (3, 1)] 710 all_opts = set (i_seq_opts + j_seq_opts) 711 712 def filt (opts, lim): 713 return [(start, step) for (start, step) in opts 714 if start + (2 * step) + 1 <= lim] 715 716 lims = [(filt (i_seq_opts, lim), filt (j_seq_opts, lim)) 717 for lim in range (unfold_limit) 718 if [1 for (start, step) in all_opts 719 if start + (2 * step) + 1 == lim]] 720 lims = [(i_opts, j_opts) for (i_opts, j_opts) in lims 721 if i_opts and j_opts] 722 return lims 723 724necessary_split_opts_trace = [] 725 726def get_interesting_linear_series_exprs (p, head): 727 k = ('interesting_linear_series', head) 728 if k in p.cached_analysis: 729 return p.cached_analysis[k] 730 res = logic.interesting_linear_series_exprs (p, head, 731 get_loop_var_analysis_at (p, head)) 732 p.cached_analysis[k] = res 733 return res 734 735def split_opt_test (p, tags = None): 736 if not tags: 737 tags = p.pairing.tags 738 heads = [head for head in init_loops_to_split (p, ()) 739 if p.node_tags[head][0] == tags[0]] 740 hyps = check.init_point_hyps (p) 741 return [(head, get_necessary_split_opts (p, head, (), hyps)) 742 for head in heads] 743 744def interesting_linear_test (p): 745 p.do_analysis () 746 for head in p.loop_heads (): 747 inter = get_interesting_linear_series_exprs (p, head) 748 hooks = target_objects.hooks ('loop_var_analysis') 749 n_exprs = [(n, expr, offs) for (n, vs) in inter.iteritems () 750 if not [hook for hook in hooks if hook (p, n) != None] 751 for (kind, expr, offs) in vs] 752 if n_exprs: 753 rep = rep_graph.mk_graph_slice (p) 754 for (n, expr, offs) in n_exprs: 755 restrs = tuple ([(n2, vc) for (n2, vc) 756 in restr_others_both (p, (), 2, 2) 757 if p.loop_id (n2) != p.loop_id (head)]) 758 vis1 = (n, ((head, vc_offs (1)), ) + restrs) 759 vis2 = (n, ((head, vc_offs (2)), ) + restrs) 760 pc = rep.get_pc (vis2) 761 imp = mk_implies (pc, mk_eq (rep.to_smt_expr (expr, vis2), 762 rep.to_smt_expr (mk_plus (expr, offs), vis1))) 763 assert rep.test_hyp_whyps (imp, []) 764 return True 765 766last_necessary_split_opts = [0] 767 768def get_necessary_split_opts (p, head, restrs, hyps, tags = None): 769 if not tags: 770 tags = p.pairing.tags 771 [l_tag, r_tag] = tags 772 last_necessary_split_opts[0] = (p, head, restrs, hyps, tags) 773 774 rep = rep_graph.mk_graph_slice (p, fast = True) 775 entries = get_loop_entry_sites (rep, restrs, hyps, head) 776 if len (entries) > 1: 777 return ('CaseSplit', ((entries[0], tags[0]), [entries[0]])) 778 for n in init_loops_to_split (p, restrs): 779 if p.node_tags[n][0] != r_tag: 780 continue 781 entries = get_loop_entry_sites (rep, restrs, hyps, n) 782 if len (entries) > 1: 783 return ('CaseSplit', ((entries[0], r_tag), 784 [entries[0]])) 785 786 stuff = linear_setup_stuff (rep, head, restrs, hyps, tags) 787 if stuff == None: 788 return None 789 seq_eqs = get_matching_linear_seqs (rep, head, restrs, hyps, tags) 790 791 vis = stuff['vis'] 792 for v in seq_eqs: 793 if v[0] == 'LoopUnroll': 794 (_, n, est_bound) = v 795 lim = get_split_limit (p, n, restrs, hyps, 'Number', 796 est_bound = est_bound, must_find = False) 797 if lim != None: 798 return ('LoopUnroll', n) 799 continue 800 ((n, expr), (n2, expr2), (l_start, l_step), (r_start, r_step), 801 _, _) = v 802 eqs = [rep_graph.eq_hyp ((expr, 803 (vis (n, l_start + (i * l_step)), l_tag)), 804 (expr2, (vis (n2, r_start + (i * r_step)), r_tag))) 805 for i in range (2)] 806 vis_hyp = rep_graph.pc_true_hyp ((vis (n, l_start), l_tag)) 807 vis_hyps = [vis_hyp] + stuff['hyps'] 808 eq = foldr1 (mk_and, map (rep.interpret_hyp, eqs)) 809 m = {} 810 if rep.test_hyp_whyps (eq, vis_hyps, model = m): 811 trace ('found necessary split info: (%s, %s), (%s, %s)' 812 % (l_start, l_step, r_start, r_step)) 813 return mk_i_j_opts ([(l_start + i, l_step) 814 for i in range (r_step + 1)], 815 [(r_start + i, r_step) 816 for i in range (l_step + 1)], 817 unfold_limit = 100) 818 n_vcs = entry_path_no_loops (rep, l_tag, m, head) 819 path_hyps = [rep_graph.pc_true_hyp ((n_vc, l_tag)) for n_vc in n_vcs] 820 if rep.test_hyp_whyps (eq, stuff['hyps'] + path_hyps): 821 # immediate case split on difference between entry paths 822 checks = [(stuff['hyps'], eq_hyp, 'eq') 823 for eq_hyp in eqs] 824 return derive_case_split (rep, n_vcs, checks) 825 necessary_split_opts_trace.append ((n, expr, (l_start, l_step), 826 (r_start, r_step), 'Seq check failed')) 827 return None 828 829def linear_setup_stuff (rep, head, restrs, hyps, tags): 830 [l_tag, r_tag] = tags 831 k = ('linear_seq setup', head, restrs, tuple (hyps), tuple (tags)) 832 p = rep.p 833 if k in p.cached_analysis: 834 return p.cached_analysis[k] 835 836 assert p.node_tags[head][0] == l_tag 837 l_seq_vs = get_interesting_linear_series_exprs (p, head) 838 if not l_seq_vs: 839 return None 840 r_seq_vs = {} 841 restr_env = {p.loop_id (head): restrs} 842 for n in init_loops_to_split (p, restrs): 843 if p.node_tags[n][0] != r_tag: 844 continue 845 vs = get_interesting_linear_series_exprs (p, n) 846 r_seq_vs.update (vs) 847 if not r_seq_vs: 848 return None 849 850 def vis (n, i): 851 restrs2 = get_nth_visit_restrs (rep, restrs, hyps, n, i) 852 return (n, restrs2) 853 smt = lambda expr, n, i: rep.to_smt_expr (expr, vis (n, i)) 854 smt_pc = lambda n, i: rep.get_pc (vis (n, i)) 855 856 # remove duplicates by concretising 857 l_seq_vs = dict ([(smt (expr, n, 2), (kind, n, expr, offs, oset)) 858 for n in l_seq_vs 859 for (kind, expr, offs, oset) in l_seq_vs[n]]).values () 860 r_seq_vs = dict ([(smt (expr, n, 2), (kind, n, expr, offs, oset)) 861 for n in r_seq_vs 862 for (kind, expr, offs, oset) in r_seq_vs[n]]).values () 863 864 hyps = hyps + [rep_graph.pc_triv_hyp ((vis (n, 3), r_tag)) 865 for n in set ([n for (_, n, _, _, _) in r_seq_vs])] 866 hyps = hyps + [rep_graph.pc_triv_hyp ((vis (n, 3), l_tag)) 867 for n in set ([n for (_, n, _, _, _) in l_seq_vs])] 868 hyps = hyps + [check.non_r_err_pc_hyp (tags, 869 restr_others (p, restrs, 2))] 870 871 r = {'l_seq_vs': l_seq_vs, 'r_seq_vs': r_seq_vs, 872 'hyps': hyps, 'vis': vis, 'smt': smt, 'smt_pc': smt_pc} 873 p.cached_analysis[k] = r 874 return r 875 876def get_matching_linear_seqs (rep, head, restrs, hyps, tags): 877 k = ('matching linear seqs', head, restrs, tuple (hyps), tuple (tags)) 878 p = rep.p 879 if k in p.cached_analysis: 880 v = p.cached_analysis[k] 881 (x, y) = itertools.tee (v[0]) 882 v[0] = x 883 return y 884 885 [l_tag, r_tag] = tags 886 stuff = linear_setup_stuff (rep, head, restrs, hyps, tags) 887 if stuff == None: 888 return [] 889 890 hyps = stuff['hyps'] 891 vis = stuff['vis'] 892 893 def get_model (n, offs): 894 m = {} 895 offs_smt = stuff['smt'] (offs, n, 1) 896 eq = mk_eq (mk_times (offs_smt, mk_num (4, offs_smt.typ)), 897 mk_num (0, offs_smt.typ)) 898 ex_hyps = [rep_graph.pc_true_hyp ((vis (n, 1), l_tag)), 899 rep_graph.pc_true_hyp ((vis (n, 2), l_tag))] 900 res = rep.test_hyp_whyps (eq, hyps + ex_hyps, model = m) 901 if not m: 902 necessary_split_opts_trace.append ((n, kind, 'NoModel')) 903 return None 904 return m 905 906 r = (seq_eq 907 for (kind, n, expr, offs, oset) in sorted (stuff['l_seq_vs']) 908 if [v for v in stuff['r_seq_vs'] if v[0] == kind] 909 for m in [get_model (n, offs)] 910 if m 911 for seq_eq in [get_linear_seq_eq (rep, m, stuff, 912 (kind, n, expr, offs, oset)), 913 get_model_r_side_unroll (rep, tags, m, 914 restrs, hyps, stuff)] 915 if seq_eq != None) 916 (x, y) = itertools.tee (r) 917 p.cached_analysis[k] = [y] 918 return x 919 920def get_linear_seq_eq (rep, m, stuff, expr_t1): 921 def get_int_min (expr): 922 v = eval_model_expr (m, rep.solv, expr) 923 assert v.kind == 'Num', v 924 vs = [v.val + (i << v.typ.num) for i in range (-2, 3)] 925 (_, v) = min ([(abs (v), v) for v in vs]) 926 return v 927 (kind, n1, expr1, offs1, oset1) = expr_t1 928 smt = stuff['smt'] 929 expr_init = smt (expr1, n1, 0) 930 expr_v = get_int_min (expr_init) 931 offs_v = get_int_min (smt (offs1, n1, 1)) 932 r_seqs = [(n, expr, offs, oset2, 933 get_int_min (mk_minus (expr_init, smt (expr, n, 0))), 934 get_int_min (smt (offs, n, 0))) 935 for (kind2, n, expr, offs, oset2) in sorted (stuff['r_seq_vs']) 936 if kind2 == kind] 937 938 for (n, expr, offs2, oset2, diff, offs_v2) in sorted (r_seqs): 939 mult = offs_v / offs_v2 940 if offs_v % offs_v2 != 0 or mult > 8: 941 necessary_split_opts_trace.append ((n, expr, 942 'StepWrong', offs_v, offs_v2)) 943 elif diff % offs_v2 != 0 or (diff * offs_v2) < 0 or (diff / offs_v2) > 8: 944 necessary_split_opts_trace.append ((n, expr, 945 'StartWrong', diff, offs_v2)) 946 else: 947 return ((n1, expr1), (n, expr), (0, 1), 948 (diff / offs_v2, mult), (offs1, offs2), 949 (oset1, oset2)) 950 return None 951 952last_r_side_unroll = [None] 953 954def get_model_r_side_unroll (rep, tags, m, restrs, hyps, stuff): 955 p = rep.p 956 [l_tag, r_tag] = tags 957 last_r_side_unroll[0] = (rep, tags, m, restrs, hyps, stuff) 958 959 r_kinds = set ([kind for (kind, n, _, _, _) in stuff['r_seq_vs']]) 960 l_visited_ns_vcs = logic.dict_list ([(n, vc) 961 for (tag, n, vc) in rep.node_pc_env_order 962 if tag == l_tag 963 if eval_pc (rep, m, (n, vc))]) 964 l_arc_interesting = [(n, vc, kind, expr) 965 for (n, vcs) in l_visited_ns_vcs.iteritems () 966 if len (vcs) == 1 967 for vc in vcs 968 for (kind, expr) 969 in logic.interesting_node_exprs (p, n, tags = tags) 970 if kind in r_kinds 971 if expr.typ.kind == 'Word'] 972 l_kinds = set ([kind for (n, vc, kind, _) in l_arc_interesting]) 973 974 # FIXME: cloned 975 def canon_n (n, typ): 976 vs = [n + (i << typ.num) for i in range (-2, 3)] 977 (_, v) = min ([(abs (v), v) for v in vs]) 978 return v 979 def get_int_min (expr): 980 v = eval_model_expr (m, rep.solv, expr) 981 assert v.kind == 'Num', v 982 return canon_n (v.val, v.typ) 983 def eval (expr, n, vc): 984 expr = rep.to_smt_expr (expr, (n, vc)) 985 return get_int_min (expr) 986 987 val_interesting_map = logic.dict_list ([((kind, eval (expr, n, vc)), n) 988 for (n, vc, kind, expr) in l_arc_interesting]) 989 990 smt = stuff['smt'] 991 992 for (kind, n, expr, offs, _) in stuff['r_seq_vs']: 993 if kind not in l_kinds: 994 continue 995 if expr.typ.kind != 'Word': 996 continue 997 expr_n = get_int_min (smt (expr, n, 0)) 998 offs_n = get_int_min (smt (offs, n, 0)) 999 hit = ([i for i in range (64) 1000 if (kind, canon_n (expr_n + (offs_n * i), expr.typ)) 1001 in val_interesting_map]) 1002 if [i for i in hit if i > 4]: 1003 return ('LoopUnroll', p.loop_id (n), max (hit)) 1004 return None 1005 1006last_failed_pairings = [] 1007 1008def setup_split_search (rep, head, restrs, hyps, 1009 i_opts, j_opts, unfold_limit = None, tags = None, 1010 node_restrs = None): 1011 p = rep.p 1012 1013 if not tags: 1014 tags = p.pairing.tags 1015 if node_restrs == None: 1016 node_restrs = set (p.nodes) 1017 if unfold_limit == None: 1018 unfold_limit = max ([start + (2 * step) + 1 1019 for (start, step) in i_opts + j_opts]) 1020 1021 trace ('Split search at %d, unfold limit %d.' % (head, unfold_limit)) 1022 1023 l_tag, r_tag = tags 1024 loop_elts = [(n, start, step) for n in p.splittable_points (head) 1025 if n in node_restrs 1026 for (start, step) in i_opts] 1027 init_to_split = init_loops_to_split (p, restrs) 1028 r_to_split = [n for n in init_to_split if p.node_tags[n][0] == r_tag] 1029 cand_r_loop_elts = [(n2, start, step) for n in r_to_split 1030 for n2 in p.splittable_points (n) 1031 if n2 in node_restrs 1032 for (start, step) in j_opts] 1033 1034 err_restrs = restr_others (p, tuple ([(sp, vc_upto (unfold_limit)) 1035 for sp in r_to_split]) + restrs, 1) 1036 nrerr_pc = mk_not (rep.get_pc (('Err', err_restrs), tag = r_tag)) 1037 1038 def get_pc (n, k): 1039 restrs2 = get_nth_visit_restrs (rep, restrs, hyps, n, k) 1040 return rep.get_pc ((n, restrs2)) 1041 1042 for n in r_to_split: 1043 get_pc (n, unfold_limit) 1044 get_pc (head, unfold_limit) 1045 1046 premise = foldr1 (mk_and, [nrerr_pc] + map (rep.interpret_hyp, hyps)) 1047 premise = logic.weaken_assert (premise) 1048 1049 knowledge = SearchKnowledge (rep, 1050 'search at %d (unfold limit %d)' % (head, unfold_limit), 1051 restrs, hyps, tags, (loop_elts, cand_r_loop_elts)) 1052 knowledge.premise = premise 1053 last_knowledge[0] = knowledge 1054 1055 # make sure the representation is in sync 1056 rep.test_hyp_whyps (true_term, hyps) 1057 1058 # make sure all mem eqs are being tracked 1059 mem_vs = [v for v in knowledge.v_ids if v[0].typ == builtinTs['Mem']] 1060 for (i, v) in enumerate (mem_vs): 1061 for v2 in mem_vs[:i]: 1062 for pred in expand_var_eqs (knowledge, (v, v2)): 1063 smt_expr (pred, {}, rep.solv) 1064 for v in knowledge.v_ids: 1065 for pred in expand_var_eqs (knowledge, (v, 'Const')): 1066 smt_expr (pred, {}, rep.solv) 1067 1068 return knowledge 1069 1070def get_loop_entry_sites (rep, restrs, hyps, head): 1071 k = ('loop_entry_sites', restrs, tuple (hyps), rep.p.loop_id (head)) 1072 if k in rep.p.cached_analysis: 1073 return rep.p.cached_analysis[k] 1074 ns = set ([n for n2 in rep.p.loop_body (head) 1075 for n in rep.p.preds[n2] 1076 if rep.p.loop_id (n) == None]) 1077 def npc (n): 1078 return rep_graph.pc_false_hyp (((n, tuple ([(n2, restr) 1079 for (n2, restr) in restrs if n2 != n])), 1080 rep.p.node_tags[n][0])) 1081 res = [n for n in ns if not rep.test_hyp_imp (hyps, npc (n))] 1082 rep.p.cached_analysis[k] = res 1083 return res 1084 1085def rebuild_knowledge (head, knowledge): 1086 i_opts = sorted (set ([(start, step) 1087 for ((_, start, step), _) in knowledge.pairs])) 1088 j_opts = sorted (set ([(start, step) 1089 for (_, (_, start, step)) in knowledge.pairs])) 1090 knowledge2 = setup_split_search (knowledge.rep, head, knowledge.restrs, 1091 knowledge.hyps, i_opts, j_opts) 1092 knowledge2.facts.update (knowledge.facts) 1093 for m in knowledge.model_trace: 1094 knowledge2.add_model (m) 1095 return knowledge2 1096 1097def split_search (head, knowledge): 1098 rep = knowledge.rep 1099 p = rep.p 1100 1101 # test any relevant cached solutions. 1102 p.cached_analysis.setdefault (('v_eqs', head), set ()) 1103 v_eq_cache = p.cached_analysis[('v_eqs', head)] 1104 for (pair, eqs) in v_eq_cache: 1105 if pair in knowledge.pairs: 1106 knowledge.eqs_add_model (list (eqs), 1107 assert_progress = False) 1108 1109 while True: 1110 trace ('In %s' % knowledge.name) 1111 trace ('Computing live pairings') 1112 pair_eqs = [(pair, mk_pairing_v_eqs (knowledge, pair)) 1113 for pair in sorted (knowledge.pairs) 1114 if knowledge.pairs[pair][0] != 'Failed'] 1115 if not pair_eqs: 1116 ind_fails = trace_search_fail (knowledge) 1117 return (None, ind_fails) 1118 1119 endorsed = [(pair, eqs) for (pair, eqs) in pair_eqs 1120 if eqs != None] 1121 trace (' ... %d live pairings, %d endorsed' % 1122 (len (pair_eqs), len (endorsed))) 1123 knowledge.live_pairs_trace.append (len (pair_eqs)) 1124 for (pair, eqs) in endorsed: 1125 if knowledge.is_weak_split (eqs): 1126 trace (' dropping endorsed - probably weak.') 1127 knowledge.pairs[pair] = ('Failed', 1128 'ExpectedSplitWeak', eqs) 1129 continue 1130 split = build_and_check_split (p, pair, eqs, 1131 knowledge.restrs, knowledge.hyps, 1132 knowledge.tags) 1133 if split == None: 1134 knowledge.pairs[pair] = ('Failed', 1135 'SplitWeak', eqs) 1136 knowledge.add_weak_split (eqs) 1137 continue 1138 elif split == 'InductFailed': 1139 knowledge.pairs[pair] = ('Failed', 1140 'InductFailed', eqs) 1141 elif split[0] == 'SingleRevInduct': 1142 return split 1143 else: 1144 v_eq_cache.add ((pair, tuple (eqs))) 1145 trace ('Found split!') 1146 return ('Split', split) 1147 if endorsed: 1148 continue 1149 1150 (pair, _) = pair_eqs[0] 1151 trace ('Testing guess for pair: %s' % str (pair)) 1152 eqs = mk_pairing_v_eqs (knowledge, pair, endorsed = False) 1153 assert eqs, pair 1154 knowledge.eqs_add_model (eqs) 1155 1156def build_and_check_split_inner (p, pair, eqs, restrs, hyps, tags): 1157 split = v_eqs_to_split (p, pair, eqs, restrs, hyps, tags = tags) 1158 if split == None: 1159 return None 1160 res = check_split_induct (p, restrs, hyps, split, tags = tags) 1161 if res: 1162 return split 1163 else: 1164 return 'InductFailed' 1165 1166def build_and_check_split (p, pair, eqs, restrs, hyps, tags): 1167 res = build_and_check_split_inner (p, pair, eqs, restrs, hyps, tags) 1168 if res != 'InductFailed': 1169 return res 1170 1171 # induction has failed at this point, but we might be able to rescue 1172 # it one of two different ways. 1173 ((l_split, _, l_step), _) = pair 1174 extra = get_new_extra_linear_seq_eqs (p, restrs, l_split, l_step) 1175 if extra: 1176 res = build_and_check_split (p, pair, eqs, restrs, hyps, tags) 1177 # the additional linear eqs get built into the result 1178 if res != 'InductFailed': 1179 return res 1180 1181 (_, (r_split, _, _)) = pair 1182 r_loop = p.loop_id (r_split) 1183 spec = get_new_rhs_speculate_ineq (p, restrs, hyps, r_loop) 1184 if spec: 1185 hyp = check.single_induct_resulting_hyp (p, restrs, spec) 1186 hyps2 = hyps + [hyp] 1187 res = build_and_check_split (p, pair, eqs, restrs, hyps2, tags) 1188 if res != 'InductFailed': 1189 return ('SingleRevInduct', spec) 1190 return 'InductFailed' 1191 1192def get_new_extra_linear_seq_eqs (p, restrs, l_split, l_step): 1193 k = ('extra_linear_seq_eqs', l_split, l_step) 1194 if k in p.cached_analysis: 1195 return [] 1196 if not [v for (v, data) in get_loop_var_analysis_at (p, l_split) 1197 if data[0] == 'LoopLinearSeries']: 1198 return [] 1199 import loop_bounds 1200 lin_series_eqs = loop_bounds.get_linear_series_eqs (p, l_split, 1201 restrs, [], omit_standard = True) 1202 p.cached_analysis[k] = lin_series_eqs 1203 return lin_series_eqs 1204 1205def trace_search_fail (knowledge): 1206 trace (('Exhausted split candidates for %s' % knowledge.name)) 1207 fails = [it for it in knowledge.pairs.items () 1208 if it[1][0] == 'Failed'] 1209 last_failed_pairings.append (fails) 1210 del last_failed_pairings[:-10] 1211 fails10 = fails[:10] 1212 trace (' %d of %d failed pairings:' % (len (fails10), 1213 len (fails))) 1214 for f in fails10: 1215 trace (' %s' % (f,)) 1216 ind_fails = [it for it in fails 1217 if str (it[1][1]) == 'InductFailed'] 1218 if ind_fails: 1219 trace ( 'Inductive failures!') 1220 else: 1221 trace ( 'No inductive failures.') 1222 for f in ind_fails: 1223 trace (' %s' % (f,)) 1224 return ind_fails 1225 1226def find_split (rep, head, restrs, hyps, i_opts, j_opts, 1227 unfold_limit = None, tags = None, 1228 node_restrs = None): 1229 knowledge = setup_split_search (rep, head, restrs, hyps, 1230 i_opts, j_opts, unfold_limit = unfold_limit, 1231 tags = tags, node_restrs = node_restrs) 1232 1233 res = split_search (head, knowledge) 1234 1235 if res[0]: 1236 return res 1237 1238 (models, facts, n_vcs) = most_common_path (head, knowledge) 1239 if not n_vcs: 1240 return res 1241 1242 [tag, _] = knowledge.tags 1243 knowledge = setup_split_search (rep, head, restrs, 1244 hyps + [rep_graph.pc_true_hyp ((n_vc, tag)) for n_vc in n_vcs], 1245 i_opts, j_opts, unfold_limit, tags, node_restrs = node_restrs) 1246 knowledge.facts.update (facts) 1247 for m in models: 1248 knowledge.add_model (m) 1249 res = split_search (head, knowledge) 1250 1251 if res[0] == None: 1252 return res 1253 (_, split) = res 1254 checks = check.split_init_step_checks (rep.p, restrs, 1255 hyps, split) 1256 1257 return derive_case_split (rep, n_vcs, checks) 1258 1259def most_common_path (head, knowledge): 1260 rep = knowledge.rep 1261 [tag, _] = knowledge.tags 1262 data = logic.dict_list ([(tuple (entry_path_no_loops (rep, 1263 tag, m, head)), m) 1264 for m in knowledge.model_trace]) 1265 if len (data) < 2: 1266 return (None, None, None) 1267 1268 (_, path) = max ([(len (data[path]), path) for path in data]) 1269 models = data[path] 1270 facts = knowledge.facts 1271 other_n_vcs = set.intersection (* [set (path2) for path2 in data 1272 if path2 != path]) 1273 1274 n_vcs = [] 1275 pcs = set () 1276 for n_vc in path: 1277 if n_vc in other_n_vcs: 1278 continue 1279 if rep.p.loop_id (n_vc[0]): 1280 continue 1281 pc = rep.get_pc (n_vc) 1282 if pc not in pcs: 1283 pcs.add (pc) 1284 n_vcs.append (n_vc) 1285 assert n_vcs 1286 1287 return (models, facts, n_vcs) 1288 1289def eval_pc (rep, m, n_vc, tag = None): 1290 hit = eval_model_expr (m, rep.solv, rep.get_pc (n_vc, tag = tag)) 1291 assert hit in [syntax.true_term, syntax.false_term], (n_vc, hit) 1292 return hit == syntax.true_term 1293 1294def entry_path (rep, tag, m, head): 1295 n_vcs = [] 1296 for (tag2, n, vc) in rep.node_pc_env_order: 1297 if n == head: 1298 break 1299 if tag2 != tag: 1300 continue 1301 if eval_pc (rep, m, (n, vc), tag): 1302 n_vcs.append ((n, vc)) 1303 return n_vcs 1304 1305def entry_path_no_loops (rep, tag, m, head = None): 1306 n_vcs = entry_path (rep, tag, m, head) 1307 return [(n, vc) for (n, vc) in n_vcs 1308 if not rep.p.loop_id (n)] 1309 1310last_derive_case_split = [0] 1311 1312def derive_case_split (rep, n_vcs, checks): 1313 last_derive_case_split[0] = (rep.p, n_vcs, checks) 1314 # remove duplicate pcs 1315 n_vcs_uniq = dict ([(rep.get_pc (n_vc), (i, n_vc)) 1316 for (i, n_vc) in enumerate (n_vcs)]).values () 1317 n_vcs = [n_vc for (i, n_vc) in sorted (n_vcs_uniq)] 1318 assert n_vcs 1319 tag = rep.p.node_tags[n_vcs[0][0]][0] 1320 keep_n_vcs = [] 1321 test_n_vcs = n_vcs 1322 mk_thyps = lambda n_vcs: [rep_graph.pc_true_hyp ((n_vc, tag)) 1323 for n_vc in n_vcs] 1324 while len (test_n_vcs) > 1: 1325 i = len (test_n_vcs) / 2 1326 test_in = test_n_vcs[:i] 1327 test_out = test_n_vcs[i:] 1328 checks2 = [(hyps + mk_thyps (test_in + keep_n_vcs), hyp, nm) 1329 for (hyps, hyp, nm) in checks] 1330 (verdict, _) = check.test_hyp_group (rep, checks2) 1331 if verdict: 1332 # forget n_vcs that were tested out 1333 test_n_vcs = test_in 1334 else: 1335 # focus on n_vcs that were tested out 1336 test_n_vcs = test_out 1337 keep_n_vcs.extend (test_in) 1338 [(n, vc)] = test_n_vcs 1339 return ('CaseSplit', ((n, tag), [n])) 1340 1341def mk_seq_eqs (p, split, step, with_rodata): 1342 # eqs take the form of a number of constant expressions 1343 eqs = [] 1344 1345 # the variable 'loop' will be converted to the point in 1346 # the sequence - note this should be multiplied by the step size 1347 loop = mk_var ('%i', word32T) 1348 if step == 1: 1349 minus_loop_step = mk_uminus (loop) 1350 else: 1351 minus_loop_step = mk_times (loop, mk_word32 (- step)) 1352 1353 for (var, data) in get_loop_var_analysis_at (p, split): 1354 if data == 'LoopVariable': 1355 if with_rodata and var.typ == builtinTs['Mem']: 1356 eqs.append (logic.mk_rodata (var)) 1357 elif data == 'LoopConst': 1358 if var.typ not in syntax.phantom_types: 1359 eqs.append (var) 1360 elif data == 'LoopLeaf': 1361 continue 1362 elif data[0] == 'LoopLinearSeries': 1363 (_, form, _) = data 1364 eqs.append (form (var, 1365 mk_cast (minus_loop_step, var.typ))) 1366 else: 1367 assert not 'var_deps type understood' 1368 1369 k = ('extra_linear_seq_eqs', split, step) 1370 eqs += p.cached_analysis.get (k, []) 1371 1372 return eqs 1373 1374def c_memory_loop_invariant (p, c_sp, a_sp): 1375 def mem_vars (split): 1376 return [v for (v, data) in get_loop_var_analysis_at (p, split) 1377 if v.typ == builtinTs['Mem'] 1378 if data == 'LoopVariable'] 1379 1380 if mem_vars (a_sp): 1381 return [] 1382 # if ASM keeps memory constant through the loop, it is implying this 1383 # is semantically possible in C also, though it may not be 1384 # syntactically the case 1385 # anyway, we have to assert C memory equals *something* inductively 1386 # so we pick C initial memory. 1387 return mem_vars (c_sp) 1388 1389def v_eqs_to_split (p, pair, v_eqs, restrs, hyps, tags = None): 1390 trace ('v_eqs_to_split: (%s, %s)' % pair) 1391 1392 ((l_n, l_init, l_step), (r_n, r_init, r_step)) = pair 1393 l_details = (l_n, (l_init, l_step), mk_seq_eqs (p, l_n, l_step, True) 1394 + [v_i[0] for (v_i, v_j) in v_eqs if v_j == 'Const']) 1395 r_details = (r_n, (r_init, r_step), mk_seq_eqs (p, r_n, r_step, False) 1396 + c_memory_loop_invariant (p, r_n, l_n)) 1397 1398 eqs = [(v_i[0], mk_cast (v_j[0], v_i[0].typ)) 1399 for (v_i, v_j) in v_eqs if v_j != 'Const' 1400 if v_i[0] != syntax.mk_word32 (0)] 1401 1402 n = 2 1403 split = (l_details, r_details, eqs, n, (n * r_step) - 1) 1404 trace ('Split: %s' % (split, )) 1405 if tags == None: 1406 tags = p.pairing.tags 1407 hyps = hyps + check.split_loop_hyps (tags, split, restrs, exit = True) 1408 1409 r_max = get_split_limit (p, r_n, restrs, hyps, 'Offset', 1410 bound = (n + 2) * r_step, must_find = False, 1411 hints = [n * r_step, n * r_step + 1]) 1412 if r_max == None: 1413 trace ('v_eqs_to_split: no RHS limit') 1414 return None 1415 1416 if r_max > n * r_step: 1417 trace ('v_eqs_to_split: RHS limit not %d' % (n * r_step)) 1418 return None 1419 trace ('v_eqs_to_split: split %s' % (split,)) 1420 return split 1421 1422def get_n_offset_successes (rep, sp, step, restrs): 1423 loop = rep.p.loop_body (sp) 1424 ns = [n for n in loop if rep.p.nodes[n].kind == 'Call'] 1425 succs = [] 1426 for i in range (step): 1427 for n in ns: 1428 vc = vc_offs (i + 1) 1429 if n == sp: 1430 vc = vc_offs (i) 1431 n_vc = (n, restrs + tuple ([(sp, vc)])) 1432 (_, _, succ) = rep.get_func (n_vc) 1433 pc = rep.get_pc (n_vc) 1434 succs.append (syntax.mk_implies (pc, succ)) 1435 return succs 1436 1437eq_ineq_ops = set (['Equals', 'Less', 'LessEquals', 1438 'SignedLess', 'SignedLessEquals']) 1439 1440def split_linear_eq (cond): 1441 if cond.is_op ('Not'): 1442 [c] = cond.vals 1443 return split_linear_eq (c) 1444 elif cond.is_op (eq_ineq_ops): 1445 return (cond.vals[0], cond.vals[1]) 1446 elif cond.is_op ('PArrayValid'): 1447 [htd, typ_expr, p, num] = cond.vals 1448 assert typ_expr.kind == 'Type' 1449 typ = typ_expr.val 1450 return split_linear_eq (logic.mk_array_size_ineq (typ, num, p)) 1451 else: 1452 return None 1453 1454def possibly_linear_ineq (cond): 1455 rv = split_linear_eq (cond) 1456 if not rv: 1457 return False 1458 (lhs, rhs) = rv 1459 return logic.possibly_linear (lhs) and logic.possibly_linear (rhs) 1460 1461def linear_const_comparison (p, n, cond): 1462 """examines a condition. if it is a linear (e.g. Less) comparison 1463 between a linear series variable and a loop-constant expression, 1464 return (linear side, const side), or None if not the case.""" 1465 rv = split_linear_eq (cond) 1466 loop_head = p.loop_id (n) 1467 if not rv: 1468 return None 1469 (lhs, rhs) = rv 1470 zero = mk_num (0, lhs.typ) 1471 offs = logic.get_loop_linear_offs (p, loop_head) 1472 (lhs_offs, rhs_offs) = [offs (n, expr) for expr in [lhs, rhs]] 1473 oset = set ([lhs_offs, rhs_offs]) 1474 if zero in oset and None not in oset and len (oset) > 1: 1475 if lhs_offs == zero: 1476 return (rhs, lhs) 1477 else: 1478 return (lhs, rhs) 1479 return None 1480 1481def do_linear_rev_test (rep, restrs, hyps, split, eqs_assume, pred, large): 1482 p = rep.p 1483 (tag, _) = p.node_tags[split] 1484 checks = (check.single_loop_rev_induct_checks (p, restrs, hyps, tag, 1485 split, eqs_assume, pred) 1486 + check.single_loop_rev_induct_base_checks (p, restrs, hyps, 1487 tag, split, large, eqs_assume, pred)) 1488 1489 groups = check.proof_check_groups (checks) 1490 for group in groups: 1491 (res, _) = check.test_hyp_group (rep, group) 1492 if not res: 1493 return False 1494 return True 1495 1496def get_extra_assn_linear_conds (expr): 1497 if expr.is_op ('And'): 1498 return [cond for conj in logic.split_conjuncts (expr) 1499 for cond in get_extra_assn_linear_conds (conj)] 1500 if not expr.is_op ('Or'): 1501 return [expr] 1502 arr_vs = [v for v in expr.vals if v.is_op ('PArrayValid')] 1503 if not arr_vs: 1504 return [expr] 1505 [htd, typ_expr, p, num] = arr_vs[0].vals 1506 assert typ_expr.kind == 'Type' 1507 typ = typ_expr.val 1508 less_eq = logic.mk_array_size_ineq (typ, num, p) 1509 assn = logic.mk_align_valid_ineq (('Array', typ, num), p) 1510 return get_extra_assn_linear_conds (assn) + [less_eq] 1511 1512def get_rhs_speculate_ineq (p, restrs, loop_head): 1513 assert p.loop_id (loop_head), loop_head 1514 loop_head = p.loop_id (loop_head) 1515 restrs = tuple ([(n, vc) for (n, vc) in restrs 1516 if p.node_tags[n][0] == p.node_tags[loop_head][0]]) 1517 key = ('rhs_speculate_ineq', restrs, loop_head) 1518 if key in p.cached_analysis: 1519 return p.cached_analysis[key] 1520 1521 res = rhs_speculate_ineq (p, restrs, loop_head) 1522 p.cached_analysis[key] = res 1523 return res 1524 1525def get_new_rhs_speculate_ineq (p, restrs, hyps, loop_head): 1526 res = get_rhs_speculate_ineq (p, restrs, loop_head) 1527 if res == None: 1528 return None 1529 (point, _, (pred, _)) = res 1530 hs = [h for h in hyps if point in [n for ((n, _), _) in h.visits ()] 1531 if pred in h.get_vals ()] 1532 if hs: 1533 return None 1534 return res 1535 1536def rhs_speculate_ineq (p, restrs, loop_head): 1537 """code for handling an interesting case in which the compiler 1538 knows that the RHS program might fail in the future. for instance, 1539 consider a loop that cannot be exited until iterator i reaches value n. 1540 any error condition which implies i < b must hold of i - 1, thus 1541 n <= b. 1542 1543 detects this case and identifies the inequality n <= b""" 1544 body = p.loop_body (loop_head) 1545 1546 # if the loop contains function calls, skip it, 1547 # otherwise we need to figure out whether they terminate 1548 if [n for n in body if p.nodes[n].kind == 'Call']: 1549 return None 1550 1551 exit_nodes = set ([n for n in body for n2 in p.nodes[n].get_conts () 1552 if n2 != 'Err' if n2 not in body]) 1553 assert set ([p.nodes[n].kind for n in exit_nodes]) <= set (['Cond']) 1554 1555 # if there are multiple exit conditions, too hard for now 1556 if len (exit_nodes) > 1: 1557 return None 1558 1559 [exit_n] = list (exit_nodes) 1560 rv = linear_const_comparison (p, exit_n, p.nodes[exit_n].cond) 1561 if not rv: 1562 return None 1563 (linear, const) = rv 1564 1565 err_cond_sites = [(n, p.nodes[n].err_cond ()) for n in body] 1566 err_conds = set ([(n, cond) for (n, err_cond) in err_cond_sites 1567 if err_cond 1568 for assn in logic.split_conjuncts (mk_not (err_cond)) 1569 for cond in get_extra_assn_linear_conds (assn) 1570 if possibly_linear_ineq (cond)]) 1571 if not err_conds: 1572 return None 1573 1574 assert const.typ.kind == 'Word' 1575 rep = rep_graph.mk_graph_slice (p) 1576 eqs = mk_seq_eqs (p, exit_n, 1, False) 1577 import loop_bounds 1578 eqs += loop_bounds.get_linear_series_eqs (p, exit_n, 1579 restrs, [], omit_standard = True) 1580 1581 large = (2 ** const.typ.num) - 3 1582 const_less = lambda n: mk_less (const, mk_num (n, const.typ)) 1583 less_test = lambda n: do_linear_rev_test (rep, restrs, [], exit_n, 1584 eqs, const_less (n), large) 1585 const_ge = lambda n: mk_less (mk_num (n, const.typ), const) 1586 ge_test = lambda n: do_linear_rev_test (rep, restrs, [], exit_n, 1587 eqs, const_ge (n), large) 1588 1589 res = logic.binary_search_least (less_test, 1, large) 1590 if res: 1591 return (loop_head, (eqs, 1), (const_less (res), large)) 1592 res = logic.binary_search_greatest (ge_test, 0, large) 1593 if res: 1594 return (loop_head, (eqs, 1), (const_ge (res), large)) 1595 return None 1596 1597def check_split_induct (p, restrs, hyps, split, tags = None): 1598 """perform both the induction check and a function-call based check 1599 on successes which can avoid some problematic inductions.""" 1600 ((l_split, (_, l_step), _), (r_split, (_, r_step), _), _, n, _) = split 1601 if tags == None: 1602 tags = p.pairing.tags 1603 1604 err_hyp = check.split_r_err_pc_hyp (p, split, restrs, tags = tags) 1605 hyps = [err_hyp] + hyps + check.split_loop_hyps (tags, split, 1606 restrs, exit = False) 1607 1608 rep = mk_graph_slice (p) 1609 1610 if not check.check_split_induct_step_group (rep, restrs, hyps, split, 1611 tags = tags): 1612 return False 1613 1614 l_succs = get_n_offset_successes (rep, l_split, l_step, restrs) 1615 r_succs = get_n_offset_successes (rep, r_split, r_step, restrs) 1616 1617 if not l_succs: 1618 return True 1619 1620 hyp = syntax.foldr1 (syntax.mk_and, l_succs) 1621 if r_succs: 1622 hyp = syntax.mk_implies (foldr1 (syntax.mk_and, r_succs), hyp) 1623 1624 return rep.test_hyp_whyps (hyp, hyps) 1625 1626def init_loops_to_split (p, restrs): 1627 to_split = loops_to_split (p, restrs) 1628 1629 return [n for n in to_split 1630 if not [n2 for n2 in to_split if n2 != n 1631 and p.is_reachable_from (n2, n)]] 1632 1633def restr_others_both (p, restrs, n, m): 1634 extras = [(sp, vc_double_range (n, m)) 1635 for sp in loops_to_split (p, restrs)] 1636 return restrs + tuple (extras) 1637 1638def restr_others_as_necessary (p, n, restrs, init_bound, offs_bound, 1639 skip_loops = []): 1640 extras = [(sp, vc_double_range (init_bound, offs_bound)) 1641 for sp in loops_to_split (p, restrs) 1642 if sp not in skip_loops 1643 if p.is_reachable_from (sp, n)] 1644 return restrs + tuple (extras) 1645 1646def loop_no_match_unroll (rep, restrs, hyps, split, other_tag, unroll): 1647 p = rep.p 1648 assert p.node_tags[split][0] != other_tag 1649 restr = ((split, vc_num (unroll)), ) 1650 restrs2 = restr_others (p, restr + restrs, 2) 1651 loop_cond = rep.get_pc ((split, restr + restrs)) 1652 ret_cond = rep.get_pc (('Ret', restrs2), tag = other_tag) 1653 # loop should be reachable 1654 if rep.test_hyp_whyps (mk_not (loop_cond), hyps): 1655 trace ('Loop weak at %d (unroll count %d).' % 1656 (split, unroll)) 1657 return True 1658 # reaching the loop should imply reaching a loop on the other side 1659 hyp = mk_not (mk_and (loop_cond, ret_cond)) 1660 if not rep.test_hyp_whyps (hyp, hyps): 1661 trace ('Loop independent at %d (unroll count %d).' % 1662 (split, unroll)) 1663 return True 1664 return False 1665 1666def loop_no_match (rep, restrs, hyps, split, other_tag, 1667 check_speculate_ineq = False): 1668 if not loop_no_match_unroll (rep, restrs, hyps, split, other_tag, 4): 1669 return False 1670 if not loop_no_match_unroll (rep, restrs, hyps, split, other_tag, 8): 1671 return False 1672 if not check_speculate_ineq: 1673 return 'Restr' 1674 spec = get_new_rhs_speculate_ineq (rep.p, restrs, hyps, split) 1675 if not spec: 1676 return 'Restr' 1677 hyp = check.single_induct_resulting_hyp (rep.p, restrs, spec) 1678 hyps2 = hyps + [hyp] 1679 if not loop_no_match_unroll (rep, restrs, hyps2, split, other_tag, 8): 1680 return 'SingleRevInduct' 1681 return 'Restr' 1682 1683last_searcher_results = [] 1684 1685def restr_point_name (p, n): 1686 if p.loop_id (n): 1687 return '%s (loop head)' % n 1688 elif p.loop_id (n): 1689 return '%s (in loop %d)' % (n, p.loop_id (n)) 1690 else: 1691 return str (n) 1692 1693def fail_searcher (p, restrs, hyps): 1694 return ('Fail Searcher', None) 1695 1696def build_proof_rec (searcher, p, restrs, hyps, name = "problem"): 1697 trace ('doing build proof rec with restrs = %r, hyps = %r' % (restrs, hyps)) 1698 if searcher == None: 1699 searcher = default_searcher 1700 1701 (kind, details) = searcher (p, restrs, hyps) 1702 last_searcher_results.append ((p, restrs, hyps, kind, details, name)) 1703 del last_searcher_results[:-10] 1704 if kind == 'Restr': 1705 (restr_kind, restr_points) = details 1706 printout ("Discovered that points [%s] can be bounded" 1707 % ', '.join ([restr_point_name (p, n) 1708 for n in restr_points])) 1709 printout (" (in %s)" % name) 1710 restr_hints = [(n, restr_kind, True) for n in restr_points] 1711 return build_proof_rec_with_restrs (restr_hints, 1712 searcher, p, restrs, hyps, name = name) 1713 elif kind == 'Leaf': 1714 return ProofNode ('Leaf', None, ()) 1715 assert kind in ['CaseSplit', 'Split', 'SingleRevInduct'], kind 1716 if kind == 'CaseSplit': 1717 (details, hints) = details 1718 probs = check.proof_subproblems (p, kind, details, restrs, hyps, name) 1719 if kind == 'CaseSplit': 1720 printout ("Decided to case split at %s" % str (details)) 1721 printout (" (in %s)" % name) 1722 restr_hints = [[(n, 'Number', False) for n in hints] 1723 for cases in [0, 1]] 1724 elif kind == 'SingleRevInduct': 1725 printout ('Found a future induction at %s' % str (details[0])) 1726 restr_hints = [[]] 1727 else: 1728 restr_points = check.split_heads (details) 1729 restr_hints = [[(n, rkind, True) for n in restr_points] 1730 for rkind in ['Number', 'Offset']] 1731 printout ("Discovered a loop relation for split points %s" 1732 % list (restr_points)) 1733 printout (" (in %s)" % name) 1734 subpfs = [] 1735 for ((restrs, hyps, name), hints) in logic.azip (probs, restr_hints): 1736 printout ('Now doing proof search in %s.' % name) 1737 pf = build_proof_rec_with_restrs (hints, searcher, 1738 p, restrs, hyps, name = name) 1739 subpfs.append (pf) 1740 return ProofNode (kind, details, subpfs) 1741 1742def build_proof_rec_with_restrs (split_hints, searcher, p, restrs, 1743 hyps, name = "problem"): 1744 if not split_hints: 1745 return build_proof_rec (searcher, p, restrs, hyps, name = name) 1746 1747 (sp, kind, must_find) = split_hints[0] 1748 use_hyps = list (hyps) 1749 if p.node_tags[sp][0] != p.pairing.tags[1]: 1750 nrerr_hyp = check.non_r_err_pc_hyp (p.pairing.tags, 1751 restr_others (p, restrs, 2)) 1752 use_hyps = use_hyps + [nrerr_hyp] 1753 1754 if p.loop_id (sp): 1755 lim_pair = get_proof_split_limit (p, sp, restrs, use_hyps, 1756 kind, must_find = must_find) 1757 else: 1758 lim_pair = get_proof_visit_restr (p, sp, restrs, use_hyps, 1759 kind, must_find = must_find) 1760 1761 if not lim_pair: 1762 assert not must_find 1763 return build_proof_rec_with_restrs (split_hints[1:], 1764 searcher, p, restrs, hyps, name = name) 1765 1766 (min_v, max_v) = lim_pair 1767 if kind == 'Number': 1768 vc_opts = rep_graph.vc_options (range (min_v, max_v), []) 1769 else: 1770 vc_opts = rep_graph.vc_options ([], range (min_v, max_v)) 1771 1772 restrs = restrs + ((sp, vc_opts), ) 1773 subproof = build_proof_rec_with_restrs (split_hints[1:], 1774 searcher, p, restrs, hyps, name = name) 1775 1776 return ProofNode ('Restr', (sp, (kind, (min_v, max_v))), [subproof]) 1777 1778def get_proof_split_limit (p, sp, restrs, hyps, kind, must_find = False): 1779 limit = get_split_limit (p, sp, restrs, hyps, kind, 1780 must_find = must_find) 1781 if limit == None: 1782 return None 1783 # double-check this limit with a rep constructed without the 'fast' flag 1784 limit = find_split_limit (p, sp, restrs, hyps, kind, 1785 hints = [limit, limit + 1], use_rep = mk_graph_slice (p)) 1786 return (0, limit + 1) 1787 1788def get_proof_visit_restr (p, sp, restrs, hyps, kind, must_find = False): 1789 rep = rep_graph.mk_graph_slice (p) 1790 pc = rep.get_pc ((sp, restrs)) 1791 if rep.test_hyp_whyps (pc, hyps): 1792 return (1, 2) 1793 elif rep.test_hyp_whyps (mk_not (pc), hyps): 1794 return (0, 1) 1795 else: 1796 assert not must_find 1797 return None 1798 1799def default_searcher (p, restrs, hyps): 1800 # use any handy init splits 1801 res = init_proof_case_split (p, restrs, hyps) 1802 if res: 1803 return res 1804 1805 # detect any un-split loops 1806 to_split_init = init_loops_to_split (p, restrs) 1807 rep = mk_graph_slice (p, fast = True) 1808 1809 l_tag, r_tag = p.pairing.tags 1810 l_to_split = [n for n in to_split_init if p.node_tags[n][0] == l_tag] 1811 r_to_split = [n for n in to_split_init if p.node_tags[n][0] == r_tag] 1812 l_ep = p.get_entry (l_tag) 1813 r_ep = p.get_entry (r_tag) 1814 1815 for r_sp in r_to_split: 1816 trace ('checking loop_no_match at %d' % r_sp, push = 1) 1817 res = loop_no_match (rep, restrs, hyps, r_sp, l_tag, 1818 check_speculate_ineq = True) 1819 if res == 'Restr': 1820 return ('Restr', ('Number', [r_sp])) 1821 elif res == 'SingleRevInduct': 1822 spec = get_new_rhs_speculate_ineq (p, restrs, hyps, r_sp) 1823 assert spec 1824 return ('SingleRevInduct', spec) 1825 trace (' .. done checking loop no match', push = -1) 1826 1827 if l_to_split and not r_to_split: 1828 n = l_to_split[0] 1829 trace ('lhs loop alone, limit must be found.') 1830 return ('Restr', ('Number', [n])) 1831 1832 if l_to_split: 1833 n = l_to_split[0] 1834 trace ('checking lhs loop_no_match at %d' % n, push = 1) 1835 if loop_no_match (rep, restrs, hyps, n, r_tag): 1836 trace ('loop does not match!', push = -1) 1837 return ('Restr', ('Number', [n])) 1838 trace (' .. done checking loop no match', push = -1) 1839 1840 (kind, split) = find_split_loop (p, n, restrs, hyps) 1841 if kind == 'LoopUnroll': 1842 return ('Restr', ('Number', [split])) 1843 return (kind, split) 1844 1845 if r_to_split: 1846 n = r_to_split[0] 1847 trace ('rhs loop alone, limit must be found.') 1848 return ('Restr', ('Number', [n])) 1849 1850 return ('Leaf', None) 1851 1852def use_split_searcher (p, split): 1853 xs = set ([p.loop_id (h) for h in check.split_heads (split)]) 1854 def searcher (p, restrs, hyps): 1855 ys = set ([p.loop_id (h) 1856 for h in init_loops_to_split (p, restrs)]) 1857 if xs <= ys: 1858 return ('Split', split) 1859 else: 1860 return default_searcher (p, restrs, hyps) 1861 return searcher 1862 1863