1# 2# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230) 3# 4# SPDX-License-Identifier: BSD-2-Clause 5# 6 7# proof scripts and check process 8 9from rep_graph import mk_graph_slice, Hyp, eq_hyp, pc_true_hyp, pc_false_hyp 10import rep_graph 11from problem import Problem, inline_at_point 12import problem 13 14from solver import to_smt_expr 15from target_objects import functions, pairings, trace, printout 16import target_objects 17from rep_graph import (vc_num, vc_offs, vc_double_range, vc_upto, mk_vc_opts, 18 VisitCount) 19import logic 20 21from syntax import (true_term, false_term, boolT, mk_var, mk_word32, mk_word8, 22 mk_plus, mk_minus, word32T, word8T, mk_and, mk_eq, mk_implies, mk_not, 23 rename_expr) 24import syntax 25 26def build_problem (pairing, force_inline = None, avoid_abort = False): 27 p = Problem (pairing) 28 29 for (tag, fname) in pairing.funs.items (): 30 p.add_entry_function (functions[fname], tag) 31 32 p.do_analysis () 33 34 # FIXME: the inlining is heuristic, and arguably belongs in 'search' 35 inline_completely_unmatched (p, skip_underspec = avoid_abort) 36 37 # now do any C inlining 38 inline_reachable_unmatched_C (p, force_inline, 39 skip_underspec = avoid_abort) 40 41 trace ('Done inlining.') 42 43 p.pad_merge_points () 44 p.do_analysis () 45 46 if not avoid_abort: 47 p.check_no_inner_loops () 48 49 return p 50 51def inline_completely_unmatched (p, ref_tags = None, skip_underspec = False): 52 if ref_tags == None: 53 ref_tags = p.pairing.tags 54 while True: 55 ns = [(n, skip_underspec 56 and not functions[p.nodes[n].fname].entry) 57 for n in p.nodes 58 if p.nodes[n].kind == 'Call' 59 if not [pair for pair 60 in pairings.get (p.nodes[n].fname, []) 61 if pair.tags == ref_tags]] 62 [trace ('Skipped inlining underspecified %s.' 63 % p.nodes[n].fname) for (n, skip) in ns if skip] 64 ns = [n for (n, skip) in ns if not skip] 65 for n in ns: 66 trace ('Function %s at %d - %s - completely unmatched.' 67 % (p.nodes[n].fname, n, p.node_tags[n][0])) 68 inline_at_point (p, n, do_analysis = False) 69 if not ns: 70 p.do_analysis () 71 return 72 73def inline_reachable_unmatched_C (p, force_inline = None, 74 skip_underspec = False): 75 if 'C' not in p.pairing.tags: 76 return 77 [compare_tag] = [tag for tag in p.pairing.tags if tag != 'C'] 78 inline_reachable_unmatched (p, 'C', compare_tag, force_inline, 79 skip_underspec = skip_underspec) 80 81def inline_reachable_unmatched (p, inline_tag, compare_tag, 82 force_inline = None, skip_underspec = False): 83 funs = [pair.funs[inline_tag] 84 for n in p.nodes 85 if p.nodes[n].kind == 'Call' 86 if p.node_tags[n][0] == compare_tag 87 for pair in pairings.get (p.nodes[n].fname, []) 88 if inline_tag in pair.tags] 89 90 rep = mk_graph_slice (p, 91 consider_inline (funs, inline_tag, force_inline, 92 skip_underspec)) 93 opts = vc_double_range (3, 3) 94 while True: 95 try: 96 heads = problem.loop_heads_including_inner (p) 97 limits = [(n, opts) for n in heads] 98 99 for n in p.nodes.keys (): 100 try: 101 r = rep.get_node_pc_env ((n, limits)) 102 except rep.TooGeneral: 103 pass 104 105 rep.get_node_pc_env (('Ret', limits), inline_tag) 106 rep.get_node_pc_env (('Err', limits), inline_tag) 107 break 108 except rep_graph.InlineEvent: 109 continue 110 111def consider_inline1 (p, n, matched_funs, inline_tag, 112 force_inline, skip_underspec): 113 node = p.nodes[n] 114 assert node.kind == 'Call' 115 116 if p.node_tags[n][0] != inline_tag: 117 return False 118 119 f_nm = node.fname 120 if skip_underspec and not functions[f_nm].entry: 121 trace ('Skipping inlining underspecified %s' % f_nm) 122 return False 123 if f_nm not in matched_funs or (force_inline and force_inline (f_nm)): 124 return lambda: inline_at_point (p, n) 125 else: 126 return False 127 128def consider_inline (matched_funs, tag, force_inline, skip_underspec = False): 129 return lambda (p, n): consider_inline1 (p, n, matched_funs, tag, 130 force_inline, skip_underspec) 131 132def inst_eqs (p, restrs, eqs, tag_map = {}): 133 addr_map = {} 134 if not tag_map: 135 tag_map = dict ([(tag, tag) for tag in p.tags ()]) 136 for (pair_tag, p_tag) in tag_map.iteritems (): 137 addr_map[pair_tag + '_IN'] = ((p.get_entry (p_tag), ()), p_tag) 138 addr_map[pair_tag + '_OUT'] = (('Ret', restrs), p_tag) 139 renames = p.entry_exit_renames (tag_map.values ()) 140 for (pair_tag, p_tag) in tag_map.iteritems (): 141 renames[pair_tag + '_IN'] = renames[p_tag + '_IN'] 142 renames[pair_tag + '_OUT'] = renames[p_tag + '_OUT'] 143 hyps = [] 144 for (lhs, rhs) in eqs: 145 vals = [(rename_expr (x, renames[x_addr]), addr_map[x_addr]) 146 for (x, x_addr) in (lhs, rhs)] 147 hyps.append (eq_hyp (vals[0], vals[1])) 148 return hyps 149 150def init_point_hyps (p): 151 (inp_eqs, _) = p.pairing.eqs 152 return inst_eqs (p, (), inp_eqs) 153 154class ProofNode: 155 def __init__ (self, kind, args = None, subproofs = []): 156 self.kind = kind 157 self.args = args 158 self.subproofs = tuple (subproofs) 159 if self.kind == 'Leaf': 160 assert args == None 161 assert list (subproofs) == [] 162 elif self.kind == 'Restr': 163 (self.point, self.restr_range) = args 164 assert len (subproofs) == 1 165 elif self.kind == 'SingleRevInduct': 166 (self.point, self.eqs_proof, self.rev_proof) = args 167 assert len (subproofs) == 1 168 elif self.kind == 'Split': 169 self.split = args 170 (l_details, r_details, eqs, n, loop_r_max) = args 171 assert len (subproofs) == 2 172 elif self.kind == 'CaseSplit': 173 (self.point, self.tag) = args 174 assert len (subproofs) == 2 175 else: 176 assert not 'proof node kind understood', kind 177 178 def __repr__ (self): 179 return 'ProofNode (%r, %r, %r)' % (self.kind, 180 self.args, self.subproofs) 181 182 def serialise (self, p, ss): 183 if self.kind == 'Leaf': 184 ss.append ('Leaf') 185 elif self.kind == 'Restr': 186 (kind, (x, y)) = self.restr_range 187 tag = p.node_tags[self.point][0] 188 ss.extend (['Restr', '%d' % self.point, 189 tag, kind, '%d' % x, '%d' % y]) 190 elif self.kind == 'SingleRevInduct': 191 tag = p.node_tags[self.point][0] 192 (eqs, n) = self.eqs_proof 193 ss.extend (['SingleRevInduct', '%d' % self.point, 194 tag, '%d' % n, '%d' % len (eqs)]) 195 for (x, y) in eqs: 196 serialise_lambda (x, ss) 197 serialise_lambda (y, ss) 198 (pred, n_bound) = self.rev_proof 199 pred.serialise (ss) 200 ss.append ('%d' % n_bound) 201 elif self.kind == 'Split': 202 (l_details, r_details, eqs, n, loop_r_max) = self.args 203 ss.extend (['Split', '%d' % n, '%d' % loop_r_max]) 204 serialise_details (l_details, ss) 205 serialise_details (r_details, ss) 206 ss.append ('%d' % len (eqs)) 207 for (x, y) in eqs: 208 serialise_lambda (x, ss) 209 serialise_lambda (y, ss) 210 elif self.kind == 'CaseSplit': 211 ss.extend (['CaseSplit', '%d' % self.point, self.tag]) 212 else: 213 assert not 'proof node kind understood' 214 for proof in self.subproofs: 215 proof.serialise (p, ss) 216 217 def all_subproofs (self): 218 return [self] + [proof for proof1 in self.subproofs 219 for proof in proof1.all_subproofs ()] 220 221 def all_subproblems (self, p, restrs, hyps, name): 222 subproblems = proof_subproblems (p, self.kind, 223 self.args, restrs, hyps, name) 224 subproofs = logic.azip (subproblems, self.subproofs) 225 return [(self, restrs, hyps)] + [problem 226 for ((restrs2, hyps2, name2), proof) in subproofs 227 for problem in proof.all_subproblems (p, restrs2, 228 hyps2, name2)] 229 230 def save_serialise (self, p, fname): 231 f = open (fname, 'w') 232 ss = [] 233 self.serialise (p, ss) 234 f.write (' '.join (ss) + '\n') 235 f.close () 236 237 def __hash__ (self): 238 return syntax.hash_tuplify (self.kind, self.args, 239 self.subproofs) 240 241def serialise_details (details, ss): 242 (split, (seq_start, step), eqs) = details 243 ss.extend (['%d' % split, '%d' % seq_start, '%d' % step]) 244 ss.append ('%d' % len (eqs)) 245 for eq in eqs: 246 serialise_lambda (eq, ss) 247 248def serialise_lambda (eq_term, ss): 249 ss.extend (['Lambda', '%i']) 250 word32T.serialise (ss) 251 eq_term.serialise (ss) 252 253def deserialise_details (ss, i): 254 (split, seq_start, step) = [int (x) for x in ss[i : i + 3]] 255 (i, eqs) = syntax.parse_list (deserialise_lambda, ss, i + 3) 256 return (i, (split, (seq_start, step), eqs)) 257 258def deserialise_lambda (ss, i): 259 assert ss[i : i + 2] == ['Lambda', '%i'], (ss, i) 260 (i, typ) = syntax.parse_typ (ss, i + 2) 261 assert typ == word32T, typ 262 (i, eq_term) = syntax.parse_expr (ss, i) 263 return (i, eq_term) 264 265def deserialise_double_lambda (ss, i): 266 (i, x) = deserialise_lambda (ss, i) 267 (i, y) = deserialise_lambda (ss, i) 268 return (i, (x, y)) 269 270def deserialise_inner (ss, i): 271 if ss[i] == 'Leaf': 272 return (i + 1, ProofNode ('Leaf')) 273 elif ss[i] == 'Restr': 274 point = int (ss[i + 1]) 275 tag = ss[i + 2] 276 kind = ss[i + 3] 277 assert kind in ['Number', 'Offset'], (kind, i) 278 x = int (ss[i + 4]) 279 y = int (ss[i + 5]) 280 (i, p1) = deserialise_inner (ss, i + 6) 281 return (i, ProofNode ('Restr', (point, (kind, (x, y))), [p1])) 282 elif ss[i] == 'SingleRevInduct': 283 point = int (ss[i + 1]) 284 tag = ss[i + 2] 285 n = int (ss[i + 3]) 286 (i, eqs) = syntax.parse_list (deserialise_double_lambda, ss, i + 4) 287 (i, pred) = syntax.parse_term (ss, i) 288 n_bound = int (ss[i]) 289 (i, p1) = deserialise_inner (ss, i + 1) 290 return (i, ProofNode ('SingleRevInduct', (point, (eqs, n), 291 (pred, n_bound)), [p1])) 292 elif ss[i] == 'Split': 293 n = int (ss[i + 1]) 294 loop_r_max = int (ss[i + 2]) 295 (i, l_details) = deserialise_details (ss, i + 3) 296 (i, r_details) = deserialise_details (ss, i) 297 (i, eqs) = syntax.parse_list (deserialise_double_lambda, ss, i) 298 (i, p1) = deserialise_inner (ss, i) 299 (i, p2) = deserialise_inner (ss, i) 300 return (i, ProofNode ('Split', (l_details, r_details, eqs, 301 n, loop_r_max), [p1, p2])) 302 elif ss[i] == 'CaseSplit': 303 n = int (ss[i + 1]) 304 tag = ss[i + 2] 305 (i, p1) = deserialise_inner (ss, i + 3) 306 (i, p2) = deserialise_inner (ss, i) 307 return (i, ProofNode ('CaseSplit', (n, tag), [p1, p2])) 308 else: 309 assert not 'proof node type understood', (ss, i) 310 311def deserialise (line): 312 ss = line.split () 313 (i, proof) = deserialise_inner (ss, 0) 314 assert i == len (ss), (ss, i) 315 return proof 316 317def proof_subproblems (p, kind, args, restrs, hyps, path): 318 tags = p.pairing.tags 319 if kind == 'Leaf': 320 return [] 321 elif kind == 'Restr': 322 restr = get_proof_restr (args[0], args[1]) 323 hyps = hyps + [restr_trivial_hyp (p, args[0], args[1], restrs)] 324 return [((restr,) + restrs, hyps, 325 '%s (%d limited)' % (path, args[0]))] 326 elif kind == 'SingleRevInduct': 327 hyp = single_induct_resulting_hyp (p, restrs, args) 328 return [(restrs, hyps + [hyp], path)] 329 elif kind == 'Split': 330 split = args 331 return [(restrs, hyps + split_no_loop_hyps (tags, split, restrs), 332 '%d init case in %s' % (split[0][0], path)), 333 (restrs, hyps + split_loop_hyps (tags, split, restrs, exit = True), 334 '%d loop case in %s' % (split[0][0], path))] 335 elif kind == 'CaseSplit': 336 (point, tag) = args 337 visit = ((point, restrs), tag) 338 true_hyps = hyps + [pc_true_hyp (visit)] 339 false_hyps = hyps + [pc_false_hyp (visit)] 340 return [(restrs, true_hyps, 341 'true case (%d visited) in %s' % (point, path)), 342 (restrs, false_hyps, 343 'false case (%d not visited) in %s' % (point, path))] 344 else: 345 assert not 'proof node kind understood', proof.kind 346 347 348def split_heads ((l_details, r_details, eqs, n, _)): 349 (l_split, _, _) = l_details 350 (r_split, _, _) = r_details 351 return [l_split, r_split] 352 353def split_no_loop_hyps (tags, split, restrs): 354 ((_, (l_seq_start, l_step), _), _, _, n, _) = split 355 356 (l_visit, _) = split_visit_visits (tags, split, restrs, vc_num (n)) 357 358 return [pc_false_hyp (l_visit)] 359 360def split_visit_one_visit (tag, details, restrs, visit): 361 if details == None: 362 return None 363 (split, (seq_start, step), eqs) = details 364 365 # the split point sequence at low numbers ('Number') is offset 366 # by the point the sequence starts. At symbolic offsets we ignore 367 # that, instead having the loop counter for the two sequences 368 # be the same number of iterations after the sequence start. 369 if visit.kind == 'Offset': 370 visit = vc_offs (visit.n * step) 371 else: 372 visit = vc_num (seq_start + (visit.n * step)) 373 374 visit = ((split, ((split, visit), ) + restrs), tag) 375 return visit 376 377def split_visit_visits (tags, split, restrs, visit): 378 (ltag, rtag) = tags 379 (l_details, r_details, eqs, _, _) = split 380 381 l_visit = split_visit_one_visit (ltag, l_details, restrs, visit) 382 r_visit = split_visit_one_visit (rtag, r_details, restrs, visit) 383 384 return (l_visit, r_visit) 385 386def split_hyps_at_visit (tags, split, restrs, visit): 387 (l_details, r_details, eqs, _, _) = split 388 (l_split, (l_seq_start, l_step), l_eqs) = l_details 389 (r_split, (r_seq_start, r_step), r_eqs) = r_details 390 391 (l_visit, r_visit) = split_visit_visits (tags, split, restrs, visit) 392 (l_start, r_start) = split_visit_visits (tags, split, restrs, vc_num (0)) 393 (l_tag, r_tag) = tags 394 395 def mksub (v): 396 return lambda exp: logic.var_subst (exp, {('%i', word32T) : v}, 397 must_subst = False) 398 def inst (exp): 399 return logic.inst_eq_at_visit (exp, visit) 400 zsub = mksub (mk_word32 (0)) 401 if visit.kind == 'Number': 402 lsub = mksub (mk_word32 (visit.n)) 403 else: 404 lsub = mksub (mk_plus (mk_var ('%n', word32T), 405 mk_word32 (visit.n))) 406 407 hyps = [(Hyp ('PCImp', l_visit, r_visit), 'pc imp'), 408 (Hyp ('PCImp', l_visit, l_start), '%s pc imp' % l_tag), 409 (Hyp ('PCImp', r_visit, r_start), '%s pc imp' % r_tag)] 410 hyps += [(eq_hyp ((zsub (l_exp), l_start), (lsub (l_exp), l_visit), 411 (l_split, r_split)), '%s const' % l_tag) 412 for l_exp in l_eqs if inst (l_exp)] 413 hyps += [(eq_hyp ((zsub (r_exp), r_start), (lsub (r_exp), r_visit), 414 (l_split, r_split)), '%s const' % r_tag) 415 for r_exp in r_eqs if inst (r_exp)] 416 hyps += [(eq_hyp ((lsub (l_exp), l_visit), (lsub (r_exp), r_visit), 417 (l_split, r_split)), 'eq') 418 for (l_exp, r_exp) in eqs 419 if inst (l_exp) and inst (r_exp)] 420 return hyps 421 422def split_loop_hyps (tags, split, restrs, exit): 423 ((r_split, _, _), _, _, n, _) = split 424 (l_visit, _) = split_visit_visits (tags, split, restrs, vc_offs (n - 1)) 425 (l_cont, _) = split_visit_visits (tags, split, restrs, vc_offs (n)) 426 (l_tag, r_tag) = tags 427 428 l_enter = pc_true_hyp (l_visit) 429 l_exit = pc_false_hyp (l_cont) 430 if exit: 431 hyps = [l_enter, l_exit] 432 else: 433 hyps = [l_enter] 434 return hyps + [hyp for offs in map (vc_offs, range (n)) 435 for (hyp, _) in split_hyps_at_visit (tags, split, restrs, offs)] 436 437def loops_to_split (p, restrs): 438 loop_heads_with_split = set ([p.loop_id (n) 439 for (n, visit_set) in restrs]) 440 rem_loop_heads = set (p.loop_heads ()) - loop_heads_with_split 441 for (n, visit_set) in restrs: 442 if not visit_set.has_zero (): 443 # n must be visited, so loop heads must be 444 # reachable from n (or on another tag) 445 rem_loop_heads = [lh for lh in rem_loop_heads 446 if p.is_reachable_from (n, lh) 447 or p.node_tags[n][0] != p.node_tags[lh][0]] 448 return rem_loop_heads 449 450def restr_others (p, restrs, n): 451 extras = [(sp, vc_upto (n)) for sp in loops_to_split (p, restrs)] 452 return restrs + tuple (extras) 453 454def non_r_err_pc_hyp (tags, restrs): 455 return pc_false_hyp ((('Err', restrs), tags[1])) 456 457def split_r_err_pc_hyp (p, split, restrs, tags = None): 458 (_, r_details, _, n, loop_r_max) = split 459 (r_split, (r_seq_start, r_step), r_eqs) = r_details 460 461 nc = n * r_step 462 vc = vc_double_range (r_seq_start + nc, loop_r_max + 2) 463 464 restrs = restr_others (p, ((r_split, vc), ) + restrs, 2) 465 466 if tags == None: 467 tags = p.pairing.tags 468 469 return non_r_err_pc_hyp (tags, restrs) 470 471restr_bump = 0 472 473def get_proof_restr (n, (kind, (x, y))): 474 return (n, mk_vc_opts ([VisitCount (kind, i) 475 for i in range (x, y + restr_bump)])) 476 477def restr_trivial_hyp (p, n, (kind, (x, y)), restrs): 478 restr = (n, VisitCount (kind, y - 1)) 479 return rep_graph.pc_triv_hyp (((n, (restr, ) + restrs), 480 p.node_tags[n][0])) 481 482def proof_restr_checks (n, (kind, (x, y)), p, restrs, hyps): 483 restr = get_proof_restr (n, (kind, (x, y))) 484 ncerr_hyp = non_r_err_pc_hyp (p.pairing.tags, 485 restr_others (p, (restr, ) + restrs, 2)) 486 hyps = [ncerr_hyp] + hyps 487 def visit (vc): 488 return ((n, ((n, vc), ) + restrs), p.node_tags[n][0]) 489 490 # this cannot be more uniform because the representation of visit 491 # at offset 0 is all a bit odd, with n being the only node so visited: 492 if kind == 'Offset': 493 min_vc = vc_offs (max (0, x - 1)) 494 elif x > 1: 495 min_vc = vc_num (x - 1) 496 else: 497 min_vc = None 498 if min_vc: 499 init_check = [(hyps, pc_true_hyp (visit (min_vc)), 500 'Check of restr min %d %s for %d' % (x, kind, n))] 501 else: 502 init_check = [] 503 504 # if we can reach node n with (y - 1) visits to n, then the next 505 # node will have y visits to n, which we are disallowing 506 # thus we show that this visit is impossible 507 top_vc = VisitCount (kind, y - 1) 508 top_check = (hyps, pc_false_hyp (visit (top_vc)), 509 'Check of restr max %d %s for %d' % (y, kind, n)) 510 return init_check + [top_check] 511 512def split_init_step_checks (p, restrs, hyps, split, tags = None): 513 (_, _, _, n, _) = split 514 if tags == None: 515 tags = p.pairing.tags 516 517 err_hyp = split_r_err_pc_hyp (p, split, restrs, tags = tags) 518 hyps = [err_hyp] + hyps 519 checks = [] 520 for i in range (n): 521 (l_visit, r_visit) = split_visit_visits (tags, split, 522 restrs, vc_num (i)) 523 lpc_hyp = pc_true_hyp (l_visit) 524 # this trivial 'hyp' ensures the rep is built to include 525 # the matching rhs visits when checking lhs consts 526 rpc_triv_hyp = rep_graph.pc_triv_hyp (r_visit) 527 vis_hyps = split_hyps_at_visit (tags, split, restrs, vc_num (i)) 528 529 for (hyp, desc) in vis_hyps: 530 checks.append ((hyps + [lpc_hyp, rpc_triv_hyp], hyp, 531 'Induct check at visit %d: %s' % (i, desc))) 532 return checks 533 534def split_induct_step_checks (p, restrs, hyps, split, tags = None): 535 ((l_split, _, _), _, _, n, _) = split 536 if tags == None: 537 tags = p.pairing.tags 538 539 err_hyp = split_r_err_pc_hyp (p, split, restrs, tags = tags) 540 (cont, r_cont) = split_visit_visits (tags, split, restrs, vc_offs (n)) 541 # the 'trivial' hyp here ensures the representation includes a loop 542 # of the rhs when proving const equations on the lhs 543 hyps = ([err_hyp, pc_true_hyp (cont), 544 rep_graph.pc_triv_hyp (r_cont)] + hyps 545 + split_loop_hyps (tags, split, restrs, exit = False)) 546 547 return [(hyps, hyp, 'Induct check (%s) at inductive step for %d' 548 % (desc, l_split)) 549 for (hyp, desc) in split_hyps_at_visit (tags, split, 550 restrs, vc_offs (n))] 551 552def check_split_induct_step_group (rep, restrs, hyps, split, tags = None): 553 checks = split_induct_step_checks (rep.p, restrs, hyps, split, 554 tags = tags) 555 groups = proof_check_groups (checks) 556 for group in groups: 557 (verdict, _) = test_hyp_group (rep, group) 558 if not verdict: 559 return False 560 return True 561 562def split_checks (p, restrs, hyps, split, tags = None): 563 return (split_init_step_checks (p, restrs, hyps, split, tags = tags) 564 + split_induct_step_checks (p, restrs, hyps, split, tags = tags)) 565 566def loop_eq_hyps_at_visit (tag, split, eqs, restrs, visit_num, 567 use_if_at = False): 568 details = (split, (0, 1), eqs) 569 visit = split_visit_one_visit (tag, details, restrs, visit_num) 570 start = split_visit_one_visit (tag, details, restrs, vc_num (0)) 571 572 def mksub (v): 573 return lambda exp: logic.var_subst (exp, {('%i', word32T) : v}, 574 must_subst = False) 575 zsub = mksub (mk_word32 (0)) 576 if visit_num.kind == 'Number': 577 isub = mksub (mk_word32 (visit_num.n)) 578 else: 579 isub = mksub (mk_plus (mk_var ('%n', word32T), 580 mk_word32 (visit_num.n))) 581 582 hyps = [(Hyp ('PCImp', visit, start), '%s pc imp' % tag)] 583 hyps += [(eq_hyp ((zsub (exp), start), (isub (exp), visit), 584 (split, 0), use_if_at = use_if_at), '%s const' % tag) 585 for exp in eqs if logic.inst_eq_at_visit (exp, visit_num)] 586 587 return hyps 588 589def single_induct_resulting_hyp (p, restrs, rev_induct_args): 590 (point, _, (pred, _)) = rev_induct_args 591 (tag, _) = p.node_tags[point] 592 vis = ((point, restrs + tuple ([(point, vc_num (0))])), tag) 593 return rep_graph.true_if_at_hyp (pred, vis) 594 595def single_loop_induct_base_checks (p, restrs, hyps, tag, split, n, eqs): 596 tests = [] 597 details = (split, (0, 1), eqs) 598 for i in range (n + 1): 599 reach = split_visit_one_visit (tag, details, restrs, vc_num (i)) 600 nhyps = [pc_true_hyp (reach)] 601 tests.extend ([(hyps + nhyps, hyp, 602 'Base check (%s, %d) at induct step for %d' 603 % (desc, i, split)) 604 for (hyp, desc) in loop_eq_hyps_at_visit (tag, split, 605 eqs, restrs, vc_num (i))]) 606 return tests 607 608def single_loop_induct_step_checks (p, restrs, hyps, tag, split, n, 609 eqs, eqs_assume = None): 610 if eqs_assume == None: 611 eqs_assume = [] 612 details = (split, (0, 1), eqs_assume + eqs) 613 cont = split_visit_one_visit (tag, details, restrs, vc_offs (n)) 614 hyps = ([pc_true_hyp (cont)] + hyps 615 + [h for i in range (n) 616 for (h, _) in loop_eq_hyps_at_visit (tag, split, 617 eqs_assume + eqs, restrs, vc_offs (i))]) 618 619 return [(hyps, hyp, 'Induct check (%s) at inductive step for %d' 620 % (desc, split)) 621 for (hyp, desc) in loop_eq_hyps_at_visit (tag, split, eqs, 622 restrs, vc_offs (n))] 623 624def mk_loop_counter_eq_hyp (p, split, restrs, n): 625 details = (split, (0, 1), []) 626 (tag, _) = p.node_tags[split] 627 visit = split_visit_one_visit (tag, details, restrs, vc_offs (0)) 628 return eq_hyp ((mk_var ('%n', word32T), visit), 629 (mk_word32 (n), visit), (split, 0)) 630 631def single_loop_rev_induct_base_checks (p, restrs, hyps, tag, split, 632 n_bound, eqs_assume, pred): 633 details = (split, (0, 1), eqs_assume) 634 cont = split_visit_one_visit (tag, details, restrs, vc_offs (1)) 635 n_hyp = mk_loop_counter_eq_hyp (p, split, restrs, n_bound) 636 637 split_details = (None, details, None, 1, 1) 638 non_err = split_r_err_pc_hyp (p, split_details, restrs) 639 640 hyps = (hyps + [n_hyp, pc_true_hyp (cont), non_err] 641 + [h for (h, _) in loop_eq_hyps_at_visit (tag, 642 split, eqs_assume, restrs, vc_offs (0))]) 643 goal = rep_graph.true_if_at_hyp (pred, cont) 644 645 return [(hyps, goal, 'Pred true at %d check.' % n_bound)] 646 647def single_loop_rev_induct_checks (p, restrs, hyps, tag, split, 648 eqs_assume, pred): 649 details = (split, (0, 1), eqs_assume) 650 curr = split_visit_one_visit (tag, details, restrs, vc_offs (1)) 651 cont = split_visit_one_visit (tag, details, restrs, vc_offs (2)) 652 653 split_details = (None, details, None, 1, 1) 654 non_err = split_r_err_pc_hyp (p, split_details, restrs) 655 true_next = rep_graph.true_if_at_hyp (pred, cont) 656 657 hyps = (hyps + [pc_true_hyp (curr), true_next, non_err] 658 + [h for (h, _) in loop_eq_hyps_at_visit (tag, split, 659 eqs_assume, restrs, vc_offs (1), use_if_at = True)]) 660 goal = rep_graph.true_if_at_hyp (pred, curr) 661 662 return [(hyps, goal, 'Pred reverse step.')] 663 664def all_rev_induct_checks (p, restrs, hyps, point, (eqs, n), (pred, n_bound)): 665 (tag, _) = p.node_tags[point] 666 checks = (single_loop_induct_step_checks (p, restrs, hyps, tag, 667 point, n, eqs) 668 + single_loop_induct_base_checks (p, restrs, hyps, tag, 669 point, n, eqs) 670 + single_loop_rev_induct_checks (p, restrs, hyps, tag, 671 point, eqs, pred) 672 + single_loop_rev_induct_base_checks (p, restrs, hyps, 673 tag, point, n_bound, eqs, pred)) 674 return checks 675 676def leaf_condition_checks (p, restrs, hyps): 677 '''checks of the final refinement conditions''' 678 nrerr_pc_hyp = non_r_err_pc_hyp (p.pairing.tags, restrs) 679 hyps = [nrerr_pc_hyp] + hyps 680 [l_tag, r_tag] = p.pairing.tags 681 682 nlerr_pc = pc_false_hyp ((('Err', restrs), l_tag)) 683 # this 'hypothesis' ensures that the representation is built all 684 # the way to Ret. in particular this ensures that function relations 685 # are available to use in proving single-side equalities 686 ret_eq = eq_hyp ((true_term, (('Ret', restrs), l_tag)), 687 (true_term, (('Ret', restrs), r_tag))) 688 689 ### TODO: previously we considered the case where 'Ret' was unreachable 690 ### (as a result of unsatisfiable hyps) and proved a simpler property. 691 ### we might want to restore this 692 (_, out_eqs) = p.pairing.eqs 693 checks = [(hyps + [nlerr_pc, ret_eq], hyp, 'Leaf eq check') for hyp in 694 inst_eqs (p, restrs, out_eqs)] 695 return [(hyps + [ret_eq], nlerr_pc, 'Leaf path-cond imp')] + checks 696 697def proof_checks (p, proof): 698 return proof_checks_rec (p, (), init_point_hyps (p), proof, 'root') 699 700def proof_checks_imm (p, restrs, hyps, proof, path): 701 if proof.kind == 'Restr': 702 checks = proof_restr_checks (proof.point, proof.restr_range, 703 p, restrs, hyps) 704 elif proof.kind == 'SingleRevInduct': 705 checks = all_rev_induct_checks (p, restrs, hyps, proof.point, 706 proof.eqs_proof, proof.rev_proof) 707 elif proof.kind == 'Split': 708 checks = split_checks (p, restrs, hyps, proof.split) 709 elif proof.kind == 'Leaf': 710 checks = leaf_condition_checks (p, restrs, hyps) 711 elif proof.kind == 'CaseSplit': 712 checks = [] 713 714 return [(hs, hyp, '%s on %s' % (name, path)) 715 for (hs, hyp, name) in checks] 716 717def proof_checks_rec (p, restrs, hyps, proof, path): 718 checks = proof_checks_imm (p, restrs, hyps, proof, path) 719 720 subproblems = proof_subproblems (p, proof.kind, 721 proof.args, restrs, hyps, path) 722 for (subprob, subproof) in logic.azip (subproblems, proof.subproofs): 723 (restrs, hyps, path) = subprob 724 checks.extend (proof_checks_rec (p, restrs, hyps, subproof, path)) 725 return checks 726 727last_failed_check = [None] 728 729def proof_check_groups (checks): 730 groups = {} 731 for (hyps, hyp, name) in checks: 732 n_vcs = set ([n_vc for hyp2 in [hyp] + hyps 733 for n_vc in hyp2.visits ()]) 734 k = (tuple (sorted (list (n_vcs)))) 735 groups.setdefault (k, []).append ((hyps, hyp, name)) 736 return groups.values () 737 738def test_hyp_group (rep, group, detail = None): 739 imps = [(hyps, hyp) for (hyps, hyp, _) in group] 740 names = set ([name for (_, _, name) in group]) 741 742 trace ('Testing group of hyps: %s' % list (names), push = 1) 743 (res, i, res_kind) = rep.test_hyp_imps (imps) 744 trace ('Group result: %r' % res, push = -1) 745 if res: 746 return (res, None) 747 else: 748 if detail: 749 detail[0] = res_kind 750 return (res, group[i]) 751 752def failed_test_sets (p, checks): 753 failed = [] 754 sets = {} 755 for (hyps, hyp, name) in checks: 756 sets.setdefault (name, []) 757 sets[name].append ((hyps, hyp)) 758 for name in sets: 759 rep = rep_graph.mk_graph_slice (p) 760 (res, _, _) = rep.test_hyp_imps (sets[name]) 761 if not res: 762 failed.append (name) 763 return failed 764 765save_checked_proofs = [None] 766 767def check_proof (p, proof, use_rep = None): 768 checks = proof_checks (p, proof) 769 groups = proof_check_groups (checks) 770 771 for group in groups: 772 if use_rep == None: 773 rep = rep_graph.mk_graph_slice (p) 774 else: 775 rep = use_rep 776 777 detail = [0] 778 (verdict, elt) = test_hyp_group (rep, group, detail) 779 if verdict: 780 continue 781 (hyps, hyp, name) = elt 782 last_failed_check[0] = elt 783 trace ('%s: proof failed!' % name) 784 trace (' (failure kind: %r)' % detail[0]) 785 return False 786 if save_checked_proofs[0]: 787 save = save_checked_proofs[0] 788 save (p, proof) 789 return True 790 791def pretty_vseq ((split, (seq_start, seq_step), _)): 792 if (seq_start, seq_step) == (0, 1): 793 return 'visits to %d' % split 794 else: 795 i = seq_start + 1 796 j = i + seq_step 797 k = j + seq_step 798 return 'visits [%d, %d, %d ...] to %d' % (i, j, k, split) 799 800def next_induct_var (n): 801 s = 'ijkabc' 802 v = s[n % 6] 803 if n >= 6: 804 v += str ((n / 6) + 1) 805 return v 806 807def pretty_lambda (t): 808 v = syntax.mk_var ('#seq-visits', word32T) 809 t = logic.var_subst (t, {('%i', word32T) : v}, must_subst = False) 810 return syntax.pretty_expr (t, print_type = True) 811 812def check_proof_report_rec (p, restrs, hyps, proof, step_num, ctxt, inducts, 813 do_check = True): 814 printout ('Step %d: %s' % (step_num, ctxt)) 815 if proof.kind == 'Restr': 816 (kind, (x, y)) = proof.restr_range 817 if kind == 'Offset': 818 v = inducts[1][proof.point] 819 rexpr = '{%s + %s ..< %s + %s}' % (v, x, v, y) 820 else: 821 rexpr = '{%s ..< %s}' % (x, y) 822 printout (' Prove the number of visits to %d is in %s' 823 % (proof.point, rexpr)) 824 825 checks = proof_restr_checks (proof.point, proof.restr_range, 826 p, restrs, hyps) 827 cases = [''] 828 elif proof.kind == 'SingleRevInduct': 829 printout (' Proving a predicate by future induction.') 830 (eqs, n) = proof.eqs_proof 831 point = proof.point 832 printout (' proving these invariants by %d-induction' % n) 833 for x in eqs: 834 printout (' %s (@ addr %s)' 835 % (pretty_lambda (x), point)) 836 printout (' then establishing this predicate') 837 (pred, n_bound) = proof.rev_proof 838 printout (' %s (@ addr %s)' 839 % (pretty_lambda (pred), point)) 840 printout (' at large iterations (%d) and by back induction.' 841 % n_bound) 842 cases = [''] 843 checks = all_rev_induct_checks (p, restrs, hyps, point, 844 proof.eqs_proof, proof.rev_proof) 845 elif proof.kind == 'Split': 846 (l_dts, r_dts, eqs, n, lrmx) = proof.split 847 v = next_induct_var (inducts[0]) 848 inducts = (inducts[0] + 1, dict (inducts[1])) 849 inducts[1][l_dts[0]] = v 850 inducts[1][r_dts[0]] = v 851 printout (' prove %s related to %s' % (pretty_vseq (l_dts), 852 pretty_vseq (r_dts))) 853 printout (' with equalities') 854 for (x, y) in eqs: 855 printout (' %s (@ addr %s)' % (pretty_lambda (x), 856 l_dts[0])) 857 printout (' = %s (@ addr %s)' % (pretty_lambda (y), 858 r_dts[0])) 859 printout (' and with invariants') 860 for x in l_dts[2]: 861 printout (' %s (@ addr %s)' 862 % (pretty_lambda (x), l_dts[0])) 863 for x in r_dts[2]: 864 printout (' %s (@ addr %s)' 865 % (pretty_lambda (x), r_dts[0])) 866 checks = split_checks (p, restrs, hyps, proof.split) 867 cases = ['case in (%d) where the length of the sequence < %d' 868 % (step_num, n), 869 'case in (%d) where the length of the sequence is %s + %s' 870 % (step_num, v, n)] 871 elif proof.kind == 'Leaf': 872 printout (' prove all verification conditions') 873 checks = leaf_condition_checks (p, restrs, hyps) 874 cases = [] 875 elif proof.kind == 'CaseSplit': 876 printout (' case split on whether %d is visited' % proof.point) 877 checks = [] 878 cases = ['case in (%d) where %d is visited' % (step_num, proof.point), 879 'case in (%d) where %d is not visited' % (step_num, proof.point)] 880 881 if checks and do_check: 882 groups = proof_check_groups (checks) 883 for group in groups: 884 rep = rep_graph.mk_graph_slice (p) 885 detail = [0] 886 (res, _) = test_hyp_group (rep, group, detail) 887 if not res: 888 printout (' .. failed to prove this.') 889 printout (' (failure kind: %r)' % detail[0]) 890 return 891 892 printout (' .. proven.') 893 894 subproblems = proof_subproblems (p, proof.kind, 895 proof.args, restrs, hyps, '') 896 xs = logic.azip (subproblems, proof.subproofs) 897 xs = logic.azip (xs, cases) 898 step_num += 1 899 for ((subprob, subproof), case) in xs: 900 (restrs, hyps, _) = subprob 901 res = check_proof_report_rec (p, restrs, hyps, subproof, 902 step_num, case, inducts, do_check = do_check) 903 if not res: 904 return 905 (step_num, induct_var_num) = res 906 inducts = (induct_var_num, inducts[1]) 907 return (step_num, inducts[0]) 908 909def check_proof_report (p, proof, do_check = True): 910 res = check_proof_report_rec (p, (), init_point_hyps (p), proof, 911 1, '', (0, {}), do_check = do_check) 912 return bool (res) 913 914def save_proofs_to_file (fname, mode = 'w'): 915 assert mode in ['w', 'a'] 916 f = open (fname, mode) 917 918 def save (p, proof): 919 f.write ('ProblemProof (%s) {\n' % p.name) 920 for s in p.serialise (): 921 f.write (s + '\n') 922 ss = [] 923 proof.serialise (p, ss) 924 f.write (' '.join (ss)) 925 f.write ('\n}\n') 926 f.flush () 927 return save 928 929def load_proofs_from_file (fname): 930 f = open (fname) 931 932 proofs = {} 933 lines = None 934 for line in f: 935 line = line.strip () 936 if line.startswith ('ProblemProof'): 937 assert line.endswith ('{'), line 938 name_bit = line[len ('ProblemProof') : -1].strip () 939 assert name_bit.startswith ('('), name_bit 940 assert name_bit.endswith (')'), name_bit 941 name = name_bit[1:-1] 942 lines = [] 943 elif line == '}': 944 assert lines[0] == 'Problem' 945 assert lines[-2] == 'EndProblem' 946 import problem 947 trace ('loading proof from %d lines' % len (lines)) 948 p = problem.deserialise (name, lines[:-1]) 949 proof = deserialise (lines[-1]) 950 proofs.setdefault (name, []) 951 proofs[name].append ((p, proof)) 952 trace ('loaded proof %s' % name) 953 lines = None 954 elif line.startswith ('#'): 955 pass 956 elif line: 957 lines.append (line) 958 assert not lines 959 return proofs 960 961