1# * Copyright 2015, NICTA 2# * 3# * This software may be distributed and modified according to the terms of 4# * the BSD 2-Clause license. Note that NO WARRANTY is provided. 5# * See "LICENSE_BSD2.txt" for details. 6# * 7# * @TAG(NICTA_BSD) 8 9# proof scripts and check process 10 11from rep_graph import mk_graph_slice, Hyp, eq_hyp, pc_true_hyp, pc_false_hyp 12import rep_graph 13from problem import Problem, inline_at_point 14import problem 15 16from solver import to_smt_expr 17from target_objects import functions, pairings, trace, printout 18import target_objects 19from rep_graph import (vc_num, vc_offs, vc_double_range, vc_upto, mk_vc_opts, 20 VisitCount) 21import logic 22 23from syntax import (true_term, false_term, boolT, mk_var, mk_word32, mk_word8, 24 mk_plus, mk_minus, word32T, word8T, mk_and, mk_eq, mk_implies, mk_not, 25 rename_expr) 26import syntax 27 28def build_problem (pairing, force_inline = None, avoid_abort = False): 29 p = Problem (pairing) 30 31 for (tag, fname) in pairing.funs.items (): 32 p.add_entry_function (functions[fname], tag) 33 34 p.do_analysis () 35 36 # FIXME: the inlining is heuristic, and arguably belongs in 'search' 37 inline_completely_unmatched (p, skip_underspec = avoid_abort) 38 39 # now do any C inlining 40 inline_reachable_unmatched_C (p, force_inline, 41 skip_underspec = avoid_abort) 42 43 trace ('Done inlining.') 44 45 p.pad_merge_points () 46 p.do_analysis () 47 48 if not avoid_abort: 49 p.check_no_inner_loops () 50 51 return p 52 53def inline_completely_unmatched (p, ref_tags = None, skip_underspec = False): 54 if ref_tags == None: 55 ref_tags = p.pairing.tags 56 while True: 57 ns = [(n, skip_underspec 58 and not functions[p.nodes[n].fname].entry) 59 for n in p.nodes 60 if p.nodes[n].kind == 'Call' 61 if not [pair for pair 62 in pairings.get (p.nodes[n].fname, []) 63 if pair.tags == ref_tags]] 64 [trace ('Skipped inlining underspecified %s.' 65 % p.nodes[n].fname) for (n, skip) in ns if skip] 66 ns = [n for (n, skip) in ns if not skip] 67 for n in ns: 68 trace ('Function %s at %d - %s - completely unmatched.' 69 % (p.nodes[n].fname, n, p.node_tags[n][0])) 70 inline_at_point (p, n, do_analysis = False) 71 if not ns: 72 p.do_analysis () 73 return 74 75def inline_reachable_unmatched_C (p, force_inline = None, 76 skip_underspec = False): 77 if 'C' not in p.pairing.tags: 78 return 79 [compare_tag] = [tag for tag in p.pairing.tags if tag != 'C'] 80 inline_reachable_unmatched (p, 'C', compare_tag, force_inline, 81 skip_underspec = skip_underspec) 82 83def inline_reachable_unmatched (p, inline_tag, compare_tag, 84 force_inline = None, skip_underspec = False): 85 funs = [pair.funs[inline_tag] 86 for n in p.nodes 87 if p.nodes[n].kind == 'Call' 88 if p.node_tags[n][0] == compare_tag 89 for pair in pairings.get (p.nodes[n].fname, []) 90 if inline_tag in pair.tags] 91 92 rep = mk_graph_slice (p, 93 consider_inline (funs, inline_tag, force_inline, 94 skip_underspec)) 95 opts = vc_double_range (3, 3) 96 while True: 97 try: 98 heads = problem.loop_heads_including_inner (p) 99 limits = [(n, opts) for n in heads] 100 101 for n in p.nodes.keys (): 102 try: 103 r = rep.get_node_pc_env ((n, limits)) 104 except rep.TooGeneral: 105 pass 106 107 rep.get_node_pc_env (('Ret', limits), inline_tag) 108 rep.get_node_pc_env (('Err', limits), inline_tag) 109 break 110 except rep_graph.InlineEvent: 111 continue 112 113def consider_inline1 (p, n, matched_funs, inline_tag, 114 force_inline, skip_underspec): 115 node = p.nodes[n] 116 assert node.kind == 'Call' 117 118 if p.node_tags[n][0] != inline_tag: 119 return False 120 121 f_nm = node.fname 122 if skip_underspec and not functions[f_nm].entry: 123 trace ('Skipping inlining underspecified %s' % f_nm) 124 return False 125 if f_nm not in matched_funs or (force_inline and force_inline (f_nm)): 126 return lambda: inline_at_point (p, n) 127 else: 128 return False 129 130def consider_inline (matched_funs, tag, force_inline, skip_underspec = False): 131 return lambda (p, n): consider_inline1 (p, n, matched_funs, tag, 132 force_inline, skip_underspec) 133 134def inst_eqs (p, restrs, eqs, tag_map = {}): 135 addr_map = {} 136 if not tag_map: 137 tag_map = dict ([(tag, tag) for tag in p.tags ()]) 138 for (pair_tag, p_tag) in tag_map.iteritems (): 139 addr_map[pair_tag + '_IN'] = ((p.get_entry (p_tag), ()), p_tag) 140 addr_map[pair_tag + '_OUT'] = (('Ret', restrs), p_tag) 141 renames = p.entry_exit_renames (tag_map.values ()) 142 for (pair_tag, p_tag) in tag_map.iteritems (): 143 renames[pair_tag + '_IN'] = renames[p_tag + '_IN'] 144 renames[pair_tag + '_OUT'] = renames[p_tag + '_OUT'] 145 hyps = [] 146 for (lhs, rhs) in eqs: 147 vals = [(rename_expr (x, renames[x_addr]), addr_map[x_addr]) 148 for (x, x_addr) in (lhs, rhs)] 149 hyps.append (eq_hyp (vals[0], vals[1])) 150 return hyps 151 152def init_point_hyps (p): 153 (inp_eqs, _) = p.pairing.eqs 154 return inst_eqs (p, (), inp_eqs) 155 156class ProofNode: 157 def __init__ (self, kind, args = None, subproofs = []): 158 self.kind = kind 159 self.args = args 160 self.subproofs = tuple (subproofs) 161 if self.kind == 'Leaf': 162 assert args == None 163 assert list (subproofs) == [] 164 elif self.kind == 'Restr': 165 (self.point, self.restr_range) = args 166 assert len (subproofs) == 1 167 elif self.kind == 'SingleRevInduct': 168 (self.point, self.eqs_proof, self.rev_proof) = args 169 assert len (subproofs) == 1 170 elif self.kind == 'Split': 171 self.split = args 172 (l_details, r_details, eqs, n, loop_r_max) = args 173 assert len (subproofs) == 2 174 elif self.kind == 'CaseSplit': 175 (self.point, self.tag) = args 176 assert len (subproofs) == 2 177 else: 178 assert not 'proof node kind understood', kind 179 180 def __repr__ (self): 181 return 'ProofNode (%r, %r, %r)' % (self.kind, 182 self.args, self.subproofs) 183 184 def serialise (self, p, ss): 185 if self.kind == 'Leaf': 186 ss.append ('Leaf') 187 elif self.kind == 'Restr': 188 (kind, (x, y)) = self.restr_range 189 tag = p.node_tags[self.point][0] 190 ss.extend (['Restr', '%d' % self.point, 191 tag, kind, '%d' % x, '%d' % y]) 192 elif self.kind == 'SingleRevInduct': 193 tag = p.node_tags[self.point][0] 194 (eqs, n) = self.eqs_proof 195 ss.extend (['SingleRevInduct', '%d' % self.point, 196 tag, '%d' % n, '%d' % len (eqs)]) 197 for (x, y) in eqs: 198 serialise_lambda (x, ss) 199 serialise_lambda (y, ss) 200 (pred, n_bound) = self.rev_proof 201 pred.serialise (ss) 202 ss.append ('%d' % n_bound) 203 elif self.kind == 'Split': 204 (l_details, r_details, eqs, n, loop_r_max) = self.args 205 ss.extend (['Split', '%d' % n, '%d' % loop_r_max]) 206 serialise_details (l_details, ss) 207 serialise_details (r_details, ss) 208 ss.append ('%d' % len (eqs)) 209 for (x, y) in eqs: 210 serialise_lambda (x, ss) 211 serialise_lambda (y, ss) 212 elif self.kind == 'CaseSplit': 213 ss.extend (['CaseSplit', '%d' % self.point, self.tag]) 214 else: 215 assert not 'proof node kind understood' 216 for proof in self.subproofs: 217 proof.serialise (p, ss) 218 219 def all_subproofs (self): 220 return [self] + [proof for proof1 in self.subproofs 221 for proof in proof1.all_subproofs ()] 222 223 def all_subproblems (self, p, restrs, hyps, name): 224 subproblems = proof_subproblems (p, self.kind, 225 self.args, restrs, hyps, name) 226 subproofs = logic.azip (subproblems, self.subproofs) 227 return [(self, restrs, hyps)] + [problem 228 for ((restrs2, hyps2, name2), proof) in subproofs 229 for problem in proof.all_subproblems (p, restrs2, 230 hyps2, name2)] 231 232 def save_serialise (self, p, fname): 233 f = open (fname, 'w') 234 ss = [] 235 self.serialise (p, ss) 236 f.write (' '.join (ss) + '\n') 237 f.close () 238 239 def __hash__ (self): 240 return syntax.hash_tuplify (self.kind, self.args, 241 self.subproofs) 242 243def serialise_details (details, ss): 244 (split, (seq_start, step), eqs) = details 245 ss.extend (['%d' % split, '%d' % seq_start, '%d' % step]) 246 ss.append ('%d' % len (eqs)) 247 for eq in eqs: 248 serialise_lambda (eq, ss) 249 250def serialise_lambda (eq_term, ss): 251 ss.extend (['Lambda', '%i']) 252 word32T.serialise (ss) 253 eq_term.serialise (ss) 254 255def deserialise_details (ss, i): 256 (split, seq_start, step) = [int (x) for x in ss[i : i + 3]] 257 (i, eqs) = syntax.parse_list (deserialise_lambda, ss, i + 3) 258 return (i, (split, (seq_start, step), eqs)) 259 260def deserialise_lambda (ss, i): 261 assert ss[i : i + 2] == ['Lambda', '%i'], (ss, i) 262 (i, typ) = syntax.parse_typ (ss, i + 2) 263 assert typ == word32T, typ 264 (i, eq_term) = syntax.parse_expr (ss, i) 265 return (i, eq_term) 266 267def deserialise_double_lambda (ss, i): 268 (i, x) = deserialise_lambda (ss, i) 269 (i, y) = deserialise_lambda (ss, i) 270 return (i, (x, y)) 271 272def deserialise_inner (ss, i): 273 if ss[i] == 'Leaf': 274 return (i + 1, ProofNode ('Leaf')) 275 elif ss[i] == 'Restr': 276 point = int (ss[i + 1]) 277 tag = ss[i + 2] 278 kind = ss[i + 3] 279 assert kind in ['Number', 'Offset'], (kind, i) 280 x = int (ss[i + 4]) 281 y = int (ss[i + 5]) 282 (i, p1) = deserialise_inner (ss, i + 6) 283 return (i, ProofNode ('Restr', (point, (kind, (x, y))), [p1])) 284 elif ss[i] == 'SingleRevInduct': 285 point = int (ss[i + 1]) 286 tag = ss[i + 2] 287 n = int (ss[i + 3]) 288 (i, eqs) = syntax.parse_list (deserialise_double_lambda, ss, i + 4) 289 (i, pred) = syntax.parse_term (ss, i) 290 n_bound = int (ss[i]) 291 (i, p1) = deserialise_inner (ss, i + 1) 292 return (i, ProofNode ('SingleRevInduct', (point, (eqs, n), 293 (pred, n_bound)), [p1])) 294 elif ss[i] == 'Split': 295 n = int (ss[i + 1]) 296 loop_r_max = int (ss[i + 2]) 297 (i, l_details) = deserialise_details (ss, i + 3) 298 (i, r_details) = deserialise_details (ss, i) 299 (i, eqs) = syntax.parse_list (deserialise_double_lambda, ss, i) 300 (i, p1) = deserialise_inner (ss, i) 301 (i, p2) = deserialise_inner (ss, i) 302 return (i, ProofNode ('Split', (l_details, r_details, eqs, 303 n, loop_r_max), [p1, p2])) 304 elif ss[i] == 'CaseSplit': 305 n = int (ss[i + 1]) 306 tag = ss[i + 2] 307 (i, p1) = deserialise_inner (ss, i + 3) 308 (i, p2) = deserialise_inner (ss, i) 309 return (i, ProofNode ('CaseSplit', (n, tag), [p1, p2])) 310 else: 311 assert not 'proof node type understood', (ss, i) 312 313def deserialise (line): 314 ss = line.split () 315 (i, proof) = deserialise_inner (ss, 0) 316 assert i == len (ss), (ss, i) 317 return proof 318 319def proof_subproblems (p, kind, args, restrs, hyps, path): 320 tags = p.pairing.tags 321 if kind == 'Leaf': 322 return [] 323 elif kind == 'Restr': 324 restr = get_proof_restr (args[0], args[1]) 325 hyps = hyps + [restr_trivial_hyp (p, args[0], args[1], restrs)] 326 return [((restr,) + restrs, hyps, 327 '%s (%d limited)' % (path, args[0]))] 328 elif kind == 'SingleRevInduct': 329 hyp = single_induct_resulting_hyp (p, restrs, args) 330 return [(restrs, hyps + [hyp], path)] 331 elif kind == 'Split': 332 split = args 333 return [(restrs, hyps + split_no_loop_hyps (tags, split, restrs), 334 '%d init case in %s' % (split[0][0], path)), 335 (restrs, hyps + split_loop_hyps (tags, split, restrs, exit = True), 336 '%d loop case in %s' % (split[0][0], path))] 337 elif kind == 'CaseSplit': 338 (point, tag) = args 339 visit = ((point, restrs), tag) 340 true_hyps = hyps + [pc_true_hyp (visit)] 341 false_hyps = hyps + [pc_false_hyp (visit)] 342 return [(restrs, true_hyps, 343 'true case (%d visited) in %s' % (point, path)), 344 (restrs, false_hyps, 345 'false case (%d not visited) in %s' % (point, path))] 346 else: 347 assert not 'proof node kind understood', proof.kind 348 349 350def split_heads ((l_details, r_details, eqs, n, _)): 351 (l_split, _, _) = l_details 352 (r_split, _, _) = r_details 353 return [l_split, r_split] 354 355def split_no_loop_hyps (tags, split, restrs): 356 ((_, (l_seq_start, l_step), _), _, _, n, _) = split 357 358 (l_visit, _) = split_visit_visits (tags, split, restrs, vc_num (n)) 359 360 return [pc_false_hyp (l_visit)] 361 362def split_visit_one_visit (tag, details, restrs, visit): 363 if details == None: 364 return None 365 (split, (seq_start, step), eqs) = details 366 367 # the split point sequence at low numbers ('Number') is offset 368 # by the point the sequence starts. At symbolic offsets we ignore 369 # that, instead having the loop counter for the two sequences 370 # be the same number of iterations after the sequence start. 371 if visit.kind == 'Offset': 372 visit = vc_offs (visit.n * step) 373 else: 374 visit = vc_num (seq_start + (visit.n * step)) 375 376 visit = ((split, ((split, visit), ) + restrs), tag) 377 return visit 378 379def split_visit_visits (tags, split, restrs, visit): 380 (ltag, rtag) = tags 381 (l_details, r_details, eqs, _, _) = split 382 383 l_visit = split_visit_one_visit (ltag, l_details, restrs, visit) 384 r_visit = split_visit_one_visit (rtag, r_details, restrs, visit) 385 386 return (l_visit, r_visit) 387 388def split_hyps_at_visit (tags, split, restrs, visit): 389 (l_details, r_details, eqs, _, _) = split 390 (l_split, (l_seq_start, l_step), l_eqs) = l_details 391 (r_split, (r_seq_start, r_step), r_eqs) = r_details 392 393 (l_visit, r_visit) = split_visit_visits (tags, split, restrs, visit) 394 (l_start, r_start) = split_visit_visits (tags, split, restrs, vc_num (0)) 395 (l_tag, r_tag) = tags 396 397 def mksub (v): 398 return lambda exp: logic.var_subst (exp, {('%i', word32T) : v}, 399 must_subst = False) 400 def inst (exp): 401 return logic.inst_eq_at_visit (exp, visit) 402 zsub = mksub (mk_word32 (0)) 403 if visit.kind == 'Number': 404 lsub = mksub (mk_word32 (visit.n)) 405 else: 406 lsub = mksub (mk_plus (mk_var ('%n', word32T), 407 mk_word32 (visit.n))) 408 409 hyps = [(Hyp ('PCImp', l_visit, r_visit), 'pc imp'), 410 (Hyp ('PCImp', l_visit, l_start), '%s pc imp' % l_tag), 411 (Hyp ('PCImp', r_visit, r_start), '%s pc imp' % r_tag)] 412 hyps += [(eq_hyp ((zsub (l_exp), l_start), (lsub (l_exp), l_visit), 413 (l_split, r_split)), '%s const' % l_tag) 414 for l_exp in l_eqs if inst (l_exp)] 415 hyps += [(eq_hyp ((zsub (r_exp), r_start), (lsub (r_exp), r_visit), 416 (l_split, r_split)), '%s const' % r_tag) 417 for r_exp in r_eqs if inst (r_exp)] 418 hyps += [(eq_hyp ((lsub (l_exp), l_visit), (lsub (r_exp), r_visit), 419 (l_split, r_split)), 'eq') 420 for (l_exp, r_exp) in eqs 421 if inst (l_exp) and inst (r_exp)] 422 return hyps 423 424def split_loop_hyps (tags, split, restrs, exit): 425 ((r_split, _, _), _, _, n, _) = split 426 (l_visit, _) = split_visit_visits (tags, split, restrs, vc_offs (n - 1)) 427 (l_cont, _) = split_visit_visits (tags, split, restrs, vc_offs (n)) 428 (l_tag, r_tag) = tags 429 430 l_enter = pc_true_hyp (l_visit) 431 l_exit = pc_false_hyp (l_cont) 432 if exit: 433 hyps = [l_enter, l_exit] 434 else: 435 hyps = [l_enter] 436 return hyps + [hyp for offs in map (vc_offs, range (n)) 437 for (hyp, _) in split_hyps_at_visit (tags, split, restrs, offs)] 438 439def loops_to_split (p, restrs): 440 loop_heads_with_split = set ([p.loop_id (n) 441 for (n, visit_set) in restrs]) 442 rem_loop_heads = set (p.loop_heads ()) - loop_heads_with_split 443 for (n, visit_set) in restrs: 444 if not visit_set.has_zero (): 445 # n must be visited, so loop heads must be 446 # reachable from n (or on another tag) 447 rem_loop_heads = [lh for lh in rem_loop_heads 448 if p.is_reachable_from (n, lh) 449 or p.node_tags[n][0] != p.node_tags[lh][0]] 450 return rem_loop_heads 451 452def restr_others (p, restrs, n): 453 extras = [(sp, vc_upto (n)) for sp in loops_to_split (p, restrs)] 454 return restrs + tuple (extras) 455 456def non_r_err_pc_hyp (tags, restrs): 457 return pc_false_hyp ((('Err', restrs), tags[1])) 458 459def split_r_err_pc_hyp (p, split, restrs, tags = None): 460 (_, r_details, _, n, loop_r_max) = split 461 (r_split, (r_seq_start, r_step), r_eqs) = r_details 462 463 nc = n * r_step 464 vc = vc_double_range (r_seq_start + nc, loop_r_max + 2) 465 466 restrs = restr_others (p, ((r_split, vc), ) + restrs, 2) 467 468 if tags == None: 469 tags = p.pairing.tags 470 471 return non_r_err_pc_hyp (tags, restrs) 472 473restr_bump = 0 474 475def get_proof_restr (n, (kind, (x, y))): 476 return (n, mk_vc_opts ([VisitCount (kind, i) 477 for i in range (x, y + restr_bump)])) 478 479def restr_trivial_hyp (p, n, (kind, (x, y)), restrs): 480 restr = (n, VisitCount (kind, y - 1)) 481 return rep_graph.pc_triv_hyp (((n, (restr, ) + restrs), 482 p.node_tags[n][0])) 483 484def proof_restr_checks (n, (kind, (x, y)), p, restrs, hyps): 485 restr = get_proof_restr (n, (kind, (x, y))) 486 ncerr_hyp = non_r_err_pc_hyp (p.pairing.tags, 487 restr_others (p, (restr, ) + restrs, 2)) 488 hyps = [ncerr_hyp] + hyps 489 def visit (vc): 490 return ((n, ((n, vc), ) + restrs), p.node_tags[n][0]) 491 492 # this cannot be more uniform because the representation of visit 493 # at offset 0 is all a bit odd, with n being the only node so visited: 494 if kind == 'Offset': 495 min_vc = vc_offs (max (0, x - 1)) 496 elif x > 1: 497 min_vc = vc_num (x - 1) 498 else: 499 min_vc = None 500 if min_vc: 501 init_check = [(hyps, pc_true_hyp (visit (min_vc)), 502 'Check of restr min %d %s for %d' % (x, kind, n))] 503 else: 504 init_check = [] 505 506 # if we can reach node n with (y - 1) visits to n, then the next 507 # node will have y visits to n, which we are disallowing 508 # thus we show that this visit is impossible 509 top_vc = VisitCount (kind, y - 1) 510 top_check = (hyps, pc_false_hyp (visit (top_vc)), 511 'Check of restr max %d %s for %d' % (y, kind, n)) 512 return init_check + [top_check] 513 514def split_init_step_checks (p, restrs, hyps, split, tags = None): 515 (_, _, _, n, _) = split 516 if tags == None: 517 tags = p.pairing.tags 518 519 err_hyp = split_r_err_pc_hyp (p, split, restrs, tags = tags) 520 hyps = [err_hyp] + hyps 521 checks = [] 522 for i in range (n): 523 (l_visit, r_visit) = split_visit_visits (tags, split, 524 restrs, vc_num (i)) 525 lpc_hyp = pc_true_hyp (l_visit) 526 # this trivial 'hyp' ensures the rep is built to include 527 # the matching rhs visits when checking lhs consts 528 rpc_triv_hyp = rep_graph.pc_triv_hyp (r_visit) 529 vis_hyps = split_hyps_at_visit (tags, split, restrs, vc_num (i)) 530 531 for (hyp, desc) in vis_hyps: 532 checks.append ((hyps + [lpc_hyp, rpc_triv_hyp], hyp, 533 'Induct check at visit %d: %s' % (i, desc))) 534 return checks 535 536def split_induct_step_checks (p, restrs, hyps, split, tags = None): 537 ((l_split, _, _), _, _, n, _) = split 538 if tags == None: 539 tags = p.pairing.tags 540 541 err_hyp = split_r_err_pc_hyp (p, split, restrs, tags = tags) 542 (cont, r_cont) = split_visit_visits (tags, split, restrs, vc_offs (n)) 543 # the 'trivial' hyp here ensures the representation includes a loop 544 # of the rhs when proving const equations on the lhs 545 hyps = ([err_hyp, pc_true_hyp (cont), 546 rep_graph.pc_triv_hyp (r_cont)] + hyps 547 + split_loop_hyps (tags, split, restrs, exit = False)) 548 549 return [(hyps, hyp, 'Induct check (%s) at inductive step for %d' 550 % (desc, l_split)) 551 for (hyp, desc) in split_hyps_at_visit (tags, split, 552 restrs, vc_offs (n))] 553 554def check_split_induct_step_group (rep, restrs, hyps, split, tags = None): 555 checks = split_induct_step_checks (rep.p, restrs, hyps, split, 556 tags = tags) 557 groups = proof_check_groups (checks) 558 for group in groups: 559 (verdict, _) = test_hyp_group (rep, group) 560 if not verdict: 561 return False 562 return True 563 564def split_checks (p, restrs, hyps, split, tags = None): 565 return (split_init_step_checks (p, restrs, hyps, split, tags = tags) 566 + split_induct_step_checks (p, restrs, hyps, split, tags = tags)) 567 568def loop_eq_hyps_at_visit (tag, split, eqs, restrs, visit_num, 569 use_if_at = False): 570 details = (split, (0, 1), eqs) 571 visit = split_visit_one_visit (tag, details, restrs, visit_num) 572 start = split_visit_one_visit (tag, details, restrs, vc_num (0)) 573 574 def mksub (v): 575 return lambda exp: logic.var_subst (exp, {('%i', word32T) : v}, 576 must_subst = False) 577 zsub = mksub (mk_word32 (0)) 578 if visit_num.kind == 'Number': 579 isub = mksub (mk_word32 (visit_num.n)) 580 else: 581 isub = mksub (mk_plus (mk_var ('%n', word32T), 582 mk_word32 (visit_num.n))) 583 584 hyps = [(Hyp ('PCImp', visit, start), '%s pc imp' % tag)] 585 hyps += [(eq_hyp ((zsub (exp), start), (isub (exp), visit), 586 (split, 0), use_if_at = use_if_at), '%s const' % tag) 587 for exp in eqs if logic.inst_eq_at_visit (exp, visit_num)] 588 589 return hyps 590 591def single_induct_resulting_hyp (p, restrs, rev_induct_args): 592 (point, _, (pred, _)) = rev_induct_args 593 (tag, _) = p.node_tags[point] 594 vis = ((point, restrs + tuple ([(point, vc_num (0))])), tag) 595 return rep_graph.true_if_at_hyp (pred, vis) 596 597def single_loop_induct_base_checks (p, restrs, hyps, tag, split, n, eqs): 598 tests = [] 599 details = (split, (0, 1), eqs) 600 for i in range (n + 1): 601 reach = split_visit_one_visit (tag, details, restrs, vc_num (i)) 602 nhyps = [pc_true_hyp (reach)] 603 tests.extend ([(hyps + nhyps, hyp, 604 'Base check (%s, %d) at induct step for %d' 605 % (desc, i, split)) 606 for (hyp, desc) in loop_eq_hyps_at_visit (tag, split, 607 eqs, restrs, vc_num (i))]) 608 return tests 609 610def single_loop_induct_step_checks (p, restrs, hyps, tag, split, n, 611 eqs, eqs_assume = None): 612 if eqs_assume == None: 613 eqs_assume = [] 614 details = (split, (0, 1), eqs_assume + eqs) 615 cont = split_visit_one_visit (tag, details, restrs, vc_offs (n)) 616 hyps = ([pc_true_hyp (cont)] + hyps 617 + [h for i in range (n) 618 for (h, _) in loop_eq_hyps_at_visit (tag, split, 619 eqs_assume + eqs, restrs, vc_offs (i))]) 620 621 return [(hyps, hyp, 'Induct check (%s) at inductive step for %d' 622 % (desc, split)) 623 for (hyp, desc) in loop_eq_hyps_at_visit (tag, split, eqs, 624 restrs, vc_offs (n))] 625 626def mk_loop_counter_eq_hyp (p, split, restrs, n): 627 details = (split, (0, 1), []) 628 (tag, _) = p.node_tags[split] 629 visit = split_visit_one_visit (tag, details, restrs, vc_offs (0)) 630 return eq_hyp ((mk_var ('%n', word32T), visit), 631 (mk_word32 (n), visit), (split, 0)) 632 633def single_loop_rev_induct_base_checks (p, restrs, hyps, tag, split, 634 n_bound, eqs_assume, pred): 635 details = (split, (0, 1), eqs_assume) 636 cont = split_visit_one_visit (tag, details, restrs, vc_offs (1)) 637 n_hyp = mk_loop_counter_eq_hyp (p, split, restrs, n_bound) 638 639 split_details = (None, details, None, 1, 1) 640 non_err = split_r_err_pc_hyp (p, split_details, restrs) 641 642 hyps = (hyps + [n_hyp, pc_true_hyp (cont), non_err] 643 + [h for (h, _) in loop_eq_hyps_at_visit (tag, 644 split, eqs_assume, restrs, vc_offs (0))]) 645 goal = rep_graph.true_if_at_hyp (pred, cont) 646 647 return [(hyps, goal, 'Pred true at %d check.' % n_bound)] 648 649def single_loop_rev_induct_checks (p, restrs, hyps, tag, split, 650 eqs_assume, pred): 651 details = (split, (0, 1), eqs_assume) 652 curr = split_visit_one_visit (tag, details, restrs, vc_offs (1)) 653 cont = split_visit_one_visit (tag, details, restrs, vc_offs (2)) 654 655 split_details = (None, details, None, 1, 1) 656 non_err = split_r_err_pc_hyp (p, split_details, restrs) 657 true_next = rep_graph.true_if_at_hyp (pred, cont) 658 659 hyps = (hyps + [pc_true_hyp (curr), true_next, non_err] 660 + [h for (h, _) in loop_eq_hyps_at_visit (tag, split, 661 eqs_assume, restrs, vc_offs (1), use_if_at = True)]) 662 goal = rep_graph.true_if_at_hyp (pred, curr) 663 664 return [(hyps, goal, 'Pred reverse step.')] 665 666def all_rev_induct_checks (p, restrs, hyps, point, (eqs, n), (pred, n_bound)): 667 (tag, _) = p.node_tags[point] 668 checks = (single_loop_induct_step_checks (p, restrs, hyps, tag, 669 point, n, eqs) 670 + single_loop_induct_base_checks (p, restrs, hyps, tag, 671 point, n, eqs) 672 + single_loop_rev_induct_checks (p, restrs, hyps, tag, 673 point, eqs, pred) 674 + single_loop_rev_induct_base_checks (p, restrs, hyps, 675 tag, point, n_bound, eqs, pred)) 676 return checks 677 678def leaf_condition_checks (p, restrs, hyps): 679 '''checks of the final refinement conditions''' 680 nrerr_pc_hyp = non_r_err_pc_hyp (p.pairing.tags, restrs) 681 hyps = [nrerr_pc_hyp] + hyps 682 [l_tag, r_tag] = p.pairing.tags 683 684 nlerr_pc = pc_false_hyp ((('Err', restrs), l_tag)) 685 # this 'hypothesis' ensures that the representation is built all 686 # the way to Ret. in particular this ensures that function relations 687 # are available to use in proving single-side equalities 688 ret_eq = eq_hyp ((true_term, (('Ret', restrs), l_tag)), 689 (true_term, (('Ret', restrs), r_tag))) 690 691 ### TODO: previously we considered the case where 'Ret' was unreachable 692 ### (as a result of unsatisfiable hyps) and proved a simpler property. 693 ### we might want to restore this 694 (_, out_eqs) = p.pairing.eqs 695 checks = [(hyps + [nlerr_pc, ret_eq], hyp, 'Leaf eq check') for hyp in 696 inst_eqs (p, restrs, out_eqs)] 697 return [(hyps + [ret_eq], nlerr_pc, 'Leaf path-cond imp')] + checks 698 699def proof_checks (p, proof): 700 return proof_checks_rec (p, (), init_point_hyps (p), proof, 'root') 701 702def proof_checks_imm (p, restrs, hyps, proof, path): 703 if proof.kind == 'Restr': 704 checks = proof_restr_checks (proof.point, proof.restr_range, 705 p, restrs, hyps) 706 elif proof.kind == 'SingleRevInduct': 707 checks = all_rev_induct_checks (p, restrs, hyps, proof.point, 708 proof.eqs_proof, proof.rev_proof) 709 elif proof.kind == 'Split': 710 checks = split_checks (p, restrs, hyps, proof.split) 711 elif proof.kind == 'Leaf': 712 checks = leaf_condition_checks (p, restrs, hyps) 713 elif proof.kind == 'CaseSplit': 714 checks = [] 715 716 return [(hs, hyp, '%s on %s' % (name, path)) 717 for (hs, hyp, name) in checks] 718 719def proof_checks_rec (p, restrs, hyps, proof, path): 720 checks = proof_checks_imm (p, restrs, hyps, proof, path) 721 722 subproblems = proof_subproblems (p, proof.kind, 723 proof.args, restrs, hyps, path) 724 for (subprob, subproof) in logic.azip (subproblems, proof.subproofs): 725 (restrs, hyps, path) = subprob 726 checks.extend (proof_checks_rec (p, restrs, hyps, subproof, path)) 727 return checks 728 729last_failed_check = [None] 730 731def proof_check_groups (checks): 732 groups = {} 733 for (hyps, hyp, name) in checks: 734 n_vcs = set ([n_vc for hyp2 in [hyp] + hyps 735 for n_vc in hyp2.visits ()]) 736 k = (tuple (sorted (list (n_vcs)))) 737 groups.setdefault (k, []).append ((hyps, hyp, name)) 738 return groups.values () 739 740def test_hyp_group (rep, group, detail = None): 741 imps = [(hyps, hyp) for (hyps, hyp, _) in group] 742 names = set ([name for (_, _, name) in group]) 743 744 trace ('Testing group of hyps: %s' % list (names), push = 1) 745 (res, i, res_kind) = rep.test_hyp_imps (imps) 746 trace ('Group result: %r' % res, push = -1) 747 if res: 748 return (res, None) 749 else: 750 if detail: 751 detail[0] = res_kind 752 return (res, group[i]) 753 754def failed_test_sets (p, checks): 755 failed = [] 756 sets = {} 757 for (hyps, hyp, name) in checks: 758 sets.setdefault (name, []) 759 sets[name].append ((hyps, hyp)) 760 for name in sets: 761 rep = rep_graph.mk_graph_slice (p) 762 (res, _, _) = rep.test_hyp_imps (sets[name]) 763 if not res: 764 failed.append (name) 765 return failed 766 767save_checked_proofs = [None] 768 769def check_proof (p, proof, use_rep = None): 770 checks = proof_checks (p, proof) 771 groups = proof_check_groups (checks) 772 773 for group in groups: 774 if use_rep == None: 775 rep = rep_graph.mk_graph_slice (p) 776 else: 777 rep = use_rep 778 779 detail = [0] 780 (verdict, elt) = test_hyp_group (rep, group, detail) 781 if verdict: 782 continue 783 (hyps, hyp, name) = elt 784 last_failed_check[0] = elt 785 trace ('%s: proof failed!' % name) 786 trace (' (failure kind: %r)' % detail[0]) 787 return False 788 if save_checked_proofs[0]: 789 save = save_checked_proofs[0] 790 save (p, proof) 791 return True 792 793def pretty_vseq ((split, (seq_start, seq_step), _)): 794 if (seq_start, seq_step) == (0, 1): 795 return 'visits to %d' % split 796 else: 797 i = seq_start + 1 798 j = i + seq_step 799 k = j + seq_step 800 return 'visits [%d, %d, %d ...] to %d' % (i, j, k, split) 801 802def next_induct_var (n): 803 s = 'ijkabc' 804 v = s[n % 6] 805 if n >= 6: 806 v += str ((n / 6) + 1) 807 return v 808 809def pretty_lambda (t): 810 v = syntax.mk_var ('#seq-visits', word32T) 811 t = logic.var_subst (t, {('%i', word32T) : v}, must_subst = False) 812 return syntax.pretty_expr (t, print_type = True) 813 814def check_proof_report_rec (p, restrs, hyps, proof, step_num, ctxt, inducts, 815 do_check = True): 816 printout ('Step %d: %s' % (step_num, ctxt)) 817 if proof.kind == 'Restr': 818 (kind, (x, y)) = proof.restr_range 819 if kind == 'Offset': 820 v = inducts[1][proof.point] 821 rexpr = '{%s + %s ..< %s + %s}' % (v, x, v, y) 822 else: 823 rexpr = '{%s ..< %s}' % (x, y) 824 printout (' Prove the number of visits to %d is in %s' 825 % (proof.point, rexpr)) 826 827 checks = proof_restr_checks (proof.point, proof.restr_range, 828 p, restrs, hyps) 829 cases = [''] 830 elif proof.kind == 'SingleRevInduct': 831 printout (' Proving a predicate by future induction.') 832 (eqs, n) = proof.eqs_proof 833 point = proof.point 834 printout (' proving these invariants by %d-induction' % n) 835 for x in eqs: 836 printout (' %s (@ addr %s)' 837 % (pretty_lambda (x), point)) 838 printout (' then establishing this predicate') 839 (pred, n_bound) = proof.rev_proof 840 printout (' %s (@ addr %s)' 841 % (pretty_lambda (pred), point)) 842 printout (' at large iterations (%d) and by back induction.' 843 % n_bound) 844 cases = [''] 845 checks = all_rev_induct_checks (p, restrs, hyps, point, 846 proof.eqs_proof, proof.rev_proof) 847 elif proof.kind == 'Split': 848 (l_dts, r_dts, eqs, n, lrmx) = proof.split 849 v = next_induct_var (inducts[0]) 850 inducts = (inducts[0] + 1, dict (inducts[1])) 851 inducts[1][l_dts[0]] = v 852 inducts[1][r_dts[0]] = v 853 printout (' prove %s related to %s' % (pretty_vseq (l_dts), 854 pretty_vseq (r_dts))) 855 printout (' with equalities') 856 for (x, y) in eqs: 857 printout (' %s (@ addr %s)' % (pretty_lambda (x), 858 l_dts[0])) 859 printout (' = %s (@ addr %s)' % (pretty_lambda (y), 860 r_dts[0])) 861 printout (' and with invariants') 862 for x in l_dts[2]: 863 printout (' %s (@ addr %s)' 864 % (pretty_lambda (x), l_dts[0])) 865 for x in r_dts[2]: 866 printout (' %s (@ addr %s)' 867 % (pretty_lambda (x), r_dts[0])) 868 checks = split_checks (p, restrs, hyps, proof.split) 869 cases = ['case in (%d) where the length of the sequence < %d' 870 % (step_num, n), 871 'case in (%d) where the length of the sequence is %s + %s' 872 % (step_num, v, n)] 873 elif proof.kind == 'Leaf': 874 printout (' prove all verification conditions') 875 checks = leaf_condition_checks (p, restrs, hyps) 876 cases = [] 877 elif proof.kind == 'CaseSplit': 878 printout (' case split on whether %d is visited' % proof.point) 879 checks = [] 880 cases = ['case in (%d) where %d is visited' % (step_num, proof.point), 881 'case in (%d) where %d is not visited' % (step_num, proof.point)] 882 883 if checks and do_check: 884 groups = proof_check_groups (checks) 885 for group in groups: 886 rep = rep_graph.mk_graph_slice (p) 887 detail = [0] 888 (res, _) = test_hyp_group (rep, group, detail) 889 if not res: 890 printout (' .. failed to prove this.') 891 printout (' (failure kind: %r)' % detail[0]) 892 return 893 894 printout (' .. proven.') 895 896 subproblems = proof_subproblems (p, proof.kind, 897 proof.args, restrs, hyps, '') 898 xs = logic.azip (subproblems, proof.subproofs) 899 xs = logic.azip (xs, cases) 900 step_num += 1 901 for ((subprob, subproof), case) in xs: 902 (restrs, hyps, _) = subprob 903 res = check_proof_report_rec (p, restrs, hyps, subproof, 904 step_num, case, inducts, do_check = do_check) 905 if not res: 906 return 907 (step_num, induct_var_num) = res 908 inducts = (induct_var_num, inducts[1]) 909 return (step_num, inducts[0]) 910 911def check_proof_report (p, proof, do_check = True): 912 res = check_proof_report_rec (p, (), init_point_hyps (p), proof, 913 1, '', (0, {}), do_check = do_check) 914 return bool (res) 915 916def save_proofs_to_file (fname, mode = 'w'): 917 assert mode in ['w', 'a'] 918 f = open (fname, mode) 919 920 def save (p, proof): 921 f.write ('ProblemProof (%s) {\n' % p.name) 922 for s in p.serialise (): 923 f.write (s + '\n') 924 ss = [] 925 proof.serialise (p, ss) 926 f.write (' '.join (ss)) 927 f.write ('\n}\n') 928 f.flush () 929 return save 930 931def load_proofs_from_file (fname): 932 f = open (fname) 933 934 proofs = {} 935 lines = None 936 for line in f: 937 line = line.strip () 938 if line.startswith ('ProblemProof'): 939 assert line.endswith ('{'), line 940 name_bit = line[len ('ProblemProof') : -1].strip () 941 assert name_bit.startswith ('('), name_bit 942 assert name_bit.endswith (')'), name_bit 943 name = name_bit[1:-1] 944 lines = [] 945 elif line == '}': 946 assert lines[0] == 'Problem' 947 assert lines[-2] == 'EndProblem' 948 import problem 949 trace ('loading proof from %d lines' % len (lines)) 950 p = problem.deserialise (name, lines[:-1]) 951 proof = deserialise (lines[-1]) 952 proofs.setdefault (name, []) 953 proofs[name].append ((p, proof)) 954 trace ('loaded proof %s' % name) 955 lines = None 956 elif line.startswith ('#'): 957 pass 958 elif line: 959 lines.append (line) 960 assert not lines 961 return proofs 962 963