1#
2# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3#
4# SPDX-License-Identifier: BSD-2-Clause
5#
6
7# proof scripts and check process
8
9from rep_graph import mk_graph_slice, Hyp, eq_hyp, pc_true_hyp, pc_false_hyp
10import rep_graph
11from problem import Problem, inline_at_point
12import problem
13
14from solver import to_smt_expr
15from target_objects import functions, pairings, trace, printout
16import target_objects
17from rep_graph import (vc_num, vc_offs, vc_double_range, vc_upto, mk_vc_opts,
18	VisitCount)
19import logic
20
21from syntax import (true_term, false_term, boolT, mk_var, mk_word32, mk_word8,
22	mk_plus, mk_minus, word32T, word8T, mk_and, mk_eq, mk_implies, mk_not,
23	rename_expr)
24import syntax
25
26def build_problem (pairing, force_inline = None, avoid_abort = False):
27	p = Problem (pairing)
28
29	for (tag, fname) in pairing.funs.items ():
30		p.add_entry_function (functions[fname], tag)
31
32	p.do_analysis ()
33
34	# FIXME: the inlining is heuristic, and arguably belongs in 'search'
35	inline_completely_unmatched (p, skip_underspec = avoid_abort)
36
37	# now do any C inlining
38	inline_reachable_unmatched_C (p, force_inline,
39		skip_underspec = avoid_abort)
40
41	trace ('Done inlining.')
42
43	p.pad_merge_points ()
44	p.do_analysis ()
45
46	if not avoid_abort:
47		p.check_no_inner_loops ()
48
49	return p
50
51def inline_completely_unmatched (p, ref_tags = None, skip_underspec = False):
52	if ref_tags == None:
53		ref_tags = p.pairing.tags
54	while True:
55		ns = [(n, skip_underspec
56                                and not functions[p.nodes[n].fname].entry)
57			for n in p.nodes
58			if p.nodes[n].kind == 'Call'
59			if not [pair for pair
60				in pairings.get (p.nodes[n].fname, [])
61				if pair.tags == ref_tags]]
62		[trace ('Skipped inlining underspecified %s.'
63			% p.nodes[n].fname) for (n, skip) in ns if skip]
64		ns = [n for (n, skip) in ns if not skip]
65		for n in ns:
66			trace ('Function %s at %d - %s - completely unmatched.'
67				% (p.nodes[n].fname, n, p.node_tags[n][0]))
68			inline_at_point (p, n, do_analysis = False)
69		if not ns:
70			p.do_analysis ()
71			return
72
73def inline_reachable_unmatched_C (p, force_inline = None,
74		skip_underspec = False):
75	if 'C' not in p.pairing.tags:
76		return
77	[compare_tag] = [tag for tag in p.pairing.tags if tag != 'C']
78	inline_reachable_unmatched (p, 'C', compare_tag, force_inline,
79		skip_underspec = skip_underspec)
80
81def inline_reachable_unmatched (p, inline_tag, compare_tag,
82		force_inline = None, skip_underspec = False):
83	funs = [pair.funs[inline_tag]
84		for n in p.nodes
85		if p.nodes[n].kind == 'Call'
86		if p.node_tags[n][0] == compare_tag
87		for pair in pairings.get (p.nodes[n].fname, [])
88		if inline_tag in pair.tags]
89
90	rep = mk_graph_slice (p,
91		consider_inline (funs, inline_tag, force_inline,
92			skip_underspec))
93	opts = vc_double_range (3, 3)
94	while True:
95		try:
96			heads = problem.loop_heads_including_inner (p)
97			limits = [(n, opts) for n in heads]
98
99			for n in p.nodes.keys ():
100				try:
101					r = rep.get_node_pc_env ((n, limits))
102				except rep.TooGeneral:
103					pass
104
105			rep.get_node_pc_env (('Ret', limits), inline_tag)
106			rep.get_node_pc_env (('Err', limits), inline_tag)
107			break
108		except rep_graph.InlineEvent:
109			continue
110
111def consider_inline1 (p, n, matched_funs, inline_tag,
112		force_inline, skip_underspec):
113	node = p.nodes[n]
114	assert node.kind == 'Call'
115
116	if p.node_tags[n][0] != inline_tag:
117		return False
118
119	f_nm = node.fname
120	if skip_underspec and not functions[f_nm].entry:
121		trace ('Skipping inlining underspecified %s' % f_nm)
122		return False
123	if f_nm not in matched_funs or (force_inline and force_inline (f_nm)):
124		return lambda: inline_at_point (p, n)
125	else:
126		return False
127
128def consider_inline (matched_funs, tag, force_inline, skip_underspec = False):
129	return lambda (p, n): consider_inline1 (p, n, matched_funs, tag,
130		force_inline, skip_underspec)
131
132def inst_eqs (p, restrs, eqs, tag_map = {}):
133	addr_map = {}
134	if not tag_map:
135		tag_map = dict ([(tag, tag) for tag in p.tags ()])
136	for (pair_tag, p_tag) in tag_map.iteritems ():
137		addr_map[pair_tag + '_IN'] = ((p.get_entry (p_tag), ()), p_tag)
138		addr_map[pair_tag + '_OUT'] = (('Ret', restrs), p_tag)
139	renames = p.entry_exit_renames (tag_map.values ())
140	for (pair_tag, p_tag) in tag_map.iteritems ():
141		renames[pair_tag + '_IN'] = renames[p_tag + '_IN']
142		renames[pair_tag + '_OUT'] = renames[p_tag + '_OUT']
143	hyps = []
144	for (lhs, rhs) in eqs:
145		vals = [(rename_expr (x, renames[x_addr]), addr_map[x_addr])
146			for (x, x_addr) in (lhs, rhs)]
147		hyps.append (eq_hyp (vals[0], vals[1]))
148	return hyps
149
150def init_point_hyps (p):
151	(inp_eqs, _) = p.pairing.eqs
152	return inst_eqs (p, (), inp_eqs)
153
154class ProofNode:
155	def __init__ (self, kind, args = None, subproofs = []):
156		self.kind = kind
157		self.args = args
158		self.subproofs = tuple (subproofs)
159		if self.kind == 'Leaf':
160			assert args == None
161			assert list (subproofs) == []
162		elif self.kind == 'Restr':
163			(self.point, self.restr_range) = args
164			assert len (subproofs) == 1
165		elif self.kind == 'SingleRevInduct':
166			(self.point, self.eqs_proof, self.rev_proof) = args
167			assert len (subproofs) == 1
168		elif self.kind == 'Split':
169			self.split = args
170			(l_details, r_details, eqs, n, loop_r_max) = args
171			assert len (subproofs) == 2
172		elif self.kind == 'CaseSplit':
173			(self.point, self.tag) = args
174			assert len (subproofs) == 2
175		else:
176			assert not 'proof node kind understood', kind
177
178	def __repr__ (self):
179		return 'ProofNode (%r, %r, %r)' % (self.kind,
180			self.args, self.subproofs)
181
182	def serialise (self, p, ss):
183		if self.kind == 'Leaf':
184			ss.append ('Leaf')
185		elif self.kind == 'Restr':
186			(kind, (x, y)) = self.restr_range
187			tag = p.node_tags[self.point][0]
188			ss.extend (['Restr', '%d' % self.point,
189				tag, kind, '%d' % x, '%d' % y])
190		elif self.kind == 'SingleRevInduct':
191			tag = p.node_tags[self.point][0]
192			(eqs, n) = self.eqs_proof
193			ss.extend (['SingleRevInduct', '%d' % self.point,
194				tag, '%d' % n, '%d' % len (eqs)])
195			for (x, y) in eqs:
196				serialise_lambda (x, ss)
197				serialise_lambda (y, ss)
198			(pred, n_bound) = self.rev_proof
199			pred.serialise (ss)
200			ss.append ('%d' % n_bound)
201		elif self.kind == 'Split':
202			(l_details, r_details, eqs, n, loop_r_max) = self.args
203			ss.extend (['Split', '%d' % n, '%d' % loop_r_max])
204			serialise_details (l_details, ss)
205			serialise_details (r_details, ss)
206			ss.append ('%d' % len (eqs))
207			for (x, y) in eqs:
208				serialise_lambda (x, ss)
209				serialise_lambda (y, ss)
210		elif self.kind == 'CaseSplit':
211			ss.extend (['CaseSplit', '%d' % self.point, self.tag])
212		else:
213			assert not 'proof node kind understood'
214		for proof in self.subproofs:
215			proof.serialise (p, ss)
216
217	def all_subproofs (self):
218		return [self] + [proof for proof1 in self.subproofs
219			for proof in proof1.all_subproofs ()]
220
221	def all_subproblems (self, p, restrs, hyps, name):
222		subproblems = proof_subproblems (p, self.kind,
223			self.args, restrs, hyps, name)
224		subproofs = logic.azip (subproblems, self.subproofs)
225		return [(self, restrs, hyps)] + [problem
226			for ((restrs2, hyps2, name2), proof) in subproofs
227			for problem in proof.all_subproblems (p, restrs2,
228				hyps2, name2)]
229
230	def save_serialise (self, p, fname):
231		f = open (fname, 'w')
232		ss = []
233		self.serialise (p, ss)
234		f.write (' '.join (ss) + '\n')
235		f.close ()
236
237	def __hash__ (self):
238		return syntax.hash_tuplify (self.kind, self.args,
239			self.subproofs)
240
241def serialise_details (details, ss):
242	(split, (seq_start, step), eqs) = details
243	ss.extend (['%d' % split, '%d' % seq_start, '%d' % step])
244	ss.append ('%d' % len (eqs))
245	for eq in eqs:
246		serialise_lambda (eq, ss)
247
248def serialise_lambda (eq_term, ss):
249	ss.extend (['Lambda', '%i'])
250	word32T.serialise (ss)
251	eq_term.serialise (ss)
252
253def deserialise_details (ss, i):
254	(split, seq_start, step) = [int (x) for x in ss[i : i + 3]]
255	(i, eqs) = syntax.parse_list (deserialise_lambda, ss, i + 3)
256	return (i, (split, (seq_start, step), eqs))
257
258def deserialise_lambda (ss, i):
259	assert ss[i : i + 2] == ['Lambda', '%i'], (ss, i)
260	(i, typ) = syntax.parse_typ (ss, i + 2)
261	assert typ == word32T, typ
262	(i, eq_term) = syntax.parse_expr (ss, i)
263	return (i, eq_term)
264
265def deserialise_double_lambda (ss, i):
266	(i, x) = deserialise_lambda (ss, i)
267	(i, y) = deserialise_lambda (ss, i)
268	return (i, (x, y))
269
270def deserialise_inner (ss, i):
271	if ss[i] == 'Leaf':
272		return (i + 1, ProofNode ('Leaf'))
273	elif ss[i] == 'Restr':
274		point = int (ss[i + 1])
275		tag = ss[i + 2]
276		kind = ss[i + 3]
277		assert kind in ['Number', 'Offset'], (kind, i)
278		x = int (ss[i + 4])
279		y = int (ss[i + 5])
280		(i, p1) = deserialise_inner (ss, i + 6)
281		return (i, ProofNode ('Restr', (point, (kind, (x, y))), [p1]))
282	elif ss[i] == 'SingleRevInduct':
283		point = int (ss[i + 1])
284		tag = ss[i + 2]
285		n = int (ss[i + 3])
286		(i, eqs) = syntax.parse_list (deserialise_double_lambda, ss, i + 4)
287		(i, pred) = syntax.parse_term (ss, i)
288		n_bound = int (ss[i])
289		(i, p1) = deserialise_inner (ss, i + 1)
290		return (i, ProofNode ('SingleRevInduct', (point, (eqs, n),
291			(pred, n_bound)), [p1]))
292	elif ss[i] == 'Split':
293		n = int (ss[i + 1])
294		loop_r_max = int (ss[i + 2])
295		(i, l_details) = deserialise_details (ss, i + 3)
296		(i, r_details) = deserialise_details (ss, i)
297		(i, eqs) = syntax.parse_list (deserialise_double_lambda, ss, i)
298		(i, p1) = deserialise_inner (ss, i)
299		(i, p2) = deserialise_inner (ss, i)
300		return (i, ProofNode ('Split', (l_details, r_details, eqs,
301			n, loop_r_max), [p1, p2]))
302	elif ss[i] == 'CaseSplit':
303		n = int (ss[i + 1])
304		tag = ss[i + 2]
305		(i, p1) = deserialise_inner (ss, i + 3)
306		(i, p2) = deserialise_inner (ss, i)
307		return (i, ProofNode ('CaseSplit', (n, tag), [p1, p2]))
308	else:
309		assert not 'proof node type understood', (ss, i)
310
311def deserialise (line):
312	ss = line.split ()
313	(i, proof) = deserialise_inner (ss, 0)
314	assert i == len (ss), (ss, i)
315	return proof
316
317def proof_subproblems (p, kind, args, restrs, hyps, path):
318	tags = p.pairing.tags
319	if kind == 'Leaf':
320		return []
321	elif kind == 'Restr':
322		restr = get_proof_restr (args[0], args[1])
323		hyps = hyps + [restr_trivial_hyp (p, args[0], args[1], restrs)]
324		return [((restr,) + restrs, hyps,
325			'%s (%d limited)' % (path, args[0]))]
326	elif kind == 'SingleRevInduct':
327		hyp = single_induct_resulting_hyp (p, restrs, args)
328		return [(restrs, hyps + [hyp], path)]
329	elif kind == 'Split':
330		split = args
331		return [(restrs, hyps + split_no_loop_hyps (tags, split, restrs),
332			'%d init case in %s' % (split[0][0], path)),
333		(restrs, hyps + split_loop_hyps (tags, split, restrs, exit = True),
334			'%d loop case in %s' % (split[0][0], path))]
335	elif kind == 'CaseSplit':
336		(point, tag) = args
337		visit = ((point, restrs), tag)
338		true_hyps = hyps + [pc_true_hyp (visit)]
339		false_hyps = hyps + [pc_false_hyp (visit)]
340		return [(restrs, true_hyps,
341			'true case (%d visited) in %s' % (point, path)),
342		(restrs, false_hyps,
343			'false case (%d not visited) in %s' % (point, path))]
344	else:
345		assert not 'proof node kind understood', proof.kind
346
347
348def split_heads ((l_details, r_details, eqs, n, _)):
349	(l_split, _, _) = l_details
350	(r_split, _, _) = r_details
351	return [l_split, r_split]
352
353def split_no_loop_hyps (tags, split, restrs):
354	((_, (l_seq_start, l_step), _), _, _, n, _) = split
355
356	(l_visit, _) = split_visit_visits (tags, split, restrs, vc_num (n))
357
358	return [pc_false_hyp (l_visit)]
359
360def split_visit_one_visit (tag, details, restrs, visit):
361	if details == None:
362		return None
363	(split, (seq_start, step), eqs) = details
364
365	# the split point sequence at low numbers ('Number') is offset
366	# by the point the sequence starts. At symbolic offsets we ignore
367	# that, instead having the loop counter for the two sequences
368	# be the same number of iterations after the sequence start.
369	if visit.kind == 'Offset':
370		visit = vc_offs (visit.n * step)
371	else:
372		visit = vc_num (seq_start + (visit.n * step))
373
374	visit = ((split, ((split, visit), ) + restrs), tag)
375	return visit
376
377def split_visit_visits (tags, split, restrs, visit):
378	(ltag, rtag) = tags
379	(l_details, r_details, eqs, _, _) = split
380
381	l_visit = split_visit_one_visit (ltag, l_details, restrs, visit)
382	r_visit = split_visit_one_visit (rtag, r_details, restrs, visit)
383
384	return (l_visit, r_visit)
385
386def split_hyps_at_visit (tags, split, restrs, visit):
387	(l_details, r_details, eqs, _, _) = split
388	(l_split, (l_seq_start, l_step), l_eqs) = l_details
389	(r_split, (r_seq_start, r_step), r_eqs) = r_details
390
391	(l_visit, r_visit) = split_visit_visits (tags, split, restrs, visit)
392	(l_start, r_start) = split_visit_visits (tags, split, restrs, vc_num (0))
393	(l_tag, r_tag) = tags
394
395	def mksub (v):
396		return lambda exp: logic.var_subst (exp, {('%i', word32T) : v},
397			must_subst = False)
398	def inst (exp):
399		return logic.inst_eq_at_visit (exp, visit)
400	zsub = mksub (mk_word32 (0))
401	if visit.kind == 'Number':
402		lsub = mksub (mk_word32 (visit.n))
403	else:
404		lsub = mksub (mk_plus (mk_var ('%n', word32T),
405			mk_word32 (visit.n)))
406
407	hyps = [(Hyp ('PCImp', l_visit, r_visit), 'pc imp'),
408		(Hyp ('PCImp', l_visit, l_start), '%s pc imp' % l_tag),
409		(Hyp ('PCImp', r_visit, r_start), '%s pc imp' % r_tag)]
410	hyps += [(eq_hyp ((zsub (l_exp), l_start), (lsub (l_exp), l_visit),
411				(l_split, r_split)), '%s const' % l_tag)
412			for l_exp in l_eqs if inst (l_exp)]
413	hyps += [(eq_hyp ((zsub (r_exp), r_start), (lsub (r_exp), r_visit),
414				(l_split, r_split)), '%s const' % r_tag)
415			for r_exp in r_eqs if inst (r_exp)]
416	hyps += [(eq_hyp ((lsub (l_exp), l_visit), (lsub (r_exp), r_visit),
417				(l_split, r_split)), 'eq')
418			for (l_exp, r_exp) in eqs
419			if inst (l_exp) and inst (r_exp)]
420	return hyps
421
422def split_loop_hyps (tags, split, restrs, exit):
423	((r_split, _, _), _, _, n, _) = split
424	(l_visit, _) = split_visit_visits (tags, split, restrs, vc_offs (n - 1))
425	(l_cont, _) = split_visit_visits (tags, split, restrs, vc_offs (n))
426	(l_tag, r_tag) = tags
427
428	l_enter = pc_true_hyp (l_visit)
429	l_exit = pc_false_hyp (l_cont)
430	if exit:
431		hyps = [l_enter, l_exit]
432	else:
433		hyps = [l_enter]
434	return hyps + [hyp for offs in map (vc_offs, range (n))
435		for (hyp, _) in split_hyps_at_visit (tags, split, restrs, offs)]
436
437def loops_to_split (p, restrs):
438	loop_heads_with_split = set ([p.loop_id (n)
439		for (n, visit_set) in restrs])
440	rem_loop_heads = set (p.loop_heads ()) - loop_heads_with_split
441	for (n, visit_set) in restrs:
442		if not visit_set.has_zero ():
443			# n must be visited, so loop heads must be
444			# reachable from n (or on another tag)
445			rem_loop_heads = [lh for lh in rem_loop_heads
446				if p.is_reachable_from (n, lh)
447				or p.node_tags[n][0] != p.node_tags[lh][0]]
448	return rem_loop_heads
449
450def restr_others (p, restrs, n):
451	extras = [(sp, vc_upto (n)) for sp in loops_to_split (p, restrs)]
452	return restrs + tuple (extras)
453
454def non_r_err_pc_hyp (tags, restrs):
455	return pc_false_hyp ((('Err', restrs), tags[1]))
456
457def split_r_err_pc_hyp (p, split, restrs, tags = None):
458	(_, r_details, _, n, loop_r_max) = split
459	(r_split, (r_seq_start, r_step), r_eqs) = r_details
460
461	nc = n * r_step
462	vc = vc_double_range (r_seq_start + nc, loop_r_max + 2)
463
464	restrs = restr_others (p, ((r_split, vc), ) + restrs, 2)
465
466	if tags == None:
467		tags = p.pairing.tags
468
469	return non_r_err_pc_hyp (tags, restrs)
470
471restr_bump = 0
472
473def get_proof_restr (n, (kind, (x, y))):
474	return (n, mk_vc_opts ([VisitCount (kind, i)
475		for i in range (x, y + restr_bump)]))
476
477def restr_trivial_hyp (p, n, (kind, (x, y)), restrs):
478	restr = (n, VisitCount (kind, y - 1))
479	return rep_graph.pc_triv_hyp (((n, (restr, ) + restrs),
480		p.node_tags[n][0]))
481
482def proof_restr_checks (n, (kind, (x, y)), p, restrs, hyps):
483	restr = get_proof_restr (n, (kind, (x, y)))
484	ncerr_hyp = non_r_err_pc_hyp (p.pairing.tags,
485		restr_others (p, (restr, ) + restrs, 2))
486	hyps = [ncerr_hyp] + hyps
487	def visit (vc):
488		return ((n, ((n, vc), ) + restrs), p.node_tags[n][0])
489
490	# this cannot be more uniform because the representation of visit
491	# at offset 0 is all a bit odd, with n being the only node so visited:
492	if kind == 'Offset':
493		min_vc = vc_offs (max (0, x - 1))
494	elif x > 1:
495		min_vc = vc_num (x - 1)
496	else:
497		min_vc = None
498	if min_vc:
499		init_check = [(hyps, pc_true_hyp (visit (min_vc)),
500			'Check of restr min %d %s for %d' % (x, kind, n))]
501	else:
502		init_check = []
503
504	# if we can reach node n with (y - 1) visits to n, then the next
505	# node will have y visits to n, which we are disallowing
506	# thus we show that this visit is impossible
507	top_vc = VisitCount (kind, y - 1)
508	top_check = (hyps, pc_false_hyp (visit (top_vc)),
509		'Check of restr max %d %s for %d' % (y, kind, n))
510	return init_check + [top_check]
511
512def split_init_step_checks (p, restrs, hyps, split, tags = None):
513	(_, _, _, n, _) = split
514	if tags == None:
515		tags = p.pairing.tags
516
517	err_hyp = split_r_err_pc_hyp (p, split, restrs, tags = tags)
518	hyps = [err_hyp] + hyps
519	checks = []
520	for i in range (n):
521		(l_visit, r_visit) = split_visit_visits (tags, split,
522			restrs, vc_num (i))
523		lpc_hyp = pc_true_hyp (l_visit)
524		# this trivial 'hyp' ensures the rep is built to include
525		# the matching rhs visits when checking lhs consts
526		rpc_triv_hyp = rep_graph.pc_triv_hyp (r_visit)
527		vis_hyps = split_hyps_at_visit (tags, split, restrs, vc_num (i))
528
529		for (hyp, desc) in vis_hyps:
530			checks.append ((hyps + [lpc_hyp, rpc_triv_hyp], hyp,
531				'Induct check at visit %d: %s' % (i, desc)))
532	return checks
533
534def split_induct_step_checks (p, restrs, hyps, split, tags = None):
535	((l_split, _, _), _, _, n, _) = split
536	if tags == None:
537		tags = p.pairing.tags
538
539	err_hyp = split_r_err_pc_hyp (p, split, restrs, tags = tags)
540	(cont, r_cont) = split_visit_visits (tags, split, restrs, vc_offs (n))
541	# the 'trivial' hyp here ensures the representation includes a loop
542	# of the rhs when proving const equations on the lhs
543	hyps = ([err_hyp, pc_true_hyp (cont),
544			rep_graph.pc_triv_hyp (r_cont)] + hyps
545		+ split_loop_hyps (tags, split, restrs, exit = False))
546
547	return [(hyps, hyp, 'Induct check (%s) at inductive step for %d'
548			% (desc, l_split))
549		for (hyp, desc) in split_hyps_at_visit (tags, split,
550			restrs, vc_offs (n))]
551
552def check_split_induct_step_group (rep, restrs, hyps, split, tags = None):
553	checks = split_induct_step_checks (rep.p, restrs, hyps, split,
554		tags = tags)
555	groups = proof_check_groups (checks)
556	for group in groups:
557		(verdict, _) = test_hyp_group (rep, group)
558		if not verdict:
559			return False
560	return True
561
562def split_checks (p, restrs, hyps, split, tags = None):
563	return (split_init_step_checks (p, restrs, hyps, split, tags = tags)
564		+ split_induct_step_checks (p, restrs, hyps, split, tags = tags))
565
566def loop_eq_hyps_at_visit (tag, split, eqs, restrs, visit_num,
567		use_if_at = False):
568	details = (split, (0, 1), eqs)
569	visit = split_visit_one_visit (tag, details, restrs, visit_num)
570	start = split_visit_one_visit (tag, details, restrs, vc_num (0))
571
572	def mksub (v):
573		return lambda exp: logic.var_subst (exp, {('%i', word32T) : v},
574			must_subst = False)
575	zsub = mksub (mk_word32 (0))
576	if visit_num.kind == 'Number':
577		isub = mksub (mk_word32 (visit_num.n))
578	else:
579		isub = mksub (mk_plus (mk_var ('%n', word32T),
580			mk_word32 (visit_num.n)))
581
582	hyps = [(Hyp ('PCImp', visit, start), '%s pc imp' % tag)]
583	hyps += [(eq_hyp ((zsub (exp), start), (isub (exp), visit),
584			(split, 0), use_if_at = use_if_at), '%s const' % tag)
585		for exp in eqs if logic.inst_eq_at_visit (exp, visit_num)]
586
587	return hyps
588
589def single_induct_resulting_hyp (p, restrs, rev_induct_args):
590	(point, _, (pred, _)) = rev_induct_args
591	(tag, _) = p.node_tags[point]
592	vis = ((point, restrs + tuple ([(point, vc_num (0))])), tag)
593	return rep_graph.true_if_at_hyp (pred, vis)
594
595def single_loop_induct_base_checks (p, restrs, hyps, tag, split, n, eqs):
596	tests = []
597	details = (split, (0, 1), eqs)
598	for i in range (n + 1):
599		reach = split_visit_one_visit (tag, details, restrs, vc_num (i))
600		nhyps = [pc_true_hyp (reach)]
601		tests.extend ([(hyps + nhyps, hyp,
602			'Base check (%s, %d) at induct step for %d'
603				% (desc, i, split))
604			for (hyp, desc) in loop_eq_hyps_at_visit (tag, split,
605				eqs, restrs, vc_num (i))])
606	return tests
607
608def single_loop_induct_step_checks (p, restrs, hyps, tag, split, n,
609				eqs, eqs_assume = None):
610	if eqs_assume == None:
611		eqs_assume = []
612	details = (split, (0, 1), eqs_assume + eqs)
613	cont = split_visit_one_visit (tag, details, restrs, vc_offs (n))
614	hyps = ([pc_true_hyp (cont)] + hyps
615			+ [h for i in range (n)
616		for (h, _) in loop_eq_hyps_at_visit (tag, split,
617					eqs_assume + eqs, restrs, vc_offs (i))])
618
619	return [(hyps, hyp, 'Induct check (%s) at inductive step for %d'
620			% (desc, split))
621		for (hyp, desc) in loop_eq_hyps_at_visit (tag, split, eqs,
622			restrs, vc_offs (n))]
623
624def mk_loop_counter_eq_hyp (p, split, restrs, n):
625	details = (split, (0, 1), [])
626	(tag, _) = p.node_tags[split]
627	visit = split_visit_one_visit (tag, details, restrs, vc_offs (0))
628	return eq_hyp ((mk_var ('%n', word32T), visit),
629		(mk_word32 (n), visit), (split, 0))
630
631def single_loop_rev_induct_base_checks (p, restrs, hyps, tag, split,
632		n_bound, eqs_assume, pred):
633	details = (split, (0, 1), eqs_assume)
634	cont = split_visit_one_visit (tag, details, restrs, vc_offs (1))
635	n_hyp = mk_loop_counter_eq_hyp (p, split, restrs, n_bound)
636
637	split_details = (None, details, None, 1, 1)
638	non_err = split_r_err_pc_hyp (p, split_details, restrs)
639
640	hyps = (hyps + [n_hyp, pc_true_hyp (cont), non_err]
641		+ [h for (h, _) in loop_eq_hyps_at_visit (tag,
642			split, eqs_assume, restrs, vc_offs (0))])
643	goal = rep_graph.true_if_at_hyp (pred, cont)
644
645	return [(hyps, goal, 'Pred true at %d check.' % n_bound)]
646
647def single_loop_rev_induct_checks (p, restrs, hyps, tag, split,
648		eqs_assume, pred):
649	details = (split, (0, 1), eqs_assume)
650	curr = split_visit_one_visit (tag, details, restrs, vc_offs (1))
651	cont = split_visit_one_visit (tag, details, restrs, vc_offs (2))
652
653	split_details = (None, details, None, 1, 1)
654	non_err = split_r_err_pc_hyp (p, split_details, restrs)
655	true_next = rep_graph.true_if_at_hyp (pred, cont)
656
657	hyps = (hyps + [pc_true_hyp (curr), true_next, non_err]
658		+ [h for (h, _) in loop_eq_hyps_at_visit (tag, split,
659			eqs_assume, restrs, vc_offs (1), use_if_at = True)])
660	goal = rep_graph.true_if_at_hyp (pred, curr)
661
662	return [(hyps, goal, 'Pred reverse step.')]
663
664def all_rev_induct_checks (p, restrs, hyps, point, (eqs, n), (pred, n_bound)):
665	(tag, _) = p.node_tags[point]
666        checks = (single_loop_induct_step_checks (p, restrs, hyps, tag,
667			point, n, eqs)
668		+ single_loop_induct_base_checks (p, restrs, hyps, tag,
669			point, n, eqs)
670		+ single_loop_rev_induct_checks (p, restrs, hyps, tag,
671			point, eqs, pred)
672		+ single_loop_rev_induct_base_checks (p, restrs, hyps,
673			tag, point, n_bound, eqs, pred))
674	return checks
675
676def leaf_condition_checks (p, restrs, hyps):
677	'''checks of the final refinement conditions'''
678	nrerr_pc_hyp = non_r_err_pc_hyp (p.pairing.tags, restrs)
679	hyps = [nrerr_pc_hyp] + hyps
680	[l_tag, r_tag] = p.pairing.tags
681
682	nlerr_pc = pc_false_hyp ((('Err', restrs), l_tag))
683	# this 'hypothesis' ensures that the representation is built all
684	# the way to Ret. in particular this ensures that function relations
685	# are available to use in proving single-side equalities
686	ret_eq = eq_hyp ((true_term, (('Ret', restrs), l_tag)),
687		(true_term, (('Ret', restrs), r_tag)))
688
689	### TODO: previously we considered the case where 'Ret' was unreachable
690	### (as a result of unsatisfiable hyps) and proved a simpler property.
691	### we might want to restore this
692	(_, out_eqs) = p.pairing.eqs
693	checks = [(hyps + [nlerr_pc, ret_eq], hyp, 'Leaf eq check') for hyp in
694		inst_eqs (p, restrs, out_eqs)]
695	return [(hyps + [ret_eq], nlerr_pc, 'Leaf path-cond imp')] + checks
696
697def proof_checks (p, proof):
698	return proof_checks_rec (p, (), init_point_hyps (p), proof, 'root')
699
700def proof_checks_imm (p, restrs, hyps, proof, path):
701	if proof.kind == 'Restr':
702		checks = proof_restr_checks (proof.point, proof.restr_range,
703			p, restrs, hyps)
704	elif proof.kind == 'SingleRevInduct':
705		checks = all_rev_induct_checks (p, restrs, hyps, proof.point,
706			proof.eqs_proof, proof.rev_proof)
707	elif proof.kind == 'Split':
708		checks = split_checks (p, restrs, hyps, proof.split)
709	elif proof.kind == 'Leaf':
710		checks = leaf_condition_checks (p, restrs, hyps)
711	elif proof.kind == 'CaseSplit':
712		checks = []
713
714	return [(hs, hyp, '%s on %s' % (name, path))
715		for (hs, hyp, name) in checks]
716
717def proof_checks_rec (p, restrs, hyps, proof, path):
718	checks = proof_checks_imm (p, restrs, hyps, proof, path)
719
720	subproblems = proof_subproblems (p, proof.kind,
721		proof.args, restrs, hyps, path)
722	for (subprob, subproof) in logic.azip (subproblems, proof.subproofs):
723		(restrs, hyps, path) = subprob
724		checks.extend (proof_checks_rec (p, restrs, hyps, subproof, path))
725	return checks
726
727last_failed_check = [None]
728
729def proof_check_groups (checks):
730	groups = {}
731	for (hyps, hyp, name) in checks:
732		n_vcs = set ([n_vc for hyp2 in [hyp] + hyps
733			for n_vc in hyp2.visits ()])
734		k = (tuple (sorted (list (n_vcs))))
735		groups.setdefault (k, []).append ((hyps, hyp, name))
736	return groups.values ()
737
738def test_hyp_group (rep, group, detail = None):
739	imps = [(hyps, hyp) for (hyps, hyp, _) in group]
740	names = set ([name for (_, _, name) in group])
741
742	trace ('Testing group of hyps: %s' % list (names), push = 1)
743	(res, i, res_kind) = rep.test_hyp_imps (imps)
744	trace ('Group result: %r' % res, push = -1)
745	if res:
746		return (res, None)
747	else:
748		if detail:
749			detail[0] = res_kind
750		return (res, group[i])
751
752def failed_test_sets (p, checks):
753	failed = []
754	sets = {}
755	for (hyps, hyp, name) in checks:
756		sets.setdefault (name, [])
757		sets[name].append ((hyps, hyp))
758	for name in sets:
759		rep = rep_graph.mk_graph_slice (p)
760		(res, _, _) = rep.test_hyp_imps (sets[name])
761		if not res:
762			failed.append (name)
763	return failed
764
765save_checked_proofs = [None]
766
767def check_proof (p, proof, use_rep = None):
768	checks = proof_checks (p, proof)
769	groups = proof_check_groups (checks)
770
771	for group in groups:
772		if use_rep == None:
773			rep = rep_graph.mk_graph_slice (p)
774		else:
775			rep = use_rep
776
777		detail = [0]
778		(verdict, elt) = test_hyp_group (rep, group, detail)
779		if verdict:
780			continue
781		(hyps, hyp, name) = elt
782		last_failed_check[0] = elt
783		trace ('%s: proof failed!' % name)
784		trace ('  (failure kind: %r)' % detail[0])
785		return False
786	if save_checked_proofs[0]:
787		save = save_checked_proofs[0]
788		save (p, proof)
789	return True
790
791def pretty_vseq ((split, (seq_start, seq_step), _)):
792	if (seq_start, seq_step) == (0, 1):
793		return 'visits to %d' % split
794	else:
795		i = seq_start + 1
796		j = i + seq_step
797		k = j + seq_step
798		return 'visits [%d, %d, %d ...] to %d' % (i, j, k, split)
799
800def next_induct_var (n):
801	s = 'ijkabc'
802	v = s[n % 6]
803	if n >= 6:
804		v += str ((n / 6) + 1)
805	return v
806
807def pretty_lambda (t):
808	v = syntax.mk_var ('#seq-visits', word32T)
809	t = logic.var_subst (t, {('%i', word32T) : v}, must_subst = False)
810	return syntax.pretty_expr (t, print_type = True)
811
812def check_proof_report_rec (p, restrs, hyps, proof, step_num, ctxt, inducts,
813		do_check = True):
814	printout ('Step %d: %s' % (step_num, ctxt))
815	if proof.kind == 'Restr':
816		(kind, (x, y)) = proof.restr_range
817		if kind == 'Offset':
818			v = inducts[1][proof.point]
819			rexpr = '{%s + %s ..< %s + %s}' % (v, x, v, y)
820		else:
821			rexpr = '{%s ..< %s}' % (x, y)
822		printout ('  Prove the number of visits to %d is in %s'
823			% (proof.point, rexpr))
824
825		checks = proof_restr_checks (proof.point, proof.restr_range,
826			p, restrs, hyps)
827		cases = ['']
828	elif proof.kind == 'SingleRevInduct':
829		printout ('  Proving a predicate by future induction.')
830		(eqs, n) = proof.eqs_proof
831		point = proof.point
832		printout ('    proving these invariants by %d-induction' % n)
833		for x in eqs:
834			printout ('      %s (@ addr %s)'
835				% (pretty_lambda (x), point))
836		printout ('    then establishing this predicate')
837		(pred, n_bound) = proof.rev_proof
838		printout ('      %s (@ addr %s)'
839			% (pretty_lambda (pred), point))
840		printout ('    at large iterations (%d) and by back induction.'
841			% n_bound)
842		cases = ['']
843		checks = all_rev_induct_checks (p, restrs, hyps, point,
844			proof.eqs_proof, proof.rev_proof)
845	elif proof.kind == 'Split':
846		(l_dts, r_dts, eqs, n, lrmx) = proof.split
847		v = next_induct_var (inducts[0])
848		inducts = (inducts[0] + 1, dict (inducts[1]))
849		inducts[1][l_dts[0]] = v
850		inducts[1][r_dts[0]] = v
851		printout ('  prove %s related to %s' % (pretty_vseq (l_dts),
852			pretty_vseq (r_dts)))
853		printout ('    with equalities')
854		for (x, y) in eqs:
855			printout ('      %s (@ addr %s)' % (pretty_lambda (x),
856				l_dts[0]))
857			printout ('      = %s (@ addr %s)' % (pretty_lambda (y),
858				r_dts[0]))
859		printout ('    and with invariants')
860		for x in l_dts[2]:
861			printout ('      %s (@ addr %s)'
862				% (pretty_lambda (x), l_dts[0]))
863		for x in r_dts[2]:
864			printout ('      %s (@ addr %s)'
865				% (pretty_lambda (x), r_dts[0]))
866		checks = split_checks (p, restrs, hyps, proof.split)
867		cases = ['case in (%d) where the length of the sequence < %d'
868				% (step_num, n),
869			'case in (%d) where the length of the sequence is %s + %s'
870				% (step_num, v, n)]
871	elif proof.kind == 'Leaf':
872		printout ('  prove all verification conditions')
873		checks = leaf_condition_checks (p, restrs, hyps)
874		cases = []
875	elif proof.kind == 'CaseSplit':
876		printout ('  case split on whether %d is visited' % proof.point)
877		checks = []
878		cases = ['case in (%d) where %d is visited' % (step_num, proof.point),
879			'case in (%d) where %d is not visited' % (step_num, proof.point)]
880
881	if checks and do_check:
882		groups = proof_check_groups (checks)
883		for group in groups:
884			rep = rep_graph.mk_graph_slice (p)
885			detail = [0]
886			(res, _) = test_hyp_group (rep, group, detail)
887			if not res:
888				printout ('    .. failed to prove this.')
889				printout ('      (failure kind: %r)' % detail[0])
890				return
891
892		printout ('    .. proven.')
893
894	subproblems = proof_subproblems (p, proof.kind,
895		proof.args, restrs, hyps, '')
896	xs = logic.azip (subproblems, proof.subproofs)
897	xs = logic.azip (xs, cases)
898	step_num += 1
899	for ((subprob, subproof), case) in xs:
900		(restrs, hyps, _) = subprob
901		res = check_proof_report_rec (p, restrs, hyps, subproof,
902			step_num, case, inducts, do_check = do_check)
903		if not res:
904			return
905		(step_num, induct_var_num) = res
906		inducts = (induct_var_num, inducts[1])
907	return (step_num, inducts[0])
908
909def check_proof_report (p, proof, do_check = True):
910	res = check_proof_report_rec (p, (), init_point_hyps (p), proof,
911		1, '', (0, {}), do_check = do_check)
912	return bool (res)
913
914def save_proofs_to_file (fname, mode = 'w'):
915	assert mode in ['w', 'a']
916	f = open (fname, mode)
917
918	def save (p, proof):
919		f.write ('ProblemProof (%s) {\n' % p.name)
920		for s in p.serialise ():
921			f.write (s + '\n')
922		ss = []
923		proof.serialise (p, ss)
924		f.write (' '.join (ss))
925		f.write ('\n}\n')
926		f.flush ()
927	return save
928
929def load_proofs_from_file (fname):
930	f = open (fname)
931
932	proofs = {}
933	lines = None
934	for line in f:
935		line = line.strip ()
936		if line.startswith ('ProblemProof'):
937			assert line.endswith ('{'), line
938			name_bit = line[len ('ProblemProof') : -1].strip ()
939			assert name_bit.startswith ('('), name_bit
940			assert name_bit.endswith (')'), name_bit
941			name = name_bit[1:-1]
942			lines = []
943		elif line == '}':
944			assert lines[0] == 'Problem'
945			assert lines[-2] == 'EndProblem'
946			import problem
947			trace ('loading proof from %d lines' % len (lines))
948			p = problem.deserialise (name, lines[:-1])
949			proof = deserialise (lines[-1])
950			proofs.setdefault (name, [])
951			proofs[name].append ((p, proof))
952			trace ('loaded proof %s' % name)
953			lines = None
954		elif line.startswith ('#'):
955			pass
956		elif line:
957			lines.append (line)
958	assert not lines
959	return proofs
960
961