1# * Copyright 2015, NICTA
2# *
3# * This software may be distributed and modified according to the terms of
4# * the BSD 2-Clause license. Note that NO WARRANTY is provided.
5# * See "LICENSE_BSD2.txt" for details.
6# *
7# * @TAG(NICTA_BSD)
8
9from target_objects import functions, pairings
10import target_objects
11from problem import Problem
12import problem
13import logic
14import syntax
15import solver
16import search
17import rep_graph
18import check
19
20import random
21
22def check_entry_var_deps (f):
23	if not f.entry:
24		return set ()
25	p = f.as_problem (Problem)
26	diff = check_problem_entry_var_deps (p)
27
28	return diff
29
30def check_problem_entry_var_deps (p, var_deps = None):
31	if var_deps == None:
32		var_deps = p.compute_var_dependencies ()
33	for (entry, tag, _, inputs) in p.entries:
34		if entry not in var_deps:
35			print 'Entry missing from var_deps: %d' % entry
36			continue
37		diff = set (var_deps[entry]) - set (inputs)
38		if diff:
39			print 'Vars deps escaped in %s in %s: %s' % (tag,
40				p.name, diff)
41			return diff
42	return set ()
43
44def check_all_var_deps ():
45	return [f for f in functions if check_entry_var_deps(functions[f])]
46
47def walk_var_deps (p, n, v, var_deps = None,
48			interest = set (), symmetric = False):
49	if var_deps == None:
50		var_deps = p.compute_var_dependencies ()
51	while True:
52		if n == 'Ret' or n == 'Err':
53			print n
54			return n
55		if symmetric:
56			opts = set ([n2 for n2 in p.preds[n] if n2 in p.nodes])
57		else:
58			opts = set ([n2 for n2 in p.nodes[n].get_conts ()
59				if n2 in p.nodes])
60		choices = [n2 for n2 in opts if v in var_deps[n2]]
61		if not choices:
62			print 'Walk ends at %d.' % n
63			return
64		if len (choices) > 1:
65			print 'choices %s, gambling' % choices
66			random.shuffle (choices)
67			print ' ... rolled a %s' % choices[0]
68		elif len (opts) > 1:
69			print 'picked %s from %s' % (choices[0], opts)
70		n = choices[0]
71		if n in interest:
72			print '** %d' % n
73		else:
74			print n
75
76def diagram_var_deps (p, fname, v, var_deps = None):
77	if var_deps == None:
78		var_deps = p.compute_var_dependencies ()
79	cols = {}
80	for n in p.nodes:
81		if n not in var_deps:
82			cols[n] = 'darkgrey'
83		elif v not in var_deps[n]:
84			cols[n] = 'darkblue'
85		else:
86			cols[n] = 'orange'
87	problem.save_graph (p.nodes, fname, cols = cols)
88
89def trace_model (rep, m, simplify = True):
90	p = rep.p
91	tags = set ([tag for (tag, n, vc) in rep.node_pc_env_order])
92	if p.pairing and tags == set (p.pairing.tags):
93		tags = reversed (p.pairing.tags)
94	for tag in tags:
95		print "Walking %s in model" % tag
96		n_vcs = walk_model (rep, tag, m)
97		prev_era = None
98		for (i, (n, vc)) in enumerate (n_vcs):
99			era = n_vc_era (p, (n, vc))
100			if era != prev_era:
101				print 'now in era %s' % era
102			prev_era = era
103			if n in ['Ret', 'Err']:
104				print 'ends at %s' % n
105				break
106			node = logic.simplify_node_elementary (p.nodes[n])
107			if node.kind != 'Cond':
108				continue
109			name = rep.cond_name ((n, vc))
110			cond = m[name] == syntax.true_term
111			print '%s: %s (%s, %s)' % (name, cond,
112				node.left, node.right)
113			investigate_cond (rep, m, name, simplify)
114
115def walk_model (rep, tag, m):
116	n_vcs = [(n, vc) for (tag2, n, vc) in rep.node_pc_env_order
117		if tag2 == tag
118		if search.eval_model_expr (m, rep.solv,
119				rep.get_pc ((n, vc), tag))
120			== syntax.true_term]
121
122	n_vcs = era_sort (rep, n_vcs)
123
124	return n_vcs
125
126def investigate_cond (rep, m, cond, simplify = True, rec = True):
127	cond_def = rep.solv.defs[cond]
128	while rec and type (cond_def) == str and cond_def in rep.solv.defs:
129		cond_def = rep.solv.defs[cond_def]
130	def do_bit (bit):
131		if bit == 'true':
132			return True
133		valid = eval_model_bool (m, bit)
134		if simplify:
135			# looks a bit strange to do this now but some pointer
136			# lookups have to be done with unmodified s-exprs
137			bit = simplify_sexp (bit, rep, m, flatten = False)
138		print '  %s: %s' % (valid, solver.flat_s_expression (bit))
139		return valid
140	while cond_def[0] == '=>':
141		valid = do_bit (cond_def[1])
142		if not valid:
143			break
144		cond_def = cond_def[2]
145	bits = solver.split_hyp_sexpr (cond_def, [])
146	for bit in bits:
147		do_bit (bit)
148
149def eval_model_bool (m, x):
150	if hasattr (x, 'typ'):
151		x = solver.smt_expr (x, {}, None)
152		x = solver.parse_s_expression (x)
153	try:
154		r = search.eval_model (m, x)
155		assert r in [syntax.true_term, syntax.false_term], r
156		return r == syntax.true_term
157	except:
158		return 'EXCEPT'
159
160def funcall_name (rep):
161	return lambda n_vc: "%s @%s" % (rep.p.nodes[n_vc[0]].fname,
162		rep.node_count_name (n_vc))
163
164def n_vc_era (p, (n, vc)):
165	era = 0
166	for (split, vcount) in vc:
167		if not p.loop_id (split):
168			continue
169		(ns, os) = vcount.get_opts ()
170		if len (ns + os) > 1:
171			era += 3
172		elif ns:
173			era += 1
174		elif os:
175			era += 2
176	return era
177
178def era_merge (era):
179	# fold onramp to loops into pre-loop era
180	if era % 3 == 1:
181		era -= 1
182	return era
183
184def do_era_merge (do_merge, era):
185	if do_merge:
186		return era_merge (era)
187	else:
188		return era
189
190def era_sort (rep, n_vcs):
191	with_eras = [(n_vc_era (rep.p, n_vc), n_vc) for n_vc in n_vcs]
192	with_eras.sort (key = lambda x: x[0])
193	for i in range (len (with_eras) - 1):
194		(e1, n_vc1) = with_eras[i]
195		(e2, n_vc2) = with_eras[i + 1]
196		if e1 != e2:
197			continue
198		if n_vc1[0] in ['Ret', 'Err']:
199			assert not 'Era issues', n_vcs
200		assert rep.is_cont (n_vc1, n_vc2), [n_vc1, n_vc2]
201	return [n_vc for (_, n_vc) in with_eras]
202
203def investigate_funcalls (rep, m, verbose = False, verbose_imp = False,
204		simplify = True, pairing = 'Args', era_merge = True):
205	l_tag, r_tag = rep.p.pairing.tags
206	l_ns = walk_model (rep, l_tag, m)
207	r_ns = walk_model (rep, r_tag, m)
208	nodes = rep.p.nodes
209
210	l_calls = [n_vc for n_vc in l_ns if n_vc in rep.funcs]
211	r_calls = [n_vc for n_vc in r_ns if n_vc in rep.funcs]
212	print '%s calls: %s' % (l_tag, map (funcall_name (rep), l_calls))
213	print '%s calls: %s' % (r_tag, map (funcall_name (rep), r_calls))
214
215	if pairing == 'Eras':
216		fc_pairs = pair_funcalls_by_era (rep, l_calls, r_calls,
217			era_m = era_merge)
218	elif pairing == 'Seq':
219		fc_pairs = pair_funcalls_sequential (rep, l_calls, r_calls)
220	elif pairing == 'Args':
221		fc_pairs = pair_funcalls_by_match (rep, m, l_calls, r_calls,
222			era_m = era_merge)
223	elif pairing == 'All':
224		fc_pairs = [(lc, rc) for lc in l_calls for rc in r_calls]
225	else:
226		assert pairing in ['Eras', 'Seq', 'Args', 'All'], pairing
227
228	for (l_n_vc, r_n_vc) in fc_pairs:
229		if not rep.get_func_pairing (l_n_vc, r_n_vc):
230			print 'call seq mismatch: (%s, %s)' % (l_n_vc, r_n_vc)
231			continue
232		investigate_funcall_pair (rep, m, l_n_vc, r_n_vc,
233			verbose, verbose_imp, simplify)
234
235def pair_funcalls_by_era (rep, l_calls, r_calls, era_m = True):
236	eras = set ([n_vc_era (rep.p, n_vc) for n_vc in l_calls + r_calls])
237	eras = sorted (eras + set (map (era_merge, eras)))
238	pairs = []
239	for era in eras:
240		ls = [n_vc for n_vc in l_calls
241			if do_era_merge (era_m, n_vc_era (rep.p, n_vc)) == era]
242		rs = [n_vc for n_vc in r_calls
243			if do_era_merge (era_m, n_vc_era (rep.p, n_vc)) == era]
244		if len (ls) != len (rs):
245			print 'call seq length mismatch in era %d:' % era
246			print map (funcall_name (rep), ls)
247			print map (funcall_name (rep), rs)
248		pairs.extend (zip (ls, rs))
249	return pairs
250
251def pair_funcalls_sequential (rep, l_calls, r_calls):
252	if len (l_calls) != len (r_calls):
253		print 'call seq tail mismatch'
254		if len (l_calls) > len (r_calls):
255			print 'dropping lhs: %s' % map (funcall_name (rep),
256				l_calls[len (r_calls):])
257		else:
258			print 'dropping rhs: %s' % map (funcall_name (rep),
259				r_calls[len (l_calls):])
260	# really should add some smarts to this to 'recover' from upsets or
261	# reorders, but maybe not worth it.
262	return zip (l_calls, r_calls)
263
264def pair_funcalls_by_match (rep, m, l_calls, r_calls, era_m = True):
265	eras = set ([n_vc_era (rep.p, n_vc) for n_vc in l_calls + r_calls])
266	eras = sorted (set.union (eras, set (map (era_merge, eras))))
267	pairs = []
268	for era in eras:
269		ls = [n_vc for n_vc in l_calls
270			if do_era_merge (era_m, n_vc_era (rep.p, n_vc)) == era]
271		rs = [n_vc for n_vc in r_calls
272			if do_era_merge (era_m, n_vc_era (rep.p, n_vc)) == era]
273		res = None
274		matches = [(1 - func_assert_premise_strength (rep, m,
275				n_vc, n_vc2), i, j)
276			for (i, n_vc) in enumerate (ls)
277			for (j, n_vc2) in enumerate (rs)
278			if rep.get_func_pairing (n_vc, n_vc2)]
279		matches.sort ()
280		if not matches:
281			print 'Cannot match any (%d, %d) at era %d' % (len (ls),
282				len (rs), era)
283			continue
284		(_, i, j) = matches[0]
285		if i > j:
286			pairs.extend ((zip (ls[i - j:], rs)))
287		else:
288			pairs.extend ((zip (ls, rs[j - i:])))
289	return pairs
290
291def func_assert_premise_strength (rep, m, l_n_vc, r_n_vc):
292	imp = rep.get_func_assert (l_n_vc, r_n_vc)
293	assert imp.is_op ('Implies'), imp
294	[pred, concl] = imp.vals
295	pred = solver.smt_expr (pred, {}, rep.solv)
296	pred = solver.parse_s_expression (pred)
297	bits = solver.split_hyp_sexpr (pred, [])
298	assert bits, bits
299	scores = []
300	for bit in bits:
301		try:
302			res = eval_model_bool (m, bit)
303			if res:
304				scores.append (1.0)
305			else:
306				scores.append (0.0)
307		except solver.EnvMiss, e:
308			scores.append (0.5)
309		except AssertionError, e:
310			scores.append (0.5)
311	return sum (scores) / len (scores)
312	return all ([eval_model_bool (m, v) for v in bits])
313
314def investigate_funcall_pair (rep, m, l_n_vc, r_n_vc,
315		verbose = False, verbose_imp = False, simplify = True):
316
317	l_nm = "%s @ %s" % (rep.p.nodes[l_n_vc[0]].fname, rep.node_count_name (l_n_vc))
318	r_nm = "%s @ %s" % (rep.p.nodes[r_n_vc[0]].fname, rep.node_count_name (r_n_vc))
319	print 'Attempt match %s -> %s' % (l_nm, r_nm)
320	imp = rep.get_func_assert (l_n_vc, r_n_vc)
321	imp = logic.weaken_assert (imp)
322	if verbose_imp:
323		imp2 = solver.smt_expr (imp, {}, rep.solv)
324		if simplify:
325			imp2 = simplify_sexp (imp2, rep, m)
326		print imp2
327	assert imp.is_op ('Implies'), imp
328	[pred, concl] = imp.vals
329	pred = solver.smt_expr (pred, {}, rep.solv)
330	pred = solver.parse_s_expression (pred)
331	bits = solver.split_hyp_sexpr (pred, [])
332	xs = [eval_model_bool (m, v) for v in bits]
333	print '  %s' % xs
334	for (v, bit) in zip (xs, bits):
335		if v != True or verbose:
336			print '  %s: %s' % (v, bit)
337			if bit[0] == 'word32-eq':
338				vs = [model_sx_word (m, x)
339					for x in bit[1:]]
340				print '    (%s = %s)' % tuple (vs)
341
342def model_sx_word (m, sx):
343	v = search.eval_model (m, sx)
344	x = expr_num (v)
345	return solver.smt_num_t (x, v.typ)
346
347def expr_num (expr):
348	assert expr.typ.kind == 'Word'
349	return expr.val & ((1 << expr.typ.num) - 1)
350
351def str_to_num (smt_str):
352	v = solver.smt_to_val(smt_str)
353	return expr_num (v)
354
355def m_var_name (expr):
356	while expr.is_op ('MemUpdate'):
357		[expr, p, v] = expr.vals
358	if expr.kind == 'Var':
359		return expr.name
360	elif expr.kind == 'Op':
361		return '<Op %s>' % op.name
362	else:
363		return '<Expr %s>' % expr.kind
364
365def eval_str (expr, env, solv, m):
366	expr = solver.to_smt_expr (expr, env, solv)
367	v = search.eval_model_expr (m, solv, expr)
368	if v.typ == syntax.boolT:
369		assert v in [syntax.true_term, syntax.false_term]
370		return v.name
371	elif v.typ.kind == 'Word':
372		return solver.smt_num_t (v.val, v.typ)
373	else:
374		assert not 'type printable', v
375
376def trace_mem (rep, tag, m, verbose = False, simplify = True, symbs = True,
377		resolve_addrs = False):
378	p = rep.p
379	ns = walk_model (rep, tag, m)
380	trace = []
381	for (n, vc) in ns:
382		if (n, vc) not in rep.arc_pc_envs:
383			# this n_vc has a pre-state, but has not been emitted.
384			# no point trying to evaluate its expressions, the
385			# solve won't have seen them yet.
386			continue
387		n_nm = rep.node_count_name ((n, vc))
388		node = p.nodes[n]
389		if node.kind == 'Call':
390			exprs = list (node.args)
391		elif node.kind == 'Basic':
392			exprs = [expr for (_, expr) in node.upds]
393		elif node.kind == 'Cond':
394			exprs = [node.cond]
395		env = rep.node_pc_envs[(tag, n, vc)][1]
396		accs = list (set ([acc for expr in exprs
397			for acc in expr.get_mem_accesses ()]))
398		for (kind, addr, v, mem) in accs:
399			addr_s = solver.smt_expr (addr, env, rep.solv)
400			v_s = solver.smt_expr (v, env, rep.solv)
401			addr = eval_str (addr, env, rep.solv, m)
402			v = eval_str (v, env, rep.solv, m)
403			m_nm = m_var_name (mem)
404			print '%s: %s @ <%s>   -- %s -- %s' % (kind, m_nm, addr, v, n_nm)
405			if simplify:
406				addr_s = simplify_sexp (addr_s, rep, m)
407				v_s = simplify_sexp (v_s, rep, m)
408			if verbose:
409				print '\t %s -- %s' % (addr_s, v_s)
410			if symbs:
411				addr_n = str_to_num (addr)
412				(hit_symbs, secs) = find_symbol (addr_n, output = False)
413				ss = hit_symbs + secs
414				if ss:
415					print '\t [%s]' % ', '.join (ss)
416		if resolve_addrs:
417			accs = [(kind, solver.to_smt_expr (addr, env, rep.solv),
418				solver.to_smt_expr (v, env, rep.solv), mem)
419				for (kind, addr, v, mem) in accs]
420		trace.extend ([(kind, addr, v, mem, n, vc)
421			for (kind, addr, v, mem) in accs])
422		if node.kind == 'Call':
423			msg = '<function call to %s at %s>' % (node.fname, n_nm)
424			print msg
425			trace.append (msg)
426	return trace
427
428def simplify_sexp (smt_xp, rep, m, flatten = True):
429	if type (smt_xp) == str:
430		smt_xp = solver.parse_s_expression (smt_xp)
431	if smt_xp[0] == 'ite':
432		(_, c, x, y) = smt_xp
433		if eval_model_bool (m, c):
434			return simplify_sexp (x, rep, m, flatten)
435		else:
436			return simplify_sexp (y, rep, m, flatten)
437	if type (smt_xp) == tuple:
438		smt_xp = tuple ([simplify_sexp (x, rep, m, False)
439			for x in smt_xp])
440	if flatten:
441		return solver.flat_s_expression (smt_xp)
442	else:
443		return smt_xp
444
445def trace_mems (rep, m, verbose = False, symbs = True, tags = None):
446	if tags == None:
447		if rep.p.pairing:
448			tags = reversed (rep.p.pairing.tags)
449		else:
450			tags = rep.p.tags ()
451	for tag in tags:
452		print '%s mem trace:' % tag
453		trace_mem (rep, tag, m, verbose = verbose, symbs = symbs)
454
455def trace_mems_diff (rep, m, tags = ['ASM', 'C']):
456	asms = trace_mem (rep, tags[0], m, resolve_addrs = True)
457	cs = trace_mem (rep, tags[1], m, resolve_addrs = True)
458	ev = lambda expr: eval_str (expr, {}, None, m)
459	c_upds = [(ev (addr), ev (v)) for (kind, addr, v, mem, _, _) in cs
460		if kind == 'MemUpdate']
461	asm_upds = [(ev (addr), ev (v)) for (kind, addr, v, mem, _, _) in asms
462		if kind == 'MemUpdate' and 'mem' in m_var_name (mem)]
463	c_upd_d = dict (c_upds)
464	asm_upd_d = dict (asm_upds)
465	addr_ord = [addr for (addr, _) in asm_upds] + [addr for (addr, _) in c_upds
466		if addr not in asm_upd_d]
467	mism = [addr for addr in addr_ord
468		if c_upd_d.get (addr) != asm_upd_d.get (addr)]
469	return (c_upd_d == asm_upd_d, mism, c_upds, asm_upds)
470
471def get_pv_type (pv):
472	assert pv.is_op (['PValid', 'PArrayValid'])
473	typ_v = pv.vals[1]
474	assert typ_v.kind == 'Type'
475	typ = typ_v.val
476	if pv.is_op ('PArrayValid'):
477		return ('PArrayValid', typ, pv.vals[3])
478	else:
479		return ('PValid', typ, None)
480
481def guess_pv (p, n, addr_expr):
482	vs = syntax.get_expr_var_set (addr_expr)
483	[pred] = p.preds[n]
484	pvs = []
485	def vis (expr):
486		if expr.is_op (['PValid', 'PArrayValid']):
487			pvs.append (expr)
488	p.nodes[pred].cond.visit (vis)
489	match_pvs = [pv for pv in pvs
490		if set.union (* [syntax.get_expr_var_set (v) for v in pv.vals[2:]])
491			== vs]
492	if len (match_pvs) > 1:
493		match_pvs = [pv for pv in match_pvs if pv.is_op ('PArrayValid')]
494	pv = match_pvs[0]
495	return pv
496
497def eval_pv_type (rep, (n, vc), m, data):
498	if data[0] == 'PValid':
499		return data
500	else:
501		(nm, typ, offs) = data
502		offs = rep.to_smt_expr (offs, (n, vc))
503		offs = search.eval_model_expr (m, rep.solv, offs)
504		return (nm, typ, offs)
505
506def trace_suspicious_mem (rep, m, tag = 'C'):
507	cs = trace_mem (rep, tag, m)
508	data = [(addr, search.eval_model_expr (m, rep.solv,
509			rep.to_smt_expr (addr, (n, vc))), (n, vc))
510		for (kind, addr, v, mem, n, vc) in cs]
511	addr_sets = {}
512	for (addr, addr_v, _) in data:
513		addr_sets.setdefault (addr_v, set ())
514		addr_sets[addr_v].add (addr)
515	dup_addrs = set ([addr_v for addr_v in addr_sets
516		if len (addr_sets[addr_v]) > 1])
517	data = [(addr, addr_v, guess_pv (rep.p, n, addr), (n, vc))
518		for (addr, addr_v, (n, vc)) in data
519		if addr_v in dup_addrs]
520	data = [(addr, addr_v, eval_pv_type (rep, (n, vc), m,
521			get_pv_type (pv)), rep.to_smt_expr (pv, (n, vc)), n)
522		for (addr, addr_v, pv, (n, vc)) in data]
523	dup_addr_types = set ([addr_v for addr_v in dup_addrs
524		if len (set ([t for (_, addr_v2, t, _, _) in data
525			if addr_v2 == addr_v])) > 1])
526	res = [(addr_v, [(t, pv, n) for (_, addr_v2, t, pv, n) in data
527			if addr_v2 == addr_v])
528		for addr_v in dup_addr_types]
529	for (addr_v, insts) in res:
530		print 'Address %s' % addr_v
531		for (t, pv, n) in insts:
532			print '  -- accessed with type %s at %s' % (t, n)
533			print '    (covered by %s)' % pv
534	return res
535
536def trace_var (rep, tag, m, v):
537	p = rep.p
538	ns = walk_model (rep, tag, m)
539	vds = rep.p.compute_var_dependencies ()
540	trace = []
541	vs = syntax.get_expr_var_set (v)
542	def fetch ((n, vc)):
543		if n in vds and [(nm, typ) for (nm, typ) in vs
544				if (nm, typ) not in vds[n]]:
545			return None
546		try:
547			(_, env) = rep.get_node_pc_env ((n, vc), tag)
548			s = solver.smt_expr (v, env, rep.solv)
549			s_x = solver.parse_s_expression (s)
550			ev = search.eval_model (m, s_x)
551			return (s, solver.smt_expr (ev, {}, None))
552		except solver.EnvMiss, e:
553			return None
554		except AssertionError, e:
555			return None
556	val = None
557	for (n, vc) in ns:
558		n_nm = rep.node_count_name ((n, vc))
559		val2 = fetch ((n, vc))
560		if val2 != val:
561			if val2 == None:
562				print 'at %s: undefined' % n_nm
563			else:
564				print 'at %s:\t\t%s:\t\t%s' % (n_nm,
565					val2[0], val2[1])
566			val = val2
567			trace.append (((n, vc), val))
568		if n not in p.nodes:
569			break
570		node = p.nodes[n]
571		if node.kind == 'Call':
572			msg = '<function call to %s at %s>' % (node.fname,
573				rep.node_count_name ((n, vc)))
574			print msg
575			trace.append (msg)
576	return trace
577
578def trace_deriv_ops (rep, m, tag):
579	n_vcs = walk_model (rep, tag, m)
580	derivs = set (('CountTrailingZeroes', 'CountLeadingZeroes',
581		'WordReverse'))
582	def get_derivs (node):
583		dvs = set ()
584		def visit (expr):
585			if expr.is_op (derivs):
586				dvs.add (expr)
587		node.visit (lambda x: (), visit)
588		return dvs
589	for (n, vc) in n_vcs:
590		if n not in rep.p.nodes:
591			continue
592		dvs = get_derivs (rep.p.nodes[n])
593		if not dvs:
594			continue
595		print '%s:' % (rep.node_count_name ((n, vc)))
596		for dv in dvs:
597			[x] = dv.vals
598			x = rep.to_smt_expr (x, (n, vc))
599			x = eval_str (x, {}, rep.solv, m)
600			print '\t%s: %s' % (dv.name, x)
601
602def check_pairings ():
603	for p in pairings.itervalues ():
604		print p['C'], p['ASM']
605		as_args = functions[p['ASM']].inputs
606		c_args = functions[p['C']].inputs
607		print as_args, c_args
608		logic.mk_fun_inp_eqs (as_args, c_args, True)
609
610def loop_var_deps (p):
611	return [(n, [v for v in p.var_deps[n]
612			if p.var_deps[n][v] == 'LoopVariable'])
613		for n in p.loop_data]
614
615def find_symbol (n, output = True):
616	from target_objects import symbols, sections
617	symbs = []
618	secs = []
619	if output:
620		def p (s):
621			print s
622	else:
623		p = lambda s: ()
624	for (s, (addr, size, _)) in symbols.iteritems ():
625		if addr <= n and n < addr + size:
626			symbs.append (s)
627			p ('%x in %s (%x - %x)' % (n, s, addr, addr + size - 1))
628	for (s, (start, end)) in sections.iteritems ():
629		if start <= n and n <= end:
630			secs.append (s)
631			p ('%x in section %s (%x - %x)' % (n, s, start, end))
632	return (symbs, secs)
633
634def assembly_point (p, n):
635	(_, hints) = p.node_tags[n]
636	if type (hints) != tuple or not logic.is_int (hints[1]):
637		return None
638	while p.node_tags[n][1][1] % 4 != 0:
639		[n] = p.preds[n]
640	return p.node_tags[n][1][1]
641
642def assembly_points (p, ns):
643	ns = [assembly_point (p, n) for n in ns]
644	ns = [n for n in ns if n != None]
645	return ns
646
647def disassembly_lines (addrs):
648	f = open ('%s/kernel.elf.txt' % target_objects.target_dir)
649	addr_set = set (['%x' % addr for addr in addrs])
650	ss = [l.strip ()
651		for l in f if ':' in l and l.split(':', 1)[0] in addr_set]
652	return ss
653
654def disassembly (p, n):
655	if hasattr (n, '__iter__'):
656		ns = set (n)
657	else:
658		ns = [n]
659	addrs = sorted (set ([assembly_point (p, n) for n in ns])
660		- set ([None]))
661	print 'asm %s' % ', '.join (['0x%x' % addr for addr in addrs])
662	for s in disassembly_lines (addrs):
663		print s
664
665def disassembly_loop (p, n):
666	head = p.loop_id (n)
667	loop = p.loop_body (n)
668	ns = sorted (set (assembly_points (p, loop)))
669	entries = assembly_points (p, [n for n in p.preds[head]
670		if n not in loop])
671	print 'Loop: [%s]' % ', '.join (['%x' % addr for addr in ns])
672	for s in disassembly_lines (ns):
673		print s
674	print 'entry from %s' % ', '.join (['%x' % addr for addr in entries])
675	for s in disassembly_lines (entries):
676		print s
677
678def try_interpret_hyp (rep, hyp):
679	try:
680		expr = rep.interpret_hyp (hyp)
681		solver.smt_expr (expr, {}, rep.solv)
682		return None
683	except:
684		return ('Broken Hyp', hyp)
685
686def check_checks ():
687	p = problem.last_problem[0]
688	rep = rep_graph.mk_graph_slice (p)
689	proof = search.last_proof[0]
690	checks = check.proof_checks (p, proof)
691	all_hyps = set ([hyp for (_, hyp, _) in checks]
692		+ [hyp for (hyps, _, _) in checks for hyp in hyps])
693	results = [try_interpret_hyp (rep, hyp) for hyp in all_hyps]
694	return [r[1] for r in results if r]
695
696def proof_failed_groups (p = None, proof = None):
697	if p == None:
698		p = problem.last_problem[0]
699	if proof == None:
700		proof = search.last_proof[0]
701	checks = check.proof_checks (p, proof)
702	groups = check.proof_check_groups (checks)
703	failed = []
704	for group in groups:
705		rep = rep_graph.mk_graph_slice (p)
706		(res, el) = check.test_hyp_group (rep, group)
707		if not res:
708			failed.append (group)
709			print 'Failed element: %s' % el
710	failed_nms = set ([s for group in failed for (_, _, s) in group])
711	print 'Failed: %s' % failed_nms
712	return failed
713
714def read_summary (f):
715	results = {}
716	times = {}
717	for line in f:
718		if not line.startswith ('Time taken to'):
719			continue
720		bits = line.split ()
721		assert bits[:4] == ['Time', 'taken', 'to', 'check']
722		res = bits[4]
723		[ref] = [i for (i, b) in enumerate (bits) if b == '<=']
724		f = bits[ref + 1]
725		[pair] = [pair for pair in pairings[f]
726			if pair.name in line]
727		time = float (bits[-1])
728		results[pair] = res
729		times[pair] = time
730	return (results, times)
731
732def unfold_defs_sexpr (defs, sexpr, depthlimit = -1):
733	if type (sexpr) == str:
734		sexpr = defs.get (sexpr, sexpr)
735		print sexpr
736		return sexpr
737	elif depthlimit == 0:
738		return sexpr
739	return tuple ([sexpr[0]] + [unfold_defs_sexpr (defs, s, depthlimit - 1)
740		for s in sexpr[1:]])
741
742def unfold_defs (defs, hyp, depthlimit = -1):
743	return solver.flat_s_expression (unfold_defs_sexpr (defs,
744		solver.parse_s_expression (hyp), depthlimit))
745
746def investigate_unsat (solv, hyps = None):
747	if hyps == None:
748		hyps = list (solver.last_hyps[0])
749	assert solv.hyps_sat_raw (hyps) == 'unsat', hyps
750	kept_hyps = []
751	while hyps:
752		h = hyps.pop ()
753		if solv.hyps_sat_raw (hyps + kept_hyps) != 'unsat':
754			kept_hyps.append (h)
755	assert solv.hyps_sat_raw (kept_hyps) == 'unsat', kept_hyps
756	split_hyps = sorted (set ([(hyp2, tag) for (hyp, tag) in kept_hyps
757		for hyp2 in solver.split_hyp (hyp)]))
758	if len (split_hyps) > len (kept_hyps):
759		return investigate_unsat (solv, split_hyps)
760	def_hyps = [(unfold_defs (solv.defs, h, 2), tag)
761		for (h, tag) in kept_hyps]
762	if def_hyps != kept_hyps:
763		return investigate_unsat (solv, def_hyps)
764	return kept_hyps
765
766def test_interesting_linear_series_exprs ():
767	pairs = set ([pair for f in pairings for pair in pairings[f]])
768	notes = {}
769	for pair in pairs:
770		p = check.build_problem (pair)
771		for n in search.init_loops_to_split (p, ()):
772			intr = logic.interesting_linear_series_exprs (p, n,
773				search.get_loop_var_analysis_at (p, n))
774			if intr:
775				notes[pair.name] = True
776			if 'Call' in str (intr):
777				notes[pair.name] = 'Call!'
778	return notes
779
780def var_analysis (p, n):
781	va = search.get_loop_var_analysis_at (p, n)
782	cats = {}
783	for (v, kind) in va:
784		if kind[0] == 'LoopLinearSeries':
785			offs = kind[2]
786			kind = kind[0]
787		else:
788			offs = None
789		cats.setdefault (kind, [])
790		cats[kind].append ((v, offs))
791	for kind in cats:
792		print '%s:' % kind
793		for (v, offs) in cats[kind]:
794			print '  %s   (%s)' % (syntax.pretty_expr (v),
795				syntax.pretty_type (v.typ))
796			if offs:
797				print '      ++ %s' % syntax.pretty_expr (offs)
798
799def var_value_sites (rep, v):
800	if type (v) == str:
801		matches = lambda (nm, _): v in nm
802	elif type (v) == tuple:
803		matches = lambda (nm, typ): v == (nm, typ)
804	v_ord = []
805	d = {}
806	for (tag, n, vc) in rep.node_pc_env_order:
807		(pc, env) = rep.get_node_pc_env ((n, vc), tag = tag)
808		for (v2, smt_exp) in env.iteritems ():
809			if matches (v2):
810				if smt_exp not in d:
811					v_ord.append (smt_exp)
812					d[smt_exp] = []
813				d[smt_exp].append ((n, vc))
814	for smt_exp in v_ord:
815		print smt_exp
816		if smt_exp in rep.solv.defs:
817			print ('  = %s' % repr (rep.solv.defs[smt_exp]))
818		print ('  - at: %s' % d[smt_exp])
819	if v_ord:
820		print ('')
821	return (v_ord, d)
822
823def loop_num_leaves (p, n):
824	for n in p.loop_body (n):
825		va = search.get_loop_var_analysis_at (p, n)
826		n_leaf = len ([1 for (v, kind) in va if kind == 'LoopLeaf'])
827		print (n, n_leaf)
828
829def try_pairing_at_funcall (p, name, head = None, restrs = None, hyps = None,
830		at = 'At'):
831	pairs = set (pairings[name])
832	addrs = [n for (n, name2) in p.function_call_addrs ()
833		if [pair for pair in pairings[name2] if pair in pairs]]
834	assert at in ['At', 'After']
835	if at == 'After':
836		addrs = [p.nodes[n].cont for n in addrs]
837	if head == None:
838		tags = p.pairing.tags
839		[head] = [n for n in search.init_loops_to_split (p, ())
840			if p.node_tags[n][0] == tags[0]]
841	if restrs == None:
842		restrs = ()
843	if hyps == None:
844		hyps = check.init_point_hyps (p)
845	while True:
846		res = search.find_split_loop (p, head, restrs, hyps,
847			node_restrs = set (addrs))
848		if res[0] == 'CaseSplit':
849			(_, ((n, tag), _)) = res
850			hyp = rep_graph.pc_true_hyp (((n, restrs), tag))
851			hyps = hyps + [hyp]
852		else:
853			return res
854
855def init_true_hyp (p, tag, expr):
856	n = p.get_entry (tag)
857	vis = ((n, ()), tag)
858	assert expr.typ == syntax.boolT, expr
859	return rep_graph.eq_hyp ((expr, vis), (syntax.true_term, vis))
860
861def smt_print (expr):
862	env = {}
863	while True:
864		try:
865			return solver.smt_expr (expr, env, None)
866		except solver.EnvMiss, e:
867			env[(e.name, e.typ)] = e.name
868
869