1#
2# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3#
4# SPDX-License-Identifier: BSD-2-Clause
5#
6
7from target_objects import functions, pairings
8import target_objects
9from problem import Problem
10import problem
11import logic
12import syntax
13import solver
14import search
15import rep_graph
16import check
17
18import random
19
20def check_entry_var_deps (f):
21	if not f.entry:
22		return set ()
23	p = f.as_problem (Problem)
24	diff = check_problem_entry_var_deps (p)
25
26	return diff
27
28def check_problem_entry_var_deps (p, var_deps = None):
29	if var_deps == None:
30		var_deps = p.compute_var_dependencies ()
31	for (entry, tag, _, inputs) in p.entries:
32		if entry not in var_deps:
33			print 'Entry missing from var_deps: %d' % entry
34			continue
35		diff = set (var_deps[entry]) - set (inputs)
36		if diff:
37			print 'Vars deps escaped in %s in %s: %s' % (tag,
38				p.name, diff)
39			return diff
40	return set ()
41
42def check_all_var_deps ():
43	return [f for f in functions if check_entry_var_deps(functions[f])]
44
45def walk_var_deps (p, n, v, var_deps = None,
46			interest = set (), symmetric = False):
47	if var_deps == None:
48		var_deps = p.compute_var_dependencies ()
49	while True:
50		if n == 'Ret' or n == 'Err':
51			print n
52			return n
53		if symmetric:
54			opts = set ([n2 for n2 in p.preds[n] if n2 in p.nodes])
55		else:
56			opts = set ([n2 for n2 in p.nodes[n].get_conts ()
57				if n2 in p.nodes])
58		choices = [n2 for n2 in opts if v in var_deps[n2]]
59		if not choices:
60			print 'Walk ends at %d.' % n
61			return
62		if len (choices) > 1:
63			print 'choices %s, gambling' % choices
64			random.shuffle (choices)
65			print ' ... rolled a %s' % choices[0]
66		elif len (opts) > 1:
67			print 'picked %s from %s' % (choices[0], opts)
68		n = choices[0]
69		if n in interest:
70			print '** %d' % n
71		else:
72			print n
73
74def diagram_var_deps (p, fname, v, var_deps = None):
75	if var_deps == None:
76		var_deps = p.compute_var_dependencies ()
77	cols = {}
78	for n in p.nodes:
79		if n not in var_deps:
80			cols[n] = 'darkgrey'
81		elif v not in var_deps[n]:
82			cols[n] = 'darkblue'
83		else:
84			cols[n] = 'orange'
85	problem.save_graph (p.nodes, fname, cols = cols)
86
87def trace_model (rep, m, simplify = True):
88	p = rep.p
89	tags = set ([tag for (tag, n, vc) in rep.node_pc_env_order])
90	if p.pairing and tags == set (p.pairing.tags):
91		tags = reversed (p.pairing.tags)
92	for tag in tags:
93		print "Walking %s in model" % tag
94		n_vcs = walk_model (rep, tag, m)
95		prev_era = None
96		for (i, (n, vc)) in enumerate (n_vcs):
97			era = n_vc_era (p, (n, vc))
98			if era != prev_era:
99				print 'now in era %s' % era
100			prev_era = era
101			if n in ['Ret', 'Err']:
102				print 'ends at %s' % n
103				break
104			node = logic.simplify_node_elementary (p.nodes[n])
105			if node.kind != 'Cond':
106				continue
107			name = rep.cond_name ((n, vc))
108			cond = m[name] == syntax.true_term
109			print '%s: %s (%s, %s)' % (name, cond,
110				node.left, node.right)
111			investigate_cond (rep, m, name, simplify)
112
113def walk_model (rep, tag, m):
114	n_vcs = [(n, vc) for (tag2, n, vc) in rep.node_pc_env_order
115		if tag2 == tag
116		if search.eval_model_expr (m, rep.solv,
117				rep.get_pc ((n, vc), tag))
118			== syntax.true_term]
119
120	n_vcs = era_sort (rep, n_vcs)
121
122	return n_vcs
123
124def investigate_cond (rep, m, cond, simplify = True, rec = True):
125	cond_def = rep.solv.defs[cond]
126	while rec and type (cond_def) == str and cond_def in rep.solv.defs:
127		cond_def = rep.solv.defs[cond_def]
128	def do_bit (bit):
129		if bit == 'true':
130			return True
131		valid = eval_model_bool (m, bit)
132		if simplify:
133			# looks a bit strange to do this now but some pointer
134			# lookups have to be done with unmodified s-exprs
135			bit = simplify_sexp (bit, rep, m, flatten = False)
136		print '  %s: %s' % (valid, solver.flat_s_expression (bit))
137		return valid
138	while cond_def[0] == '=>':
139		valid = do_bit (cond_def[1])
140		if not valid:
141			break
142		cond_def = cond_def[2]
143	bits = solver.split_hyp_sexpr (cond_def, [])
144	for bit in bits:
145		do_bit (bit)
146
147def eval_model_bool (m, x):
148	if hasattr (x, 'typ'):
149		x = solver.smt_expr (x, {}, None)
150		x = solver.parse_s_expression (x)
151	try:
152		r = search.eval_model (m, x)
153		assert r in [syntax.true_term, syntax.false_term], r
154		return r == syntax.true_term
155	except:
156		return 'EXCEPT'
157
158def funcall_name (rep):
159	return lambda n_vc: "%s @%s" % (rep.p.nodes[n_vc[0]].fname,
160		rep.node_count_name (n_vc))
161
162def n_vc_era (p, (n, vc)):
163	era = 0
164	for (split, vcount) in vc:
165		if not p.loop_id (split):
166			continue
167		(ns, os) = vcount.get_opts ()
168		if len (ns + os) > 1:
169			era += 3
170		elif ns:
171			era += 1
172		elif os:
173			era += 2
174	return era
175
176def era_merge (era):
177	# fold onramp to loops into pre-loop era
178	if era % 3 == 1:
179		era -= 1
180	return era
181
182def do_era_merge (do_merge, era):
183	if do_merge:
184		return era_merge (era)
185	else:
186		return era
187
188def era_sort (rep, n_vcs):
189	with_eras = [(n_vc_era (rep.p, n_vc), n_vc) for n_vc in n_vcs]
190	with_eras.sort (key = lambda x: x[0])
191	for i in range (len (with_eras) - 1):
192		(e1, n_vc1) = with_eras[i]
193		(e2, n_vc2) = with_eras[i + 1]
194		if e1 != e2:
195			continue
196		if n_vc1[0] in ['Ret', 'Err']:
197			assert not 'Era issues', n_vcs
198		assert rep.is_cont (n_vc1, n_vc2), [n_vc1, n_vc2]
199	return [n_vc for (_, n_vc) in with_eras]
200
201def investigate_funcalls (rep, m, verbose = False, verbose_imp = False,
202		simplify = True, pairing = 'Args', era_merge = True):
203	l_tag, r_tag = rep.p.pairing.tags
204	l_ns = walk_model (rep, l_tag, m)
205	r_ns = walk_model (rep, r_tag, m)
206	nodes = rep.p.nodes
207
208	l_calls = [n_vc for n_vc in l_ns if n_vc in rep.funcs]
209	r_calls = [n_vc for n_vc in r_ns if n_vc in rep.funcs]
210	print '%s calls: %s' % (l_tag, map (funcall_name (rep), l_calls))
211	print '%s calls: %s' % (r_tag, map (funcall_name (rep), r_calls))
212
213	if pairing == 'Eras':
214		fc_pairs = pair_funcalls_by_era (rep, l_calls, r_calls,
215			era_m = era_merge)
216	elif pairing == 'Seq':
217		fc_pairs = pair_funcalls_sequential (rep, l_calls, r_calls)
218	elif pairing == 'Args':
219		fc_pairs = pair_funcalls_by_match (rep, m, l_calls, r_calls,
220			era_m = era_merge)
221	elif pairing == 'All':
222		fc_pairs = [(lc, rc) for lc in l_calls for rc in r_calls]
223	else:
224		assert pairing in ['Eras', 'Seq', 'Args', 'All'], pairing
225
226	for (l_n_vc, r_n_vc) in fc_pairs:
227		if not rep.get_func_pairing (l_n_vc, r_n_vc):
228			print 'call seq mismatch: (%s, %s)' % (l_n_vc, r_n_vc)
229			continue
230		investigate_funcall_pair (rep, m, l_n_vc, r_n_vc,
231			verbose, verbose_imp, simplify)
232
233def pair_funcalls_by_era (rep, l_calls, r_calls, era_m = True):
234	eras = set ([n_vc_era (rep.p, n_vc) for n_vc in l_calls + r_calls])
235	eras = sorted (eras + set (map (era_merge, eras)))
236	pairs = []
237	for era in eras:
238		ls = [n_vc for n_vc in l_calls
239			if do_era_merge (era_m, n_vc_era (rep.p, n_vc)) == era]
240		rs = [n_vc for n_vc in r_calls
241			if do_era_merge (era_m, n_vc_era (rep.p, n_vc)) == era]
242		if len (ls) != len (rs):
243			print 'call seq length mismatch in era %d:' % era
244			print map (funcall_name (rep), ls)
245			print map (funcall_name (rep), rs)
246		pairs.extend (zip (ls, rs))
247	return pairs
248
249def pair_funcalls_sequential (rep, l_calls, r_calls):
250	if len (l_calls) != len (r_calls):
251		print 'call seq tail mismatch'
252		if len (l_calls) > len (r_calls):
253			print 'dropping lhs: %s' % map (funcall_name (rep),
254				l_calls[len (r_calls):])
255		else:
256			print 'dropping rhs: %s' % map (funcall_name (rep),
257				r_calls[len (l_calls):])
258	# really should add some smarts to this to 'recover' from upsets or
259	# reorders, but maybe not worth it.
260	return zip (l_calls, r_calls)
261
262def pair_funcalls_by_match (rep, m, l_calls, r_calls, era_m = True):
263	eras = set ([n_vc_era (rep.p, n_vc) for n_vc in l_calls + r_calls])
264	eras = sorted (set.union (eras, set (map (era_merge, eras))))
265	pairs = []
266	for era in eras:
267		ls = [n_vc for n_vc in l_calls
268			if do_era_merge (era_m, n_vc_era (rep.p, n_vc)) == era]
269		rs = [n_vc for n_vc in r_calls
270			if do_era_merge (era_m, n_vc_era (rep.p, n_vc)) == era]
271		res = None
272		matches = [(1 - func_assert_premise_strength (rep, m,
273				n_vc, n_vc2), i, j)
274			for (i, n_vc) in enumerate (ls)
275			for (j, n_vc2) in enumerate (rs)
276			if rep.get_func_pairing (n_vc, n_vc2)]
277		matches.sort ()
278		if not matches:
279			print 'Cannot match any (%d, %d) at era %d' % (len (ls),
280				len (rs), era)
281			continue
282		(_, i, j) = matches[0]
283		if i > j:
284			pairs.extend ((zip (ls[i - j:], rs)))
285		else:
286			pairs.extend ((zip (ls, rs[j - i:])))
287	return pairs
288
289def func_assert_premise_strength (rep, m, l_n_vc, r_n_vc):
290	imp = rep.get_func_assert (l_n_vc, r_n_vc)
291	assert imp.is_op ('Implies'), imp
292	[pred, concl] = imp.vals
293	pred = solver.smt_expr (pred, {}, rep.solv)
294	pred = solver.parse_s_expression (pred)
295	bits = solver.split_hyp_sexpr (pred, [])
296	assert bits, bits
297	scores = []
298	for bit in bits:
299		try:
300			res = eval_model_bool (m, bit)
301			if res:
302				scores.append (1.0)
303			else:
304				scores.append (0.0)
305		except solver.EnvMiss, e:
306			scores.append (0.5)
307		except AssertionError, e:
308			scores.append (0.5)
309	return sum (scores) / len (scores)
310	return all ([eval_model_bool (m, v) for v in bits])
311
312def investigate_funcall_pair (rep, m, l_n_vc, r_n_vc,
313		verbose = False, verbose_imp = False, simplify = True):
314
315	l_nm = "%s @ %s" % (rep.p.nodes[l_n_vc[0]].fname, rep.node_count_name (l_n_vc))
316	r_nm = "%s @ %s" % (rep.p.nodes[r_n_vc[0]].fname, rep.node_count_name (r_n_vc))
317	print 'Attempt match %s -> %s' % (l_nm, r_nm)
318	imp = rep.get_func_assert (l_n_vc, r_n_vc)
319	imp = logic.weaken_assert (imp)
320	if verbose_imp:
321		imp2 = solver.smt_expr (imp, {}, rep.solv)
322		if simplify:
323			imp2 = simplify_sexp (imp2, rep, m)
324		print imp2
325	assert imp.is_op ('Implies'), imp
326	[pred, concl] = imp.vals
327	pred = solver.smt_expr (pred, {}, rep.solv)
328	pred = solver.parse_s_expression (pred)
329	bits = solver.split_hyp_sexpr (pred, [])
330	xs = [eval_model_bool (m, v) for v in bits]
331	print '  %s' % xs
332	for (v, bit) in zip (xs, bits):
333		if v != True or verbose:
334			print '  %s: %s' % (v, bit)
335			if bit[0] == 'word32-eq':
336				vs = [model_sx_word (m, x)
337					for x in bit[1:]]
338				print '    (%s = %s)' % tuple (vs)
339
340def model_sx_word (m, sx):
341	v = search.eval_model (m, sx)
342	x = expr_num (v)
343	return solver.smt_num_t (x, v.typ)
344
345def expr_num (expr):
346	assert expr.typ.kind == 'Word'
347	return expr.val & ((1 << expr.typ.num) - 1)
348
349def str_to_num (smt_str):
350	v = solver.smt_to_val(smt_str)
351	return expr_num (v)
352
353def m_var_name (expr):
354	while expr.is_op ('MemUpdate'):
355		[expr, p, v] = expr.vals
356	if expr.kind == 'Var':
357		return expr.name
358	elif expr.kind == 'Op':
359		return '<Op %s>' % op.name
360	else:
361		return '<Expr %s>' % expr.kind
362
363def eval_str (expr, env, solv, m):
364	expr = solver.to_smt_expr (expr, env, solv)
365	v = search.eval_model_expr (m, solv, expr)
366	if v.typ == syntax.boolT:
367		assert v in [syntax.true_term, syntax.false_term]
368		return v.name
369	elif v.typ.kind == 'Word':
370		return solver.smt_num_t (v.val, v.typ)
371	else:
372		assert not 'type printable', v
373
374def trace_mem (rep, tag, m, verbose = False, simplify = True, symbs = True,
375		resolve_addrs = False):
376	p = rep.p
377	ns = walk_model (rep, tag, m)
378	trace = []
379	for (n, vc) in ns:
380		if (n, vc) not in rep.arc_pc_envs:
381			# this n_vc has a pre-state, but has not been emitted.
382			# no point trying to evaluate its expressions, the
383			# solve won't have seen them yet.
384			continue
385		n_nm = rep.node_count_name ((n, vc))
386		node = p.nodes[n]
387		if node.kind == 'Call':
388			exprs = list (node.args)
389		elif node.kind == 'Basic':
390			exprs = [expr for (_, expr) in node.upds]
391		elif node.kind == 'Cond':
392			exprs = [node.cond]
393		env = rep.node_pc_envs[(tag, n, vc)][1]
394		accs = list (set ([acc for expr in exprs
395			for acc in expr.get_mem_accesses ()]))
396		for (kind, addr, v, mem) in accs:
397			addr_s = solver.smt_expr (addr, env, rep.solv)
398			v_s = solver.smt_expr (v, env, rep.solv)
399			addr = eval_str (addr, env, rep.solv, m)
400			v = eval_str (v, env, rep.solv, m)
401			m_nm = m_var_name (mem)
402			print '%s: %s @ <%s>   -- %s -- %s' % (kind, m_nm, addr, v, n_nm)
403			if simplify:
404				addr_s = simplify_sexp (addr_s, rep, m)
405				v_s = simplify_sexp (v_s, rep, m)
406			if verbose:
407				print '\t %s -- %s' % (addr_s, v_s)
408			if symbs:
409				addr_n = str_to_num (addr)
410				(hit_symbs, secs) = find_symbol (addr_n, output = False)
411				ss = hit_symbs + secs
412				if ss:
413					print '\t [%s]' % ', '.join (ss)
414		if resolve_addrs:
415			accs = [(kind, solver.to_smt_expr (addr, env, rep.solv),
416				solver.to_smt_expr (v, env, rep.solv), mem)
417				for (kind, addr, v, mem) in accs]
418		trace.extend ([(kind, addr, v, mem, n, vc)
419			for (kind, addr, v, mem) in accs])
420		if node.kind == 'Call':
421			msg = '<function call to %s at %s>' % (node.fname, n_nm)
422			print msg
423			trace.append (msg)
424	return trace
425
426def simplify_sexp (smt_xp, rep, m, flatten = True):
427	if type (smt_xp) == str:
428		smt_xp = solver.parse_s_expression (smt_xp)
429	if smt_xp[0] == 'ite':
430		(_, c, x, y) = smt_xp
431		if eval_model_bool (m, c):
432			return simplify_sexp (x, rep, m, flatten)
433		else:
434			return simplify_sexp (y, rep, m, flatten)
435	if type (smt_xp) == tuple:
436		smt_xp = tuple ([simplify_sexp (x, rep, m, False)
437			for x in smt_xp])
438	if flatten:
439		return solver.flat_s_expression (smt_xp)
440	else:
441		return smt_xp
442
443def trace_mems (rep, m, verbose = False, symbs = True, tags = None):
444	if tags == None:
445		if rep.p.pairing:
446			tags = reversed (rep.p.pairing.tags)
447		else:
448			tags = rep.p.tags ()
449	for tag in tags:
450		print '%s mem trace:' % tag
451		trace_mem (rep, tag, m, verbose = verbose, symbs = symbs)
452
453def trace_mems_diff (rep, m, tags = ['ASM', 'C']):
454	asms = trace_mem (rep, tags[0], m, resolve_addrs = True)
455	cs = trace_mem (rep, tags[1], m, resolve_addrs = True)
456	ev = lambda expr: eval_str (expr, {}, None, m)
457	c_upds = [(ev (addr), ev (v)) for (kind, addr, v, mem, _, _) in cs
458		if kind == 'MemUpdate']
459	asm_upds = [(ev (addr), ev (v)) for (kind, addr, v, mem, _, _) in asms
460		if kind == 'MemUpdate' and 'mem' in m_var_name (mem)]
461	c_upd_d = dict (c_upds)
462	asm_upd_d = dict (asm_upds)
463	addr_ord = [addr for (addr, _) in asm_upds] + [addr for (addr, _) in c_upds
464		if addr not in asm_upd_d]
465	mism = [addr for addr in addr_ord
466		if c_upd_d.get (addr) != asm_upd_d.get (addr)]
467	return (c_upd_d == asm_upd_d, mism, c_upds, asm_upds)
468
469def get_pv_type (pv):
470	assert pv.is_op (['PValid', 'PArrayValid'])
471	typ_v = pv.vals[1]
472	assert typ_v.kind == 'Type'
473	typ = typ_v.val
474	if pv.is_op ('PArrayValid'):
475		return ('PArrayValid', typ, pv.vals[3])
476	else:
477		return ('PValid', typ, None)
478
479def guess_pv (p, n, addr_expr):
480	vs = syntax.get_expr_var_set (addr_expr)
481	[pred] = p.preds[n]
482	pvs = []
483	def vis (expr):
484		if expr.is_op (['PValid', 'PArrayValid']):
485			pvs.append (expr)
486	p.nodes[pred].cond.visit (vis)
487	match_pvs = [pv for pv in pvs
488		if set.union (* [syntax.get_expr_var_set (v) for v in pv.vals[2:]])
489			== vs]
490	if len (match_pvs) > 1:
491		match_pvs = [pv for pv in match_pvs if pv.is_op ('PArrayValid')]
492	pv = match_pvs[0]
493	return pv
494
495def eval_pv_type (rep, (n, vc), m, data):
496	if data[0] == 'PValid':
497		return data
498	else:
499		(nm, typ, offs) = data
500		offs = rep.to_smt_expr (offs, (n, vc))
501		offs = search.eval_model_expr (m, rep.solv, offs)
502		return (nm, typ, offs)
503
504def trace_suspicious_mem (rep, m, tag = 'C'):
505	cs = trace_mem (rep, tag, m)
506	data = [(addr, search.eval_model_expr (m, rep.solv,
507			rep.to_smt_expr (addr, (n, vc))), (n, vc))
508		for (kind, addr, v, mem, n, vc) in cs]
509	addr_sets = {}
510	for (addr, addr_v, _) in data:
511		addr_sets.setdefault (addr_v, set ())
512		addr_sets[addr_v].add (addr)
513	dup_addrs = set ([addr_v for addr_v in addr_sets
514		if len (addr_sets[addr_v]) > 1])
515	data = [(addr, addr_v, guess_pv (rep.p, n, addr), (n, vc))
516		for (addr, addr_v, (n, vc)) in data
517		if addr_v in dup_addrs]
518	data = [(addr, addr_v, eval_pv_type (rep, (n, vc), m,
519			get_pv_type (pv)), rep.to_smt_expr (pv, (n, vc)), n)
520		for (addr, addr_v, pv, (n, vc)) in data]
521	dup_addr_types = set ([addr_v for addr_v in dup_addrs
522		if len (set ([t for (_, addr_v2, t, _, _) in data
523			if addr_v2 == addr_v])) > 1])
524	res = [(addr_v, [(t, pv, n) for (_, addr_v2, t, pv, n) in data
525			if addr_v2 == addr_v])
526		for addr_v in dup_addr_types]
527	for (addr_v, insts) in res:
528		print 'Address %s' % addr_v
529		for (t, pv, n) in insts:
530			print '  -- accessed with type %s at %s' % (t, n)
531			print '    (covered by %s)' % pv
532	return res
533
534def trace_var (rep, tag, m, v):
535	p = rep.p
536	ns = walk_model (rep, tag, m)
537	vds = rep.p.compute_var_dependencies ()
538	trace = []
539	vs = syntax.get_expr_var_set (v)
540	def fetch ((n, vc)):
541		if n in vds and [(nm, typ) for (nm, typ) in vs
542				if (nm, typ) not in vds[n]]:
543			return None
544		try:
545			(_, env) = rep.get_node_pc_env ((n, vc), tag)
546			s = solver.smt_expr (v, env, rep.solv)
547			s_x = solver.parse_s_expression (s)
548			ev = search.eval_model (m, s_x)
549			return (s, solver.smt_expr (ev, {}, None))
550		except solver.EnvMiss, e:
551			return None
552		except AssertionError, e:
553			return None
554	val = None
555	for (n, vc) in ns:
556		n_nm = rep.node_count_name ((n, vc))
557		val2 = fetch ((n, vc))
558		if val2 != val:
559			if val2 == None:
560				print 'at %s: undefined' % n_nm
561			else:
562				print 'at %s:\t\t%s:\t\t%s' % (n_nm,
563					val2[0], val2[1])
564			val = val2
565			trace.append (((n, vc), val))
566		if n not in p.nodes:
567			break
568		node = p.nodes[n]
569		if node.kind == 'Call':
570			msg = '<function call to %s at %s>' % (node.fname,
571				rep.node_count_name ((n, vc)))
572			print msg
573			trace.append (msg)
574	return trace
575
576def trace_deriv_ops (rep, m, tag):
577	n_vcs = walk_model (rep, tag, m)
578	derivs = set (('CountTrailingZeroes', 'CountLeadingZeroes',
579		'WordReverse'))
580	def get_derivs (node):
581		dvs = set ()
582		def visit (expr):
583			if expr.is_op (derivs):
584				dvs.add (expr)
585		node.visit (lambda x: (), visit)
586		return dvs
587	for (n, vc) in n_vcs:
588		if n not in rep.p.nodes:
589			continue
590		dvs = get_derivs (rep.p.nodes[n])
591		if not dvs:
592			continue
593		print '%s:' % (rep.node_count_name ((n, vc)))
594		for dv in dvs:
595			[x] = dv.vals
596			x = rep.to_smt_expr (x, (n, vc))
597			x = eval_str (x, {}, rep.solv, m)
598			print '\t%s: %s' % (dv.name, x)
599
600def check_pairings ():
601	for p in pairings.itervalues ():
602		print p['C'], p['ASM']
603		as_args = functions[p['ASM']].inputs
604		c_args = functions[p['C']].inputs
605		print as_args, c_args
606		logic.mk_fun_inp_eqs (as_args, c_args, True)
607
608def loop_var_deps (p):
609	return [(n, [v for v in p.var_deps[n]
610			if p.var_deps[n][v] == 'LoopVariable'])
611		for n in p.loop_data]
612
613def find_symbol (n, output = True):
614	from target_objects import symbols, sections
615	symbs = []
616	secs = []
617	if output:
618		def p (s):
619			print s
620	else:
621		p = lambda s: ()
622	for (s, (addr, size, _)) in symbols.iteritems ():
623		if addr <= n and n < addr + size:
624			symbs.append (s)
625			p ('%x in %s (%x - %x)' % (n, s, addr, addr + size - 1))
626	for (s, (start, end)) in sections.iteritems ():
627		if start <= n and n <= end:
628			secs.append (s)
629			p ('%x in section %s (%x - %x)' % (n, s, start, end))
630	return (symbs, secs)
631
632def assembly_point (p, n):
633	(_, hints) = p.node_tags[n]
634	if type (hints) != tuple or not logic.is_int (hints[1]):
635		return None
636	while p.node_tags[n][1][1] % 4 != 0:
637		[n] = p.preds[n]
638	return p.node_tags[n][1][1]
639
640def assembly_points (p, ns):
641	ns = [assembly_point (p, n) for n in ns]
642	ns = [n for n in ns if n != None]
643	return ns
644
645def disassembly_lines (addrs):
646	f = open ('%s/kernel.elf.txt' % target_objects.target_dir)
647	addr_set = set (['%x' % addr for addr in addrs])
648	ss = [l.strip ()
649		for l in f if ':' in l and l.split(':', 1)[0] in addr_set]
650	return ss
651
652def disassembly (p, n):
653	if hasattr (n, '__iter__'):
654		ns = set (n)
655	else:
656		ns = [n]
657	addrs = sorted (set ([assembly_point (p, n) for n in ns])
658		- set ([None]))
659	print 'asm %s' % ', '.join (['0x%x' % addr for addr in addrs])
660	for s in disassembly_lines (addrs):
661		print s
662
663def disassembly_loop (p, n):
664	head = p.loop_id (n)
665	loop = p.loop_body (n)
666	ns = sorted (set (assembly_points (p, loop)))
667	entries = assembly_points (p, [n for n in p.preds[head]
668		if n not in loop])
669	print 'Loop: [%s]' % ', '.join (['%x' % addr for addr in ns])
670	for s in disassembly_lines (ns):
671		print s
672	print 'entry from %s' % ', '.join (['%x' % addr for addr in entries])
673	for s in disassembly_lines (entries):
674		print s
675
676def try_interpret_hyp (rep, hyp):
677	try:
678		expr = rep.interpret_hyp (hyp)
679		solver.smt_expr (expr, {}, rep.solv)
680		return None
681	except:
682		return ('Broken Hyp', hyp)
683
684def check_checks ():
685	p = problem.last_problem[0]
686	rep = rep_graph.mk_graph_slice (p)
687	proof = search.last_proof[0]
688	checks = check.proof_checks (p, proof)
689	all_hyps = set ([hyp for (_, hyp, _) in checks]
690		+ [hyp for (hyps, _, _) in checks for hyp in hyps])
691	results = [try_interpret_hyp (rep, hyp) for hyp in all_hyps]
692	return [r[1] for r in results if r]
693
694def proof_failed_groups (p = None, proof = None):
695	if p == None:
696		p = problem.last_problem[0]
697	if proof == None:
698		proof = search.last_proof[0]
699	checks = check.proof_checks (p, proof)
700	groups = check.proof_check_groups (checks)
701	failed = []
702	for group in groups:
703		rep = rep_graph.mk_graph_slice (p)
704		(res, el) = check.test_hyp_group (rep, group)
705		if not res:
706			failed.append (group)
707			print 'Failed element: %s' % el
708	failed_nms = set ([s for group in failed for (_, _, s) in group])
709	print 'Failed: %s' % failed_nms
710	return failed
711
712def read_summary (f):
713	results = {}
714	times = {}
715	for line in f:
716		if not line.startswith ('Time taken to'):
717			continue
718		bits = line.split ()
719		assert bits[:4] == ['Time', 'taken', 'to', 'check']
720		res = bits[4]
721		[ref] = [i for (i, b) in enumerate (bits) if b == '<=']
722		f = bits[ref + 1]
723		[pair] = [pair for pair in pairings[f]
724			if pair.name in line]
725		time = float (bits[-1])
726		results[pair] = res
727		times[pair] = time
728	return (results, times)
729
730def unfold_defs_sexpr (defs, sexpr, depthlimit = -1):
731	if type (sexpr) == str:
732		sexpr = defs.get (sexpr, sexpr)
733		print sexpr
734		return sexpr
735	elif depthlimit == 0:
736		return sexpr
737	return tuple ([sexpr[0]] + [unfold_defs_sexpr (defs, s, depthlimit - 1)
738		for s in sexpr[1:]])
739
740def unfold_defs (defs, hyp, depthlimit = -1):
741	return solver.flat_s_expression (unfold_defs_sexpr (defs,
742		solver.parse_s_expression (hyp), depthlimit))
743
744def investigate_unsat (solv, hyps = None):
745	if hyps == None:
746		hyps = list (solver.last_hyps[0])
747	assert solv.hyps_sat_raw (hyps) == 'unsat', hyps
748	kept_hyps = []
749	while hyps:
750		h = hyps.pop ()
751		if solv.hyps_sat_raw (hyps + kept_hyps) != 'unsat':
752			kept_hyps.append (h)
753	assert solv.hyps_sat_raw (kept_hyps) == 'unsat', kept_hyps
754	split_hyps = sorted (set ([(hyp2, tag) for (hyp, tag) in kept_hyps
755		for hyp2 in solver.split_hyp (hyp)]))
756	if len (split_hyps) > len (kept_hyps):
757		return investigate_unsat (solv, split_hyps)
758	def_hyps = [(unfold_defs (solv.defs, h, 2), tag)
759		for (h, tag) in kept_hyps]
760	if def_hyps != kept_hyps:
761		return investigate_unsat (solv, def_hyps)
762	return kept_hyps
763
764def test_interesting_linear_series_exprs ():
765	pairs = set ([pair for f in pairings for pair in pairings[f]])
766	notes = {}
767	for pair in pairs:
768		p = check.build_problem (pair)
769		for n in search.init_loops_to_split (p, ()):
770			intr = logic.interesting_linear_series_exprs (p, n,
771				search.get_loop_var_analysis_at (p, n))
772			if intr:
773				notes[pair.name] = True
774			if 'Call' in str (intr):
775				notes[pair.name] = 'Call!'
776	return notes
777
778def var_analysis (p, n):
779	va = search.get_loop_var_analysis_at (p, n)
780	cats = {}
781	for (v, kind) in va:
782		if kind[0] == 'LoopLinearSeries':
783			offs = kind[2]
784			kind = kind[0]
785		else:
786			offs = None
787		cats.setdefault (kind, [])
788		cats[kind].append ((v, offs))
789	for kind in cats:
790		print '%s:' % kind
791		for (v, offs) in cats[kind]:
792			print '  %s   (%s)' % (syntax.pretty_expr (v),
793				syntax.pretty_type (v.typ))
794			if offs:
795				print '      ++ %s' % syntax.pretty_expr (offs)
796
797def var_value_sites (rep, v):
798	if type (v) == str:
799		matches = lambda (nm, _): v in nm
800	elif type (v) == tuple:
801		matches = lambda (nm, typ): v == (nm, typ)
802	v_ord = []
803	d = {}
804	for (tag, n, vc) in rep.node_pc_env_order:
805		(pc, env) = rep.get_node_pc_env ((n, vc), tag = tag)
806		for (v2, smt_exp) in env.iteritems ():
807			if matches (v2):
808				if smt_exp not in d:
809					v_ord.append (smt_exp)
810					d[smt_exp] = []
811				d[smt_exp].append ((n, vc))
812	for smt_exp in v_ord:
813		print smt_exp
814		if smt_exp in rep.solv.defs:
815			print ('  = %s' % repr (rep.solv.defs[smt_exp]))
816		print ('  - at: %s' % d[smt_exp])
817	if v_ord:
818		print ('')
819	return (v_ord, d)
820
821def loop_num_leaves (p, n):
822	for n in p.loop_body (n):
823		va = search.get_loop_var_analysis_at (p, n)
824		n_leaf = len ([1 for (v, kind) in va if kind == 'LoopLeaf'])
825		print (n, n_leaf)
826
827def try_pairing_at_funcall (p, name, head = None, restrs = None, hyps = None,
828		at = 'At'):
829	pairs = set (pairings[name])
830	addrs = [n for (n, name2) in p.function_call_addrs ()
831		if [pair for pair in pairings[name2] if pair in pairs]]
832	assert at in ['At', 'After']
833	if at == 'After':
834		addrs = [p.nodes[n].cont for n in addrs]
835	if head == None:
836		tags = p.pairing.tags
837		[head] = [n for n in search.init_loops_to_split (p, ())
838			if p.node_tags[n][0] == tags[0]]
839	if restrs == None:
840		restrs = ()
841	if hyps == None:
842		hyps = check.init_point_hyps (p)
843	while True:
844		res = search.find_split_loop (p, head, restrs, hyps,
845			node_restrs = set (addrs))
846		if res[0] == 'CaseSplit':
847			(_, ((n, tag), _)) = res
848			hyp = rep_graph.pc_true_hyp (((n, restrs), tag))
849			hyps = hyps + [hyp]
850		else:
851			return res
852
853def init_true_hyp (p, tag, expr):
854	n = p.get_entry (tag)
855	vis = ((n, ()), tag)
856	assert expr.typ == syntax.boolT, expr
857	return rep_graph.eq_hyp ((expr, vis), (syntax.true_term, vis))
858
859def smt_print (expr):
860	env = {}
861	while True:
862		try:
863			return solver.smt_expr (expr, env, None)
864		except solver.EnvMiss, e:
865			env[(e.name, e.typ)] = e.name
866
867