1# * Copyright 2015, NICTA
2# *
3# * This software may be distributed and modified according to the terms of
4# * the BSD 2-Clause license. Note that NO WARRANTY is provided.
5# * See "LICENSE_BSD2.txt" for details.
6# *
7# * @TAG(NICTA_BSD)
8
9# proof scripts and check process
10
11from rep_graph import mk_graph_slice, Hyp, eq_hyp, pc_true_hyp, pc_false_hyp
12import rep_graph
13from problem import Problem, inline_at_point
14import problem
15
16from solver import to_smt_expr
17from target_objects import functions, pairings, trace, printout
18import target_objects
19from rep_graph import (vc_num, vc_offs, vc_double_range, vc_upto, mk_vc_opts,
20	VisitCount)
21import logic
22
23from syntax import (true_term, false_term, boolT, mk_var, mk_word32, mk_word8,
24	mk_plus, mk_minus, word32T, word8T, mk_and, mk_eq, mk_implies, mk_not,
25	rename_expr)
26import syntax
27
28def build_problem (pairing, force_inline = None, avoid_abort = False):
29	p = Problem (pairing)
30
31	for (tag, fname) in pairing.funs.items ():
32		p.add_entry_function (functions[fname], tag)
33
34	p.do_analysis ()
35
36	# FIXME: the inlining is heuristic, and arguably belongs in 'search'
37	inline_completely_unmatched (p, skip_underspec = avoid_abort)
38
39	# now do any C inlining
40	inline_reachable_unmatched_C (p, force_inline,
41		skip_underspec = avoid_abort)
42
43	trace ('Done inlining.')
44
45	p.pad_merge_points ()
46	p.do_analysis ()
47
48	if not avoid_abort:
49		p.check_no_inner_loops ()
50
51	return p
52
53def inline_completely_unmatched (p, ref_tags = None, skip_underspec = False):
54	if ref_tags == None:
55		ref_tags = p.pairing.tags
56	while True:
57		ns = [(n, skip_underspec
58                                and not functions[p.nodes[n].fname].entry)
59			for n in p.nodes
60			if p.nodes[n].kind == 'Call'
61			if not [pair for pair
62				in pairings.get (p.nodes[n].fname, [])
63				if pair.tags == ref_tags]]
64		[trace ('Skipped inlining underspecified %s.'
65			% p.nodes[n].fname) for (n, skip) in ns if skip]
66		ns = [n for (n, skip) in ns if not skip]
67		for n in ns:
68			trace ('Function %s at %d - %s - completely unmatched.'
69				% (p.nodes[n].fname, n, p.node_tags[n][0]))
70			inline_at_point (p, n, do_analysis = False)
71		if not ns:
72			p.do_analysis ()
73			return
74
75def inline_reachable_unmatched_C (p, force_inline = None,
76		skip_underspec = False):
77	if 'C' not in p.pairing.tags:
78		return
79	[compare_tag] = [tag for tag in p.pairing.tags if tag != 'C']
80	inline_reachable_unmatched (p, 'C', compare_tag, force_inline,
81		skip_underspec = skip_underspec)
82
83def inline_reachable_unmatched (p, inline_tag, compare_tag,
84		force_inline = None, skip_underspec = False):
85	funs = [pair.funs[inline_tag]
86		for n in p.nodes
87		if p.nodes[n].kind == 'Call'
88		if p.node_tags[n][0] == compare_tag
89		for pair in pairings.get (p.nodes[n].fname, [])
90		if inline_tag in pair.tags]
91
92	rep = mk_graph_slice (p,
93		consider_inline (funs, inline_tag, force_inline,
94			skip_underspec))
95	opts = vc_double_range (3, 3)
96	while True:
97		try:
98			heads = problem.loop_heads_including_inner (p)
99			limits = [(n, opts) for n in heads]
100
101			for n in p.nodes.keys ():
102				try:
103					r = rep.get_node_pc_env ((n, limits))
104				except rep.TooGeneral:
105					pass
106
107			rep.get_node_pc_env (('Ret', limits), inline_tag)
108			rep.get_node_pc_env (('Err', limits), inline_tag)
109			break
110		except rep_graph.InlineEvent:
111			continue
112
113def consider_inline1 (p, n, matched_funs, inline_tag,
114		force_inline, skip_underspec):
115	node = p.nodes[n]
116	assert node.kind == 'Call'
117
118	if p.node_tags[n][0] != inline_tag:
119		return False
120
121	f_nm = node.fname
122	if skip_underspec and not functions[f_nm].entry:
123		trace ('Skipping inlining underspecified %s' % f_nm)
124		return False
125	if f_nm not in matched_funs or (force_inline and force_inline (f_nm)):
126		return lambda: inline_at_point (p, n)
127	else:
128		return False
129
130def consider_inline (matched_funs, tag, force_inline, skip_underspec = False):
131	return lambda (p, n): consider_inline1 (p, n, matched_funs, tag,
132		force_inline, skip_underspec)
133
134def inst_eqs (p, restrs, eqs, tag_map = {}):
135	addr_map = {}
136	if not tag_map:
137		tag_map = dict ([(tag, tag) for tag in p.tags ()])
138	for (pair_tag, p_tag) in tag_map.iteritems ():
139		addr_map[pair_tag + '_IN'] = ((p.get_entry (p_tag), ()), p_tag)
140		addr_map[pair_tag + '_OUT'] = (('Ret', restrs), p_tag)
141	renames = p.entry_exit_renames (tag_map.values ())
142	for (pair_tag, p_tag) in tag_map.iteritems ():
143		renames[pair_tag + '_IN'] = renames[p_tag + '_IN']
144		renames[pair_tag + '_OUT'] = renames[p_tag + '_OUT']
145	hyps = []
146	for (lhs, rhs) in eqs:
147		vals = [(rename_expr (x, renames[x_addr]), addr_map[x_addr])
148			for (x, x_addr) in (lhs, rhs)]
149		hyps.append (eq_hyp (vals[0], vals[1]))
150	return hyps
151
152def init_point_hyps (p):
153	(inp_eqs, _) = p.pairing.eqs
154	return inst_eqs (p, (), inp_eqs)
155
156class ProofNode:
157	def __init__ (self, kind, args = None, subproofs = []):
158		self.kind = kind
159		self.args = args
160		self.subproofs = tuple (subproofs)
161		if self.kind == 'Leaf':
162			assert args == None
163			assert list (subproofs) == []
164		elif self.kind == 'Restr':
165			(self.point, self.restr_range) = args
166			assert len (subproofs) == 1
167		elif self.kind == 'SingleRevInduct':
168			(self.point, self.eqs_proof, self.rev_proof) = args
169			assert len (subproofs) == 1
170		elif self.kind == 'Split':
171			self.split = args
172			(l_details, r_details, eqs, n, loop_r_max) = args
173			assert len (subproofs) == 2
174		elif self.kind == 'CaseSplit':
175			(self.point, self.tag) = args
176			assert len (subproofs) == 2
177		else:
178			assert not 'proof node kind understood', kind
179
180	def __repr__ (self):
181		return 'ProofNode (%r, %r, %r)' % (self.kind,
182			self.args, self.subproofs)
183
184	def serialise (self, p, ss):
185		if self.kind == 'Leaf':
186			ss.append ('Leaf')
187		elif self.kind == 'Restr':
188			(kind, (x, y)) = self.restr_range
189			tag = p.node_tags[self.point][0]
190			ss.extend (['Restr', '%d' % self.point,
191				tag, kind, '%d' % x, '%d' % y])
192		elif self.kind == 'SingleRevInduct':
193			tag = p.node_tags[self.point][0]
194			(eqs, n) = self.eqs_proof
195			ss.extend (['SingleRevInduct', '%d' % self.point,
196				tag, '%d' % n, '%d' % len (eqs)])
197			for (x, y) in eqs:
198				serialise_lambda (x, ss)
199				serialise_lambda (y, ss)
200			(pred, n_bound) = self.rev_proof
201			pred.serialise (ss)
202			ss.append ('%d' % n_bound)
203		elif self.kind == 'Split':
204			(l_details, r_details, eqs, n, loop_r_max) = self.args
205			ss.extend (['Split', '%d' % n, '%d' % loop_r_max])
206			serialise_details (l_details, ss)
207			serialise_details (r_details, ss)
208			ss.append ('%d' % len (eqs))
209			for (x, y) in eqs:
210				serialise_lambda (x, ss)
211				serialise_lambda (y, ss)
212		elif self.kind == 'CaseSplit':
213			ss.extend (['CaseSplit', '%d' % self.point, self.tag])
214		else:
215			assert not 'proof node kind understood'
216		for proof in self.subproofs:
217			proof.serialise (p, ss)
218
219	def all_subproofs (self):
220		return [self] + [proof for proof1 in self.subproofs
221			for proof in proof1.all_subproofs ()]
222
223	def all_subproblems (self, p, restrs, hyps, name):
224		subproblems = proof_subproblems (p, self.kind,
225			self.args, restrs, hyps, name)
226		subproofs = logic.azip (subproblems, self.subproofs)
227		return [(self, restrs, hyps)] + [problem
228			for ((restrs2, hyps2, name2), proof) in subproofs
229			for problem in proof.all_subproblems (p, restrs2,
230				hyps2, name2)]
231
232	def save_serialise (self, p, fname):
233		f = open (fname, 'w')
234		ss = []
235		self.serialise (p, ss)
236		f.write (' '.join (ss) + '\n')
237		f.close ()
238
239	def __hash__ (self):
240		return syntax.hash_tuplify (self.kind, self.args,
241			self.subproofs)
242
243def serialise_details (details, ss):
244	(split, (seq_start, step), eqs) = details
245	ss.extend (['%d' % split, '%d' % seq_start, '%d' % step])
246	ss.append ('%d' % len (eqs))
247	for eq in eqs:
248		serialise_lambda (eq, ss)
249
250def serialise_lambda (eq_term, ss):
251	ss.extend (['Lambda', '%i'])
252	word32T.serialise (ss)
253	eq_term.serialise (ss)
254
255def deserialise_details (ss, i):
256	(split, seq_start, step) = [int (x) for x in ss[i : i + 3]]
257	(i, eqs) = syntax.parse_list (deserialise_lambda, ss, i + 3)
258	return (i, (split, (seq_start, step), eqs))
259
260def deserialise_lambda (ss, i):
261	assert ss[i : i + 2] == ['Lambda', '%i'], (ss, i)
262	(i, typ) = syntax.parse_typ (ss, i + 2)
263	assert typ == word32T, typ
264	(i, eq_term) = syntax.parse_expr (ss, i)
265	return (i, eq_term)
266
267def deserialise_double_lambda (ss, i):
268	(i, x) = deserialise_lambda (ss, i)
269	(i, y) = deserialise_lambda (ss, i)
270	return (i, (x, y))
271
272def deserialise_inner (ss, i):
273	if ss[i] == 'Leaf':
274		return (i + 1, ProofNode ('Leaf'))
275	elif ss[i] == 'Restr':
276		point = int (ss[i + 1])
277		tag = ss[i + 2]
278		kind = ss[i + 3]
279		assert kind in ['Number', 'Offset'], (kind, i)
280		x = int (ss[i + 4])
281		y = int (ss[i + 5])
282		(i, p1) = deserialise_inner (ss, i + 6)
283		return (i, ProofNode ('Restr', (point, (kind, (x, y))), [p1]))
284	elif ss[i] == 'SingleRevInduct':
285		point = int (ss[i + 1])
286		tag = ss[i + 2]
287		n = int (ss[i + 3])
288		(i, eqs) = syntax.parse_list (deserialise_double_lambda, ss, i + 4)
289		(i, pred) = syntax.parse_term (ss, i)
290		n_bound = int (ss[i])
291		(i, p1) = deserialise_inner (ss, i + 1)
292		return (i, ProofNode ('SingleRevInduct', (point, (eqs, n),
293			(pred, n_bound)), [p1]))
294	elif ss[i] == 'Split':
295		n = int (ss[i + 1])
296		loop_r_max = int (ss[i + 2])
297		(i, l_details) = deserialise_details (ss, i + 3)
298		(i, r_details) = deserialise_details (ss, i)
299		(i, eqs) = syntax.parse_list (deserialise_double_lambda, ss, i)
300		(i, p1) = deserialise_inner (ss, i)
301		(i, p2) = deserialise_inner (ss, i)
302		return (i, ProofNode ('Split', (l_details, r_details, eqs,
303			n, loop_r_max), [p1, p2]))
304	elif ss[i] == 'CaseSplit':
305		n = int (ss[i + 1])
306		tag = ss[i + 2]
307		(i, p1) = deserialise_inner (ss, i + 3)
308		(i, p2) = deserialise_inner (ss, i)
309		return (i, ProofNode ('CaseSplit', (n, tag), [p1, p2]))
310	else:
311		assert not 'proof node type understood', (ss, i)
312
313def deserialise (line):
314	ss = line.split ()
315	(i, proof) = deserialise_inner (ss, 0)
316	assert i == len (ss), (ss, i)
317	return proof
318
319def proof_subproblems (p, kind, args, restrs, hyps, path):
320	tags = p.pairing.tags
321	if kind == 'Leaf':
322		return []
323	elif kind == 'Restr':
324		restr = get_proof_restr (args[0], args[1])
325		hyps = hyps + [restr_trivial_hyp (p, args[0], args[1], restrs)]
326		return [((restr,) + restrs, hyps,
327			'%s (%d limited)' % (path, args[0]))]
328	elif kind == 'SingleRevInduct':
329		hyp = single_induct_resulting_hyp (p, restrs, args)
330		return [(restrs, hyps + [hyp], path)]
331	elif kind == 'Split':
332		split = args
333		return [(restrs, hyps + split_no_loop_hyps (tags, split, restrs),
334			'%d init case in %s' % (split[0][0], path)),
335		(restrs, hyps + split_loop_hyps (tags, split, restrs, exit = True),
336			'%d loop case in %s' % (split[0][0], path))]
337	elif kind == 'CaseSplit':
338		(point, tag) = args
339		visit = ((point, restrs), tag)
340		true_hyps = hyps + [pc_true_hyp (visit)]
341		false_hyps = hyps + [pc_false_hyp (visit)]
342		return [(restrs, true_hyps,
343			'true case (%d visited) in %s' % (point, path)),
344		(restrs, false_hyps,
345			'false case (%d not visited) in %s' % (point, path))]
346	else:
347		assert not 'proof node kind understood', proof.kind
348
349
350def split_heads ((l_details, r_details, eqs, n, _)):
351	(l_split, _, _) = l_details
352	(r_split, _, _) = r_details
353	return [l_split, r_split]
354
355def split_no_loop_hyps (tags, split, restrs):
356	((_, (l_seq_start, l_step), _), _, _, n, _) = split
357
358	(l_visit, _) = split_visit_visits (tags, split, restrs, vc_num (n))
359
360	return [pc_false_hyp (l_visit)]
361
362def split_visit_one_visit (tag, details, restrs, visit):
363	if details == None:
364		return None
365	(split, (seq_start, step), eqs) = details
366
367	# the split point sequence at low numbers ('Number') is offset
368	# by the point the sequence starts. At symbolic offsets we ignore
369	# that, instead having the loop counter for the two sequences
370	# be the same number of iterations after the sequence start.
371	if visit.kind == 'Offset':
372		visit = vc_offs (visit.n * step)
373	else:
374		visit = vc_num (seq_start + (visit.n * step))
375
376	visit = ((split, ((split, visit), ) + restrs), tag)
377	return visit
378
379def split_visit_visits (tags, split, restrs, visit):
380	(ltag, rtag) = tags
381	(l_details, r_details, eqs, _, _) = split
382
383	l_visit = split_visit_one_visit (ltag, l_details, restrs, visit)
384	r_visit = split_visit_one_visit (rtag, r_details, restrs, visit)
385
386	return (l_visit, r_visit)
387
388def split_hyps_at_visit (tags, split, restrs, visit):
389	(l_details, r_details, eqs, _, _) = split
390	(l_split, (l_seq_start, l_step), l_eqs) = l_details
391	(r_split, (r_seq_start, r_step), r_eqs) = r_details
392
393	(l_visit, r_visit) = split_visit_visits (tags, split, restrs, visit)
394	(l_start, r_start) = split_visit_visits (tags, split, restrs, vc_num (0))
395	(l_tag, r_tag) = tags
396
397	def mksub (v):
398		return lambda exp: logic.var_subst (exp, {('%i', word32T) : v},
399			must_subst = False)
400	def inst (exp):
401		return logic.inst_eq_at_visit (exp, visit)
402	zsub = mksub (mk_word32 (0))
403	if visit.kind == 'Number':
404		lsub = mksub (mk_word32 (visit.n))
405	else:
406		lsub = mksub (mk_plus (mk_var ('%n', word32T),
407			mk_word32 (visit.n)))
408
409	hyps = [(Hyp ('PCImp', l_visit, r_visit), 'pc imp'),
410		(Hyp ('PCImp', l_visit, l_start), '%s pc imp' % l_tag),
411		(Hyp ('PCImp', r_visit, r_start), '%s pc imp' % r_tag)]
412	hyps += [(eq_hyp ((zsub (l_exp), l_start), (lsub (l_exp), l_visit),
413				(l_split, r_split)), '%s const' % l_tag)
414			for l_exp in l_eqs if inst (l_exp)]
415	hyps += [(eq_hyp ((zsub (r_exp), r_start), (lsub (r_exp), r_visit),
416				(l_split, r_split)), '%s const' % r_tag)
417			for r_exp in r_eqs if inst (r_exp)]
418	hyps += [(eq_hyp ((lsub (l_exp), l_visit), (lsub (r_exp), r_visit),
419				(l_split, r_split)), 'eq')
420			for (l_exp, r_exp) in eqs
421			if inst (l_exp) and inst (r_exp)]
422	return hyps
423
424def split_loop_hyps (tags, split, restrs, exit):
425	((r_split, _, _), _, _, n, _) = split
426	(l_visit, _) = split_visit_visits (tags, split, restrs, vc_offs (n - 1))
427	(l_cont, _) = split_visit_visits (tags, split, restrs, vc_offs (n))
428	(l_tag, r_tag) = tags
429
430	l_enter = pc_true_hyp (l_visit)
431	l_exit = pc_false_hyp (l_cont)
432	if exit:
433		hyps = [l_enter, l_exit]
434	else:
435		hyps = [l_enter]
436	return hyps + [hyp for offs in map (vc_offs, range (n))
437		for (hyp, _) in split_hyps_at_visit (tags, split, restrs, offs)]
438
439def loops_to_split (p, restrs):
440	loop_heads_with_split = set ([p.loop_id (n)
441		for (n, visit_set) in restrs])
442	rem_loop_heads = set (p.loop_heads ()) - loop_heads_with_split
443	for (n, visit_set) in restrs:
444		if not visit_set.has_zero ():
445			# n must be visited, so loop heads must be
446			# reachable from n (or on another tag)
447			rem_loop_heads = [lh for lh in rem_loop_heads
448				if p.is_reachable_from (n, lh)
449				or p.node_tags[n][0] != p.node_tags[lh][0]]
450	return rem_loop_heads
451
452def restr_others (p, restrs, n):
453	extras = [(sp, vc_upto (n)) for sp in loops_to_split (p, restrs)]
454	return restrs + tuple (extras)
455
456def non_r_err_pc_hyp (tags, restrs):
457	return pc_false_hyp ((('Err', restrs), tags[1]))
458
459def split_r_err_pc_hyp (p, split, restrs, tags = None):
460	(_, r_details, _, n, loop_r_max) = split
461	(r_split, (r_seq_start, r_step), r_eqs) = r_details
462
463	nc = n * r_step
464	vc = vc_double_range (r_seq_start + nc, loop_r_max + 2)
465
466	restrs = restr_others (p, ((r_split, vc), ) + restrs, 2)
467
468	if tags == None:
469		tags = p.pairing.tags
470
471	return non_r_err_pc_hyp (tags, restrs)
472
473restr_bump = 0
474
475def get_proof_restr (n, (kind, (x, y))):
476	return (n, mk_vc_opts ([VisitCount (kind, i)
477		for i in range (x, y + restr_bump)]))
478
479def restr_trivial_hyp (p, n, (kind, (x, y)), restrs):
480	restr = (n, VisitCount (kind, y - 1))
481	return rep_graph.pc_triv_hyp (((n, (restr, ) + restrs),
482		p.node_tags[n][0]))
483
484def proof_restr_checks (n, (kind, (x, y)), p, restrs, hyps):
485	restr = get_proof_restr (n, (kind, (x, y)))
486	ncerr_hyp = non_r_err_pc_hyp (p.pairing.tags,
487		restr_others (p, (restr, ) + restrs, 2))
488	hyps = [ncerr_hyp] + hyps
489	def visit (vc):
490		return ((n, ((n, vc), ) + restrs), p.node_tags[n][0])
491
492	# this cannot be more uniform because the representation of visit
493	# at offset 0 is all a bit odd, with n being the only node so visited:
494	if kind == 'Offset':
495		min_vc = vc_offs (max (0, x - 1))
496	elif x > 1:
497		min_vc = vc_num (x - 1)
498	else:
499		min_vc = None
500	if min_vc:
501		init_check = [(hyps, pc_true_hyp (visit (min_vc)),
502			'Check of restr min %d %s for %d' % (x, kind, n))]
503	else:
504		init_check = []
505
506	# if we can reach node n with (y - 1) visits to n, then the next
507	# node will have y visits to n, which we are disallowing
508	# thus we show that this visit is impossible
509	top_vc = VisitCount (kind, y - 1)
510	top_check = (hyps, pc_false_hyp (visit (top_vc)),
511		'Check of restr max %d %s for %d' % (y, kind, n))
512	return init_check + [top_check]
513
514def split_init_step_checks (p, restrs, hyps, split, tags = None):
515	(_, _, _, n, _) = split
516	if tags == None:
517		tags = p.pairing.tags
518
519	err_hyp = split_r_err_pc_hyp (p, split, restrs, tags = tags)
520	hyps = [err_hyp] + hyps
521	checks = []
522	for i in range (n):
523		(l_visit, r_visit) = split_visit_visits (tags, split,
524			restrs, vc_num (i))
525		lpc_hyp = pc_true_hyp (l_visit)
526		# this trivial 'hyp' ensures the rep is built to include
527		# the matching rhs visits when checking lhs consts
528		rpc_triv_hyp = rep_graph.pc_triv_hyp (r_visit)
529		vis_hyps = split_hyps_at_visit (tags, split, restrs, vc_num (i))
530
531		for (hyp, desc) in vis_hyps:
532			checks.append ((hyps + [lpc_hyp, rpc_triv_hyp], hyp,
533				'Induct check at visit %d: %s' % (i, desc)))
534	return checks
535
536def split_induct_step_checks (p, restrs, hyps, split, tags = None):
537	((l_split, _, _), _, _, n, _) = split
538	if tags == None:
539		tags = p.pairing.tags
540
541	err_hyp = split_r_err_pc_hyp (p, split, restrs, tags = tags)
542	(cont, r_cont) = split_visit_visits (tags, split, restrs, vc_offs (n))
543	# the 'trivial' hyp here ensures the representation includes a loop
544	# of the rhs when proving const equations on the lhs
545	hyps = ([err_hyp, pc_true_hyp (cont),
546			rep_graph.pc_triv_hyp (r_cont)] + hyps
547		+ split_loop_hyps (tags, split, restrs, exit = False))
548
549	return [(hyps, hyp, 'Induct check (%s) at inductive step for %d'
550			% (desc, l_split))
551		for (hyp, desc) in split_hyps_at_visit (tags, split,
552			restrs, vc_offs (n))]
553
554def check_split_induct_step_group (rep, restrs, hyps, split, tags = None):
555	checks = split_induct_step_checks (rep.p, restrs, hyps, split,
556		tags = tags)
557	groups = proof_check_groups (checks)
558	for group in groups:
559		(verdict, _) = test_hyp_group (rep, group)
560		if not verdict:
561			return False
562	return True
563
564def split_checks (p, restrs, hyps, split, tags = None):
565	return (split_init_step_checks (p, restrs, hyps, split, tags = tags)
566		+ split_induct_step_checks (p, restrs, hyps, split, tags = tags))
567
568def loop_eq_hyps_at_visit (tag, split, eqs, restrs, visit_num,
569		use_if_at = False):
570	details = (split, (0, 1), eqs)
571	visit = split_visit_one_visit (tag, details, restrs, visit_num)
572	start = split_visit_one_visit (tag, details, restrs, vc_num (0))
573
574	def mksub (v):
575		return lambda exp: logic.var_subst (exp, {('%i', word32T) : v},
576			must_subst = False)
577	zsub = mksub (mk_word32 (0))
578	if visit_num.kind == 'Number':
579		isub = mksub (mk_word32 (visit_num.n))
580	else:
581		isub = mksub (mk_plus (mk_var ('%n', word32T),
582			mk_word32 (visit_num.n)))
583
584	hyps = [(Hyp ('PCImp', visit, start), '%s pc imp' % tag)]
585	hyps += [(eq_hyp ((zsub (exp), start), (isub (exp), visit),
586			(split, 0), use_if_at = use_if_at), '%s const' % tag)
587		for exp in eqs if logic.inst_eq_at_visit (exp, visit_num)]
588
589	return hyps
590
591def single_induct_resulting_hyp (p, restrs, rev_induct_args):
592	(point, _, (pred, _)) = rev_induct_args
593	(tag, _) = p.node_tags[point]
594	vis = ((point, restrs + tuple ([(point, vc_num (0))])), tag)
595	return rep_graph.true_if_at_hyp (pred, vis)
596
597def single_loop_induct_base_checks (p, restrs, hyps, tag, split, n, eqs):
598	tests = []
599	details = (split, (0, 1), eqs)
600	for i in range (n + 1):
601		reach = split_visit_one_visit (tag, details, restrs, vc_num (i))
602		nhyps = [pc_true_hyp (reach)]
603		tests.extend ([(hyps + nhyps, hyp,
604			'Base check (%s, %d) at induct step for %d'
605				% (desc, i, split))
606			for (hyp, desc) in loop_eq_hyps_at_visit (tag, split,
607				eqs, restrs, vc_num (i))])
608	return tests
609
610def single_loop_induct_step_checks (p, restrs, hyps, tag, split, n,
611				eqs, eqs_assume = None):
612	if eqs_assume == None:
613		eqs_assume = []
614	details = (split, (0, 1), eqs_assume + eqs)
615	cont = split_visit_one_visit (tag, details, restrs, vc_offs (n))
616	hyps = ([pc_true_hyp (cont)] + hyps
617			+ [h for i in range (n)
618		for (h, _) in loop_eq_hyps_at_visit (tag, split,
619					eqs_assume + eqs, restrs, vc_offs (i))])
620
621	return [(hyps, hyp, 'Induct check (%s) at inductive step for %d'
622			% (desc, split))
623		for (hyp, desc) in loop_eq_hyps_at_visit (tag, split, eqs,
624			restrs, vc_offs (n))]
625
626def mk_loop_counter_eq_hyp (p, split, restrs, n):
627	details = (split, (0, 1), [])
628	(tag, _) = p.node_tags[split]
629	visit = split_visit_one_visit (tag, details, restrs, vc_offs (0))
630	return eq_hyp ((mk_var ('%n', word32T), visit),
631		(mk_word32 (n), visit), (split, 0))
632
633def single_loop_rev_induct_base_checks (p, restrs, hyps, tag, split,
634		n_bound, eqs_assume, pred):
635	details = (split, (0, 1), eqs_assume)
636	cont = split_visit_one_visit (tag, details, restrs, vc_offs (1))
637	n_hyp = mk_loop_counter_eq_hyp (p, split, restrs, n_bound)
638
639	split_details = (None, details, None, 1, 1)
640	non_err = split_r_err_pc_hyp (p, split_details, restrs)
641
642	hyps = (hyps + [n_hyp, pc_true_hyp (cont), non_err]
643		+ [h for (h, _) in loop_eq_hyps_at_visit (tag,
644			split, eqs_assume, restrs, vc_offs (0))])
645	goal = rep_graph.true_if_at_hyp (pred, cont)
646
647	return [(hyps, goal, 'Pred true at %d check.' % n_bound)]
648
649def single_loop_rev_induct_checks (p, restrs, hyps, tag, split,
650		eqs_assume, pred):
651	details = (split, (0, 1), eqs_assume)
652	curr = split_visit_one_visit (tag, details, restrs, vc_offs (1))
653	cont = split_visit_one_visit (tag, details, restrs, vc_offs (2))
654
655	split_details = (None, details, None, 1, 1)
656	non_err = split_r_err_pc_hyp (p, split_details, restrs)
657	true_next = rep_graph.true_if_at_hyp (pred, cont)
658
659	hyps = (hyps + [pc_true_hyp (curr), true_next, non_err]
660		+ [h for (h, _) in loop_eq_hyps_at_visit (tag, split,
661			eqs_assume, restrs, vc_offs (1), use_if_at = True)])
662	goal = rep_graph.true_if_at_hyp (pred, curr)
663
664	return [(hyps, goal, 'Pred reverse step.')]
665
666def all_rev_induct_checks (p, restrs, hyps, point, (eqs, n), (pred, n_bound)):
667	(tag, _) = p.node_tags[point]
668        checks = (single_loop_induct_step_checks (p, restrs, hyps, tag,
669			point, n, eqs)
670		+ single_loop_induct_base_checks (p, restrs, hyps, tag,
671			point, n, eqs)
672		+ single_loop_rev_induct_checks (p, restrs, hyps, tag,
673			point, eqs, pred)
674		+ single_loop_rev_induct_base_checks (p, restrs, hyps,
675			tag, point, n_bound, eqs, pred))
676	return checks
677
678def leaf_condition_checks (p, restrs, hyps):
679	'''checks of the final refinement conditions'''
680	nrerr_pc_hyp = non_r_err_pc_hyp (p.pairing.tags, restrs)
681	hyps = [nrerr_pc_hyp] + hyps
682	[l_tag, r_tag] = p.pairing.tags
683
684	nlerr_pc = pc_false_hyp ((('Err', restrs), l_tag))
685	# this 'hypothesis' ensures that the representation is built all
686	# the way to Ret. in particular this ensures that function relations
687	# are available to use in proving single-side equalities
688	ret_eq = eq_hyp ((true_term, (('Ret', restrs), l_tag)),
689		(true_term, (('Ret', restrs), r_tag)))
690
691	### TODO: previously we considered the case where 'Ret' was unreachable
692	### (as a result of unsatisfiable hyps) and proved a simpler property.
693	### we might want to restore this
694	(_, out_eqs) = p.pairing.eqs
695	checks = [(hyps + [nlerr_pc, ret_eq], hyp, 'Leaf eq check') for hyp in
696		inst_eqs (p, restrs, out_eqs)]
697	return [(hyps + [ret_eq], nlerr_pc, 'Leaf path-cond imp')] + checks
698
699def proof_checks (p, proof):
700	return proof_checks_rec (p, (), init_point_hyps (p), proof, 'root')
701
702def proof_checks_imm (p, restrs, hyps, proof, path):
703	if proof.kind == 'Restr':
704		checks = proof_restr_checks (proof.point, proof.restr_range,
705			p, restrs, hyps)
706	elif proof.kind == 'SingleRevInduct':
707		checks = all_rev_induct_checks (p, restrs, hyps, proof.point,
708			proof.eqs_proof, proof.rev_proof)
709	elif proof.kind == 'Split':
710		checks = split_checks (p, restrs, hyps, proof.split)
711	elif proof.kind == 'Leaf':
712		checks = leaf_condition_checks (p, restrs, hyps)
713	elif proof.kind == 'CaseSplit':
714		checks = []
715
716	return [(hs, hyp, '%s on %s' % (name, path))
717		for (hs, hyp, name) in checks]
718
719def proof_checks_rec (p, restrs, hyps, proof, path):
720	checks = proof_checks_imm (p, restrs, hyps, proof, path)
721
722	subproblems = proof_subproblems (p, proof.kind,
723		proof.args, restrs, hyps, path)
724	for (subprob, subproof) in logic.azip (subproblems, proof.subproofs):
725		(restrs, hyps, path) = subprob
726		checks.extend (proof_checks_rec (p, restrs, hyps, subproof, path))
727	return checks
728
729last_failed_check = [None]
730
731def proof_check_groups (checks):
732	groups = {}
733	for (hyps, hyp, name) in checks:
734		n_vcs = set ([n_vc for hyp2 in [hyp] + hyps
735			for n_vc in hyp2.visits ()])
736		k = (tuple (sorted (list (n_vcs))))
737		groups.setdefault (k, []).append ((hyps, hyp, name))
738	return groups.values ()
739
740def test_hyp_group (rep, group, detail = None):
741	imps = [(hyps, hyp) for (hyps, hyp, _) in group]
742	names = set ([name for (_, _, name) in group])
743
744	trace ('Testing group of hyps: %s' % list (names), push = 1)
745	(res, i, res_kind) = rep.test_hyp_imps (imps)
746	trace ('Group result: %r' % res, push = -1)
747	if res:
748		return (res, None)
749	else:
750		if detail:
751			detail[0] = res_kind
752		return (res, group[i])
753
754def failed_test_sets (p, checks):
755	failed = []
756	sets = {}
757	for (hyps, hyp, name) in checks:
758		sets.setdefault (name, [])
759		sets[name].append ((hyps, hyp))
760	for name in sets:
761		rep = rep_graph.mk_graph_slice (p)
762		(res, _, _) = rep.test_hyp_imps (sets[name])
763		if not res:
764			failed.append (name)
765	return failed
766
767save_checked_proofs = [None]
768
769def check_proof (p, proof, use_rep = None):
770	checks = proof_checks (p, proof)
771	groups = proof_check_groups (checks)
772
773	for group in groups:
774		if use_rep == None:
775			rep = rep_graph.mk_graph_slice (p)
776		else:
777			rep = use_rep
778
779		detail = [0]
780		(verdict, elt) = test_hyp_group (rep, group, detail)
781		if verdict:
782			continue
783		(hyps, hyp, name) = elt
784		last_failed_check[0] = elt
785		trace ('%s: proof failed!' % name)
786		trace ('  (failure kind: %r)' % detail[0])
787		return False
788	if save_checked_proofs[0]:
789		save = save_checked_proofs[0]
790		save (p, proof)
791	return True
792
793def pretty_vseq ((split, (seq_start, seq_step), _)):
794	if (seq_start, seq_step) == (0, 1):
795		return 'visits to %d' % split
796	else:
797		i = seq_start + 1
798		j = i + seq_step
799		k = j + seq_step
800		return 'visits [%d, %d, %d ...] to %d' % (i, j, k, split)
801
802def next_induct_var (n):
803	s = 'ijkabc'
804	v = s[n % 6]
805	if n >= 6:
806		v += str ((n / 6) + 1)
807	return v
808
809def pretty_lambda (t):
810	v = syntax.mk_var ('#seq-visits', word32T)
811	t = logic.var_subst (t, {('%i', word32T) : v}, must_subst = False)
812	return syntax.pretty_expr (t, print_type = True)
813
814def check_proof_report_rec (p, restrs, hyps, proof, step_num, ctxt, inducts,
815		do_check = True):
816	printout ('Step %d: %s' % (step_num, ctxt))
817	if proof.kind == 'Restr':
818		(kind, (x, y)) = proof.restr_range
819		if kind == 'Offset':
820			v = inducts[1][proof.point]
821			rexpr = '{%s + %s ..< %s + %s}' % (v, x, v, y)
822		else:
823			rexpr = '{%s ..< %s}' % (x, y)
824		printout ('  Prove the number of visits to %d is in %s'
825			% (proof.point, rexpr))
826
827		checks = proof_restr_checks (proof.point, proof.restr_range,
828			p, restrs, hyps)
829		cases = ['']
830	elif proof.kind == 'SingleRevInduct':
831		printout ('  Proving a predicate by future induction.')
832		(eqs, n) = proof.eqs_proof
833		point = proof.point
834		printout ('    proving these invariants by %d-induction' % n)
835		for x in eqs:
836			printout ('      %s (@ addr %s)'
837				% (pretty_lambda (x), point))
838		printout ('    then establishing this predicate')
839		(pred, n_bound) = proof.rev_proof
840		printout ('      %s (@ addr %s)'
841			% (pretty_lambda (pred), point))
842		printout ('    at large iterations (%d) and by back induction.'
843			% n_bound)
844		cases = ['']
845		checks = all_rev_induct_checks (p, restrs, hyps, point,
846			proof.eqs_proof, proof.rev_proof)
847	elif proof.kind == 'Split':
848		(l_dts, r_dts, eqs, n, lrmx) = proof.split
849		v = next_induct_var (inducts[0])
850		inducts = (inducts[0] + 1, dict (inducts[1]))
851		inducts[1][l_dts[0]] = v
852		inducts[1][r_dts[0]] = v
853		printout ('  prove %s related to %s' % (pretty_vseq (l_dts),
854			pretty_vseq (r_dts)))
855		printout ('    with equalities')
856		for (x, y) in eqs:
857			printout ('      %s (@ addr %s)' % (pretty_lambda (x),
858				l_dts[0]))
859			printout ('      = %s (@ addr %s)' % (pretty_lambda (y),
860				r_dts[0]))
861		printout ('    and with invariants')
862		for x in l_dts[2]:
863			printout ('      %s (@ addr %s)'
864				% (pretty_lambda (x), l_dts[0]))
865		for x in r_dts[2]:
866			printout ('      %s (@ addr %s)'
867				% (pretty_lambda (x), r_dts[0]))
868		checks = split_checks (p, restrs, hyps, proof.split)
869		cases = ['case in (%d) where the length of the sequence < %d'
870				% (step_num, n),
871			'case in (%d) where the length of the sequence is %s + %s'
872				% (step_num, v, n)]
873	elif proof.kind == 'Leaf':
874		printout ('  prove all verification conditions')
875		checks = leaf_condition_checks (p, restrs, hyps)
876		cases = []
877	elif proof.kind == 'CaseSplit':
878		printout ('  case split on whether %d is visited' % proof.point)
879		checks = []
880		cases = ['case in (%d) where %d is visited' % (step_num, proof.point),
881			'case in (%d) where %d is not visited' % (step_num, proof.point)]
882
883	if checks and do_check:
884		groups = proof_check_groups (checks)
885		for group in groups:
886			rep = rep_graph.mk_graph_slice (p)
887			detail = [0]
888			(res, _) = test_hyp_group (rep, group, detail)
889			if not res:
890				printout ('    .. failed to prove this.')
891				printout ('      (failure kind: %r)' % detail[0])
892				return
893
894		printout ('    .. proven.')
895
896	subproblems = proof_subproblems (p, proof.kind,
897		proof.args, restrs, hyps, '')
898	xs = logic.azip (subproblems, proof.subproofs)
899	xs = logic.azip (xs, cases)
900	step_num += 1
901	for ((subprob, subproof), case) in xs:
902		(restrs, hyps, _) = subprob
903		res = check_proof_report_rec (p, restrs, hyps, subproof,
904			step_num, case, inducts, do_check = do_check)
905		if not res:
906			return
907		(step_num, induct_var_num) = res
908		inducts = (induct_var_num, inducts[1])
909	return (step_num, inducts[0])
910
911def check_proof_report (p, proof, do_check = True):
912	res = check_proof_report_rec (p, (), init_point_hyps (p), proof,
913		1, '', (0, {}), do_check = do_check)
914	return bool (res)
915
916def save_proofs_to_file (fname, mode = 'w'):
917	assert mode in ['w', 'a']
918	f = open (fname, mode)
919
920	def save (p, proof):
921		f.write ('ProblemProof (%s) {\n' % p.name)
922		for s in p.serialise ():
923			f.write (s + '\n')
924		ss = []
925		proof.serialise (p, ss)
926		f.write (' '.join (ss))
927		f.write ('\n}\n')
928		f.flush ()
929	return save
930
931def load_proofs_from_file (fname):
932	f = open (fname)
933
934	proofs = {}
935	lines = None
936	for line in f:
937		line = line.strip ()
938		if line.startswith ('ProblemProof'):
939			assert line.endswith ('{'), line
940			name_bit = line[len ('ProblemProof') : -1].strip ()
941			assert name_bit.startswith ('('), name_bit
942			assert name_bit.endswith (')'), name_bit
943			name = name_bit[1:-1]
944			lines = []
945		elif line == '}':
946			assert lines[0] == 'Problem'
947			assert lines[-2] == 'EndProblem'
948			import problem
949			trace ('loading proof from %d lines' % len (lines))
950			p = problem.deserialise (name, lines[:-1])
951			proof = deserialise (lines[-1])
952			proofs.setdefault (name, [])
953			proofs[name].append ((p, proof))
954			trace ('loaded proof %s' % name)
955			lines = None
956		elif line.startswith ('#'):
957			pass
958		elif line:
959			lines.append (line)
960	assert not lines
961	return proofs
962
963