1#
2# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3#
4# SPDX-License-Identifier: BSD-2-Clause
5#
6
7import syntax
8import solver
9import problem
10import rep_graph
11import search
12import logic
13import check
14
15from target_objects import functions, trace, pairings, pre_pairings, printout
16import target_objects
17
18from logic import azip
19
20from syntax import mk_var, word32T, builtinTs, mk_eq, mk_less_eq
21
22last_stuff = [0]
23
24def default_n_vc (p, n):
25	head = p.loop_id (n)
26	general = [(n2, rep_graph.vc_options ([0], [1]))
27		for n2 in p.loop_heads ()
28		if n2 != head]
29	specific = [(head, rep_graph.vc_offs (1)) for _ in [1] if head]
30	return (n, tuple (general + specific))
31
32def split_sum_s_expr (expr, solv, extra_defs, typ):
33	"""divides up a linear expression 'a - b - 1 + a'
34	into ({'a':2, 'b': -1}, -1) i.e. 'a' times 2 etc and constant
35	value of -1."""
36	def rec (expr):
37		return split_sum_s_expr (expr, solv, extra_defs, typ)
38	if expr[0] == 'bvadd':
39		var = {}
40		const = 0
41		for x in expr[1:]:
42			(var2, const2) = rec (x)
43			for (v, count) in var2.iteritems ():
44				var.setdefault (v, 0)
45				var[v] += count
46			const += const2
47		return (var, const)
48	elif expr[0] == 'bvsub':
49		(_, lhs, rhs) = expr
50		(lvar, lconst) = rec (lhs)
51		(rvar, rconst) = rec (rhs)
52		const = lconst - rconst
53		var = dict ([(v, lvar.get (v, 0) - rvar.get (v, 0))
54			for v in set.union (set (lvar), set (rvar))])
55		return (var, const)
56	elif expr in solv.defs:
57		return rec (solv.defs[expr])
58	elif expr in extra_defs:
59		return rec (extra_defs[expr])
60	elif expr[:2] in ['#x', '#b']:
61		val = solver.smt_to_val (expr)
62		assert val.kind == 'Num'
63		return ({}, val.val)
64	else:
65		return ({expr: 1}, 0)
66
67def split_merge_ite_sum_sexpr (foo):
68	(s0, s1) = [solver.smt_num_t (n, typ) for n in [0, 1]]
69	if y != s0:
70		expr = ('bvadd', ('ite', cond, ('bvsub', x, y), s0), y)
71		return rec (expr)
72	(xvar, xconst) = rec (x)
73	var = dict ([(('ite', cond, v, s0), n)
74		for (v, n) in xvar.iteritems ()])
75	var.setdefault (('ite', cond, s1, s0), 0)
76	var[('ite', cond, s1, s0)] += xconst
77	return (var, 0)
78
79def simplify_expr_whyps (sexpr, rep, hyps, cache = None, extra_defs = {},
80		bool_hyps = None):
81	if cache == None:
82		cache = {}
83	if bool_hyps == None:
84		bool_hyps = []
85	if sexpr in extra_defs:
86		sexpr = extra_defs[sexpr]
87	if sexpr in rep.solv.defs:
88		sexpr = rep.solv.defs[sexpr]
89	if sexpr[0] == 'ite':
90		(_, cond, x, y) = sexpr
91		cond_exp = solver.mk_smt_expr (solver.flat_s_expression (cond),
92			syntax.boolT)
93		(mk_nimp, mk_not) = (syntax.mk_n_implies, syntax.mk_not)
94		if rep.test_hyp_whyps (mk_nimp (bool_hyps, cond_exp),
95				hyps, cache = cache):
96			return x
97		elif rep.test_hyp_whyps (mk_nimp (bool_hyps, mk_not (cond_exp)),
98				hyps, cache = cache):
99			return y
100		x = simplify_expr_whyps (x, rep, hyps, cache = cache,
101			extra_defs = extra_defs,
102			bool_hyps = bool_hyps + [cond_exp])
103		y = simplify_expr_whyps (y, rep, hyps, cache = cache,
104			extra_defs = extra_defs,
105			bool_hyps = bool_hyps + [syntax.mk_not (cond_exp)])
106		if x == y:
107			return x
108		return ('ite', cond, x, y)
109	return sexpr
110
111last_10_non_const = []
112
113def offs_expr_const (addr_expr, sp_expr, rep, hyps, extra_defs = {},
114		cache = None, typ = syntax.word32T):
115	"""if the offset between a stack addr and the initial stack pointer
116	is a constant offset, try to compute it."""
117	addr_x = solver.parse_s_expression (addr_expr)
118	sp_x = solver.parse_s_expression (sp_expr)
119	vs = [(addr_x, 1), (sp_x, -1)]
120	const = 0
121
122	while True:
123		start_vs = list (vs)
124		new_vs = {}
125		for (x, mult) in vs:
126			(var, c) = split_sum_s_expr (x, rep.solv, extra_defs,
127				typ = typ)
128			for v in var:
129				new_vs.setdefault (v, 0)
130				new_vs[v] += var[v] * mult
131			const += c * mult
132		vs = [(x, n) for (x, n) in new_vs.iteritems ()
133			if n % (2 ** typ.num) != 0]
134		if not vs:
135			return const
136		vs = [(simplify_expr_whyps (x, rep, hyps,
137				cache = cache, extra_defs = extra_defs), n)
138			for (x, n) in vs]
139		if sorted (vs) == sorted (start_vs):
140			pass # vs = split_merge_ite_sum_sexpr (vs)
141		if sorted (vs) == sorted (start_vs):
142			trace ('offs_expr_const: not const')
143			trace ('%s - %s' % (addr_expr, sp_expr))
144			trace (str (vs))
145			trace (str (hyps))
146			last_10_non_const.append ((addr_expr, sp_expr, vs, hyps))
147			del last_10_non_const[:-10]
148			return None
149
150def has_stack_var (expr, stack_var):
151	while True:
152		if expr.is_op ('MemUpdate'):
153			[m, p, v] = expr.vals
154			expr = m
155		elif expr.kind == 'Var':
156			return expr == stack_var
157		else:
158			assert not 'has_stack_var: expr kind', expr
159
160def mk_not_callable_hyps (p):
161	hyps = []
162	for n in p.nodes:
163		if p.nodes[n].kind != 'Call':
164			continue
165		if get_asm_callable (p.nodes[n].fname):
166			continue
167		tag = p.node_tags[n][0]
168		hyp = rep_graph.pc_false_hyp ((default_n_vc (p, n), tag))
169		hyps.append (hyp)
170	return hyps
171
172last_get_ptr_offsets = [0]
173last_get_ptr_offsets_setup = [0]
174
175def get_ptr_offsets (p, n_ptrs, bases, hyps = [], cache = None,
176		fail_early = False):
177	"""detect which ptrs are guaranteed to be at constant offsets
178	from some set of basis ptrs"""
179	rep = rep_graph.mk_graph_slice (p, fast = True)
180	if cache == None:
181		cache = {}
182	last_get_ptr_offsets[0] = (p, n_ptrs, bases, hyps)
183
184	smt_bases = []
185	for (n, ptr, k) in bases:
186		n_vc = default_n_vc (p, n)
187		(_, env) = rep.get_node_pc_env (n_vc)
188		smt = solver.smt_expr (ptr, env, rep.solv)
189		smt_bases.append ((smt, k))
190		ptr_typ = ptr.typ
191
192	smt_ptrs = []
193	for (n, ptr) in n_ptrs:
194		n_vc = default_n_vc (p, n)
195		pc_env = rep.get_node_pc_env (n_vc)
196		if not pc_env:
197			continue
198		smt = solver.smt_expr (ptr, pc_env[1], rep.solv)
199		hyp = rep_graph.pc_true_hyp ((n_vc, p.node_tags[n][0]))
200		smt_ptrs.append (((n, ptr), smt, hyp))
201
202	hyps = hyps + mk_not_callable_hyps (p)
203	for tag in set ([p.node_tags[n][0] for (n, _) in n_ptrs]):
204		hyps = hyps + init_correctness_hyps (p, tag)
205	tags = set ([p.node_tags[n][0] for (n, ptr) in n_ptrs])
206	ex_defs = {}
207	for t in tags:
208		ex_defs.update (get_extra_sp_defs (rep, t))
209
210	offs = []
211	for (v, ptr, hyp) in smt_ptrs:
212		off = None
213		for (ptr2, k) in smt_bases:
214			off = offs_expr_const (ptr, ptr2, rep, [hyp] + hyps,
215				cache = cache, extra_defs = ex_defs,
216				typ = ptr_typ)
217			if off != None:
218				offs.append ((v, off, k))
219				break
220		if off == None:
221			trace ('get_ptr_offs fallthrough at %d: %s' % v)
222			trace (str ([hyp] + hyps))
223			assert not fail_early, (v, ptr)
224	return offs
225
226def init_correctness_hyps (p, tag):
227	(_, fname, _) = p.get_entry_details (tag)
228	if fname not in pairings:
229		# conveniently handles bootstrap case
230		return []
231	# revise if multi-pairings for ASM an option
232	[pair] = pairings[fname]
233	true_tag = None
234	if tag in pair.funs:
235		true_tag = tag
236	elif p.hook_tag_hints.get (tag, tag) in pair.funs:
237		true_tag = p.hook_tag_hints.get (tag, tag)
238	if true_tag == None:
239		return []
240	(inp_eqs, _) = pair.eqs
241	in_tag = "%s_IN" % true_tag
242	eqs = [eq for eq in inp_eqs if eq[0][1] == in_tag
243		and eq[1][1] == in_tag]
244	return check.inst_eqs (p, (), eqs, {true_tag: tag})
245
246extra_symbols = set ()
247
248def preserves_sp (fname):
249	"""all functions will keep the stack pointer equal, whether they have
250	pairing partners or not."""
251	assume_sp_equal = bool (target_objects.hooks ('assume_sp_equal'))
252	if not extra_symbols:
253		for fname2 in target_objects.symbols:
254			extra_symbols.add(fname2)
255			extra_symbols.add('_'.join (fname2.split ('.')))
256	return (get_asm_calling_convention (fname)
257		or assume_sp_equal
258		or fname in extra_symbols)
259
260def get_extra_sp_defs (rep, tag):
261	"""add extra defs/equalities about stack pointer for the
262	purposes of stack depth analysis."""
263	# FIXME how to parametrise this?
264	sp = mk_var ('r13', syntax.word32T)
265	defs = {}
266
267	fcalls = [n_vc for n_vc in rep.funcs
268		if logic.is_int (n_vc[0])
269		if rep.p.node_tags[n_vc[0]][0] == tag
270		if preserves_sp (rep.p.nodes[n_vc[0]].fname)]
271	for (n, vc) in fcalls:
272		(inputs, outputs, _) = rep.funcs[(n, vc)]
273		if (sp.name, sp.typ) not in outputs:
274			continue
275		inp_sp = solver.smt_expr (sp, inputs, rep.solv)
276		inp_sp = solver.parse_s_expression (inp_sp)
277		out_sp = solver.smt_expr (sp, outputs, rep.solv)
278		out_sp = solver.parse_s_expression (out_sp)
279		if inp_sp != out_sp:
280			defs[out_sp] = inp_sp
281	return defs
282
283def get_stack_sp (p, tag):
284	"""get stack and stack-pointer variables"""
285	entry = p.get_entry (tag)
286	renames = p.entry_exit_renames (tags = [tag])
287	r = renames[tag + '_IN']
288
289	sp = syntax.rename_expr (mk_var ('r13', syntax.word32T), r)
290	stack = syntax.rename_expr (mk_var ('stack',
291		syntax.builtinTs['Mem']), r)
292	return (stack, sp)
293
294def pseudo_node_lvals_rvals (node):
295	assert node.kind == 'Call'
296	cc = get_asm_calling_convention_at_node (node)
297	if not cc:
298		return None
299
300	arg_vars = set ([var for arg in cc['args']
301		for var in syntax.get_expr_var_set (arg)])
302
303	callee_saved_set = set (cc['callee_saved'])
304	rets = [(nm, typ) for (nm, typ) in node.rets
305		if mk_var (nm, typ) not in callee_saved_set]
306
307	return (rets, arg_vars)
308
309def is_asm_node (p, n):
310	tag = p.node_tags[n][0]
311	return tag == 'ASM' or p.hook_tag_hints.get (tag, None) == 'ASM'
312
313def all_pseudo_node_lvals_rvals (p):
314	pseudo = {}
315	for n in p.nodes:
316		if not is_asm_node (p, n):
317			continue
318		elif p.nodes[n].kind != 'Call':
319			continue
320		ps = pseudo_node_lvals_rvals (p.nodes[n])
321		if ps != None:
322			pseudo[n] = ps
323	return pseudo
324
325def adjusted_var_dep_outputs_for_tag (p, tag):
326	(ent, fname, _) = p.get_entry_details (tag)
327	fun = functions[fname]
328	cc = get_asm_calling_convention (fname)
329	callee_saved_set = set (cc['callee_saved'])
330	ret_set = set ([(nm, typ) for ret in cc['rets']
331		for (nm, typ) in syntax.get_expr_var_set (ret)])
332	rets = [(nm2, typ) for ((nm, typ), (nm2, _))
333			in azip (fun.outputs, p.outputs[tag])
334			if (nm, typ) in ret_set
335				or mk_var (nm, typ) in callee_saved_set]
336	return rets
337
338def adjusted_var_dep_outputs (p):
339	outputs = {}
340	for tag in p.outputs:
341		ent = p.get_entry (tag)
342		if is_asm_node (p, ent):
343			outputs[tag] = adjusted_var_dep_outputs_for_tag (p, tag)
344		else:
345			outputs[tag] = p.outputs[tag]
346	def output (n):
347		tag = p.node_tags[n][0]
348		return outputs[tag]
349	return output
350
351def is_stack (expr):
352	return expr.kind == 'Var' and 'stack' in expr.name
353
354class StackOffsMissing (Exception):
355	pass
356
357def stack_virtualise_expr (expr, sp_offs):
358	if expr.is_op ('MemAcc') and is_stack (expr.vals[0]):
359		[m, p] = expr.vals
360		if expr.typ == syntax.word8T:
361			ps = [(syntax.mk_minus (p, syntax.mk_word32 (n)), n)
362				for n in [0, 1, 2, 3]]
363		elif expr.typ == syntax.word32T:
364			ps = [(p, 0)]
365		else:
366			assert expr.typ == syntax.word32T, expr
367		ptrs = [(p, 'MemAcc') for (p, _) in ps]
368		if sp_offs == None:
369			return (ptrs, None)
370		# FIXME: very 32-bit specific
371		ps = [(p, n) for (p, n) in ps if p in sp_offs
372			if sp_offs[p][1] % 4 == 0]
373		if not ps:
374			return (ptrs, expr)
375		[(p, n)] = ps
376		if p not in sp_offs:
377			raise StackOffsMissing ()
378		(k, offs) = sp_offs[p]
379		v = mk_var (('Fake', k, offs), syntax.word32T)
380		if n != 0:
381			v = syntax.mk_shiftr (v, n * 8)
382		v = syntax.mk_cast (v, expr.typ)
383		return (ptrs, v)
384	elif expr.kind == 'Op':
385		vs = [stack_virtualise_expr (v, sp_offs) for v in expr.vals]
386		return ([p for (ptrs, _) in vs for p in ptrs],
387			syntax.adjust_op_vals (expr, [v for (_, v) in vs]))
388	else:
389		return ([], expr)
390
391def stack_virtualise_upd (((nm, typ), expr), sp_offs):
392	if 'stack' in nm:
393		upds = []
394		ptrs = []
395		while expr.is_op ('MemUpdate'):
396			[m, p, v] = expr.vals
397			ptrs.append ((p, 'MemUpdate'))
398			(ptrs2, v2) = stack_virtualise_expr (v, sp_offs)
399			ptrs.extend (ptrs2)
400			if sp_offs != None:
401				if p not in sp_offs:
402					raise StackOffsMissing ()
403				(k, offs) = sp_offs[p]
404				upds.append (((('Fake', k, offs),
405					syntax.word32T), v2))
406			expr = m
407		assert is_stack (expr), expr
408		return (ptrs, upds)
409	else:
410		(ptrs, expr2) = stack_virtualise_expr (expr, sp_offs)
411		return (ptrs, [((nm, typ), expr2)])
412
413def stack_virtualise_ret (expr, sp_offs):
414	if expr.kind == 'Var':
415		return ([], (expr.name, expr.typ))
416	elif expr.is_op ('MemAcc'):
417		[m, p] = expr.vals
418		assert expr.typ == syntax.word32T, expr
419		assert is_stack (m), expr
420		if sp_offs != None:
421			(k, offs) = sp_offs[p]
422			r = (('Fake', k, offs), syntax.word32T)
423		else:
424			r = None
425		return ([(p, 'MemUpdate')], r)
426	else:
427		assert not 'ret expr understood', expr
428
429def stack_virtualise_node (node, sp_offs):
430	if node.kind == 'Cond':
431		(ptrs, cond) = stack_virtualise_expr (node.cond, sp_offs)
432		if sp_offs == None:
433			return (ptrs, None)
434		else:
435			return (ptrs, syntax.Node ('Cond',
436				node.get_conts (), cond))
437	elif node.kind == 'Call':
438		if is_instruction (node.fname):
439			return ([], node)
440		cc = get_asm_calling_convention_at_node (node)
441		assert cc != None, node.fname
442		args = [arg for arg in cc['args'] if not is_stack (arg)]
443		args = [stack_virtualise_expr (arg, sp_offs) for arg in args]
444		rets = [ret for ret in cc['rets_inp'] if not is_stack (ret)]
445		rets = [stack_virtualise_ret (ret, sp_offs) for ret in rets]
446		ptrs = list (set ([p for (ps, _) in args for p in ps]
447			+ [p for (ps, _) in rets for p in ps]))
448		if sp_offs == None:
449			return (ptrs, None)
450		else:
451			return (ptrs, syntax.Node ('Call', node.cont,
452				(None, [v for (_, v) in args]
453					+ [p for (p, _) in ptrs],
454					[r for (_, r) in rets])))
455	elif node.kind == 'Basic':
456		upds = [stack_virtualise_upd (upd, sp_offs) for upd in node.upds]
457		ptrs = list (set ([p for (ps, _) in upds for p in ps]))
458		if sp_offs == None:
459			return (ptrs, None)
460		else:
461			ptr_upds = [(('unused#ptr#name%d' % i, syntax.word32T),
462				ptr) for (i, (ptr, _)) in enumerate (ptrs)]
463			return (ptrs, syntax.Node ('Basic', node.cont,
464				[upd for (_, us) in upds for upd in us]
465					+ ptr_upds))
466	else:
467		assert not "node kind understood", node.kind
468
469def mk_get_local_offs (p, tag, sp_reps):
470	(stack, _) = get_stack_sp (p, tag)
471	def mk_local (n, kind, off, k):
472		(v, off2) = sp_reps[n][k]
473		ptr = syntax.mk_plus (v, syntax.mk_word32 (off + off2))
474		if kind == 'Ptr':
475			return ptr
476		elif kind == 'MemAcc':
477			return syntax.mk_memacc (stack, ptr, syntax.word32T)
478	return mk_local
479
480def adjust_ret_ptr (ptr):
481	"""this is a bit of a hack.
482
483	the return slots are named based on r0_input, which will be unchanged,
484	which is handy, but we really want to be talking about r0, which will
485	produce meaningful offsets against the pointers actually used in the
486	program."""
487
488	return logic.var_subst (ptr, {('ret_addr_input', syntax.word32T):
489		syntax.mk_var ('r0', syntax.word32T)}, must_subst = False)
490
491def get_loop_virtual_stack_analysis (p, tag):
492	"""computes variable liveness etc analyses with stack slots treated
493	as virtual variables."""
494	cache_key = ('loop_stack_analysis', tag)
495	if cache_key in p.cached_analysis:
496		return p.cached_analysis[cache_key]
497
498	(ent, fname, _) = p.get_entry_details (tag)
499	(_, sp) = get_stack_sp (p, tag)
500	cc = get_asm_calling_convention (fname)
501	rets = list (set ([ptr for arg in cc['rets']
502		for (ptr, _) in stack_virtualise_expr (arg, None)[0]]))
503	rets = [adjust_ret_ptr (ret) for ret in rets]
504	renames = p.entry_exit_renames (tags = [tag])
505	r = renames[tag + '_OUT']
506	rets = [syntax.rename_expr (ret, r) for ret in rets]
507
508	ns = [n for n in p.nodes if p.node_tags[n][0] == tag]
509	loop_ns = logic.minimal_loop_node_set (p)
510
511	ptrs = list (set ([(n, ptr) for n in ns
512		for ptr in (stack_virtualise_node (p.nodes[n], None))[0]]))
513	ptrs += [(n, (sp, 'StackPointer')) for n in ns if n in loop_ns]
514	offs = get_ptr_offsets (p, [(n, ptr) for (n, (ptr, _)) in ptrs],
515		[(ent, sp, 'stack')]
516			+ [(ent, ptr, 'indirect_ret') for ptr in rets[:1]])
517
518	ptr_offs = {}
519	rep_offs = {}
520	upd_offsets = {}
521	for ((n, ptr), off, k) in offs:
522		off = norm_int (off, 32)
523		ptr_offs.setdefault (n, {})
524		rep_offs.setdefault (n, {})
525		ptr_offs[n][ptr] = (k, off)
526		rep_offs[n][k] = (ptr, - off)
527
528	for (n, (ptr, kind)) in ptrs:
529		if kind == 'MemUpdate' and n in loop_ns:
530			loop = p.loop_id (n)
531			(k, off) = ptr_offs[n][ptr]
532			upd_offsets.setdefault (loop, set ())
533			upd_offsets[loop].add ((k, off))
534	loc_offs = mk_get_local_offs (p, tag, rep_offs)
535
536	adj_nodes = {}
537	for n in ns:
538		try:
539			(_, node) = stack_virtualise_node (p.nodes[n],
540				ptr_offs.get (n, {}))
541		except StackOffsMissing, e:
542			printout ("Stack analysis issue at (%d, %s)."
543				% (n, p.node_tags[n]))
544			node = p.nodes[n]
545		adj_nodes[n] = node
546
547	# finally do analysis on this collection of nodes
548
549	preds = dict (p.preds)
550	preds['Ret'] = [n for n in preds['Ret'] if p.node_tags[n][0] == tag]
551	preds['Err'] = [n for n in preds['Err'] if p.node_tags[n][0] == tag]
552	vds = logic.compute_var_deps (adj_nodes,
553		adjusted_var_dep_outputs (p), preds)
554
555	result = (vds, adj_nodes, loc_offs, upd_offsets, (ptrs, offs))
556	p.cached_analysis[cache_key] = result
557	return result
558
559def norm_int (n, radix):
560	n = n & ((1 << radix) - 1)
561	n2 = n - (1 << radix)
562	if abs (n2) < abs (n):
563		return n2
564	else:
565		return n
566
567def loop_var_analysis (p, split):
568	"""computes the same loop dataflow analysis as in the 'logic' module
569	but with stack slots treated as virtual variables."""
570	if not is_asm_node (p, split):
571		return None
572	head = p.loop_id (split)
573	tag = p.node_tags[split][0]
574	assert head
575
576	key = ('loop_stack_virtual_var_cycle_analysis', split)
577	if key in p.cached_analysis:
578		return p.cached_analysis[key]
579
580	(vds, adj_nodes, loc_offs,
581		upd_offsets, _) = get_loop_virtual_stack_analysis (p, tag)
582	loop = p.loop_body (head)
583
584	va = logic.compute_loop_var_analysis (p, vds, split,
585		override_nodes = adj_nodes)
586
587	(stack, _) = get_stack_sp (p, tag)
588
589	va2 = []
590	uoffs = upd_offsets.get (head, [])
591	for (v, data) in va:
592		if v.kind == 'Var' and v.name[0] == 'Fake':
593			(_, k, offs) = v.name
594			if (k, offs) not in uoffs:
595				continue
596			v2 = loc_offs (split, 'MemAcc', offs, k)
597			va2.append ((v2, data))
598		elif v.kind == 'Var' and v.name.startswith ('stack'):
599			assert v.typ == stack.typ
600			continue
601		else:
602			va2.append ((v, data))
603	stack_const = stack
604	for (k, off) in uoffs:
605		stack_const = syntax.mk_memupd (stack_const,
606			loc_offs (split, 'Ptr', off, k),
607			syntax.mk_word32 (0))
608	sp = asm_stack_rep_hook (p, (stack.name, stack.typ), 'Loop', split)
609	assert sp and sp[0] == 'SplitMem', (split, sp)
610	(_, st_split) = sp
611	stack_const = logic.mk_stack_wrapper (st_split, stack_const, [])
612	stack_const = logic.mk_eq_selective_wrapper (stack_const,
613		([], [0]))
614
615	va2.append ((stack_const, 'LoopConst'))
616
617	p.cached_analysis[key] = va2
618	return va2
619
620def inline_no_pre_pairing (p):
621	# FIXME: handle code sharing with check.inline_completely_unmatched
622	while True:
623		ns = [n for n in p.nodes if p.nodes[n].kind == 'Call'
624			if p.nodes[n].fname not in pre_pairings
625			if not is_instruction (p.nodes[n].fname)]
626		for n in ns:
627			trace ('Inlining %s at %d.' % (p.nodes[n].fname, n))
628			problem.inline_at_point (p, n)
629		if not ns:
630			return
631
632last_asm_stack_depth_fun = [0]
633
634def check_before_guess_asm_stack_depth (fun):
635	from solver import smt_expr
636	if not fun.entry:
637		return None
638	p = fun.as_problem (problem.Problem, name = 'Target')
639	try:
640		p.do_analysis ()
641		p.check_no_inner_loops ()
642		inline_no_pre_pairing (p)
643	except problem.Abort, e:
644		return None
645	rep = rep_graph.mk_graph_slice (p, fast = True)
646	try:
647		rep.get_pc (default_n_vc (p, 'Ret'), 'Target')
648		err_pc = rep.get_pc (default_n_vc (p, 'Err'), 'Target')
649	except solver.EnvMiss, e:
650		return None
651
652	inlined_funs = set ([fn for (_, _, fn) in p.inline_scripts['Target']])
653	if inlined_funs:
654		printout ('  (stack analysis also involves %s)'
655			% ', '.join(inlined_funs))
656
657	return p
658
659def guess_asm_stack_depth (fun):
660	p = check_before_guess_asm_stack_depth (fun)
661	if not p:
662		return (0, {})
663
664	last_asm_stack_depth_fun[0] = fun.name
665
666	entry = p.get_entry ('Target')
667	(_, sp) = get_stack_sp (p, 'Target')
668
669	nodes = get_asm_reachable_nodes (p, tag_set = ['Target'])
670
671	offs = get_ptr_offsets (p, [(n, sp) for n in nodes],
672		[(entry, sp, 'InitSP')], fail_early = True)
673
674	assert len (offs) == len (nodes), map (hex, set (nodes)
675		- set ([n for ((n, _), _, _) in offs]))
676
677	all_offs = [(n, signed_offset (off, 32, 10 ** 6))
678		for ((n, ptr), off, _) in offs]
679	min_offs = min ([offs for (n, offs) in all_offs])
680	max_offs = max ([offs for (n, offs) in all_offs])
681
682	assert min_offs >= 0 or max_offs <= 0, all_offs
683	multiplier = 1
684	if min_offs < 0:
685		multiplier = -1
686		max_offs = - min_offs
687
688	fcall_offs = [(p.nodes[n].fname, offs * multiplier)
689		for (n, offs) in all_offs if p.nodes[n].kind == 'Call']
690	fun_offs = {}
691	for f in set ([f for (f, _) in fcall_offs]):
692		fun_offs[f] = max ([offs for (f2, offs) in fcall_offs
693			if f2 == f])
694
695	return (max_offs, fun_offs)
696
697def signed_offset (n, bits, bound = 0):
698	n = n & ((1 << bits) - 1)
699	if n >= (1 << (bits - 1)):
700		n = n - (1 << bits)
701	if bound:
702		assert n <= bound, (n, bound)
703		assert n >= (- bound), (n, bound)
704	return n
705
706def ident_conds (fname, idents):
707	rolling = syntax.true_term
708	conds = []
709	for ident in idents.get (fname, [syntax.true_term]):
710		conds.append ((ident, syntax.mk_and (rolling, ident)))
711		rolling = syntax.mk_and (rolling, syntax.mk_not (ident))
712	return conds
713
714def ident_callables (fname, callees, idents):
715	from solver import to_smt_expr, smt_expr
716	from syntax import mk_not, mk_and, true_term
717
718	auto_callables = dict ([((ident, f, true_term), True)
719		for ident in idents.get (fname, [true_term])
720		for f in callees if f not in idents])
721
722	if not [f for f in callees if f in idents]:
723		return auto_callables
724
725	fun = functions[fname]
726	p = fun.as_problem (problem.Problem, name = 'Target')
727	check_ns = [(n, ident, cond) for n in p.nodes
728		if p.nodes[n].kind == 'Call'
729		if p.nodes[n].fname in idents
730		for (ident, cond) in ident_conds (p.nodes[n].fname, idents)]
731
732	p.do_analysis ()
733	assert check_ns
734
735	rep = rep_graph.mk_graph_slice (p, fast = True)
736	err_hyp = rep_graph.pc_false_hyp ((default_n_vc (p, 'Err'), 'Target'))
737
738	callables = auto_callables
739	nhyps = mk_not_callable_hyps (p)
740
741	for (ident, cond) in ident_conds (fname, idents):
742		renames = p.entry_exit_renames (tags = ['Target'])
743		cond = syntax.rename_expr (cond, renames['Target_IN'])
744		entry = p.get_entry ('Target')
745		e_vis = ((entry, ()), 'Target')
746		hyps = [err_hyp, rep_graph.eq_hyp ((cond, e_vis),
747				(true_term, e_vis))]
748
749		for (n, ident2, cond2) in check_ns:
750			k = (ident, p.nodes[n].fname, ident2)
751			(inp_env, _, _) = rep.get_func (default_n_vc (p, n))
752			pc = rep.get_pc (default_n_vc (p, n))
753			cond2 = to_smt_expr (cond2, inp_env, rep.solv)
754			if rep.test_hyp_whyps (mk_not (mk_and (pc, cond2)),
755					hyps + nhyps):
756				callables[k] = False
757			else:
758				callables[k] = True
759	return callables
760
761def compute_immediate_stack_bounds (idents, names):
762	from syntax import true_term
763	immed = {}
764	names = sorted (names)
765	for (i, fname) in enumerate (names):
766		printout ('Doing stack analysis for %r. (%d of %d)' % (fname,
767			i + 1, len (names)))
768		fun = functions[fname]
769		(offs, fn_offs) = guess_asm_stack_depth (fun)
770		callables = ident_callables (fname, fn_offs.keys (), idents)
771		for ident in idents.get (fname, [true_term]):
772			calls = [((fname2, ident2), fn_offs[fname2])
773				for fname2 in fn_offs
774				for ident2 in idents.get (fname2, [true_term])
775				if callables[(ident, fname2, ident2)]]
776			immed[(fname, ident)] = (offs, dict (calls))
777	last_immediate_stack_bounds[0] = immed
778	return immed
779
780last_immediate_stack_bounds = [0]
781
782def immediate_stack_bounds_loop (immed):
783	graph = dict ([(k, immed[k][1].keys ()) for k in immed])
784	graph['ENTRY'] = list (immed)
785	comps = logic.tarjan (graph, ['ENTRY'])
786	rec_comps = [[x] + y for (x, y) in comps if y]
787	return rec_comps
788
789def compute_recursive_stack_bounds (immed):
790	assert not immediate_stack_bounds_loop (immed)
791	bounds = {}
792	todo = immed.keys ()
793	report = 1000
794	while todo:
795		if len (todo) >= report:
796			trace ('todo length %d' % len (todo))
797			trace ('tail: %s' % todo[-20:])
798			report += 1000
799		(fname, ident) = todo.pop ()
800		if (fname, ident) in bounds:
801			continue
802		(static, calls) = immed[(fname, ident)]
803		if [1 for k in calls if k not in bounds]:
804			todo.append ((fname, ident))
805			todo.extend (calls.keys ())
806			continue
807		else:
808			bounds[(fname, ident)] = max ([static]
809				+ [bounds[k] + calls[k] for k in calls])
810	return bounds
811
812def stack_bounds_to_closed_form (bounds, names, idents):
813	closed = {}
814	for fname in names:
815		res = syntax.mk_word32 (bounds[(fname, syntax.true_term)])
816		extras = []
817		if fname in idents:
818			assert idents[fname][-1] == syntax.true_term
819			extras = reversed (idents[fname][:-1])
820		for ident in extras:
821			alt = syntax.mk_word32 (bounds[(fname, ident)])
822			res = syntax.mk_if (ident, alt, res)
823		closed[fname] = res
824	return closed
825
826def compute_asm_stack_bounds (idents, names):
827	immed = compute_immediate_stack_bounds (idents, names)
828	bounds = compute_recursive_stack_bounds (immed)
829	closed = stack_bounds_to_closed_form (bounds, names, idents)
830	return closed
831
832recursion_trace = []
833recursion_last_assns = [[]]
834
835def get_recursion_identifiers (funs, extra_unfolds = []):
836	idents = {}
837	del recursion_trace[:]
838	graph = dict ([(f, list (functions[f].function_calls ()))
839		for f in functions])
840	fs = funs
841	fs2 = set ()
842	while fs2 != fs:
843		fs2 = fs
844		fs = set.union (set ([f for f in graph if [f2 for f2 in graph[f]
845				if f2 in fs2]]),
846			set ([f2 for f in fs2 for f2 in graph[f]]), fs2)
847	graph = dict ([(f, graph[f]) for f in fs])
848	entries = list (fs - set ([f2 for f in graph for f2 in graph[f]]))
849	comps = logic.tarjan (graph, entries)
850	for (head, tail) in comps:
851		if tail or head in graph[head]:
852			group = [head] + list (tail)
853			idents2 = compute_recursion_idents (group,
854				extra_unfolds)
855			idents.update (idents2)
856	return idents
857
858def compute_recursion_idents (group, extra_unfolds):
859	idents = {}
860	group = set (group)
861	recursion_trace.append ('Computing for group %s' % group)
862	printout ('Doing recursion analysis for function group:')
863	printout ('  %s' % list(group))
864	prevs = set ([f for f in functions
865		if [f2 for f2 in functions[f].function_calls () if f2 in group]])
866	for f in prevs - group:
867		recursion_trace.append ('  checking for %s' % f)
868		trace ('Checking idents for %s' % f)
869		while add_recursion_ident (f, group, idents, extra_unfolds):
870			pass
871	return idents
872
873def function_link_assns (p, call_site, tag):
874	call_vis = (default_n_vc (p, call_site), p.node_tags[call_site][0])
875	return rep_graph.mk_function_link_hyps (p, call_vis, tag)
876
877def add_recursion_ident (f, group, idents, extra_unfolds):
878	from syntax import mk_eq, mk_implies, mk_var
879	p = problem.Problem (None, name = 'Recursion Test')
880	chain = []
881	tag = 'fun0'
882	p.add_entry_function (functions[f], tag)
883	p.do_analysis ()
884	assns = []
885	recursion_last_assns[0] = assns
886
887	while True:
888		res = find_unknown_recursion (p, group, idents, tag, assns,
889			extra_unfolds)
890		if res == None:
891			break
892		if p.nodes[res].fname not in group:
893			problem.inline_at_point (p, res)
894			continue
895		fname = p.nodes[res].fname
896		chain.append (fname)
897		tag = 'fun%d' % len (chain)
898		(args, _, entry) = p.add_entry_function (functions[fname], tag)
899		p.do_analysis ()
900		assns += function_link_assns (p, res, tag)
901	if chain == []:
902		return None
903	recursion_trace.append ('  created fun chain %s' % chain)
904	word_args = [(i, mk_var (s, typ))
905		for (i, (s, typ)) in enumerate (args)
906		if typ.kind == 'Word']
907	rep = rep_graph.mk_graph_slice (p, fast = True)
908	(_, env) = rep.get_node_pc_env ((entry, ()))
909
910	m = {}
911	res = rep.test_hyp_whyps (syntax.false_term, assns, model = m)
912	assert m
913
914	if find_unknown_recursion (p, group, idents, tag, [], []) == None:
915		idents.setdefault (fname, [])
916		idents[fname].append (syntax.true_term)
917		recursion_trace.append ('      found final ident for %s' % fname)
918		return syntax.true_term
919	assert word_args
920	recursion_trace.append ('      scanning for ident for %s' % fname)
921	for (i, arg) in word_args:
922		(nm, typ) = functions[fname].inputs[i]
923		arg_smt = solver.to_smt_expr (arg, env, rep.solv)
924		val = search.eval_model_expr (m, rep.solv, arg_smt)
925		if not rep.test_hyp_whyps (mk_eq (arg_smt, val), assns):
926			recursion_trace.append ('      discarded %s = 0x%x, not stable' % (nm, val.val))
927			continue
928		entry_vis = ((entry, ()), tag)
929		ass = rep_graph.eq_hyp ((arg, entry_vis), (val, entry_vis))
930		res = find_unknown_recursion (p, group, idents, tag,
931				assns + [ass], [])
932		if res:
933			fname2 = p.nodes[res].fname
934			recursion_trace.append ('      discarded %s, allows recursion to %s' % (nm, fname2))
935			continue
936		eq = syntax.mk_eq (mk_var (nm, typ), val)
937		idents.setdefault (fname, [])
938		idents[fname].append (eq)
939		recursion_trace.append ('    found ident for %s: %s' % (fname, eq))
940		return eq
941	assert not "identifying assertion found"
942
943def find_unknown_recursion (p, group, idents, tag, assns, extra_unfolds):
944	from syntax import mk_not, mk_and, foldr1
945	rep = rep_graph.mk_graph_slice (p, fast = True)
946	for n in p.nodes:
947		if p.nodes[n].kind != 'Call':
948			continue
949		if p.node_tags[n][0] != tag:
950			continue
951		fname = p.nodes[n].fname
952		if fname in extra_unfolds:
953			return n
954		if fname not in group:
955			continue
956		(inp_env, _, _) = rep.get_func (default_n_vc (p, n))
957		pc = rep.get_pc (default_n_vc (p, n))
958		new = foldr1 (mk_and, [pc] + [syntax.mk_not (
959				solver.to_smt_expr (ident, inp_env, rep.solv))
960			for ident in idents.get (fname, [])])
961		if rep.test_hyp_whyps (mk_not (new), assns):
962			continue
963		return n
964	return None
965
966asm_cc_cache = {}
967
968def is_instruction (fname):
969	bits = fname.split ("'")
970	return bits[1:] and bits[:1] in [["l_impl"], ["instruction"]]
971
972def get_asm_calling_convention (fname):
973	if fname in asm_cc_cache:
974		return asm_cc_cache[fname]
975	if fname not in pre_pairings:
976		bits = fname.split ("'")
977		if not is_instruction (fname):
978			trace ("Warning: unusual unmatched function (%s, %s)."
979				% (fname, bits))
980		return None
981	pair = pre_pairings[fname]
982	assert pair['ASM'] == fname
983	c_fun = functions[pair['C']]
984	from logic import split_scalar_pairs
985	(var_c_args, c_imem, glob_c_args) = split_scalar_pairs (c_fun.inputs)
986	(var_c_rets, c_omem, glob_c_rets) = split_scalar_pairs (c_fun.outputs)
987
988	num_args = len (var_c_args)
989	num_rets = len (var_c_rets)
990	const_mem = not (c_omem)
991
992	cc = get_asm_calling_convention_inner (num_args, num_rets, const_mem)
993	asm_cc_cache[fname] = cc
994	return cc
995
996def get_asm_calling_convention_inner (num_c_args, num_c_rets, const_mem):
997	key = ('Inner', num_c_args, num_c_rets, const_mem)
998	if key in asm_cc_cache:
999		return asm_cc_cache[key]
1000
1001	from logic import mk_var_list, mk_stack_sequence
1002	from syntax import mk_var, word32T, builtinTs
1003
1004	arg_regs = mk_var_list (['r0', 'r1', 'r2', 'r3'], word32T)
1005	r0 = arg_regs[0]
1006	sp = mk_var ('r13', word32T)
1007	st = mk_var ('stack', builtinTs['Mem'])
1008	r0_input = mk_var ('ret_addr_input', word32T)
1009
1010	mem = mk_var ('mem', builtinTs['Mem'])
1011	dom = mk_var ('dom', builtinTs['Dom'])
1012	dom_stack = mk_var ('dom_stack', builtinTs['Dom'])
1013
1014	global_args = [mem, dom, st, dom_stack, sp, mk_var ('ret', word32T)]
1015
1016	sregs = mk_stack_sequence (sp, 4, st, word32T, num_c_args + 1)
1017
1018	arg_seq = [r for r in arg_regs] + [s for (s, _) in sregs]
1019	if num_c_rets > 1:
1020		# the 'return-too-much' issue.
1021		# instead r0 is a save-returns-here pointer
1022		arg_seq.pop (0)
1023		rets = mk_stack_sequence (r0_input, 4, st, word32T, num_c_rets)
1024		rets = [r for (r, _) in rets]
1025	else:
1026		rets = [r0]
1027
1028	callee_saved_vars = ([mk_var (v, word32T)
1029			for v in 'r4 r5 r6 r7 r8 r9 r10 r11 r13'.split ()]
1030		+ [dom, dom_stack])
1031
1032	if const_mem:
1033		callee_saved_vars += [mem]
1034	else:
1035		rets += [mem]
1036	rets += [st]
1037
1038	cc = {'args': arg_seq[: num_c_args] + global_args,
1039		'rets': rets, 'callee_saved': callee_saved_vars}
1040
1041	asm_cc_cache[key] = cc
1042	return cc
1043
1044def get_asm_calling_convention_at_node (node):
1045	cc = get_asm_calling_convention (node.fname)
1046	if not cc:
1047		return None
1048
1049	fun = functions[node.fname]
1050	arg_input_map = dict (azip (fun.inputs, node.args))
1051	ret_output_map = dict (azip (fun.outputs,
1052		[mk_var (nm, typ) for (nm, typ) in node.rets]))
1053
1054	args = [logic.var_subst (arg, arg_input_map) for arg in cc['args']]
1055	rets = [logic.var_subst (ret, ret_output_map) for ret in cc['rets']]
1056	# these are useful because they happen to map ret r0_input back to
1057	# the previous value r0, rather than the useless value r0_input_ignore.
1058	rets_inp = [logic.var_subst (ret, arg_input_map) for ret in cc['rets']]
1059	saved = [logic.var_subst (v, ret_output_map)
1060		for v in cc['callee_saved']]
1061	return {'args': args, 'rets': rets,
1062		'rets_inp': rets_inp, 'callee_saved': saved}
1063
1064call_cache = {}
1065
1066def get_asm_callable (fname):
1067	if fname not in pre_pairings:
1068		return True
1069	c_fun = pre_pairings[fname]['C']
1070
1071	if not call_cache:
1072		for f in functions:
1073			call_cache[f] = False
1074		for f in functions:
1075			fun = functions[f]
1076			for n in fun.reachable_nodes (simplify = True):
1077				if fun.nodes[n].kind == 'Call':
1078					call_cache[fun.nodes[n].fname] = True
1079	return call_cache[c_fun]
1080
1081def get_asm_reachable_nodes (p, tag_set = None):
1082	if tag_set == None:
1083		tag_set = [tag for tag in p.tags ()
1084			if is_asm_node (p, p.get_entry (tag))]
1085	frontier = [p.get_entry (tag) for tag in tag_set]
1086	nodes = set ()
1087	while frontier:
1088		n = frontier.pop ()
1089		if n in nodes or n not in p.nodes:
1090			continue
1091		nodes.add (n)
1092		node = p.nodes[n]
1093		if node.kind == 'Call' and not get_asm_callable (node.fname):
1094			continue
1095		node = logic.simplify_node_elementary (node)
1096		frontier.extend (node.get_conts ())
1097	return nodes
1098
1099def convert_recursion_idents (idents):
1100	asm_idents = {}
1101	for f in idents:
1102		if f not in pre_pairings:
1103			continue
1104		f2 = pre_pairings[f]['ASM']
1105		assert f2 != f
1106		asm_idents[f2] = []
1107		for ident in idents[f]:
1108			if ident.is_op ('True'):
1109				asm_idents[f2].append (ident)
1110			elif ident.is_op ('Equals'):
1111				[x, y] = ident.vals
1112				# this is a bit hacky
1113				[i] = [i for (i, (nm, typ))
1114					in enumerate (functions[f].inputs)
1115					if x.is_var ((nm, typ))]
1116				cc = get_asm_calling_convention (f2)
1117				x = cc['args'][i]
1118				asm_idents[f2].append (syntax.mk_eq (x, y))
1119			else:
1120				assert not 'ident kind convertible'
1121	return asm_idents
1122
1123def mk_pairing (pre_pair, stack_bounds):
1124	asm_f = pre_pair['ASM']
1125	sz = stack_bounds[asm_f]
1126	c_fun = functions[pre_pair['C']]
1127
1128	from logic import split_scalar_pairs
1129	(var_c_args, c_imem, glob_c_args) = split_scalar_pairs (c_fun.inputs)
1130	(var_c_rets, c_omem, glob_c_rets) = split_scalar_pairs (c_fun.outputs)
1131
1132	eqs = logic.mk_eqs_arm_none_eabi_gnu (var_c_args, var_c_rets,
1133		c_imem, c_omem, sz)
1134
1135	return logic.Pairing (['ASM', 'C'],
1136		{'ASM': asm_f, 'C': c_fun.name}, eqs)
1137
1138def mk_pairings (stack_bounds):
1139	new_pairings = {}
1140	for f in pre_pairings:
1141		if f in new_pairings:
1142			continue
1143		pair = mk_pairing (pre_pairings[f], stack_bounds)
1144		for fun in pair.funs.itervalues ():
1145			new_pairings[fun] = [pair]
1146	return new_pairings
1147
1148def serialise_stack_bounds (stack_bounds):
1149	lines = []
1150	for fname in stack_bounds:
1151		ss = ['StackBound', fname]
1152		stack_bounds[fname].serialise (ss)
1153		lines.append (' '.join (ss) + '\n')
1154	return lines
1155
1156def deserialise_stack_bounds (lines):
1157	bounds = {}
1158	for line in lines:
1159		bits = line.split ()
1160		if not bits:
1161			continue
1162		assert bits[0] == 'StackBound'
1163		fname = bits[1]
1164		(_, bound) = syntax.parse_expr (bits, 2)
1165		bounds[fname] = bound
1166	return bounds
1167
1168funs_with_tag = {}
1169
1170def get_functions_with_tag (tag):
1171	if tag in funs_with_tag:
1172		return funs_with_tag[tag]
1173	visit = set ([pre_pairings[f][tag] for f in pre_pairings
1174		if tag in pre_pairings[f]])
1175	visit.update ([pair.funs[tag] for f in pairings
1176		for pair in pairings[f] if tag in pair.funs])
1177	funs = set (visit)
1178	while visit:
1179		f = visit.pop ()
1180		funs.add (f)
1181		visit.update (set (functions[f].function_calls ()) - funs)
1182	funs_with_tag[tag] = funs
1183	return funs
1184
1185def compute_stack_bounds (quiet = False):
1186	prev_tracer = target_objects.tracer[0]
1187	if quiet:
1188		target_objects.tracer[0] = lambda s, n: ()
1189
1190	try:
1191		c_fs = get_functions_with_tag ('C')
1192		idents = get_recursion_identifiers (c_fs)
1193		asm_idents = convert_recursion_idents (idents)
1194		asm_fs = get_functions_with_tag ('ASM')
1195		printout ('Computed recursion limits.')
1196
1197		bounds = compute_asm_stack_bounds (asm_idents, asm_fs)
1198		printout ('Computed stack bounds.')
1199	except Exception, e:
1200		if quiet:
1201			target_objects.tracer[0] = prev_tracer
1202		raise
1203
1204	if quiet:
1205		target_objects.tracer[0] = prev_tracer
1206	return bounds
1207
1208def read_fn_hash (fname):
1209	try:
1210		f = open (fname)
1211		s = f.readline ()
1212		bits = s.split ()
1213		if bits[0] != 'FunctionHash' or len (bits) != 2:
1214			return None
1215		return int (bits[1])
1216	except ValueError, e:
1217		return None
1218	except IndexError, e:
1219		return None
1220	except IOError, e:
1221		return None
1222
1223def mk_stack_pairings (pairing_tups, stack_bounds_fname = None,
1224		quiet = True):
1225	"""build the stack-aware calling-convention-aware logical pairings
1226	once a collection of function pairs have been read."""
1227
1228	# simplifies interactive testing of this function
1229	pre_pairings.clear ()
1230
1231	for (asm_f, c_f) in pairing_tups:
1232		pair = {'ASM': asm_f, 'C': c_f}
1233		assert c_f not in pre_pairings
1234		assert asm_f not in pre_pairings
1235		pre_pairings[c_f] = pair
1236		pre_pairings[asm_f] = pair
1237
1238	fn_hash = hash (tuple (sorted ([(f, hash (functions[f]))
1239		for f in functions])))
1240	prev_hash = read_fn_hash (stack_bounds_fname)
1241	if prev_hash == fn_hash:
1242		f = open (stack_bounds_fname)
1243		f.readline ()
1244		stack_bounds = deserialise_stack_bounds (f)
1245		f.close ()
1246	else:
1247		printout ('Computing stack bounds.')
1248		stack_bounds = compute_stack_bounds (quiet = quiet)
1249		f = open (stack_bounds_fname, 'w')
1250		f.write ('FunctionHash %s\n' % fn_hash)
1251		for line in serialise_stack_bounds (stack_bounds):
1252			f.write(line)
1253		f.close ()
1254
1255	problematic_synthetic ()
1256
1257	return mk_pairings (stack_bounds)
1258
1259def asm_stack_rep_hook (p, (nm, typ), kind, n):
1260	if not is_asm_node (p, n):
1261		return None
1262
1263	if not (nm.startswith ('stack') and typ == syntax.builtinTs['Mem']):
1264		return None
1265
1266	assert kind in ['Call', 'Init', 'Loop'], kind
1267	if kind == 'Init':
1268		return None
1269
1270	tag = p.node_tags[n][0]
1271	(_, sp) = get_stack_sp (p, tag)
1272
1273	return ('SplitMem', sp)
1274
1275reg_aliases = {'r11': ['fp'], 'r14': ['lr'], 'r13': ['sp']}
1276
1277def inst_const_rets (node):
1278	assert "instruction'" in node.fname
1279	bits = set ([s.lower () for s in node.fname.split ('_')])
1280	fun = functions[node.fname]
1281	def is_const (nm, typ):
1282		if typ in [builtinTs['Mem'], builtinTs['Dom']]:
1283			return True
1284		if typ != word32T:
1285			return False
1286		return not (nm in bits or [al for al in reg_aliases.get (nm, [])
1287				if al in bits])
1288	is_consts = [is_const (nm, typ) for (nm, typ) in fun.outputs]
1289	input_set = set ([v for arg in node.args
1290		for v in syntax.get_expr_var_set (arg)])
1291	return [mk_var (nm, typ)
1292		for ((nm, typ), const) in azip (node.rets, is_consts)
1293		if const and (nm, typ) in input_set]
1294
1295def node_const_rets (node):
1296	if "instruction'" in node.fname:
1297		return inst_const_rets (node)
1298	if node.fname in pre_pairings:
1299		if pre_pairings[node.fname]['ASM'] != node.fname:
1300			return None
1301		cc = get_asm_calling_convention_at_node (node)
1302		input_set = set ([v for arg in node.args
1303			for v in syntax.get_expr_var_set (arg)])
1304		callee_saved_set = set (cc['callee_saved'])
1305		return [mk_var (nm, typ) for (nm, typ) in node.rets
1306			if mk_var (nm, typ) in callee_saved_set
1307			if (nm, typ) in input_set]
1308	elif preserves_sp (node.fname):
1309		if node.fname not in get_functions_with_tag ('ASM'):
1310			return None
1311		f_outs = functions[node.fname].outputs
1312		return [mk_var (nm, typ)
1313			for ((nm, typ), (nm2, _)) in azip (node.rets, f_outs)
1314			if nm2 == 'r13']
1315	else:
1316		return None
1317
1318def const_ret_hook (node, nm, typ):
1319	consts = node_const_rets (node)
1320	return consts and mk_var (nm, typ) in consts
1321
1322def get_const_rets (p, node_set = None):
1323	if node_set == None:
1324		node_set = p.nodes
1325	const_rets = {}
1326	for n in node_set:
1327		if p.nodes[n].kind != 'Call':
1328			continue
1329		consts = node_const_rets (node)
1330		const_rets[n] = [(v.name, v.typ) for v in consts]
1331	return const_rets
1332
1333def problematic_synthetic ():
1334	synth = [s for s in target_objects.symbols
1335		if '.clone.' in s or '.part.' in s or '.constprop.' in s]
1336	synth = ['_'.join (s.split ('.')) for s in synth]
1337	if not synth:
1338		return
1339	printout ('Synthetic symbols: %s' % synth)
1340	synth_calls = set ([f for f in synth
1341		if f in functions
1342		if functions[f].function_calls ()])
1343	printout ('Synthetic symbols which make function calls: %s'
1344		% synth_calls)
1345	if not synth_calls:
1346		return
1347	synth_stack = set ([f for f in synth_calls
1348		if [node for node in functions[f].nodes.itervalues ()
1349			if node.kind == 'Basic'
1350			if ('r13', word32T) in node.get_lvals ()]])
1351	printout ('Synthetic symbols which call and move sp: %s'
1352		% synth_stack)
1353	synth_problems = set ([f for f in synth_stack
1354		if [f2 for f2 in functions
1355			if f in functions[f2].function_calls ()
1356			if len (set (functions[f2].function_calls ())) > 1]
1357		])
1358	printout ('Problematic synthetics: %s' % synth_problems)
1359	return synth_problems
1360
1361def add_hooks ():
1362	k = 'stack_logic'
1363	add = target_objects.add_hook
1364	add ('problem_var_rep', k, asm_stack_rep_hook)
1365	add ('loop_var_analysis', k, loop_var_analysis)
1366	add ('rep_unsafe_const_ret', k, const_ret_hook)
1367
1368add_hooks ()
1369
1370