1#
2# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3#
4# SPDX-License-Identifier: BSD-2-Clause
5#
6
7import syntax
8from syntax import word32T, word8T, boolT, builtinTs, Expr, Node
9from syntax import true_term, false_term, mk_num
10from syntax import foldr1
11
12(mk_var, mk_plus, mk_uminus, mk_minus, mk_times, mk_modulus, mk_bwand, mk_eq,
13mk_less_eq, mk_less, mk_implies, mk_and, mk_or, mk_not, mk_word32, mk_word8,
14mk_word32_maybe, mk_cast, mk_memacc, mk_memupd, mk_arr_index, mk_arroffs,
15mk_if, mk_meta_typ, mk_pvalid) = syntax.mks
16
17from syntax import structs
18from target_objects import trace, printout
19
20def is_int (n):
21	return hasattr (n, '__int__')
22
23def mk_eq_with_cast (a, c):
24	return mk_eq (a, mk_cast (c, a.typ))
25
26def mk_rodata (m):
27	assert m.typ == builtinTs['Mem']
28	return Expr ('Op', boolT, name = 'ROData', vals = [m])
29
30def cast_pair (((a, a_addr), (c, c_addr))):
31	if a.typ != c.typ and c.typ == boolT:
32		c = mk_if (c, mk_word32 (1), mk_word32 (0))
33	return ((a, a_addr), (mk_cast (c, a.typ), c_addr))
34
35ghost_assertion_type = syntax.Type ('WordArray', 50, 32)
36
37def split_scalar_globals (vs):
38	for i in range (len (vs)):
39		if vs[i].typ.kind != 'Word' and vs[i].typ != boolT:
40			break
41	else:
42		i = len (vs)
43	scalars = vs[:i]
44	global_vars = vs[i:]
45	for v in global_vars:
46		if v.typ not in [builtinTs['Mem'], builtinTs['Dom'],
47				builtinTs['HTD'], builtinTs['PMS'],
48				ghost_assertion_type]:
49			assert not "scalar_global split expected", vs
50	memT = builtinTs['Mem']
51	mems = [v for v in global_vars if v.typ == memT]
52	others = [v for v in global_vars if v.typ != memT]
53	return (scalars, mems, others)
54
55def mk_vars (tups):
56	return [mk_var (nm, typ) for (nm, typ) in tups]
57
58def split_scalar_pairs (var_pairs):
59	return split_scalar_globals (mk_vars (var_pairs))
60
61def azip (xs, ys):
62	assert len (xs) == len (ys)
63	return zip (xs, ys)
64
65def mk_mem_eqs (a_imem, c_imem, a_omem, c_omem, tags):
66	[a_imem] = a_imem
67	a_tag, c_tag = tags
68	(c_in, c_out) = (c_tag + '_IN', c_tag + '_OUT')
69	(a_in, a_out) = (a_tag + '_IN', a_tag + '_OUT')
70	if c_imem:
71		[c_imem] = c_imem
72		ieqs = [((a_imem, a_in), (c_imem, c_in)),
73			((mk_rodata (c_imem), c_in), (true_term, c_in))]
74	else:
75		ieqs = [((mk_rodata (a_imem), a_in), (true_term, c_in))]
76	if c_omem:
77		[a_m] = a_omem
78		[c_omem] = c_omem
79		oeqs = [((a_m, a_out), (c_omem, c_out)),
80			((mk_rodata (c_omem), c_out), (true_term, c_out))]
81	else:
82		oeqs = [((a_m, a_out), (a_imem, a_in)) for a_m in a_omem]
83
84	return (ieqs, oeqs)
85
86def mk_fun_eqs (as_f, c_f, prunes = None):
87	(var_a_args, a_imem, glob_a_args) = split_scalar_pairs (as_f.inputs)
88	(var_c_args, c_imem, glob_c_args) = split_scalar_pairs (c_f.inputs)
89	(var_a_rets, a_omem, glob_a_rets) = split_scalar_pairs (as_f.outputs)
90	(var_c_rets, c_omem, glob_c_rets) = split_scalar_pairs (c_f.outputs)
91
92	(mem_ieqs, mem_oeqs) = mk_mem_eqs (a_imem, c_imem, a_omem, c_omem,
93		['ASM', 'C'])
94
95	if not prunes:
96		prunes = (var_a_args, var_a_args)
97	assert len (prunes[0]) == len (var_c_args), (params, var_a_args,
98		var_c_args, prunes)
99	a_map = dict (azip (prunes[1], var_a_args))
100	ivar_pairs = [((a_map[p], 'ASM_IN'), (c, 'C_IN')) for (p, c)
101		in azip (prunes[0], var_c_args) if p in a_map]
102
103	ovar_pairs = [((a_ret, 'ASM_OUT'), (c_ret, 'C_OUT')) for (a_ret, c_ret)
104		in azip (var_a_rets, var_c_rets)]
105	return (map (cast_pair, mem_ieqs + ivar_pairs),
106		map (cast_pair, mem_oeqs + ovar_pairs))
107
108def mk_var_list (vs, typ):
109	return [syntax.mk_var (v, typ) for v in vs]
110
111def mk_offs_sequence (init, offs, n, do_reverse = False):
112	r = range (n)
113	if do_reverse:
114		r.reverse ()
115	def mk_offs (n):
116		return Expr ('Num', init.typ, val = offs * n)
117	return [mk_plus (init, mk_offs (m)) for m in r]
118
119def mk_stack_sequence (sp, offs, stack, typ, n, do_reverse = False):
120	return [(mk_memacc (stack, addr, typ), addr)
121		for addr in mk_offs_sequence (sp, offs, n, do_reverse)]
122
123def mk_aligned (w, n):
124	assert w.typ.kind == 'Word'
125	mask = Expr ('Num', w.typ, val = ((1 << n) - 1))
126	return mk_eq (mk_bwand (w, mask), mk_num (0, w.typ))
127
128def mk_eqs_arm_none_eabi_gnu (var_c_args, var_c_rets, c_imem, c_omem,
129		min_stack_size):
130	arg_regs = mk_var_list (['r0', 'r1', 'r2', 'r3'], word32T)
131	r0 = arg_regs[0]
132	sp = mk_var ('r13', word32T)
133	st = mk_var ('stack', builtinTs['Mem'])
134	r0_input = mk_var ('ret_addr_input', word32T)
135	sregs = mk_stack_sequence (sp, 4, st, word32T, len (var_c_args) + 1)
136
137	ret = mk_var ('ret', word32T)
138	preconds = [mk_aligned (sp, 2), mk_eq (ret, mk_var ('r14', word32T)),
139		mk_aligned (ret, 2), mk_eq (r0_input, r0),
140		mk_less_eq (min_stack_size, sp)]
141	post_eqs = [(x, x) for x in mk_var_list (['r4', 'r5', 'r6', 'r7', 'r8',
142			'r9', 'r10', 'r11', 'r13'], word32T)]
143
144	arg_seq = [(r, None) for r in arg_regs] + sregs
145	if len (var_c_rets) > 1:
146		# the 'return-too-much' issue.
147		# instead r0 is a save-returns-here pointer
148		arg_seq.pop (0)
149		preconds += [mk_aligned (r0, 2), mk_less_eq (sp, r0)]
150		save_seq = mk_stack_sequence (r0_input, 4, st, word32T,
151			len (var_c_rets))
152		save_addrs = [addr for (_, addr) in save_seq]
153		post_eqs += [(r0_input, r0_input)]
154		out_eqs = zip (var_c_rets, [x for (x, _) in save_seq])
155		out_eqs = [(c, mk_cast (a, c.typ)) for (c, a) in out_eqs]
156		init_save_seq = mk_stack_sequence (r0, 4, st, word32T,
157			len (var_c_rets))
158		(_, last_arg_addr) = arg_seq[len (var_c_args) - 1]
159		preconds += [mk_less_eq (sp, addr)
160			for (_, addr) in init_save_seq[-1:]]
161		if last_arg_addr:
162			preconds += [mk_less (last_arg_addr, addr)
163				for (_, addr) in init_save_seq[:1]]
164	else:
165		out_eqs = zip (var_c_rets, [r0])
166		save_addrs = []
167	arg_seq_addrs = [addr for ((_, addr), _) in zip (arg_seq, var_c_args)
168		if addr != None]
169	swrap = mk_stack_wrapper (sp, st, arg_seq_addrs)
170	swrap2 = mk_stack_wrapper (sp, st, save_addrs)
171	post_eqs += [(swrap, swrap2)]
172
173	mem = mk_var ('mem', builtinTs['Mem'])
174	(mem_ieqs, mem_oeqs) = mk_mem_eqs ([mem], c_imem, [mem], c_omem,
175		['ASM', 'C'])
176
177	addr = None
178	arg_eqs = [cast_pair (((a_x, 'ASM_IN'), (c_x, 'C_IN')))
179		for (c_x, (a_x, addr)) in zip (var_c_args, arg_seq)]
180	if addr:
181		preconds += [mk_less_eq (sp, addr)]
182	ret_eqs = [cast_pair (((a_x, 'ASM_OUT'), (c_x, 'C_OUT')))
183		for (c_x, a_x) in out_eqs]
184	preconds = [((a_x, 'ASM_IN'), (true_term, 'ASM_IN')) for a_x in preconds]
185	asm_invs = [((vin, 'ASM_IN'), (vout, 'ASM_OUT')) for (vin, vout) in post_eqs]
186
187	return (arg_eqs + mem_ieqs + preconds,
188		ret_eqs + mem_oeqs + asm_invs)
189
190known_CPUs = {
191	'arm-none-eabi-gnu': mk_eqs_arm_none_eabi_gnu
192}
193
194def mk_fun_eqs_CPU (cpu_f, c_f, cpu_name, funcall_depth = 1):
195	cpu = known_CPUs[cpu_name]
196	(var_c_args, c_imem, glob_c_args) = split_scalar_pairs (c_f.inputs)
197	(var_c_rets, c_omem, glob_c_rets) = split_scalar_pairs (c_f.outputs)
198
199	return cpu (var_c_args, var_c_rets, c_imem, c_omem,
200		(funcall_depth * 256) + 256)
201
202class Pairing:
203	def __init__ (self, tags, funs, eqs, notes = None):
204		[l_tag, r_tag] = tags
205		self.tags = tags
206		assert set (funs) == set (tags)
207		self.funs = funs
208		self.eqs = eqs
209
210		self.l_f = funs[l_tag]
211		self.r_f = funs[r_tag]
212		self.name = 'Pairing (%s (%s) <= %s (%s))' % (self.l_f,
213			l_tag, self.r_f, r_tag)
214
215		self.notes = {}
216		if notes != None:
217			self.notes.update (notes)
218
219	def __str__ (self):
220		return self.name
221
222	def __hash__ (self):
223		return hash (self.name)
224
225	def __eq__ (self, other):
226		return self.name == other.name and self.eqs == other.eqs
227
228	def __ne__ (self, other):
229		return not other or not self == other
230
231def mk_pairing (functions, c_f, as_f, prunes = None, cpu = None):
232	fs = (functions[as_f], functions[c_f])
233	if cpu:
234		eqs = mk_fun_eqs_CPU (fs[0], fs[1], cpu,
235			funcall_depth = funcall_depth (functions, c_f))
236	else:
237		eqs = mk_fun_eqs (fs[0], fs[1], prunes = prunes)
238	return Pairing (['ASM', 'C'], {'C': c_f, 'ASM': as_f}, eqs)
239
240def inst_eqs_pattern (pattern, params):
241	(pat_params, inp_eqs, out_eqs) = pattern
242	substs = [((x.name, x.typ), y)
243		for (pat_vs, vs) in azip (pat_params, params)
244		for (x, y) in azip (pat_vs, vs)]
245	substs = dict (substs)
246	subst = lambda x: var_subst (x, substs)
247	return ([(subst (x), subst (y)) for (x, y) in inp_eqs],
248		[(subst (x), subst (y)) for (x, y) in out_eqs])
249
250def inst_eqs_pattern_tuples (pattern, params):
251	return inst_eqs_pattern (pattern, map (mk_vars, params))
252
253def inst_eqs_pattern_exprs (pattern, params):
254	(inp_eqs, out_eqs) = inst_eqs_pattern (pattern, params)
255	return (foldr1 (mk_and, [mk_eq (a, c) for (a, c) in inp_eqs]),
256		foldr1 (mk_and, [mk_eq (a, c) for (a, c) in out_eqs]))
257
258def var_match (var_exp, conc_exp, assigns):
259	if var_exp.typ != conc_exp.typ:
260		return False
261	if var_exp.kind == 'Var':
262		key = (var_exp.name, var_exp.typ)
263		if key in assigns:
264			return conc_exp == assigns[key]
265		else:
266			assigns[key] = conc_exp
267			return True
268	elif var_exp.kind == 'Op':
269		if conc_exp.kind != 'Op' or var_exp.name != conc_exp.name:
270			return False
271		return all ([var_match (a, b, assigns)
272			for (a, b) in azip (var_exp.vals, conc_exp.vals)])
273	else:
274		return False
275
276def var_subst (var_exp, assigns, must_subst = True):
277	def substor (var_exp):
278		if var_exp.kind == 'Var':
279			k = (var_exp.name, var_exp.typ)
280			if must_subst or k in assigns:
281				return assigns[k]
282			else:
283				return None
284		else:
285			return None
286	return syntax.do_subst (var_exp, substor)
287
288def recursive_term_subst (eqs, expr):
289	if expr in eqs:
290		return eqs[expr]
291	if expr.kind == 'Op':
292		vals = [recursive_term_subst (eqs, x) for x in expr.vals]
293		return syntax.adjust_op_vals (expr, vals)
294	return expr
295
296def mk_accum_rewrites (typ):
297	x = mk_var ('x', typ)
298	y = mk_var ('y', typ)
299	z = mk_var ('z', typ)
300	i = mk_var ('i', typ)
301	return [(x, i, mk_plus (x, y), mk_plus (x, mk_times (i, y)),
302			y),
303		(x, i, mk_plus (y, x), mk_plus (x, mk_times (i, y)),
304			y),
305		(x, i, mk_minus (x, y), mk_minus (x, mk_times (i, y)),
306			mk_uminus (y)),
307		(x, i, mk_plus (mk_plus (x, y), z),
308			mk_plus (x, mk_times (i, mk_plus (y, z))),
309			mk_plus (y, z)),
310		(x, i, mk_plus (mk_plus (y, x), z),
311			mk_plus (x, mk_times (i, mk_plus (y, z))),
312			mk_plus (y, z)),
313		(x, i, mk_plus (y, mk_plus (x, z)),
314			mk_plus (x, mk_times (i, mk_plus (y, z))),
315			mk_plus (y, z)),
316		(x, i, mk_plus (y, mk_plus (z, x)),
317			mk_plus (x, mk_times (i, mk_plus (y, z))),
318			mk_plus (y, z)),
319		(x, i, mk_minus (mk_minus (x, y), z),
320			mk_minus (x, mk_times (i, mk_plus (y, z))),
321			mk_plus (y, z)),
322	]
323
324def mk_all_accum_rewrites ():
325	return [rew for typ in [word8T, word32T, syntax.word16T,
326			syntax.word64T]
327		for rew in mk_accum_rewrites (typ)]
328
329accum_rewrites = mk_all_accum_rewrites ()
330
331def default_val (typ):
332	if typ.kind == 'Word':
333		return Expr ('Num', typ, val = 0)
334	elif typ == boolT:
335		return false_term
336	else:
337		assert not 'default value for type %s created', typ
338
339trace_accumulators = []
340
341def accumulator_closed_form (expr, (nm, typ), add_mask = None):
342	expr = toplevel_split_out_cast (expr)
343	n = get_bwand_mask (expr)
344	if n and not add_mask:
345		return accumulator_closed_form (expr.vals[0], (nm, typ),
346			add_mask = n)
347
348	for (x, i, pattern, rewrite, offset) in accum_rewrites:
349		var = mk_var (nm, typ)
350		ass = {(x.name, x.typ): var}
351		m = var_match (pattern, expr, ass)
352		if m:
353			x2_def = default_val (typ)
354			i2_def = default_val (word32T)
355			def do_rewrite (x2 = x2_def, i2 = i2_def):
356				ass[(x.name, x.typ)] = x2
357				ass[(i.name, i.typ)] = i2
358				vs = var_subst (rewrite, ass)
359				if add_mask:
360					vs = mk_bwand_mask (vs, add_mask)
361				return vs
362			offs = var_subst (offset, ass)
363			return (do_rewrite, offs)
364	if trace_accumulators:
365		trace ('no accumulator %s' % ((expr, nm, typ), ))
366	return (None, None)
367
368def split_out_cast (expr, target_typ, bits):
369	"""given a word-type expression expr (of any word length),
370	compute a simplified expression expr' of the target type, which will
371	have the property that expr' && mask bits = cast expr,
372	where && is bitwise-and (BWAnd), mask n is the bitpattern set at the
373	bottom n bits, e.g. (1 << n) - 1, and cast is WordCast."""
374	if expr.is_op (['WordCast', 'WordCastSigned']):
375		[x] = expr.vals
376		if x.typ.num >= bits and expr.typ.num >= bits:
377			return split_out_cast (x, target_typ, bits)
378		else:
379			return mk_cast (expr, target_typ)
380	elif expr.is_op ('BWAnd'):
381		[x, y] = expr.vals
382		if y.kind == 'Num':
383			val = y.val
384		else:
385			val = 0
386		full_mask = (1 << bits) - 1
387		if val & full_mask == full_mask:
388			return split_out_cast (x, target_typ, bits)
389		else:
390			return mk_cast (expr, target_typ)
391	elif expr.is_op (['Plus', 'Minus']):
392		# rounding issues will appear if this arithmetic is done
393		# at a smaller number of bits than we'll eventually report
394		if expr.typ.num >= bits:
395			vals = [split_out_cast (x, target_typ, bits)
396				for x in expr.vals]
397			if expr.is_op ('Plus'):
398				return mk_plus (vals[0], vals[1])
399			else:
400				return mk_minus (vals[0], vals[1])
401		else:
402			return mk_cast (expr, target_typ)
403	else:
404		return mk_cast (expr, target_typ)
405
406def toplevel_split_out_cast (expr):
407	bits = None
408	if expr.is_op (['WordCast', 'WordCastSigned']):
409		bits = min ([expr.typ.num, expr.vals[0].typ.num])
410	elif expr.is_op ('BWAnd'):
411		bits = get_bwand_mask (expr)
412
413	if bits:
414		expr = split_out_cast (expr, expr.typ, bits)
415		return mk_bwand_mask (expr, bits)
416	else:
417		return expr
418
419two_powers = {}
420
421def get_bwand_mask (expr):
422	"""recognise (x && mask) opers, where mask = ((1 << n) - 1)
423	for some n"""
424	if not expr.is_op ('BWAnd'):
425		return
426	[x, y] = expr.vals
427	if not y.kind == 'Num':
428		return
429	val = y.val & ((1 << (y.typ.num)) - 1)
430	if not two_powers:
431		for i in range (129):
432			two_powers[1 << i] = i
433	return two_powers.get (val + 1)
434
435def mk_bwand_mask (expr, n):
436	return mk_bwand (expr, mk_num (((1 << n) - 1), expr.typ))
437
438def end_addr (p, typ):
439	if typ[0] == 'Array':
440		(_, typ, n) = typ
441		sz = mk_times (mk_word32 (typ.size ()), n)
442	else:
443		assert typ[0] == 'Type', typ
444		(_, typ) = typ
445		sz = mk_word32 (typ.size ())
446	return mk_plus (p, mk_minus (sz, mk_word32 (1)))
447
448def pvalid_assertion1 ((typ, k, p, pv), (typ2, k2, p2, pv2)):
449	"""first pointer validity assertion: incompatibility.
450	pvalid1 & pvalid2 --> non-overlapping OR somehow-contained.
451	typ/typ2 is ('Type', syntax.Type) or ('Array', Type, Expr) for
452	dynamically sized arrays.
453	"""
454	offs1 = mk_minus (p, p2)
455	cond1 = get_styp_condition (offs1, typ, typ2)
456	offs2 = mk_minus (p2, p)
457	cond2 = get_styp_condition (offs2, typ2, typ)
458
459	out1 = mk_less (end_addr (p, typ), p2)
460	out2 = mk_less (end_addr (p2, typ2), p)
461	return mk_implies (mk_and (pv, pv2), foldr1 (mk_or,
462		[cond1, cond2, out1, out2]))
463
464def pvalid_assertion2 ((typ, k, p, pv), (typ2, k2, p2, pv2)):
465	"""second pointer validity assertion: implication.
466	pvalid1 & strictly-contained --> pvalid2
467	"""
468	if typ[0] == 'Array' and typ2[0] == 'Array':
469		# this is such a vague notion it's not worth it
470		return true_term
471	offs1 = mk_minus (p, p2)
472	cond1 = get_styp_condition (offs1, typ, typ2)
473	imp1 = mk_implies (mk_and (cond1, pv2), pv)
474	offs2 = mk_minus (p2, p)
475	cond2 = get_styp_condition (offs2, typ2, typ)
476	imp2 = mk_implies (mk_and (cond2, pv), pv2)
477	return mk_and (imp1, imp2)
478
479def sym_distinct_assertion ((typ, p, pv), (start, end)):
480	out1 = mk_less (mk_plus (p, mk_word32 (typ.size () - 1)), mk_word32 (start))
481	out2 = mk_less (mk_word32 (end), p)
482	return mk_implies (pv, mk_or (out1, out2))
483
484def norm_array_type (t):
485	if t[0] == 'Type' and t[1].kind == 'Array':
486		(_, atyp) = t
487		return ('Array', atyp.el_typ_symb, mk_word32 (atyp.num), 'Strong')
488	elif t[0] == 'Array' and len (t) == 3:
489		(_, typ, l) = t
490		# these derive from PArrayValid assertions. we know the array is
491		# at least this long, but it might be longer.
492		return ('Array', typ, l, 'Weak')
493	else:
494		return t
495
496stored_styp_conditions = {}
497
498def get_styp_condition (offs, inner_typ, outer_typ):
499	r = get_styp_condition_inner1 (inner_typ, outer_typ)
500	if not r:
501		return false_term
502	else:
503		return r (offs)
504
505def get_styp_condition_inner1 (inner_typ, outer_typ):
506	inner_typ = norm_array_type (inner_typ)
507	outer_typ = norm_array_type (outer_typ)
508	k = (inner_typ, outer_typ)
509	if k in stored_styp_conditions:
510		return stored_styp_conditions[k]
511	r = get_styp_condition_inner2 (inner_typ, outer_typ)
512	stored_styp_conditions[k] = r
513	return r
514
515def array_typ_size ((kind, el_typ, num, _)):
516	el_size = mk_word32 (el_typ.size ())
517	return mk_times (num, el_size)
518
519def get_styp_condition_inner2 (inner_typ, outer_typ):
520	if inner_typ[0] == 'Array' and outer_typ[0] == 'Array':
521		(_, ityp, inum, _) = inner_typ
522		(_, otyp, onum, outer_bound) = outer_typ
523		# array fits in another array if the starting element is
524		# a sub-element, and if the size of the left array plus
525		# the offset fits in the right array
526		cond = get_styp_condition_inner1 (('Type', ityp), outer_typ)
527		isize = array_typ_size (inner_typ)
528		osize = array_typ_size (outer_typ)
529		if outer_bound == 'Strong' and cond:
530			return lambda offs: mk_and (cond (offs),
531				mk_less_eq (mk_plus (isize, offs), osize))
532		else:
533			return cond
534	elif inner_typ == outer_typ:
535		return lambda offs: mk_eq (offs, mk_word32 (0))
536	elif outer_typ[0] == 'Type' and outer_typ[1].kind == 'Struct':
537		conds = [(get_styp_condition_inner1 (inner_typ,
538				('Type', sf_typ)), mk_word32 (offs2))
539			for (_, offs2, sf_typ)
540			in structs[outer_typ[1].name].fields.itervalues()]
541		conds = [cond for cond in conds if cond[0]]
542		if conds:
543			return lambda offs: foldr1 (mk_or,
544				[c (mk_minus (offs, offs2))
545					for (c, offs2) in conds])
546		else:
547			return None
548	elif outer_typ[0] == 'Array':
549		(_, el_typ, n, bound) = outer_typ
550		cond = get_styp_condition_inner1 (inner_typ, ('Type', el_typ))
551		el_size = mk_word32 (el_typ.size ())
552		size = mk_times (n, el_size)
553		if bound == 'Strong' and cond:
554			return lambda offs: mk_and (mk_less (offs, size),
555				cond (mk_modulus (offs, el_size)))
556		elif cond:
557			return lambda offs: cond (mk_modulus (offs, el_size))
558		else:
559			return None
560	else:
561		return None
562
563def all_vars_have_prop (expr, prop):
564	class Failed (Exception):
565		pass
566	def visit (expr):
567		if expr.kind != 'Var':
568			return
569		v2 = (expr.name, expr.typ)
570		if not prop (v2):
571			raise Failed ()
572	try:
573		expr.visit (visit)
574		return True
575	except Failed:
576		return False
577
578def all_vars_in_set (expr, var_set):
579	return all_vars_have_prop (expr, lambda v: v in var_set)
580
581def var_not_in_expr (var, expr):
582	v2 = (var.name, var.typ)
583	return all_vars_have_prop (expr, lambda v: v != v2)
584
585def mk_array_size_ineq (typ, num, p):
586	align = typ.align ()
587	size = mk_times (mk_word32 (typ.size ()), num)
588	size_lim = ((2 ** 32) - 4) / typ.size ()
589	return mk_less_eq (num, mk_word32 (size_lim))
590
591def mk_align_valid_ineq (typ, p):
592	if typ[0] == 'Type':
593		(_, typ) = typ
594		align = typ.align ()
595		size = mk_word32 (typ.size ())
596		size_req = []
597	else:
598		assert typ[0] == 'Array', typ
599		(kind, typ, num) = typ
600		align = typ.align ()
601		size = mk_times (mk_word32 (typ.size ()), num)
602		size_req = [mk_array_size_ineq (typ, num, p)]
603	assert align in [1, 4, 8]
604	w0 = mk_word32 (0)
605	if align > 1:
606		align_req = [mk_eq (mk_bwand (p, mk_word32 (align - 1)), w0)]
607	else:
608		align_req = []
609	return foldr1 (mk_and, align_req + size_req + [mk_not (mk_eq (p, w0)),
610		mk_implies (mk_less (w0, size),
611			mk_less_eq (p, mk_uminus (size)))])
612
613# generic operations on function/problem graphs
614def dict_list (xys, keys = None):
615	"""dict_list ([(1, 2), (1, 3), (2, 4)]) = {1: [2, 3], 2: [4]}"""
616	d = {}
617	for (x, y) in xys:
618		d.setdefault (x, [])
619		d[x].append (y)
620	if keys:
621		for x in keys:
622			d.setdefault (x, [])
623	return d
624
625def compute_preds (nodes):
626	preds = dict_list ([(c, n) for n in nodes
627			for c in nodes[n].get_conts ()],
628		keys = nodes)
629	for n in ['Ret', 'Err']:
630		preds.setdefault (n, [])
631	preds = dict ([(n, sorted (set (ps)))
632		for (n, ps) in preds.iteritems ()])
633	return preds
634
635def simplify_node_elementary(node):
636	if node.kind == 'Cond' and node.cond == true_term:
637		return Node ('Basic', node.left, [])
638	elif node.kind == 'Cond' and node.cond == false_term:
639		return Node ('Basic', node.right, [])
640	elif node.kind == 'Cond' and node.left == node.right:
641		return Node ('Basic', node.left, [])
642	else:
643		return node
644
645def compute_var_flows (nodes, outputs, preds, override_lvals_rvals = {}):
646	# compute a graph of reverse var flows to pass to tarjan's algorithm
647	graph = {}
648	entries = ['Ret']
649	for (n, node) in nodes.iteritems ():
650		if node.kind == 'Basic':
651			for (lv, rv) in node.upds:
652				graph[(n, 'Post', lv)] = [(n, 'Pre', v)
653					for v in syntax.get_expr_var_set (rv)]
654		elif node.is_noop ():
655			pass
656		else:
657			if n in override_lvals_rvals:
658				(lvals, rvals) = override_lvals_rvals[n]
659			else:
660				rvals = syntax.get_node_rvals (node)
661				rvals = set (rvals.iteritems ())
662				lvals = set (node.get_lvals ())
663			if node.kind != 'Basic':
664				lvals = list (lvals) + ['PC']
665				entries.append ((n, 'Post', 'PC'))
666			for lv in lvals:
667				graph[(n, 'Post', lv)] = [(n, 'Pre', rv)
668					for rv in rvals]
669	graph['Ret'] = [(n, 'Post', v)
670		for n in preds['Ret'] for v in outputs (n)]
671	vs = set ([v for k in graph for (_, _, v) in graph[k]])
672	for v in vs:
673		for n in nodes:
674			graph.setdefault ((n, 'Post', v), [(n, 'Pre', v)])
675			graph[(n, 'Pre', v)] = [(n2, 'Post', v)
676				for n2 in preds[n]]
677
678	comps = tarjan (graph, entries)
679	return (graph, comps)
680
681def mk_not_red (v):
682	if v.is_op ('Not'):
683		[v] = v.vals
684		return v
685	else:
686		return syntax.mk_not (v)
687
688def cont_with_conds (nodes, n, conds):
689	while True:
690		if n not in nodes or nodes[n].kind != 'Cond':
691			return n
692		cond = nodes[n].cond
693		if cond in conds:
694			n = nodes[n].left
695		elif mk_not_red (cond) in conds:
696			n = nodes[n].right
697		else:
698			return n
699
700def contextual_conds (nodes, preds):
701	"""computes a collection of conditions that can be assumed true
702	at any point in the node graph."""
703	pre_conds = {}
704	arc_conds = {}
705	visit = [n for n in nodes if not (preds[n])]
706	while visit:
707		n = visit.pop ()
708		if n not in nodes:
709			continue
710		in_arc_conds = [arc_conds.get ((pre, n), set ())
711			for pre in preds[n]]
712		if not in_arc_conds:
713			conds = set ()
714		else:
715			conds = set.intersection (* in_arc_conds)
716		if pre_conds.get (n) == conds:
717			continue
718		pre_conds[n] = conds
719		if n not in nodes:
720			continue
721		if nodes[n].kind == 'Cond' and nodes[n].left == nodes[n].right:
722			c_conds = [conds, conds]
723		elif nodes[n].kind == 'Cond':
724			c_conds = [nodes[n].cond, mk_not_red (nodes[n].cond)]
725			c_conds = [set.union (set ([c]), conds)
726				for c in c_conds]
727		else:
728			upds = set (nodes[n].get_lvals ())
729			c_conds = [set ([c for c in conds if
730				not set.intersection (upds,
731					syntax.get_expr_var_set (c))])]
732		for (cont, conds) in zip (nodes[n].get_conts (), c_conds):
733			arc_conds[(n, cont)] = conds
734			visit.append (cont)
735	return (arc_conds, pre_conds)
736
737def contextual_cond_simps (nodes, preds):
738	"""a common pattern in architectures with conditional operations is
739	a sequence of instructions with the same condition.
740	we can usually then reduce to a single contional block.
741	  b   e    =>    b-e
742	 / \ / \   =>   /   \
743	a-c-d-f-g  =>  a-c-f-g
744	this is sometimes important if b calculates a register that e uses
745	since variable dependency analysis will see this register escape via
746	the impossible path a-c-d-e
747	"""
748	(arc_conds, pre_conds) = contextual_conds (nodes, preds)
749	nodes = dict (nodes)
750	for n in nodes:
751		if nodes[n].kind == 'Cond':
752			continue
753		cont = nodes[n].cont
754		conds = arc_conds[(n, cont)]
755		cont2 = cont_with_conds (nodes, cont, conds)
756		if cont2 != cont:
757			nodes[n] = syntax.copy_rename (nodes[n],
758				({}, {cont: cont2}))
759	return nodes
760
761def minimal_loop_node_set (p):
762	"""discover a minimal set of loop addresses, excluding some operations
763	using conditional instructions which are syntactically within the
764	loop but semantically must always be followed by an immediate loop
765	exit.
766
767	amounts to rerunning loop detection after contextual_cond_simps."""
768
769	loop_ns = set (p.loop_data)
770	really_in_loop = {}
771	nodes = contextual_cond_simps (p.nodes, p.preds)
772	def is_really_in_loop (n):
773		if n in really_in_loop:
774			return really_in_loop[n]
775		ns = []
776		r = None
777		while r == None:
778			ns.append (n)
779			if n not in loop_ns:
780				r = False
781			elif n in p.splittable_points (n):
782				r = True
783			else:
784				conts = [n2 for n2 in nodes[n].get_conts ()
785					if n2 != 'Err']
786				if len (conts) > 1:
787					r = True
788				else:
789					[n] = conts
790		for n in ns:
791			really_in_loop[n] = r
792		return r
793	return set ([n for n in loop_ns if is_really_in_loop (n)])
794
795def possible_graph_divs (p, min_cost = 20, max_cost = 20, ratio = 0.85,
796		trace = None):
797	es = [e[0] for e in p.entries]
798	divs = []
799	direct_costs = {}
800	future_costs = {'Ret': set (), 'Err': set ()}
801	prev_costs = {}
802	int_costs = {}
803	fracs = {}
804	for n in p.nodes:
805		node = p.nodes[n]
806		if node.kind == 'Call':
807			cost = set ([(n, 20)])
808		elif p.loop_id (n):
809			cost = set ([(p.loop_id (n), 50)])
810		else:
811			cost = set ([(n, len (node.get_mem_accesses ()))])
812			cost.discard ((n, 0))
813		direct_costs[n] = cost
814	for n in p.tarjan_order:
815		prev_costs[n] = set.union (* ([direct_costs[n]]
816			+ [prev_costs.get (c, set ()) for c in p.preds[n]]))
817	for n in reversed (p.tarjan_order):
818		cont_costs = [future_costs.get (c, set ())
819			for c in p.nodes[n].get_conts ()]
820		cost = set.union (* ([direct_costs[n]] + cont_costs))
821		p_ct = sum ([c for (_, c) in prev_costs[n]])
822		future_costs[n] = cost
823		if p.nodes[n].kind != 'Cond' or p_ct > max_cost:
824			continue
825		ct = sum ([c for (_, c) in set.union (cost, prev_costs[n])])
826		if ct < min_cost:
827			continue
828		[c1, c2] = [sum ([c for (_, c)
829				in set.union (cs, prev_costs[n])])
830			for cs in cont_costs]
831		fracs[n] = ((c1 * c1) + (c2 * c2)) / (ct * ct * 1.0)
832		if fracs[n] < ratio:
833			divs.append (n)
834	divs.reverse ()
835	if trace != None:
836		trace[0] = (direct_costs, future_costs, prev_costs,
837			int_costs, fracs)
838	return divs
839
840def compute_var_deps (nodes, outputs, preds, override_lvals_rvals = {},
841		trace = None):
842	# outs = list of (outname, retvars)
843	var_deps = {}
844	visit = set ()
845	visit.update (preds['Ret'])
846	visit.update (preds['Err'])
847
848	nodes = contextual_cond_simps (nodes, preds)
849
850	while visit:
851		n = visit.pop ()
852
853		node = simplify_node_elementary (nodes[n])
854		if n in override_lvals_rvals:
855			(lvals, rvals) = override_lvals_rvals[n]
856			lvals = set (lvals)
857			rvals = set (rvals)
858		elif node.is_noop ():
859			lvals = set ([])
860			rvals = set ([])
861		else:
862			rvals = syntax.get_node_rvals (node)
863			rvals = set (rvals.iteritems ())
864			lvals = set (node.get_lvals ())
865		cont_vs = set ()
866
867		for c in node.get_conts ():
868			if c == 'Ret':
869				cont_vs.update (outputs (n))
870			elif c == 'Err':
871				pass
872			else:
873				cont_vs.update (var_deps.get (c, []))
874		vs = set.union (rvals, cont_vs - lvals)
875
876		if n in var_deps and vs <= var_deps[n]:
877			continue
878		if trace and n in trace:
879			diff = vs - var_deps.get (n, set())
880			printout ('add %s at %d' % (diff, n))
881			printout ('  %s, %s, %s, %s' % (len (vs), len (cont_vs), len (lvals), len (rvals)))
882		var_deps[n] = vs
883		visit.update (preds[n])
884
885	return var_deps
886
887def compute_loop_var_analysis (p, var_deps, n, override_nodes = None):
888	if override_nodes == None:
889		nodes = p.nodes
890	else:
891		nodes = override_nodes
892
893	upd_vs = set ([v for n2 in p.loop_body (n)
894		if not nodes[n2].is_noop ()
895		for v in nodes[n2].get_lvals ()])
896	const_vs = set ([v for n2 in p.loop_body (n)
897		for v in var_deps[n2] if v not in upd_vs])
898
899	vca = compute_var_cycle_analysis (p, nodes, n,
900		const_vs, set (var_deps[n]))
901	vca = [(syntax.mk_var (nm, typ), data)
902		for ((nm, typ), data) in vca.items ()]
903	return vca
904
905cvca_trace = []
906cvca_diag = [False]
907no_accum_expressions = set ()
908
909def compute_var_cycle_analysis (p, nodes, n, const_vars, vs, diag = None):
910
911	if diag == None:
912		diag = cvca_diag[0]
913
914	cache = {}
915	del cvca_trace[:]
916	impossible_nodes = {}
917	loop = p.loop_body (n)
918
919	def warm_cache_before (n2, v):
920		cvca_trace.append ((n2, v))
921		cvca_trace.append ('(')
922		arc = []
923		for i in range (100000):
924			opts = [n3 for n3 in p.preds[n2] if n3 in loop
925				if v not in nodes[n3].get_lvals ()
926				if n3 != n
927				if (n3, v) not in cache]
928			if not opts:
929				break
930			n2 = opts[0]
931			arc.append (n2)
932		if not (len (arc) < 100000):
933			trace ('warmup overrun in compute_var_cycle_analysis')
934			trace ('chasing %s in %s' % (v, set (arc)))
935			assert False, (v, arc[-500:])
936		for n2 in reversed (arc):
937			var_eval_before (n2, v)
938		cvca_trace.append (')')
939
940	def var_eval_before (n2, v, do_cmp = True):
941		if (n2, v) in cache and do_cmp:
942			return cache[(n2, v)]
943		if n2 == n and do_cmp:
944			var_exp = mk_var (v[0], v[1])
945			vs = set ([v for v in [v] if v not in const_vars])
946			return (vs, var_exp)
947		warm_cache_before (n2, v)
948		ps = [n3 for n3 in p.preds[n2] if n3 in loop
949			if not node_impossible (n3)]
950		if not ps:
951			return None
952		vs = [var_eval_after (n3, v) for n3 in ps]
953		if not all ([v3 == vs[0] for v3 in vs]):
954			if diag:
955				trace ('vs disagree for %s @ %d: %s' % (v, n2, vs))
956			r = None
957		else:
958			r = vs[0]
959		if do_cmp:
960			cache[(n2, v)] = r
961		return r
962	def var_eval_after (n2, v):
963		node = nodes[n2]
964		if node.kind == 'Call' and v in node.rets:
965			if diag:
966				trace ('fetched %s from call at %d' % (v, n2))
967			return None
968		elif node.kind == 'Basic':
969			for (lv, val) in node.upds:
970				if lv == v:
971					return expr_eval_before (n2, val)
972			return var_eval_before (n2, v)
973		else:
974			return var_eval_before (n2, v)
975	def expr_eval_before (n2, expr):
976		if expr.kind == 'Op':
977			if expr.vals == []:
978				return (set(), expr)
979			vals = [expr_eval_before (n2, v)
980				for v in expr.vals]
981			if None in vals:
982				return None
983			s = set.union (* [s for (s, v) in vals])
984			if len(s) > 1:
985				if diag:
986					trace ('too many vars for %s @ %d: %s' % (expr, n2, s))
987				return None
988			return (s, Expr ('Op', expr.typ,
989				name = expr.name,
990				vals = [v for (s, v) in vals]))
991		elif expr.kind == 'Num':
992			return (set(), expr)
993		elif expr.kind == 'Var':
994			return var_eval_before (n2,
995				(expr.name, expr.typ))
996		else:
997			if diag:
998				trace ('Unwalkable expr %s' % expr)
999			return None
1000	def node_impossible (n2):
1001		if n2 in impossible_nodes:
1002			return impossible_nodes[n2]
1003		if n2 == n or n2 in p.get_loop_splittables (n):
1004			imposs = False
1005		else:
1006			pres = [n3 for n3 in p.preds[n2]
1007				if n3 in loop if not node_impossible (n3)]
1008			if n2 in impossible_nodes:
1009				imposs = impossible_nodes[n2]
1010			else:
1011				imposs = not bool (pres)
1012		impossible_nodes[n2] = imposs
1013		node = nodes[n2]
1014		if imposs or node.kind != 'Cond':
1015			return imposs
1016		if 1 >= len ([n3 for n3 in node.get_conts ()
1017				if n3 in loop]):
1018			return imposs
1019		c = expr_eval_before (n2, node.cond)
1020		if c != None:
1021			c = try_eval_expr (c[1])
1022		if c != None:
1023			trace ('determined loop inner cond at %d equals %s'
1024				% (n2, c == syntax.true_term))
1025		if c == syntax.true_term:
1026			impossible_nodes[node.right] = True
1027		elif c == syntax.false_term:
1028			impossible_nodes[node.left] = True
1029		return imposs
1030
1031	vca = {}
1032	for v in vs:
1033		rv = var_eval_before (n, v, do_cmp = False)
1034		if rv == None:
1035			vca[v] = 'LoopVariable'
1036			continue
1037		(s, expr) = rv
1038		if expr == mk_var (v[0], v[1]):
1039			vca[v] = 'LoopConst'
1040			continue
1041		if all_vars_in_set (expr, const_vars):
1042			# a repeatedly evaluated const expression, is const
1043			vca[v] = 'LoopConst'
1044			continue
1045		if var_not_in_expr (mk_var (v[0], v[1]), expr):
1046			# leaf calculations do not have data flow to
1047			# themselves. the search algorithm doesn't
1048			# have to worry about these.
1049			vca[v] = 'LoopLeaf'
1050			continue
1051		(form, offs) = accumulator_closed_form (expr, v)
1052		if form != None and all_vars_in_set (form (), const_vars):
1053			vca[v] = ('LoopLinearSeries', form, offs)
1054		else:
1055			if diag:
1056				trace ('No accumulator %s => %s'
1057					% (v, expr))
1058			no_accum_expressions.add ((v, expr))
1059			vca[v] = 'LoopVariable'
1060	return vca
1061
1062eval_expr_solver = [None]
1063
1064def try_eval_expr (expr):
1065	"""attempt to reduce an expression to a single result, vaguely like
1066	what constant propagation would do. it might work!"""
1067	import search
1068	if not eval_expr_solver[0]:
1069		import solver
1070		eval_expr_solver[0] = solver.Solver ()
1071	try:
1072		return search.eval_model_expr ({}, eval_expr_solver[0], expr)
1073	except KeyboardInterrupt, e:
1074		raise e
1075	except Exception, e:
1076		return None
1077
1078expr_linear_sum = set (['Plus', 'Minus'])
1079expr_linear_cast = set (['WordCast', 'WordCastSigned'])
1080
1081expr_linear_all = set.union (expr_linear_sum, expr_linear_cast,
1082	['Times', 'ShiftLeft'])
1083
1084def possibly_linear (expr):
1085	if expr.kind in set (['Var', 'Num', 'Symbol', 'Type', 'Token']):
1086		return True
1087	elif expr.is_op (expr_linear_all):
1088		return all ([possibly_linear (x) for x in expr.vals])
1089	else:
1090		return False
1091
1092def lv_expr (expr, env):
1093	if expr in env:
1094		return env[expr]
1095	elif expr.kind in set (['Num', 'Symbol', 'Type', 'Token']):
1096		return (expr, 'LoopConst', None, set ())
1097	elif expr.kind == 'Var':
1098		return (None, None, None, None)
1099	elif expr.kind != 'Op':
1100		assert expr in env, expr
1101
1102	lvs = [lv_expr (v, env) for v in expr.vals]
1103	rs = [lv[1] for lv in lvs]
1104	mk_offs = lambda vals: syntax.adjust_op_vals (expr, vals)
1105	if None in rs:
1106		return (None, None, None, None)
1107	if set (rs) == set (['LoopConst']):
1108		return (expr, 'LoopConst', None, set ())
1109	offs_set = set.union (* ([lv[3] for lv in lvs] + [set ()]))
1110	arg_offs = []
1111	for (expr2, k, offs, _) in lvs:
1112		if k == 'LoopConst' and expr2.typ.kind == 'Word':
1113			arg_offs.append (syntax.mk_num (0, expr2.typ))
1114		else:
1115			arg_offs.append (offs)
1116	if expr.is_op (expr_linear_sum):
1117		if set (rs) == set (['LoopConst', 'LoopLinearSeries']):
1118			return (expr, 'LoopLinearSeries', mk_offs (arg_offs),
1119				offs_set)
1120	elif expr.is_op ('Times'):
1121		if set (rs) == set (['LoopLinearSeries', 'LoopConst']):
1122			# the new offset is the product of the linear offset
1123			# and the constant value
1124			[linear_offs] = [offs for (_, k, offs, _) in lvs
1125				if k == 'LoopLinearSeries']
1126			[const_value] = [v for (v, k, _, _) in lvs
1127				if k == 'LoopConst']
1128			return (expr, 'LoopLinearSeries',
1129				mk_offs ([linear_offs, const_value]), offs_set)
1130	if expr.is_op ('ShiftLeft'):
1131		if rs == ['LoopLinearSeries', 'LoopConst']:
1132			return (expr, 'LoopLinearSeries',
1133				mk_offs ([arg_offs[0], lvs[1][0]]), offs_set)
1134	if expr.is_op (expr_linear_cast):
1135		if rs == ['LoopLinearSeries']:
1136			return (expr, 'LoopLinearSeries', mk_offs (arg_offs),
1137				offs_set)
1138	return (None, None, None, None)
1139
1140# FIXME: this should probably be unified with compute_var_cycle_analysis,
1141# but doing so is complicated
1142def linear_series_exprs (p, loop, va):
1143	def lv_init (v, data):
1144		if data[0] == 'LoopLinearSeries':
1145			return (v, 'LoopLinearSeries', data[2], set ([data[2]]))
1146		elif data == 'LoopConst':
1147			return (v, 'LoopConst', None, set ())
1148		else:
1149			return (None, None, None, None)
1150	cache = {loop: dict ([(v, lv_init (v, data)) for (v, data) in va])}
1151	post_cache = {}
1152	loop_body = p.loop_body (loop)
1153	frontier = [n2 for n2 in p.nodes[loop].get_conts ()
1154		if n2 in loop_body]
1155	def lv_merge ((v1, lv1, offs1, oset1), (v2, lv2, offs2, oset2)):
1156		if v1 != v2:
1157			return (None, None, None, None)
1158		assert lv1 == lv2 and offs1 == offs2
1159		return (v1, lv1, offs1, oset1)
1160	def compute_post (n):
1161		if n in post_cache:
1162			return post_cache[n]
1163		pre_env = cache[n]
1164		env = dict (cache[n])
1165		if p.nodes[n].kind == 'Basic':
1166			for ((v, typ), rexpr) in p.nodes[n].upds:
1167				env[mk_var (v, typ)] = lv_expr (rexpr, pre_env)
1168		elif p.nodes[n].kind == 'Call':
1169			for (v, typ) in p.nodes[n].get_lvals ():
1170				env[mk_var (v, typ)] = (None, None, None, None)
1171		post_cache[n] = env
1172		return env
1173	while frontier:
1174		n = frontier.pop ()
1175		if [n2 for n2 in p.preds[n] if n2 in loop_body
1176				if n2 not in cache]:
1177			continue
1178		if n in cache:
1179			continue
1180		envs = [compute_post (n2) for n2 in p.preds[n]
1181			if n2 in loop_body]
1182		all_vs = set.union (* [set (env) for env in envs])
1183		cache[n] = dict ([(v, foldr1 (lv_merge,
1184				[env.get (v, (None, None, None, None))
1185					for env in envs]))
1186			for v in all_vs])
1187		frontier.extend ([n2 for n2 in p.nodes[n].get_conts ()
1188			if n2 in loop_body])
1189	return cache
1190
1191def get_loop_linear_offs (p, loop_head):
1192	import search
1193	va = search.get_loop_var_analysis_at (p, loop_head)
1194	exprs = linear_series_exprs (p, loop_head, va)
1195	def offs_fn (n, expr):
1196		assert p.loop_id (n) == loop_head
1197		env = exprs[n]
1198		rv = lv_expr (expr, env)
1199		if rv[1] == None:
1200			return None
1201		elif rv[1] == 'LoopConst':
1202			return mk_num (0, expr.typ)
1203		elif rv[1] == 'LoopLinearSeries':
1204			return rv[2]
1205		else:
1206			assert not 'lv_expr kind understood', rv
1207	return offs_fn
1208
1209def interesting_node_exprs (p, n, tags = None, use_pairings = True):
1210	if tags == None:
1211		tags = p.pairing.tags
1212	node = p.nodes[n]
1213	memaccs = node.get_mem_accesses ()
1214	vs = [(kind, ptr) for (kind, ptr, v, m) in memaccs]
1215	vs += [('MemUpdateArg', v) for (kind, ptr, v, m) in memaccs
1216		if kind == 'MemUpdate']
1217
1218	if node.kind == 'Call' and use_pairings:
1219		tag = p.node_tags[n][0]
1220		from target_objects import functions, pairings
1221		import solver
1222		fun = functions[node.fname]
1223		arg_input_map = dict (azip (fun.inputs, node.args))
1224		pairs = [pair for pair in pairings.get (node.fname, [])
1225			if pair.tags == tags]
1226		if not pairs:
1227			return vs
1228		[pair] = pairs
1229		in_eq_vs = [(('Call', pair.name, i),
1230				var_subst (v, arg_input_map))
1231			for (i, ((lhs, l_s), (rhs, r_s)))
1232				in enumerate (pair.eqs[0])
1233			if l_s.endswith ('_IN') and r_s.endswith ('_IN')
1234			if l_s != r_s
1235			if solver.typ_representable (lhs.typ)
1236			for (v, site) in [(lhs, l_s), (rhs, r_s)]
1237			if site == '%s_IN' % tag]
1238		vs.extend (in_eq_vs)
1239	return vs
1240
1241def interesting_linear_series_exprs (p, loop, va, tags = None,
1242		use_pairings = True):
1243	if tags == None:
1244		tags = p.pairing.tags
1245	expr_env = linear_series_exprs (p, loop, va)
1246	res_env = {}
1247	for (n, env) in expr_env.iteritems ():
1248		vs = interesting_node_exprs (p, n)
1249
1250		vs = [(kind, v, lv_expr (v, env)) for (kind, v) in vs]
1251		vs = [(kind, v, offs, offs_set)
1252			for (kind, v, (_, lv, offs, offs_set)) in vs
1253			if lv == 'LoopLinearSeries']
1254		if vs:
1255			res_env[n] = vs
1256	return res_env
1257
1258def mk_var_renames (xs, ys):
1259	renames = {}
1260	for (x, y) in azip (xs, ys):
1261		assert x.kind == 'Var' and y.kind == 'Var'
1262		assert x.name not in renames
1263		renames[x.name] = y.name
1264	return renames
1265
1266def first_aligned_address (nodes, radix):
1267	ks = [k for k in nodes
1268		if k % radix == 0]
1269	if ks:
1270		return min (ks)
1271	else:
1272		return None
1273
1274def entry_aligned_address (fun, radix):
1275	n = fun.entry
1276	while n % radix != 0:
1277		ns = fun.nodes[n].get_conts ()
1278		assert len (ns) == 1, (fun.name, n)
1279		[n] = ns
1280	return n
1281
1282def aligned_address_sanity (functions, symbols, radix):
1283	for (f, func) in functions.iteritems ():
1284		if f not in symbols:
1285			# happens for static or invented functions sometimes
1286			continue
1287		if func.entry:
1288			addr = first_aligned_address (func.nodes, radix)
1289			if addr == None:
1290				printout ('Warning: %s: no aligned instructions' % f)
1291				continue
1292			addr2 = symbols[f][0]
1293			if addr != addr2:
1294				printout ('target mismatch on func %s' % f)
1295				printout ('  (starts at 0x%x not 0x%x)' % (addr, addr2))
1296				return False
1297			addr3 = entry_aligned_address (func, radix)
1298			if addr3 != addr2:
1299				printout ('entry mismatch on func %s' % f)
1300				printout ('  (enters at 0x%x not 0x%x)' % (addr3, addr2))
1301				return False
1302	return True
1303
1304# variant of tarjan's strongly connected component algorithm
1305def tarjan (graph, entries):
1306	"""tarjan (graph, entries)
1307	variant of tarjan's strongly connected component algorithm
1308	e.g. tarjan ({1: [2, 3], 3: [4, 5]}, [1])
1309	entries should not be reachable"""
1310	data = {}
1311	comps = []
1312	for v in entries:
1313		assert v not in data
1314		tarjan1 (graph, v, data, [], set ([]), comps)
1315	return comps
1316
1317def tarjan1 (graph, v, data, stack, stack_set, comps):
1318	vs = []
1319	while True:
1320		# skip through nodes with single successors
1321		data[v] = [len(data), len(data)]
1322		stack.append(v)
1323		stack_set.add(v)
1324		cs = graph[v]
1325		if len (cs) != 1 or cs[0] in data:
1326			break
1327		vs.append ((v, cs[0]))
1328		[v] = cs
1329
1330	for c in graph[v]:
1331		if c not in data:
1332			tarjan1 (graph, c, data, stack, stack_set, comps)
1333			data[v][1] = min (data[v][1], data[c][1])
1334		elif c in stack_set:
1335			data[v][1] = min (data[v][1], data[c][0])
1336
1337	vs.reverse ()
1338	for (v2, c) in vs:
1339		data[v2][1] = min (data[v2][1], data[c][1])
1340
1341	for (v2, _) in [(v, 0)] + vs:
1342		if data[v2][1] == data[v2][0]:
1343			comp = []
1344			while True:
1345				x = stack.pop ()
1346				stack_set.remove (x)
1347				if x == v2:
1348					break
1349				comp.append (x)
1350			comps.append ((v2, comp))
1351
1352def divides_loop (graph, split_set):
1353	graph2 = dict (graph)
1354	for n in split_set:
1355		graph2[n] = []
1356	assert 'ENTRY_POINT' not in graph2
1357	graph2['ENTRY_POINT'] = list (graph)
1358	comps = tarjan (graph2, ['ENTRY_POINT'])
1359	return not ([(h, t) for (h, t) in comps if t])
1360
1361def strongly_connected_split_points1 (graph):
1362	"""find the nodes of a strongly connected
1363	component which, when removed, disconnect the component.
1364	complex loops lack such a split point."""
1365
1366	# find one simple cycle in the graph
1367	walk = []
1368	walk_set = set ()
1369	n = min (graph)
1370	while n not in walk_set:
1371		walk.append (n)
1372		walk_set.add (n)
1373		n = graph[n][0]
1374	i = walk.index (n)
1375	cycle = walk[i:]
1376
1377	def subgraph_test (subgraph):
1378		graph2 = dict ([(n, [n2 for n2 in graph[n] if n2 in subgraph])
1379			for n in subgraph])
1380		graph2['HEAD'] = list (subgraph)
1381		comps = tarjan (graph2, ['HEAD'])
1382		return bool ([h for (h, t) in comps if t])
1383
1384	cycle_set = set (cycle)
1385	cycle = [('Node', set ([n]), False,
1386			[n2 for n2 in graph[n] if n2 != graph[n][0]])
1387		for n in cycle]
1388	i = 0
1389	while i < len (cycle):
1390		print i, cycle
1391		(kind, ns, test, unvisited) = cycle[i]
1392		if not unvisited:
1393			i += 1
1394			continue
1395		n = unvisited.pop ()
1396		arc_set = set ()
1397		while n not in cycle_set:
1398			if n in arc_set:
1399				# found two totally disjoint loops, so there
1400				# are no splitting points
1401				return set ()
1402			arc_set.add (n)
1403			n = graph[n][0]
1404		if n in ns:
1405			if kind == 'Node':
1406				# only this node can be a splittable now.
1407				if subgraph_test (set (graph) - set ([n])):
1408					return set ()
1409				else:
1410					return set ([n])
1411			else:
1412				cycle[i] = (kind, ns, True, unvisited)
1413				ns.update (arc_set)
1414				continue
1415		j = (i + 1) % len (cycle)
1416		new_ns = set ()
1417		new_unvisited = set ()
1418		new_test = False
1419		while n not in cycle[j][1]:
1420			new_ns.update (cycle[j][1])
1421			new_unvisited.update (cycle[j][3])
1422			new_test = cycle[j][2] or new_test
1423			j = (j + 1) % len (cycle)
1424		new_ns.update (arc_set)
1425		new_unvisited.update ([n3 for n2 in arc_set for n3 in graph[n2]])
1426		new_v = ('Group', new_ns, new_test, list (new_unvisited - new_ns))
1427		print i, j, n
1428		if j > i:
1429			cycle[i + 1:j] = [new_v]
1430		else:
1431			cycle = [cycle[i], new_v] + cycle[j:i]
1432			i = 0
1433		cycle_set.update (new_ns)
1434	for (kind, ns, test, unvisited) in cycle:
1435		if test and subgraph_test (ns):
1436			return set ()
1437	return set ([n for (kind, ns, _, _) in cycle
1438		if kind == 'Node' for n in ns])
1439
1440def strongly_connected_split_points (graph):
1441	res = strongly_connected_split_points1 (graph)
1442	res2 = set ()
1443	for n in graph:
1444		graph2 = dict (graph)
1445		graph2[n] = []
1446		graph2['ENTRY'] = list (graph)
1447		comps = tarjan (graph2, ['ENTRY'])
1448		if not [comp for comp in comps if comp[1]]:
1449			res2.add (n)
1450	assert res == res2, (graph, res, res2)
1451	return res
1452
1453def get_one_loop_splittable (p, loop_set):
1454	"""discover a component of a strongly connected
1455	component which, when removed, disconnects the component.
1456	complex loops lack such a split point."""
1457	candidates = set (loop_set)
1458	graph = dict ([(x, [y for y in p.nodes[x].get_conts ()
1459		if y in loop_set]) for x in loop_set])
1460	while candidates:
1461		loop2 = find_loop_avoiding (graph, loop_set, candidates)
1462		candidates = set.intersection (loop2, candidates)
1463		if not candidates:
1464			return None
1465		n = candidates.pop ()
1466		graph2 = dict ([(x, [y for y in graph[x] if y != n])
1467			for x in graph])
1468		comps = tarjan (graph2, [n])
1469		comps = [(h, t) for (h, t) in comps if t]
1470		if not comps:
1471			return n
1472		for (h, t) in comps:
1473			s = set ([h] + t)
1474			candidates = set.intersection (s, candidates)
1475	return None
1476
1477def find_loop_avoiding (graph, loop, avoid):
1478	n = (list (loop - avoid) + list (loop))[0]
1479	arc = [n]
1480	visited = set ([n])
1481	while True:
1482		cs = set (graph[n])
1483		acs = cs - avoid
1484		vcs = set.intersection (cs, visited)
1485		if vcs:
1486			n = vcs.pop ()
1487			break
1488		elif acs:
1489			n = acs.pop ()
1490		else:
1491			n = cs.pop ()
1492		visited.add (n)
1493		arc.append (n)
1494	[i] = [i for (i, n2) in enumerate (arc) if n2 == n]
1495	return set (arc[i:])
1496
1497# non-equality relations in proof hypotheses are recorded as a pretend
1498# equality and reverted to their 'real' meaning here.
1499def mk_stack_wrapper (stack_ptr, stack, excepts):
1500	return syntax.mk_rel_wrapper ('StackWrapper',
1501		[stack_ptr, stack] + excepts)
1502
1503def mk_mem_acc_wrapper (addr, v):
1504	return syntax.mk_rel_wrapper ('MemAccWrapper', [addr, v])
1505
1506def mk_mem_wrapper (m):
1507	return syntax.mk_rel_wrapper ('MemWrapper', [m])
1508
1509def tm_with_word32_list (xs):
1510	if xs:
1511		return foldr1 (mk_plus, map (mk_word32, xs))
1512	else:
1513		return mk_uminus (mk_word32 (0))
1514
1515def word32_list_from_tm (t):
1516	xs = []
1517	while t.is_op ('Plus'):
1518		[x, t] = t.vals
1519		assert x.kind == 'Num' and x.typ == word32T
1520		xs.append (x.val)
1521	if t.kind == 'Num':
1522		xs.append (t.val)
1523	return xs
1524
1525def mk_eq_selective_wrapper (v, (xs, ys)):
1526	# this is a huge hack, but we need to put these lists somewhere
1527	xs = tm_with_word32_list (xs)
1528	ys = tm_with_word32_list (ys)
1529	return syntax.mk_rel_wrapper ('EqSelectiveWrapper', [v, xs, ys])
1530
1531def apply_rel_wrapper (lhs, rhs):
1532	assert lhs.typ == syntax.builtinTs['RelWrapper']
1533	assert rhs.typ == syntax.builtinTs['RelWrapper']
1534	assert lhs.kind == 'Op'
1535	assert rhs.kind == 'Op'
1536	ops = set ([lhs.name, rhs.name])
1537	if ops == set (['StackWrapper']):
1538		[sp1, st1] = lhs.vals[:2]
1539		[sp2, st2] = rhs.vals[:2]
1540		excepts = list (set (lhs.vals[2:] + rhs.vals[2:]))
1541		for p in excepts:
1542			st1 = syntax.mk_memupd (st1, p, syntax.mk_word32 (0))
1543			st2 = syntax.mk_memupd (st2, p, syntax.mk_word32 (0))
1544		return syntax.Expr ('Op', boolT, name = 'StackEquals',
1545			vals = [sp1, st1, sp2, st2])
1546	elif ops == set (['MemAccWrapper', 'MemWrapper']):
1547		[acc] = [v for v in [lhs, rhs] if v.is_op ('MemAccWrapper')]
1548		[addr, val] = acc.vals
1549		assert addr.typ == syntax.word32T
1550		[m] = [v for v in [lhs, rhs] if v.is_op ('MemWrapper')]
1551		[m] = m.vals
1552		assert m.typ == builtinTs['Mem']
1553		expr = mk_eq (mk_memacc (m, addr, val.typ), val)
1554		return expr
1555	elif ops == set (['EqSelectiveWrapper']):
1556		[lhs_v, _, _] = lhs.vals
1557		[rhs_v, _, _] = rhs.vals
1558		if lhs_v.typ == syntax.builtinTs['RelWrapper']:
1559			return apply_rel_wrapper (lhs_v, rhs_v)
1560		else:
1561			return mk_eq (lhs, rhs)
1562	else:
1563		assert not 'rel wrapper opname understood'
1564
1565def inst_eq_at_visit (exp, vis):
1566	if not exp.is_op ('EqSelectiveWrapper'):
1567		return True
1568	[_, xs, ys] = exp.vals
1569	# hacks
1570	xs = word32_list_from_tm (xs)
1571	ys = word32_list_from_tm (ys)
1572	if vis.kind == 'Number':
1573		return vis.n in xs
1574	elif vis.kind == 'Offset':
1575		return vis.n in ys
1576	else:
1577		assert not 'visit kind useable', vis
1578
1579def strengthen_hyp (expr, sign = 1):
1580	if not expr.kind == 'Op':
1581		return expr
1582	if expr.name in ['And', 'Or']:
1583		vals = [strengthen_hyp (v, sign) for v in expr.vals]
1584		return syntax.adjust_op_vals (expr, vals)
1585	elif expr.name == 'Implies':
1586		[l, r] = expr.vals
1587		l = strengthen_hyp (l, - sign)
1588		r = strengthen_hyp (r, sign)
1589		return syntax.mk_implies (l, r)
1590	elif expr.name == 'Not':
1591		[x] = expr.vals
1592		x = strengthen_hyp (x, - sign)
1593		return syntax.mk_not (x)
1594	elif expr.name == 'StackEquals':
1595		if sign == 1:
1596			return syntax.Expr ('Op', boolT,
1597				name = 'ImpliesStackEquals', vals = expr.vals)
1598		else:
1599			return syntax.Expr ('Op', boolT,
1600				name = 'StackEqualsImplies', vals = expr.vals)
1601	elif expr.name == 'ROData':
1602		if sign == 1:
1603			return syntax.Expr ('Op', boolT,
1604				name = 'ImpliesROData', vals = expr.vals)
1605		else:
1606			return expr
1607	elif expr.name == 'Equals' and expr.vals[0].typ == boolT:
1608		vals = expr.vals
1609		if vals[1] in [syntax.true_term, syntax.false_term]:
1610			vals = [vals[1], vals[0]]
1611		if vals[0] == syntax.true_term:
1612			return strengthen_hyp (vals[1], sign)
1613		elif vals[0] == syntax.false_term:
1614			return strengthen_hyp (syntax.mk_not (vals[1]), sign)
1615		else:
1616			return expr
1617	else:
1618		return expr
1619
1620def weaken_assert (expr):
1621	return strengthen_hyp (expr, -1)
1622
1623pred_logic_ops = set (['Not', 'And', 'Or', 'Implies'])
1624
1625def norm_neg (expr):
1626	if not expr.is_op ('Not'):
1627		return expr
1628	[nexpr] = expr.vals
1629	if not nexpr.is_op (pred_logic_ops):
1630		return expr
1631	if nexpr.is_op ('Not'):
1632		[expr] = nexpr.vals
1633		return norm_neg (expr)
1634	[x, y] = nexpr.vals
1635	if nexpr.is_op ('And'):
1636		return mk_or (norm_mk_not (x), norm_mk_not (y))
1637	elif nexpr.is_op ('Or'):
1638		return mk_and (norm_mk_not (x), norm_mk_not (y))
1639	elif nexpr.is_op ('Implies'):
1640		return mk_and (x, mk_not (y))
1641
1642def norm_mk_not (expr):
1643	return norm_neg (mk_not (expr))
1644
1645def split_conjuncts (expr):
1646	expr = norm_neg (expr)
1647	if expr.is_op ('And'):
1648		[x, y] = expr.vals
1649		return split_conjuncts (x) + split_conjuncts (y)
1650	else:
1651		return [expr]
1652
1653def split_disjuncts (expr):
1654	expr = norm_neg (expr)
1655	if expr.is_op ('Or'):
1656		[x, y] = expr.vals
1657		return split_disjuncts (x) + split_disjuncts (y)
1658	else:
1659		return [expr]
1660
1661def binary_search_least (test, minimum, maximum):
1662	"""find least n, minimum <= n <= maximum, for which test (n)."""
1663	assert maximum >= minimum
1664	if test (minimum):
1665		return minimum
1666	if maximum == minimum or not test (maximum):
1667		return None
1668	while maximum > minimum + 1:
1669		cur = (minimum + maximum) / 2
1670		if test (cur):
1671			maximum = cur
1672		else:
1673			minimum = cur + 1
1674	assert minimum + 1 == maximum
1675	return maximum
1676
1677def binary_search_greatest (test, minimum, maximum):
1678	"""find greatest n, minimum <= n <= maximum, for which test (n)."""
1679	assert maximum >= minimum
1680	if test (maximum):
1681		return maximum
1682	if maximum == minimum or not test (minimum):
1683		return None
1684	while maximum > minimum + 1:
1685		cur = (minimum + maximum) / 2
1686		if test (cur):
1687			minimum = cur
1688		else:
1689			maximum = cur - 1
1690	assert minimum + 1 == maximum
1691	return minimum
1692
1693