1#
2# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3#
4# SPDX-License-Identifier: BSD-2-Clause
5#
6
7# toplevel graph-refine script
8# usage: python graph-refine.py <target> <proofs>
9
10import syntax
11import pseudo_compile
12import solver
13import rep_graph
14import problem
15import check
16import search
17import logic
18from target_objects import pairings, functions
19from target_objects import trace, tracer, printout
20import target_objects
21
22import re
23import random
24import traceback
25import time
26#import diagnostic
27
28import sys
29
30if __name__ == '__main__':
31	args = target_objects.load_target_args ()
32
33def toplevel_check (pair, check_loops = True, report = False, count = None,
34		only_build_problem = False):
35	if not only_build_problem:
36		printout ('Testing Function pair %s' % pair)
37	if count and not only_build_problem:
38		(i, n) = count
39		printout ('  (function pairing %d of %d)' % (i + 1, n))
40
41	for (tag, fname) in pair.funs.iteritems ():
42		if not functions[fname].entry:
43			printout ('Skipping %s, underspecified %s' % (pair, tag))
44			return 'None'
45	prev_tracer = tracer[0]
46	if report:
47		tracer[0] = lambda s, n: ()
48
49	exception = None
50
51	trace (time.asctime ())
52	start_time = time.time()
53	sys.stdout.flush ()
54	try:
55		p = check.build_problem (pair)
56		if only_build_problem:
57			tracer[0] = prev_tracer
58			return 'True'
59		if report:
60			printout (' .. built problem, finding proof')
61		if not check_loops and p.loop_data:
62			printout ('Problem has loop!')
63			tracer[0] = prev_tracer
64			return 'Loop'
65		if check_loops == 'only' and not p.loop_data:
66			printout ('No loop in problem.')
67			tracer[0] = prev_tracer
68			return 'NoLoop'
69		proof = search.build_proof (p)
70		if report:
71			printout (' .. proof found.')
72
73		try:
74			if report:
75				result = check.check_proof_report (p, proof)
76			else:
77				result = check.check_proof (p, proof)
78				if result:
79					printout ('Refinement proven.')
80				else:
81					printout ('Refinement NOT proven.')
82		except solver.SolverFailure, e:
83			printout ('Solver timeout/failure in proof check.')
84			result = 'CheckSolverFailure'
85		except Exception, e:
86			trace ('EXCEPTION in checking %s:' % p.name)
87			exception = sys.exc_info ()
88			result = 'CheckEXCEPT'
89
90	except problem.Abort:
91		result = 'ProblemAbort'
92	except search.NoSplit:
93		result = 'ProofNoSplit'
94	except solver.SolverFailure, e:
95		printout ('Solver timeout/failure in proof search.')
96		result = 'ProofSolverFailure'
97
98	except Exception, e:
99		trace ('EXCEPTION in handling %s:' % pair)
100		exception = sys.exc_info ()
101		result = 'ProofEXCEPT'
102
103	end_time = time.time ()
104	tracer[0] = prev_tracer
105	if exception:
106		(etype, evalue, tb) = exception
107		traceback.print_exception (etype, evalue, tb,
108			file = sys.stdout)
109
110	if not only_build_problem:
111		printout ('Result %s for pair %s, time taken: %.2fs'
112			% (result, pair, end_time - start_time))
113		sys.stdout.flush ()
114
115	return str (result)
116
117word_re = re.compile('\\w+')
118
119def name_search (s, tags = None):
120	if s in pairings and len (pairings[s]) == 1:
121		return pairings[s][0]
122	pairs = list (set ([pair for f in pairings for pair in pairings[f]
123		if s in pair.name
124		if not tags or tags.issubset (set (pair.tags))]))
125	if len (pairs) == 1:
126		return pairs[0]
127	word_pairs = [p for p in pairs if s in word_re.findall (str (p))]
128	if len (word_pairs) == 1:
129		return word_pairs[0]
130	print 'Possibilities for %r: %s' % (s, [str (p) for p in pairs])
131	return None
132
133def check_search (s, tags = None, report_mode = False,
134		check_loops = True):
135	pair = name_search (s, tags = tags)
136	if not pair:
137		return 'None'
138	else:
139		return toplevel_check (pair, report = report_mode,
140			check_loops = check_loops)
141
142def problem_search (s):
143	pair = name_search (s)
144	print pair.name
145	return check.build_problem (pair)
146
147# somewhat arbitrary assignment of return codes to outcomes.
148# larger numbers are (roughly) worse outcomes.
149# key categories are success, skipped (not in covered cases), and failure
150result_codes = {
151	'True' : (0, 'Success'),
152	'Loop' : (1, 'Skipped'),
153	'NoLoop' : (2, 'Skipped'),
154	'None' : (3, 'Skipped'),
155	'ProblemAbort' : (4, 'Skipped'),
156	'False': (5, 'Failed'),
157	'ProofNoSplit' : (6, 'Failed'),
158	'ProofSolverFailure' : (7, 'Failed'),
159	'ProofEXCEPT' : (8, 'Failed'),
160	'CheckSolverFailure' : (9, 'Failed'),
161	'CheckEXCEPT' : (10, 'Failed'),
162}
163
164def comb_results (r1, r2):
165	(_, r) = max ([(result_codes[r], r) for r in [r1, r2]])
166	return r
167
168def check_pairs (pairs, loops = True, report_mode = False,
169		only_build_problem = False):
170	num_pairs = len (pairs)
171	printout ('Checking %d function pair problems' % len (pairs))
172	results = [(pair, toplevel_check (pair, check_loops = loops,
173			report = report_mode, count = (i, num_pairs),
174			only_build_problem = only_build_problem))
175		for (i, pair) in enumerate (pairs)]
176	result_dict = logic.dict_list ([(result_codes[r][1], pair)
177		for (pair, r) in results])
178	if not only_build_problem:
179		printout ('Results: %s'
180			% [(pair.name, r) for (pair, r) in results])
181	printout ('Result summary:')
182	success = result_dict.get ('Success', [])
183	if only_build_problem:
184		printout ('  - %d problems build' % len (success))
185	else:
186		printout ('  - %d proofs checked' % len (success))
187	skipped = result_dict.get ('Skipped', [])
188	printout ('  - %d proofs skipped' % len (skipped))
189	fails = [pair.name for pair in result_dict.get ('Failed', [])]
190	print_coverage_report (set (skipped), set (success + fails))
191	printout ('  - failures: %s' % fails)
192	return syntax.foldr1 (comb_results, ['True']
193		+ [r for (_, r) in results])
194
195def print_coverage_report (skipped_pairs, covered_pairs):
196	try:
197		from trace_refute import addrs_covered, funs_sort_by_num_addrs
198		covered_fs = set ([f for pair in covered_pairs
199			for f in [pair.l_f, pair.r_f]])
200		coverage = addrs_covered (covered_fs)
201		printout ('  - %.2f%% instructions covered' % (coverage * 100))
202		skipped_fs = set ([f for pair in skipped_pairs
203			for f in [pair.l_f, pair.r_f]])
204		fs = funs_sort_by_num_addrs (set (skipped_fs))
205		if not fs:
206			return
207		lrg_msgs = ['%s (%.2f%%)' % (f, addrs_covered ([f]) * 100)
208			for f in reversed (fs[-3:])]
209		printout ('  - largest skipped functions:')
210		printout ('      %s' % ', '.join (lrg_msgs))
211	except Exception, e:
212		pass
213
214def check_all (omit_set = set (), loops = True, tags = None,
215		report_mode = False, only_build_problem = False):
216	pairs = list (set ([pair for f in pairings for pair in pairings[f]
217		if omit_set.isdisjoint (pair.funs.values ())
218		if not tags or tags.issubset (set (pair.tags))]))
219	omitted = list (set ([pair for f in pairings for pair in pairings[f]
220		if not omit_set.isdisjoint (pair.funs.values())]))
221	random.shuffle (pairs)
222	r = check_pairs (pairs, loops = loops, report_mode = report_mode,
223		only_build_problem = only_build_problem)
224	if omitted:
225		printout ('  - %d pairings omitted: %s'
226			% (len (omitted), [pair.name for pair in omitted]))
227	return r
228
229def check_division_pairs (num, denom, loops = True, report_mode = False):
230	pairs = list (set ([pair for f in pairings for pair in pairings[f]]))
231	def pair_size (pair):
232		return (len (functions[pair.l_f].nodes)
233			+ len (functions[pair.r_f].nodes))
234	pairs = sorted ([(pair_size (pair), pair) for pair in pairs])
235	division = [pairs[i][1] for i in range (num, len (pairs), denom)]
236	return check_pairs (division, loops = loops, report_mode = report_mode)
237
238def check_deps (fname, report_mode = False):
239	frontier = set ([fname])
240	funs = set ()
241	while frontier:
242		fname = frontier.pop ()
243		if fname in funs:
244			continue
245		funs.add (fname)
246		frontier.update (functions[fname].function_calls ())
247	funs = sorted (funs)
248	funs = [fun for fun in funs if fun in pairings]
249	printout ('Testing functions: %s' % funs)
250	pairs = [pair for f in funs for pair in pairings[f]]
251	return check_pairs (pairs, report_mode = report_mode)
252
253def save_compiled_funcs (fname):
254	out = open (fname, 'w')
255	for (f, func) in functions.iteritems ():
256		trace ('Saving %s' % f)
257		for s in func.serialise ():
258			out.write (s + '\n')
259	out.close ()
260
261def rerun_set (vs):
262	def get_strs (vs):
263		return [v for v in vs if type (v) == str] + [v2
264			for v in vs if type (v) != str for v2 in get_strs (v)]
265	vs = set (get_strs (vs))
266	pairs = map (name_search, vs)
267	strs = [pair.funs[pair.tags[0]] for pair in pairs if pair]
268	return ' '.join (strs)
269
270def main (args):
271	excluding = False
272	excludes = set ()
273	loops = True
274	tags = set ()
275	report = True
276	result = 'True'
277	pairs_to_check = []
278	for arg in args:
279		r = 'True'
280		try:
281			if arg == 'verbose':
282				report = False
283			elif arg.startswith ('trace-to:'):
284				(_, s) = arg.split (':', 1)
285				f = open (s, 'w')
286				target_objects.trace_files.append (f)
287			elif arg == 'all':
288				r = check_all (excludes, loops = loops,
289					tags = tags, report_mode = report)
290			elif arg == 'all_safe':
291				ex = set.union (excludes,
292					target_objects.danger_set)
293				r = check_all (ex, loops = loops,
294					tags = tags, report_mode = report)
295			elif arg == 'coverage':
296				r = check_all (excludes, loops = loops,
297					tags = tags, report_mode = report,
298					only_build_problem = True)
299			elif arg.startswith ('div:'):
300				[_, num, denom] = arg.split (':')
301				num = int (num)
302				denom = int (denom)
303				r = check_division_pairs (num, denom,
304					loops = loops, report_mode = report)
305			elif arg == 'no_loops':
306				loops = False
307			elif arg == 'only_loops':
308				loops = 'only'
309			elif arg.startswith('save:'):
310				save_compiled_funcs (arg[5:])
311			elif arg.startswith('save-proofs:'):
312				fname = arg[len ('save-proofs:') :]
313				save = check.save_proofs_to_file (fname, 'a')
314				check.save_checked_proofs[0] = save
315			elif arg == '-exclude':
316				excluding = True
317			elif arg == '-end-exclude':
318				excluding = False
319			elif arg.startswith ('t:'):
320				tags.add (arg[2:])
321			elif arg.startswith ('target:'):
322				pass
323			elif arg.startswith ('skip-proofs-of:'):
324				(_, fname) = arg.split(':', 1)
325				import stats
326				prev_proofs = stats.scan_proofs (open (fname))
327				prev_fs = [f for pair in prev_proofs
328					for f in pair.funs.values ()]
329				excludes.update (prev_fs)
330			elif excluding:
331				excludes.add (arg)
332			elif arg.startswith ('deps:'):
333				r = check_deps (arg[5:],
334					report_mode = report)
335			else:
336				r = name_search (arg, tags = tags)
337				if r != None:
338					pairs_to_check.append (r)
339					r = 'True'
340				else:
341					r = 'None'
342		except Exception, e:
343			print 'EXCEPTION in syscall arg %s:' % arg
344			print traceback.format_exc ()
345			r = 'ProofEXCEPT'
346		result = comb_results (r, result)
347	if pairs_to_check:
348		r = check_pairs (pairs_to_check, loops = loops,
349			report_mode = report)
350		result = comb_results (r, result)
351	return result
352
353if __name__ == '__main__':
354	result = main (args)
355	(code, category) = result_codes[result]
356	sys.exit (0)
357
358