1# * Copyright 2015, NICTA
2# *
3# * This software may be distributed and modified according to the terms of
4# * the BSD 2-Clause license. Note that NO WARRANTY is provided.
5# * See "LICENSE_BSD2.txt" for details.
6# *
7# * @TAG(NICTA_BSD)
8
9import syntax
10import solver
11import problem
12import rep_graph
13import search
14import logic
15import check
16
17from target_objects import functions, trace, pairings, pre_pairings, printout
18import target_objects
19
20from logic import azip
21
22from syntax import mk_var, word32T, builtinTs, mk_eq, mk_less_eq
23
24last_stuff = [0]
25
26def default_n_vc (p, n):
27	head = p.loop_id (n)
28	general = [(n2, rep_graph.vc_options ([0], [1]))
29		for n2 in p.loop_heads ()
30		if n2 != head]
31	specific = [(head, rep_graph.vc_offs (1)) for _ in [1] if head]
32	return (n, tuple (general + specific))
33
34def split_sum_s_expr (expr, solv, extra_defs, typ):
35	"""divides up a linear expression 'a - b - 1 + a'
36	into ({'a':2, 'b': -1}, -1) i.e. 'a' times 2 etc and constant
37	value of -1."""
38	def rec (expr):
39		return split_sum_s_expr (expr, solv, extra_defs, typ)
40	if expr[0] == 'bvadd':
41		var = {}
42		const = 0
43		for x in expr[1:]:
44			(var2, const2) = rec (x)
45			for (v, count) in var2.iteritems ():
46				var.setdefault (v, 0)
47				var[v] += count
48			const += const2
49		return (var, const)
50	elif expr[0] == 'bvsub':
51		(_, lhs, rhs) = expr
52		(lvar, lconst) = rec (lhs)
53		(rvar, rconst) = rec (rhs)
54		const = lconst - rconst
55		var = dict ([(v, lvar.get (v, 0) - rvar.get (v, 0))
56			for v in set.union (set (lvar), set (rvar))])
57		return (var, const)
58	elif expr in solv.defs:
59		return rec (solv.defs[expr])
60	elif expr in extra_defs:
61		return rec (extra_defs[expr])
62	elif expr[:2] in ['#x', '#b']:
63		val = solver.smt_to_val (expr)
64		assert val.kind == 'Num'
65		return ({}, val.val)
66	else:
67		return ({expr: 1}, 0)
68
69def split_merge_ite_sum_sexpr (foo):
70	(s0, s1) = [solver.smt_num_t (n, typ) for n in [0, 1]]
71	if y != s0:
72		expr = ('bvadd', ('ite', cond, ('bvsub', x, y), s0), y)
73		return rec (expr)
74	(xvar, xconst) = rec (x)
75	var = dict ([(('ite', cond, v, s0), n)
76		for (v, n) in xvar.iteritems ()])
77	var.setdefault (('ite', cond, s1, s0), 0)
78	var[('ite', cond, s1, s0)] += xconst
79	return (var, 0)
80
81def simplify_expr_whyps (sexpr, rep, hyps, cache = None, extra_defs = {},
82		bool_hyps = None):
83	if cache == None:
84		cache = {}
85	if bool_hyps == None:
86		bool_hyps = []
87	if sexpr in extra_defs:
88		sexpr = extra_defs[sexpr]
89	if sexpr in rep.solv.defs:
90		sexpr = rep.solv.defs[sexpr]
91	if sexpr[0] == 'ite':
92		(_, cond, x, y) = sexpr
93		cond_exp = solver.mk_smt_expr (solver.flat_s_expression (cond),
94			syntax.boolT)
95		(mk_nimp, mk_not) = (syntax.mk_n_implies, syntax.mk_not)
96		if rep.test_hyp_whyps (mk_nimp (bool_hyps, cond_exp),
97				hyps, cache = cache):
98			return x
99		elif rep.test_hyp_whyps (mk_nimp (bool_hyps, mk_not (cond_exp)),
100				hyps, cache = cache):
101			return y
102		x = simplify_expr_whyps (x, rep, hyps, cache = cache,
103			extra_defs = extra_defs,
104			bool_hyps = bool_hyps + [cond_exp])
105		y = simplify_expr_whyps (y, rep, hyps, cache = cache,
106			extra_defs = extra_defs,
107			bool_hyps = bool_hyps + [syntax.mk_not (cond_exp)])
108		if x == y:
109			return x
110		return ('ite', cond, x, y)
111	return sexpr
112
113last_10_non_const = []
114
115def offs_expr_const (addr_expr, sp_expr, rep, hyps, extra_defs = {},
116		cache = None, typ = syntax.word32T):
117	"""if the offset between a stack addr and the initial stack pointer
118	is a constant offset, try to compute it."""
119	addr_x = solver.parse_s_expression (addr_expr)
120	sp_x = solver.parse_s_expression (sp_expr)
121	vs = [(addr_x, 1), (sp_x, -1)]
122	const = 0
123
124	while True:
125		start_vs = list (vs)
126		new_vs = {}
127		for (x, mult) in vs:
128			(var, c) = split_sum_s_expr (x, rep.solv, extra_defs,
129				typ = typ)
130			for v in var:
131				new_vs.setdefault (v, 0)
132				new_vs[v] += var[v] * mult
133			const += c * mult
134		vs = [(x, n) for (x, n) in new_vs.iteritems ()
135			if n % (2 ** typ.num) != 0]
136		if not vs:
137			return const
138		vs = [(simplify_expr_whyps (x, rep, hyps,
139				cache = cache, extra_defs = extra_defs), n)
140			for (x, n) in vs]
141		if sorted (vs) == sorted (start_vs):
142			pass # vs = split_merge_ite_sum_sexpr (vs)
143		if sorted (vs) == sorted (start_vs):
144			trace ('offs_expr_const: not const')
145			trace ('%s - %s' % (addr_expr, sp_expr))
146			trace (str (vs))
147			trace (str (hyps))
148			last_10_non_const.append ((addr_expr, sp_expr, vs, hyps))
149			del last_10_non_const[:-10]
150			return None
151
152def has_stack_var (expr, stack_var):
153	while True:
154		if expr.is_op ('MemUpdate'):
155			[m, p, v] = expr.vals
156			expr = m
157		elif expr.kind == 'Var':
158			return expr == stack_var
159		else:
160			assert not 'has_stack_var: expr kind', expr
161
162def mk_not_callable_hyps (p):
163	hyps = []
164	for n in p.nodes:
165		if p.nodes[n].kind != 'Call':
166			continue
167		if get_asm_callable (p.nodes[n].fname):
168			continue
169		tag = p.node_tags[n][0]
170		hyp = rep_graph.pc_false_hyp ((default_n_vc (p, n), tag))
171		hyps.append (hyp)
172	return hyps
173
174last_get_ptr_offsets = [0]
175last_get_ptr_offsets_setup = [0]
176
177def get_ptr_offsets (p, n_ptrs, bases, hyps = [], cache = None,
178		fail_early = False):
179	"""detect which ptrs are guaranteed to be at constant offsets
180	from some set of basis ptrs"""
181	rep = rep_graph.mk_graph_slice (p, fast = True)
182	if cache == None:
183		cache = {}
184	last_get_ptr_offsets[0] = (p, n_ptrs, bases, hyps)
185
186	smt_bases = []
187	for (n, ptr, k) in bases:
188		n_vc = default_n_vc (p, n)
189		(_, env) = rep.get_node_pc_env (n_vc)
190		smt = solver.smt_expr (ptr, env, rep.solv)
191		smt_bases.append ((smt, k))
192		ptr_typ = ptr.typ
193
194	smt_ptrs = []
195	for (n, ptr) in n_ptrs:
196		n_vc = default_n_vc (p, n)
197		pc_env = rep.get_node_pc_env (n_vc)
198		if not pc_env:
199			continue
200		smt = solver.smt_expr (ptr, pc_env[1], rep.solv)
201		hyp = rep_graph.pc_true_hyp ((n_vc, p.node_tags[n][0]))
202		smt_ptrs.append (((n, ptr), smt, hyp))
203
204	hyps = hyps + mk_not_callable_hyps (p)
205	for tag in set ([p.node_tags[n][0] for (n, _) in n_ptrs]):
206		hyps = hyps + init_correctness_hyps (p, tag)
207	tags = set ([p.node_tags[n][0] for (n, ptr) in n_ptrs])
208	ex_defs = {}
209	for t in tags:
210		ex_defs.update (get_extra_sp_defs (rep, t))
211
212	offs = []
213	for (v, ptr, hyp) in smt_ptrs:
214		off = None
215		for (ptr2, k) in smt_bases:
216			off = offs_expr_const (ptr, ptr2, rep, [hyp] + hyps,
217				cache = cache, extra_defs = ex_defs,
218				typ = ptr_typ)
219			if off != None:
220				offs.append ((v, off, k))
221				break
222		if off == None:
223			trace ('get_ptr_offs fallthrough at %d: %s' % v)
224			trace (str ([hyp] + hyps))
225			assert not fail_early, (v, ptr)
226	return offs
227
228def init_correctness_hyps (p, tag):
229	(_, fname, _) = p.get_entry_details (tag)
230	if fname not in pairings:
231		# conveniently handles bootstrap case
232		return []
233	# revise if multi-pairings for ASM an option
234	[pair] = pairings[fname]
235	true_tag = None
236	if tag in pair.funs:
237		true_tag = tag
238	elif p.hook_tag_hints.get (tag, tag) in pair.funs:
239		true_tag = p.hook_tag_hints.get (tag, tag)
240	if true_tag == None:
241		return []
242	(inp_eqs, _) = pair.eqs
243	in_tag = "%s_IN" % true_tag
244	eqs = [eq for eq in inp_eqs if eq[0][1] == in_tag
245		and eq[1][1] == in_tag]
246	return check.inst_eqs (p, (), eqs, {true_tag: tag})
247
248extra_symbols = set ()
249
250def preserves_sp (fname):
251	"""all functions will keep the stack pointer equal, whether they have
252	pairing partners or not."""
253	assume_sp_equal = bool (target_objects.hooks ('assume_sp_equal'))
254	if not extra_symbols:
255		for fname2 in target_objects.symbols:
256			extra_symbols.add(fname2)
257			extra_symbols.add('_'.join (fname2.split ('.')))
258	return (get_asm_calling_convention (fname)
259		or assume_sp_equal
260		or fname in extra_symbols)
261
262def get_extra_sp_defs (rep, tag):
263	"""add extra defs/equalities about stack pointer for the
264	purposes of stack depth analysis."""
265	# FIXME how to parametrise this?
266	sp = mk_var ('r13', syntax.word32T)
267	defs = {}
268
269	fcalls = [n_vc for n_vc in rep.funcs
270		if logic.is_int (n_vc[0])
271		if rep.p.node_tags[n_vc[0]][0] == tag
272		if preserves_sp (rep.p.nodes[n_vc[0]].fname)]
273	for (n, vc) in fcalls:
274		(inputs, outputs, _) = rep.funcs[(n, vc)]
275		if (sp.name, sp.typ) not in outputs:
276			continue
277		inp_sp = solver.smt_expr (sp, inputs, rep.solv)
278		inp_sp = solver.parse_s_expression (inp_sp)
279		out_sp = solver.smt_expr (sp, outputs, rep.solv)
280		out_sp = solver.parse_s_expression (out_sp)
281		if inp_sp != out_sp:
282			defs[out_sp] = inp_sp
283	return defs
284
285def get_stack_sp (p, tag):
286	"""get stack and stack-pointer variables"""
287	entry = p.get_entry (tag)
288	renames = p.entry_exit_renames (tags = [tag])
289	r = renames[tag + '_IN']
290
291	sp = syntax.rename_expr (mk_var ('r13', syntax.word32T), r)
292	stack = syntax.rename_expr (mk_var ('stack',
293		syntax.builtinTs['Mem']), r)
294	return (stack, sp)
295
296def pseudo_node_lvals_rvals (node):
297	assert node.kind == 'Call'
298	cc = get_asm_calling_convention_at_node (node)
299	if not cc:
300		return None
301
302	arg_vars = set ([var for arg in cc['args']
303		for var in syntax.get_expr_var_set (arg)])
304
305	callee_saved_set = set (cc['callee_saved'])
306	rets = [(nm, typ) for (nm, typ) in node.rets
307		if mk_var (nm, typ) not in callee_saved_set]
308
309	return (rets, arg_vars)
310
311def is_asm_node (p, n):
312	tag = p.node_tags[n][0]
313	return tag == 'ASM' or p.hook_tag_hints.get (tag, None) == 'ASM'
314
315def all_pseudo_node_lvals_rvals (p):
316	pseudo = {}
317	for n in p.nodes:
318		if not is_asm_node (p, n):
319			continue
320		elif p.nodes[n].kind != 'Call':
321			continue
322		ps = pseudo_node_lvals_rvals (p.nodes[n])
323		if ps != None:
324			pseudo[n] = ps
325	return pseudo
326
327def adjusted_var_dep_outputs_for_tag (p, tag):
328	(ent, fname, _) = p.get_entry_details (tag)
329	fun = functions[fname]
330	cc = get_asm_calling_convention (fname)
331	callee_saved_set = set (cc['callee_saved'])
332	ret_set = set ([(nm, typ) for ret in cc['rets']
333		for (nm, typ) in syntax.get_expr_var_set (ret)])
334	rets = [(nm2, typ) for ((nm, typ), (nm2, _))
335			in azip (fun.outputs, p.outputs[tag])
336			if (nm, typ) in ret_set
337				or mk_var (nm, typ) in callee_saved_set]
338	return rets
339
340def adjusted_var_dep_outputs (p):
341	outputs = {}
342	for tag in p.outputs:
343		ent = p.get_entry (tag)
344		if is_asm_node (p, ent):
345			outputs[tag] = adjusted_var_dep_outputs_for_tag (p, tag)
346		else:
347			outputs[tag] = p.outputs[tag]
348	def output (n):
349		tag = p.node_tags[n][0]
350		return outputs[tag]
351	return output
352
353def is_stack (expr):
354	return expr.kind == 'Var' and 'stack' in expr.name
355
356class StackOffsMissing (Exception):
357	pass
358
359def stack_virtualise_expr (expr, sp_offs):
360	if expr.is_op ('MemAcc') and is_stack (expr.vals[0]):
361		[m, p] = expr.vals
362		if expr.typ == syntax.word8T:
363			ps = [(syntax.mk_minus (p, syntax.mk_word32 (n)), n)
364				for n in [0, 1, 2, 3]]
365		elif expr.typ == syntax.word32T:
366			ps = [(p, 0)]
367		else:
368			assert expr.typ == syntax.word32T, expr
369		ptrs = [(p, 'MemAcc') for (p, _) in ps]
370		if sp_offs == None:
371			return (ptrs, None)
372		# FIXME: very 32-bit specific
373		ps = [(p, n) for (p, n) in ps if p in sp_offs
374			if sp_offs[p][1] % 4 == 0]
375		if not ps:
376			return (ptrs, expr)
377		[(p, n)] = ps
378		if p not in sp_offs:
379			raise StackOffsMissing ()
380		(k, offs) = sp_offs[p]
381		v = mk_var (('Fake', k, offs), syntax.word32T)
382		if n != 0:
383			v = syntax.mk_shiftr (v, n * 8)
384		v = syntax.mk_cast (v, expr.typ)
385		return (ptrs, v)
386	elif expr.kind == 'Op':
387		vs = [stack_virtualise_expr (v, sp_offs) for v in expr.vals]
388		return ([p for (ptrs, _) in vs for p in ptrs],
389			syntax.adjust_op_vals (expr, [v for (_, v) in vs]))
390	else:
391		return ([], expr)
392
393def stack_virtualise_upd (((nm, typ), expr), sp_offs):
394	if 'stack' in nm:
395		upds = []
396		ptrs = []
397		while expr.is_op ('MemUpdate'):
398			[m, p, v] = expr.vals
399			ptrs.append ((p, 'MemUpdate'))
400			(ptrs2, v2) = stack_virtualise_expr (v, sp_offs)
401			ptrs.extend (ptrs2)
402			if sp_offs != None:
403				if p not in sp_offs:
404					raise StackOffsMissing ()
405				(k, offs) = sp_offs[p]
406				upds.append (((('Fake', k, offs),
407					syntax.word32T), v2))
408			expr = m
409		assert is_stack (expr), expr
410		return (ptrs, upds)
411	else:
412		(ptrs, expr2) = stack_virtualise_expr (expr, sp_offs)
413		return (ptrs, [((nm, typ), expr2)])
414
415def stack_virtualise_ret (expr, sp_offs):
416	if expr.kind == 'Var':
417		return ([], (expr.name, expr.typ))
418	elif expr.is_op ('MemAcc'):
419		[m, p] = expr.vals
420		assert expr.typ == syntax.word32T, expr
421		assert is_stack (m), expr
422		if sp_offs != None:
423			(k, offs) = sp_offs[p]
424			r = (('Fake', k, offs), syntax.word32T)
425		else:
426			r = None
427		return ([(p, 'MemUpdate')], r)
428	else:
429		assert not 'ret expr understood', expr
430
431def stack_virtualise_node (node, sp_offs):
432	if node.kind == 'Cond':
433		(ptrs, cond) = stack_virtualise_expr (node.cond, sp_offs)
434		if sp_offs == None:
435			return (ptrs, None)
436		else:
437			return (ptrs, syntax.Node ('Cond',
438				node.get_conts (), cond))
439	elif node.kind == 'Call':
440		if is_instruction (node.fname):
441			return ([], node)
442		cc = get_asm_calling_convention_at_node (node)
443		assert cc != None, node.fname
444		args = [arg for arg in cc['args'] if not is_stack (arg)]
445		args = [stack_virtualise_expr (arg, sp_offs) for arg in args]
446		rets = [ret for ret in cc['rets_inp'] if not is_stack (ret)]
447		rets = [stack_virtualise_ret (ret, sp_offs) for ret in rets]
448		ptrs = list (set ([p for (ps, _) in args for p in ps]
449			+ [p for (ps, _) in rets for p in ps]))
450		if sp_offs == None:
451			return (ptrs, None)
452		else:
453			return (ptrs, syntax.Node ('Call', node.cont,
454				(None, [v for (_, v) in args]
455					+ [p for (p, _) in ptrs],
456					[r for (_, r) in rets])))
457	elif node.kind == 'Basic':
458		upds = [stack_virtualise_upd (upd, sp_offs) for upd in node.upds]
459		ptrs = list (set ([p for (ps, _) in upds for p in ps]))
460		if sp_offs == None:
461			return (ptrs, None)
462		else:
463			ptr_upds = [(('unused#ptr#name%d' % i, syntax.word32T),
464				ptr) for (i, (ptr, _)) in enumerate (ptrs)]
465			return (ptrs, syntax.Node ('Basic', node.cont,
466				[upd for (_, us) in upds for upd in us]
467					+ ptr_upds))
468	else:
469		assert not "node kind understood", node.kind
470
471def mk_get_local_offs (p, tag, sp_reps):
472	(stack, _) = get_stack_sp (p, tag)
473	def mk_local (n, kind, off, k):
474		(v, off2) = sp_reps[n][k]
475		ptr = syntax.mk_plus (v, syntax.mk_word32 (off + off2))
476		if kind == 'Ptr':
477			return ptr
478		elif kind == 'MemAcc':
479			return syntax.mk_memacc (stack, ptr, syntax.word32T)
480	return mk_local
481
482def adjust_ret_ptr (ptr):
483	"""this is a bit of a hack.
484
485	the return slots are named based on r0_input, which will be unchanged,
486	which is handy, but we really want to be talking about r0, which will
487	produce meaningful offsets against the pointers actually used in the
488	program."""
489
490	return logic.var_subst (ptr, {('r0_input', syntax.word32T):
491		syntax.mk_var ('r0', syntax.word32T)}, must_subst = False)
492
493def get_loop_virtual_stack_analysis (p, tag):
494	"""computes variable liveness etc analyses with stack slots treated
495	as virtual variables."""
496	cache_key = ('loop_stack_analysis', tag)
497	if cache_key in p.cached_analysis:
498		return p.cached_analysis[cache_key]
499
500	(ent, fname, _) = p.get_entry_details (tag)
501	(_, sp) = get_stack_sp (p, tag)
502	cc = get_asm_calling_convention (fname)
503	rets = list (set ([ptr for arg in cc['rets']
504		for (ptr, _) in stack_virtualise_expr (arg, None)[0]]))
505	rets = [adjust_ret_ptr (ret) for ret in rets]
506	renames = p.entry_exit_renames (tags = [tag])
507	r = renames[tag + '_OUT']
508	rets = [syntax.rename_expr (ret, r) for ret in rets]
509
510	ns = [n for n in p.nodes if p.node_tags[n][0] == tag]
511	loop_ns = logic.minimal_loop_node_set (p)
512
513	ptrs = list (set ([(n, ptr) for n in ns
514		for ptr in (stack_virtualise_node (p.nodes[n], None))[0]]))
515	ptrs += [(n, (sp, 'StackPointer')) for n in ns if n in loop_ns]
516	offs = get_ptr_offsets (p, [(n, ptr) for (n, (ptr, _)) in ptrs],
517		[(ent, sp, 'stack')]
518			+ [(ent, ptr, 'indirect_ret') for ptr in rets[:1]])
519
520	ptr_offs = {}
521	rep_offs = {}
522	upd_offsets = {}
523	for ((n, ptr), off, k) in offs:
524		off = norm_int (off, 32)
525		ptr_offs.setdefault (n, {})
526		rep_offs.setdefault (n, {})
527		ptr_offs[n][ptr] = (k, off)
528		rep_offs[n][k] = (ptr, - off)
529
530	for (n, (ptr, kind)) in ptrs:
531		if kind == 'MemUpdate' and n in loop_ns:
532			loop = p.loop_id (n)
533			(k, off) = ptr_offs[n][ptr]
534			upd_offsets.setdefault (loop, set ())
535			upd_offsets[loop].add ((k, off))
536	loc_offs = mk_get_local_offs (p, tag, rep_offs)
537
538	adj_nodes = {}
539	for n in ns:
540		try:
541			(_, node) = stack_virtualise_node (p.nodes[n],
542				ptr_offs.get (n, {}))
543		except StackOffsMissing, e:
544			printout ("Stack analysis issue at (%d, %s)."
545				% (n, p.node_tags[n]))
546			node = p.nodes[n]
547		adj_nodes[n] = node
548
549	# finally do analysis on this collection of nodes
550
551	preds = dict (p.preds)
552	preds['Ret'] = [n for n in preds['Ret'] if p.node_tags[n][0] == tag]
553	preds['Err'] = [n for n in preds['Err'] if p.node_tags[n][0] == tag]
554	vds = logic.compute_var_deps (adj_nodes,
555		adjusted_var_dep_outputs (p), preds)
556
557	result = (vds, adj_nodes, loc_offs, upd_offsets, (ptrs, offs))
558	p.cached_analysis[cache_key] = result
559	return result
560
561def norm_int (n, radix):
562	n = n & ((1 << radix) - 1)
563	n2 = n - (1 << radix)
564	if abs (n2) < abs (n):
565		return n2
566	else:
567		return n
568
569def loop_var_analysis (p, split):
570	"""computes the same loop dataflow analysis as in the 'logic' module
571	but with stack slots treated as virtual variables."""
572	if not is_asm_node (p, split):
573		return None
574	head = p.loop_id (split)
575	tag = p.node_tags[split][0]
576	assert head
577
578	key = ('loop_stack_virtual_var_cycle_analysis', split)
579	if key in p.cached_analysis:
580		return p.cached_analysis[key]
581
582	(vds, adj_nodes, loc_offs,
583		upd_offsets, _) = get_loop_virtual_stack_analysis (p, tag)
584	loop = p.loop_body (head)
585
586	va = logic.compute_loop_var_analysis (p, vds, split,
587		override_nodes = adj_nodes)
588
589	(stack, _) = get_stack_sp (p, tag)
590
591	va2 = []
592	uoffs = upd_offsets.get (head, [])
593	for (v, data) in va:
594		if v.kind == 'Var' and v.name[0] == 'Fake':
595			(_, k, offs) = v.name
596			if (k, offs) not in uoffs:
597				continue
598			v2 = loc_offs (split, 'MemAcc', offs, k)
599			va2.append ((v2, data))
600		elif v.kind == 'Var' and v.name.startswith ('stack'):
601			assert v.typ == stack.typ
602			continue
603		else:
604			va2.append ((v, data))
605	stack_const = stack
606	for (k, off) in uoffs:
607		stack_const = syntax.mk_memupd (stack_const,
608			loc_offs (split, 'Ptr', off, k),
609			syntax.mk_word32 (0))
610	sp = asm_stack_rep_hook (p, (stack.name, stack.typ), 'Loop', split)
611	assert sp and sp[0] == 'SplitMem', (split, sp)
612	(_, st_split) = sp
613	stack_const = logic.mk_stack_wrapper (st_split, stack_const, [])
614	stack_const = logic.mk_eq_selective_wrapper (stack_const,
615		([], [0]))
616
617	va2.append ((stack_const, 'LoopConst'))
618
619	p.cached_analysis[key] = va2
620	return va2
621
622def inline_no_pre_pairing (p):
623	# FIXME: handle code sharing with check.inline_completely_unmatched
624	while True:
625		ns = [n for n in p.nodes if p.nodes[n].kind == 'Call'
626			if p.nodes[n].fname not in pre_pairings
627			if not is_instruction (p.nodes[n].fname)]
628		for n in ns:
629			trace ('Inlining %s at %d.' % (p.nodes[n].fname, n))
630			problem.inline_at_point (p, n)
631		if not ns:
632			return
633
634last_asm_stack_depth_fun = [0]
635
636def check_before_guess_asm_stack_depth (fun):
637	from solver import smt_expr
638	if not fun.entry:
639		return None
640	p = fun.as_problem (problem.Problem, name = 'Target')
641	try:
642		p.do_analysis ()
643		p.check_no_inner_loops ()
644		inline_no_pre_pairing (p)
645	except problem.Abort, e:
646		return None
647	rep = rep_graph.mk_graph_slice (p, fast = True)
648	try:
649		rep.get_pc (default_n_vc (p, 'Ret'), 'Target')
650		err_pc = rep.get_pc (default_n_vc (p, 'Err'), 'Target')
651	except solver.EnvMiss, e:
652		return None
653
654	inlined_funs = set ([fn for (_, _, fn) in p.inline_scripts['Target']])
655	if inlined_funs:
656		printout ('  (stack analysis also involves %s)'
657			% ', '.join(inlined_funs))
658
659	return p
660
661def guess_asm_stack_depth (fun):
662	p = check_before_guess_asm_stack_depth (fun)
663	if not p:
664		return (0, {})
665
666	last_asm_stack_depth_fun[0] = fun.name
667
668	entry = p.get_entry ('Target')
669	(_, sp) = get_stack_sp (p, 'Target')
670
671	nodes = get_asm_reachable_nodes (p, tag_set = ['Target'])
672
673	offs = get_ptr_offsets (p, [(n, sp) for n in nodes],
674		[(entry, sp, 'InitSP')], fail_early = True)
675
676	assert len (offs) == len (nodes), map (hex, set (nodes)
677		- set ([n for ((n, _), _, _) in offs]))
678
679	all_offs = [(n, signed_offset (off, 32, 10 ** 6))
680		for ((n, ptr), off, _) in offs]
681	min_offs = min ([offs for (n, offs) in all_offs])
682	max_offs = max ([offs for (n, offs) in all_offs])
683
684	assert min_offs >= 0 or max_offs <= 0, all_offs
685	multiplier = 1
686	if min_offs < 0:
687		multiplier = -1
688		max_offs = - min_offs
689
690	fcall_offs = [(p.nodes[n].fname, offs * multiplier)
691		for (n, offs) in all_offs if p.nodes[n].kind == 'Call']
692	fun_offs = {}
693	for f in set ([f for (f, _) in fcall_offs]):
694		fun_offs[f] = max ([offs for (f2, offs) in fcall_offs
695			if f2 == f])
696
697	return (max_offs, fun_offs)
698
699def signed_offset (n, bits, bound = 0):
700	n = n & ((1 << bits) - 1)
701	if n >= (1 << (bits - 1)):
702		n = n - (1 << bits)
703	if bound:
704		assert n <= bound, (n, bound)
705		assert n >= (- bound), (n, bound)
706	return n
707
708def ident_conds (fname, idents):
709	rolling = syntax.true_term
710	conds = []
711	for ident in idents.get (fname, [syntax.true_term]):
712		conds.append ((ident, syntax.mk_and (rolling, ident)))
713		rolling = syntax.mk_and (rolling, syntax.mk_not (ident))
714	return conds
715
716def ident_callables (fname, callees, idents):
717	from solver import to_smt_expr, smt_expr
718	from syntax import mk_not, mk_and, true_term
719
720	auto_callables = dict ([((ident, f, true_term), True)
721		for ident in idents.get (fname, [true_term])
722		for f in callees if f not in idents])
723
724	if not [f for f in callees if f in idents]:
725		return auto_callables
726
727	fun = functions[fname]
728	p = fun.as_problem (problem.Problem, name = 'Target')
729	check_ns = [(n, ident, cond) for n in p.nodes
730		if p.nodes[n].kind == 'Call'
731		if p.nodes[n].fname in idents
732		for (ident, cond) in ident_conds (p.nodes[n].fname, idents)]
733
734	p.do_analysis ()
735	assert check_ns
736
737	rep = rep_graph.mk_graph_slice (p, fast = True)
738	err_hyp = rep_graph.pc_false_hyp ((default_n_vc (p, 'Err'), 'Target'))
739
740	callables = auto_callables
741	nhyps = mk_not_callable_hyps (p)
742
743	for (ident, cond) in ident_conds (fname, idents):
744		renames = p.entry_exit_renames (tags = ['Target'])
745		cond = syntax.rename_expr (cond, renames['Target_IN'])
746		entry = p.get_entry ('Target')
747		e_vis = ((entry, ()), 'Target')
748		hyps = [err_hyp, rep_graph.eq_hyp ((cond, e_vis),
749				(true_term, e_vis))]
750
751		for (n, ident2, cond2) in check_ns:
752			k = (ident, p.nodes[n].fname, ident2)
753			(inp_env, _, _) = rep.get_func (default_n_vc (p, n))
754			pc = rep.get_pc (default_n_vc (p, n))
755			cond2 = to_smt_expr (cond2, inp_env, rep.solv)
756			if rep.test_hyp_whyps (mk_not (mk_and (pc, cond2)),
757					hyps + nhyps):
758				callables[k] = False
759			else:
760				callables[k] = True
761	return callables
762
763def compute_immediate_stack_bounds (idents, names):
764	from syntax import true_term
765	immed = {}
766	names = sorted (names)
767	for (i, fname) in enumerate (names):
768		printout ('Doing stack analysis for %r. (%d of %d)' % (fname,
769			i + 1, len (names)))
770		fun = functions[fname]
771		(offs, fn_offs) = guess_asm_stack_depth (fun)
772		callables = ident_callables (fname, fn_offs.keys (), idents)
773		for ident in idents.get (fname, [true_term]):
774			calls = [((fname2, ident2), fn_offs[fname2])
775				for fname2 in fn_offs
776				for ident2 in idents.get (fname2, [true_term])
777				if callables[(ident, fname2, ident2)]]
778			immed[(fname, ident)] = (offs, dict (calls))
779	last_immediate_stack_bounds[0] = immed
780	return immed
781
782last_immediate_stack_bounds = [0]
783
784def immediate_stack_bounds_loop (immed):
785	graph = dict ([(k, immed[k][1].keys ()) for k in immed])
786	graph['ENTRY'] = list (immed)
787	comps = logic.tarjan (graph, ['ENTRY'])
788	rec_comps = [[x] + y for (x, y) in comps if y]
789	return rec_comps
790
791def compute_recursive_stack_bounds (immed):
792	assert not immediate_stack_bounds_loop (immed)
793	bounds = {}
794	todo = immed.keys ()
795	report = 1000
796	while todo:
797		if len (todo) >= report:
798			trace ('todo length %d' % len (todo))
799			trace ('tail: %s' % todo[-20:])
800			report += 1000
801		(fname, ident) = todo.pop ()
802		if (fname, ident) in bounds:
803			continue
804		(static, calls) = immed[(fname, ident)]
805		if [1 for k in calls if k not in bounds]:
806			todo.append ((fname, ident))
807			todo.extend (calls.keys ())
808			continue
809		else:
810			bounds[(fname, ident)] = max ([static]
811				+ [bounds[k] + calls[k] for k in calls])
812	return bounds
813
814def stack_bounds_to_closed_form (bounds, names, idents):
815	closed = {}
816	for fname in names:
817		res = syntax.mk_word32 (bounds[(fname, syntax.true_term)])
818		extras = []
819		if fname in idents:
820			assert idents[fname][-1] == syntax.true_term
821			extras = reversed (idents[fname][:-1])
822		for ident in extras:
823			alt = syntax.mk_word32 (bounds[(fname, ident)])
824			res = syntax.mk_if (ident, alt, res)
825		closed[fname] = res
826	return closed
827
828def compute_asm_stack_bounds (idents, names):
829	immed = compute_immediate_stack_bounds (idents, names)
830	bounds = compute_recursive_stack_bounds (immed)
831	closed = stack_bounds_to_closed_form (bounds, names, idents)
832	return closed
833
834recursion_trace = []
835recursion_last_assns = [[]]
836
837def get_recursion_identifiers (funs, extra_unfolds = []):
838	idents = {}
839	del recursion_trace[:]
840	graph = dict ([(f, list (functions[f].function_calls ()))
841		for f in functions])
842	fs = funs
843	fs2 = set ()
844	while fs2 != fs:
845		fs2 = fs
846		fs = set.union (set ([f for f in graph if [f2 for f2 in graph[f]
847				if f2 in fs2]]),
848			set ([f2 for f in fs2 for f2 in graph[f]]), fs2)
849	graph = dict ([(f, graph[f]) for f in fs])
850	entries = list (fs - set ([f2 for f in graph for f2 in graph[f]]))
851	comps = logic.tarjan (graph, entries)
852	for (head, tail) in comps:
853		if tail or head in graph[head]:
854			group = [head] + list (tail)
855			idents2 = compute_recursion_idents (group,
856				extra_unfolds)
857			idents.update (idents2)
858	return idents
859
860def compute_recursion_idents (group, extra_unfolds):
861	idents = {}
862	group = set (group)
863	recursion_trace.append ('Computing for group %s' % group)
864	printout ('Doing recursion analysis for function group:')
865	printout ('  %s' % list(group))
866	prevs = set ([f for f in functions
867		if [f2 for f2 in functions[f].function_calls () if f2 in group]])
868	for f in prevs - group:
869		recursion_trace.append ('  checking for %s' % f)
870		trace ('Checking idents for %s' % f)
871		while add_recursion_ident (f, group, idents, extra_unfolds):
872			pass
873	return idents
874
875def function_link_assns (p, call_site, tag):
876	call_vis = (default_n_vc (p, call_site), p.node_tags[call_site][0])
877	return rep_graph.mk_function_link_hyps (p, call_vis, tag)
878
879def add_recursion_ident (f, group, idents, extra_unfolds):
880	from syntax import mk_eq, mk_implies, mk_var
881	p = problem.Problem (None, name = 'Recursion Test')
882	chain = []
883	tag = 'fun0'
884	p.add_entry_function (functions[f], tag)
885	p.do_analysis ()
886	assns = []
887	recursion_last_assns[0] = assns
888
889	while True:
890		res = find_unknown_recursion (p, group, idents, tag, assns,
891			extra_unfolds)
892		if res == None:
893			break
894		if p.nodes[res].fname not in group:
895			problem.inline_at_point (p, res)
896			continue
897		fname = p.nodes[res].fname
898		chain.append (fname)
899		tag = 'fun%d' % len (chain)
900		(args, _, entry) = p.add_entry_function (functions[fname], tag)
901		p.do_analysis ()
902		assns += function_link_assns (p, res, tag)
903	if chain == []:
904		return None
905	recursion_trace.append ('  created fun chain %s' % chain)
906	word_args = [(i, mk_var (s, typ))
907		for (i, (s, typ)) in enumerate (args)
908		if typ.kind == 'Word']
909	rep = rep_graph.mk_graph_slice (p, fast = True)
910	(_, env) = rep.get_node_pc_env ((entry, ()))
911
912	m = {}
913	res = rep.test_hyp_whyps (syntax.false_term, assns, model = m)
914	assert m
915
916	if find_unknown_recursion (p, group, idents, tag, [], []) == None:
917		idents.setdefault (fname, [])
918		idents[fname].append (syntax.true_term)
919		recursion_trace.append ('      found final ident for %s' % fname)
920		return syntax.true_term
921	assert word_args
922	recursion_trace.append ('      scanning for ident for %s' % fname)
923	for (i, arg) in word_args:
924		(nm, typ) = functions[fname].inputs[i]
925		arg_smt = solver.to_smt_expr (arg, env, rep.solv)
926		val = search.eval_model_expr (m, rep.solv, arg_smt)
927		if not rep.test_hyp_whyps (mk_eq (arg_smt, val), assns):
928			recursion_trace.append ('      discarded %s = 0x%x, not stable' % (nm, val.val))
929			continue
930		entry_vis = ((entry, ()), tag)
931		ass = rep_graph.eq_hyp ((arg, entry_vis), (val, entry_vis))
932		res = find_unknown_recursion (p, group, idents, tag,
933				assns + [ass], [])
934		if res:
935			fname2 = p.nodes[res].fname
936			recursion_trace.append ('      discarded %s, allows recursion to %s' % (nm, fname2))
937			continue
938		eq = syntax.mk_eq (mk_var (nm, typ), val)
939		idents.setdefault (fname, [])
940		idents[fname].append (eq)
941		recursion_trace.append ('    found ident for %s: %s' % (fname, eq))
942		return eq
943	assert not "identifying assertion found"
944
945def find_unknown_recursion (p, group, idents, tag, assns, extra_unfolds):
946	from syntax import mk_not, mk_and, foldr1
947	rep = rep_graph.mk_graph_slice (p, fast = True)
948	for n in p.nodes:
949		if p.nodes[n].kind != 'Call':
950			continue
951		if p.node_tags[n][0] != tag:
952			continue
953		fname = p.nodes[n].fname
954		if fname in extra_unfolds:
955			return n
956		if fname not in group:
957			continue
958		(inp_env, _, _) = rep.get_func (default_n_vc (p, n))
959		pc = rep.get_pc (default_n_vc (p, n))
960		new = foldr1 (mk_and, [pc] + [syntax.mk_not (
961				solver.to_smt_expr (ident, inp_env, rep.solv))
962			for ident in idents.get (fname, [])])
963		if rep.test_hyp_whyps (mk_not (new), assns):
964			continue
965		return n
966	return None
967
968asm_cc_cache = {}
969
970def is_instruction (fname):
971	bits = fname.split ("'")
972	return bits[1:] and bits[:1] in [["l_impl"], ["instruction"]]
973
974def get_asm_calling_convention (fname):
975	if fname in asm_cc_cache:
976		return asm_cc_cache[fname]
977	if fname not in pre_pairings:
978		bits = fname.split ("'")
979		if not is_instruction (fname):
980			trace ("Warning: unusual unmatched function (%s, %s)."
981				% (fname, bits))
982		return None
983	pair = pre_pairings[fname]
984	assert pair['ASM'] == fname
985	c_fun = functions[pair['C']]
986	from logic import split_scalar_pairs
987	(var_c_args, c_imem, glob_c_args) = split_scalar_pairs (c_fun.inputs)
988	(var_c_rets, c_omem, glob_c_rets) = split_scalar_pairs (c_fun.outputs)
989
990	num_args = len (var_c_args)
991	num_rets = len (var_c_rets)
992	const_mem = not (c_omem)
993
994	cc = get_asm_calling_convention_inner (num_args, num_rets, const_mem)
995	asm_cc_cache[fname] = cc
996	return cc
997
998def get_asm_calling_convention_inner (num_c_args, num_c_rets, const_mem):
999	key = ('Inner', num_c_args, num_c_rets, const_mem)
1000	if key in asm_cc_cache:
1001		return asm_cc_cache[key]
1002
1003	from logic import mk_var_list, mk_stack_sequence
1004	from syntax import mk_var, word32T, builtinTs
1005
1006	arg_regs = mk_var_list (['r0', 'r1', 'r2', 'r3'], word32T)
1007	r0 = arg_regs[0]
1008	sp = mk_var ('r13', word32T)
1009	st = mk_var ('stack', builtinTs['Mem'])
1010	r0_input = mk_var ('r0_input', word32T)
1011
1012	mem = mk_var ('mem', builtinTs['Mem'])
1013	dom = mk_var ('dom', builtinTs['Dom'])
1014	dom_stack = mk_var ('dom_stack', builtinTs['Dom'])
1015
1016	global_args = [mem, dom, st, dom_stack, sp, mk_var ('ret', word32T)]
1017
1018	sregs = mk_stack_sequence (sp, 4, st, word32T, num_c_args + 1)
1019
1020	arg_seq = [r for r in arg_regs] + [s for (s, _) in sregs]
1021	if num_c_rets > 1:
1022		# the 'return-too-much' issue.
1023		# instead r0 is a save-returns-here pointer
1024		arg_seq.pop (0)
1025		rets = mk_stack_sequence (r0_input, 4, st, word32T, num_c_rets)
1026		rets = [r for (r, _) in rets]
1027	else:
1028		rets = [r0]
1029
1030	callee_saved_vars = ([mk_var (v, word32T)
1031			for v in 'r4 r5 r6 r7 r8 r9 r10 r11 r13'.split ()]
1032		+ [dom, dom_stack])
1033
1034	if const_mem:
1035		callee_saved_vars += [mem]
1036	else:
1037		rets += [mem]
1038	rets += [st]
1039
1040	cc = {'args': arg_seq[: num_c_args] + global_args,
1041		'rets': rets, 'callee_saved': callee_saved_vars}
1042
1043	asm_cc_cache[key] = cc
1044	return cc
1045
1046def get_asm_calling_convention_at_node (node):
1047	cc = get_asm_calling_convention (node.fname)
1048	if not cc:
1049		return None
1050
1051	fun = functions[node.fname]
1052	arg_input_map = dict (azip (fun.inputs, node.args))
1053	ret_output_map = dict (azip (fun.outputs,
1054		[mk_var (nm, typ) for (nm, typ) in node.rets]))
1055
1056	args = [logic.var_subst (arg, arg_input_map) for arg in cc['args']]
1057	rets = [logic.var_subst (ret, ret_output_map) for ret in cc['rets']]
1058	# these are useful because they happen to map ret r0_input back to
1059	# the previous value r0, rather than the useless value r0_input_ignore.
1060	rets_inp = [logic.var_subst (ret, arg_input_map) for ret in cc['rets']]
1061	saved = [logic.var_subst (v, ret_output_map)
1062		for v in cc['callee_saved']]
1063	return {'args': args, 'rets': rets,
1064		'rets_inp': rets_inp, 'callee_saved': saved}
1065
1066call_cache = {}
1067
1068def get_asm_callable (fname):
1069	if fname not in pre_pairings:
1070		return True
1071	c_fun = pre_pairings[fname]['C']
1072
1073	if not call_cache:
1074		for f in functions:
1075			call_cache[f] = False
1076		for f in functions:
1077			fun = functions[f]
1078			for n in fun.reachable_nodes (simplify = True):
1079				if fun.nodes[n].kind == 'Call':
1080					call_cache[fun.nodes[n].fname] = True
1081	return call_cache[c_fun]
1082
1083def get_asm_reachable_nodes (p, tag_set = None):
1084	if tag_set == None:
1085		tag_set = [tag for tag in p.tags ()
1086			if is_asm_node (p, p.get_entry (tag))]
1087	frontier = [p.get_entry (tag) for tag in tag_set]
1088	nodes = set ()
1089	while frontier:
1090		n = frontier.pop ()
1091		if n in nodes or n not in p.nodes:
1092			continue
1093		nodes.add (n)
1094		node = p.nodes[n]
1095		if node.kind == 'Call' and not get_asm_callable (node.fname):
1096			continue
1097		node = logic.simplify_node_elementary (node)
1098		frontier.extend (node.get_conts ())
1099	return nodes
1100
1101def convert_recursion_idents (idents):
1102	asm_idents = {}
1103	for f in idents:
1104		if f not in pre_pairings:
1105			continue
1106		f2 = pre_pairings[f]['ASM']
1107		assert f2 != f
1108		asm_idents[f2] = []
1109		for ident in idents[f]:
1110			if ident.is_op ('True'):
1111				asm_idents[f2].append (ident)
1112			elif ident.is_op ('Equals'):
1113				[x, y] = ident.vals
1114				# this is a bit hacky
1115				[i] = [i for (i, (nm, typ))
1116					in enumerate (functions[f].inputs)
1117					if x.is_var ((nm, typ))]
1118				cc = get_asm_calling_convention (f2)
1119				x = cc['args'][i]
1120				asm_idents[f2].append (syntax.mk_eq (x, y))
1121			else:
1122				assert not 'ident kind convertible'
1123	return asm_idents
1124
1125def mk_pairing (pre_pair, stack_bounds):
1126	asm_f = pre_pair['ASM']
1127	sz = stack_bounds[asm_f]
1128	c_fun = functions[pre_pair['C']]
1129
1130	from logic import split_scalar_pairs
1131	(var_c_args, c_imem, glob_c_args) = split_scalar_pairs (c_fun.inputs)
1132	(var_c_rets, c_omem, glob_c_rets) = split_scalar_pairs (c_fun.outputs)
1133
1134	eqs = logic.mk_eqs_arm_none_eabi_gnu (var_c_args, var_c_rets,
1135		c_imem, c_omem, sz)
1136
1137	return logic.Pairing (['ASM', 'C'],
1138		{'ASM': asm_f, 'C': c_fun.name}, eqs)
1139
1140def mk_pairings (stack_bounds):
1141	new_pairings = {}
1142	for f in pre_pairings:
1143		if f in new_pairings:
1144			continue
1145		pair = mk_pairing (pre_pairings[f], stack_bounds)
1146		for fun in pair.funs.itervalues ():
1147			new_pairings[fun] = [pair]
1148	return new_pairings
1149
1150def serialise_stack_bounds (stack_bounds):
1151	lines = []
1152	for fname in stack_bounds:
1153		ss = ['StackBound', fname]
1154		stack_bounds[fname].serialise (ss)
1155		lines.append (' '.join (ss) + '\n')
1156	return lines
1157
1158def deserialise_stack_bounds (lines):
1159	bounds = {}
1160	for line in lines:
1161		bits = line.split ()
1162		if not bits:
1163			continue
1164		assert bits[0] == 'StackBound'
1165		fname = bits[1]
1166		(_, bound) = syntax.parse_expr (bits, 2)
1167		bounds[fname] = bound
1168	return bounds
1169
1170funs_with_tag = {}
1171
1172def get_functions_with_tag (tag):
1173	if tag in funs_with_tag:
1174		return funs_with_tag[tag]
1175	visit = set ([pre_pairings[f][tag] for f in pre_pairings
1176		if tag in pre_pairings[f]])
1177	visit.update ([pair.funs[tag] for f in pairings
1178		for pair in pairings[f] if tag in pair.funs])
1179	funs = set (visit)
1180	while visit:
1181		f = visit.pop ()
1182		funs.add (f)
1183		visit.update (set (functions[f].function_calls ()) - funs)
1184	funs_with_tag[tag] = funs
1185	return funs
1186
1187def compute_stack_bounds (quiet = False):
1188	prev_tracer = target_objects.tracer[0]
1189	if quiet:
1190		target_objects.tracer[0] = lambda s, n: ()
1191
1192	try:
1193		c_fs = get_functions_with_tag ('C')
1194		idents = get_recursion_identifiers (c_fs)
1195		asm_idents = convert_recursion_idents (idents)
1196		asm_fs = get_functions_with_tag ('ASM')
1197		printout ('Computed recursion limits.')
1198
1199		bounds = compute_asm_stack_bounds (asm_idents, asm_fs)
1200		printout ('Computed stack bounds.')
1201	except Exception, e:
1202		if quiet:
1203			target_objects.tracer[0] = prev_tracer
1204		raise
1205
1206	if quiet:
1207		target_objects.tracer[0] = prev_tracer
1208	return bounds
1209
1210def read_fn_hash (fname):
1211	try:
1212		f = open (fname)
1213		s = f.readline ()
1214		bits = s.split ()
1215		if bits[0] != 'FunctionHash' or len (bits) != 2:
1216			return None
1217		return int (bits[1])
1218	except ValueError, e:
1219		return None
1220	except IndexError, e:
1221		return None
1222	except IOError, e:
1223		return None
1224
1225def mk_stack_pairings (pairing_tups, stack_bounds_fname = None,
1226		quiet = True):
1227	"""build the stack-aware calling-convention-aware logical pairings
1228	once a collection of function pairs have been read."""
1229
1230	# simplifies interactive testing of this function
1231	pre_pairings.clear ()
1232
1233	for (asm_f, c_f) in pairing_tups:
1234		pair = {'ASM': asm_f, 'C': c_f}
1235		assert c_f not in pre_pairings
1236		assert asm_f not in pre_pairings
1237		pre_pairings[c_f] = pair
1238		pre_pairings[asm_f] = pair
1239
1240	fn_hash = hash (tuple (sorted ([(f, hash (functions[f]))
1241		for f in functions])))
1242	prev_hash = read_fn_hash (stack_bounds_fname)
1243	if prev_hash == fn_hash:
1244		f = open (stack_bounds_fname)
1245		f.readline ()
1246		stack_bounds = deserialise_stack_bounds (f)
1247		f.close ()
1248	else:
1249		printout ('Computing stack bounds.')
1250		stack_bounds = compute_stack_bounds (quiet = quiet)
1251		f = open (stack_bounds_fname, 'w')
1252		f.write ('FunctionHash %s\n' % fn_hash)
1253		for line in serialise_stack_bounds (stack_bounds):
1254			f.write(line)
1255		f.close ()
1256
1257	problematic_synthetic ()
1258
1259	return mk_pairings (stack_bounds)
1260
1261def asm_stack_rep_hook (p, (nm, typ), kind, n):
1262	if not is_asm_node (p, n):
1263		return None
1264
1265	if not (nm.startswith ('stack') and typ == syntax.builtinTs['Mem']):
1266		return None
1267
1268	assert kind in ['Call', 'Init', 'Loop'], kind
1269	if kind == 'Init':
1270		return None
1271
1272	tag = p.node_tags[n][0]
1273	(_, sp) = get_stack_sp (p, tag)
1274
1275	return ('SplitMem', sp)
1276
1277reg_aliases = {'r11': ['fp'], 'r14': ['lr'], 'r13': ['sp']}
1278
1279def inst_const_rets (node):
1280	assert "instruction'" in node.fname
1281	bits = set ([s.lower () for s in node.fname.split ('_')])
1282	fun = functions[node.fname]
1283	def is_const (nm, typ):
1284		if typ in [builtinTs['Mem'], builtinTs['Dom']]:
1285			return True
1286		if typ != word32T:
1287			return False
1288		return not (nm in bits or [al for al in reg_aliases.get (nm, [])
1289				if al in bits])
1290	is_consts = [is_const (nm, typ) for (nm, typ) in fun.outputs]
1291	input_set = set ([v for arg in node.args
1292		for v in syntax.get_expr_var_set (arg)])
1293	return [mk_var (nm, typ)
1294		for ((nm, typ), const) in azip (node.rets, is_consts)
1295		if const and (nm, typ) in input_set]
1296
1297def node_const_rets (node):
1298	if "instruction'" in node.fname:
1299		return inst_const_rets (node)
1300	if node.fname in pre_pairings:
1301		if pre_pairings[node.fname]['ASM'] != node.fname:
1302			return None
1303		cc = get_asm_calling_convention_at_node (node)
1304		input_set = set ([v for arg in node.args
1305			for v in syntax.get_expr_var_set (arg)])
1306		callee_saved_set = set (cc['callee_saved'])
1307		return [mk_var (nm, typ) for (nm, typ) in node.rets
1308			if mk_var (nm, typ) in callee_saved_set
1309			if (nm, typ) in input_set]
1310	elif preserves_sp (node.fname):
1311		if node.fname not in get_functions_with_tag ('ASM'):
1312			return None
1313		f_outs = functions[node.fname].outputs
1314		return [mk_var (nm, typ)
1315			for ((nm, typ), (nm2, _)) in azip (node.rets, f_outs)
1316			if nm2 == 'r13']
1317	else:
1318		return None
1319
1320def const_ret_hook (node, nm, typ):
1321	consts = node_const_rets (node)
1322	return consts and mk_var (nm, typ) in consts
1323
1324def get_const_rets (p, node_set = None):
1325	if node_set == None:
1326		node_set = p.nodes
1327	const_rets = {}
1328	for n in node_set:
1329		if p.nodes[n].kind != 'Call':
1330			continue
1331		consts = node_const_rets (node)
1332		const_rets[n] = [(v.name, v.typ) for v in consts]
1333	return const_rets
1334
1335def problematic_synthetic ():
1336	synth = [s for s in target_objects.symbols
1337		if '.clone.' in s or '.part.' in s or '.constprop.' in s]
1338	synth = ['_'.join (s.split ('.')) for s in synth]
1339	if not synth:
1340		return
1341	printout ('Synthetic symbols: %s' % synth)
1342	synth_calls = set ([f for f in synth
1343		if f in functions
1344		if functions[f].function_calls ()])
1345	printout ('Synthetic symbols which make function calls: %s'
1346		% synth_calls)
1347	if not synth_calls:
1348		return
1349	synth_stack = set ([f for f in synth_calls
1350		if [node for node in functions[f].nodes.itervalues ()
1351			if node.kind == 'Basic'
1352			if ('r13', word32T) in node.get_lvals ()]])
1353	printout ('Synthetic symbols which call and move sp: %s'
1354		% synth_stack)
1355	synth_problems = set ([f for f in synth_stack
1356		if [f2 for f2 in functions
1357			if f in functions[f2].function_calls ()
1358			if len (set (functions[f2].function_calls ())) > 1]
1359		])
1360	printout ('Problematic synthetics: %s' % synth_problems)
1361	return synth_problems
1362
1363def add_hooks ():
1364	k = 'stack_logic'
1365	add = target_objects.add_hook
1366	add ('problem_var_rep', k, asm_stack_rep_hook)
1367	add ('loop_var_analysis', k, loop_var_analysis)
1368	add ('rep_unsafe_const_ret', k, const_ret_hook)
1369
1370add_hooks ()
1371
1372