1# * Copyright 2015, NICTA
2# *
3# * This software may be distributed and modified according to the terms of
4# * the BSD 2-Clause license. Note that NO WARRANTY is provided.
5# * See "LICENSE_BSD2.txt" for details.
6# *
7# * @TAG(NICTA_BSD)
8
9# pseudo-compiler for use of aggregate types in C-derived function code
10
11import syntax
12from syntax import structs, get_vars, get_expr_typ, get_node_vars, Expr, Node
13import logic
14
15
16(mk_var, mk_plus, mk_uminus, mk_minus, mk_times, mk_modulus, mk_bwand, mk_eq,
17mk_less_eq, mk_less, mk_implies, mk_and, mk_or, mk_not, mk_word32, mk_word8,
18mk_word32_maybe, mk_cast, mk_memacc, mk_memupd, mk_arr_index, mk_arroffs,
19mk_if, mk_meta_typ, mk_pvalid) = syntax.mks
20
21from syntax import word32T, word8T
22
23from syntax import fresh_name, foldr1
24
25from target_objects import symbols, trace
26
27def compile_field_acc (name, expr, replaces):
28	'''pseudo-compile access to field (named name) of expr'''
29	if expr.kind == 'StructCons':
30		return expr.vals[name]
31	elif expr.kind == 'FieldUpd':
32		if expr.field[0] == name:
33			return expr.val
34		else:
35			return compile_field_acc (name, expr.struct, replaces)
36	elif expr.kind == 'Var':
37		assert expr.name in replaces
38		[(v_nm, typ)] = [(v_nm, typ) for (f_nm, v_nm, typ)
39			in replaces[expr.name] if f_nm == name]
40		return mk_var (v_nm, typ)
41	elif expr.is_op ('MemAcc'):
42		assert expr.typ.kind == 'Struct'
43		(typ, offs, _) = structs[expr.typ.name].fields[name]
44		[m, p] = expr.vals
45		return mk_memacc (m, mk_plus (p, mk_word32 (offs)), typ)
46	elif expr.kind == 'Field':
47		expr2 = compile_field_acc (expr.field[0], expr.struct, replaces)
48		return compile_field_acc (name, expr2, replaces)
49	elif expr.is_op ('ArrayIndex'):
50		[arr, i] = expr.vals
51		expr2 = compile_array_acc (i, arr, replaces, False)
52		assert expr2, (arr, i)
53		return compile_field_acc (name, expr2, replaces)
54	else:
55		print expr
56		assert not 'field acc compilable'
57
58def compile_array_acc (i, expr, replaces, must = True):
59	'''pseudo-compile access to array element i of expr'''
60	if not logic.is_int (i) and i.kind == 'Num':
61		assert i.typ == word32T
62		i = i.val
63	if expr.kind == 'Array':
64		if logic.is_int (i):
65			return expr.vals[i]
66		else:
67			expr2 = expr.vals[-1]
68			for (j, v) in enumerate (expr.vals[:-1]):
69				expr2 = mk_if (mk_eq (i, mk_word32 (j)), v, expr2)
70			return expr2
71	elif expr.is_op ('ArrayUpdate'):
72		[arr, j, v] = expr.vals
73		if j.kind == 'Num' and logic.is_int (i):
74			if i == j.val:
75				return v
76			else:
77				return compile_array_acc (i, arr, replaces)
78		else:
79			return mk_if (mk_eq (j, mk_word32_maybe (i)), v,
80				compile_array_acc (i, arr, replaces))
81	elif expr.is_op ('MemAcc'):
82		[m, p] = expr.vals
83		return mk_memacc (m, mk_arroffs (p, expr.typ, i), expr.typ.el_typ)
84	elif expr.is_op ('IfThenElse'):
85		[cond, left, right] = expr.vals
86		return mk_if (cond, compile_array_acc (i, left, replaces),
87			compile_array_acc (i, right, replaces))
88	elif expr.kind == 'Var':
89		assert expr.name in replaces
90		if logic.is_int (i):
91			(_, v_nm, typ) = replaces[expr.name][i]
92			return mk_var (v_nm, typ)
93		else:
94			vs = [(mk_word32 (j), mk_var (v_nm, typ))
95				for (j, v_nm, typ)
96				in replaces[expr.name]]
97			expr2 = vs[0][1]
98			for (j, v) in vs[1:]:
99				expr2 = mk_if (mk_eq (i, j), v, expr2)
100			return expr2
101	else:
102		if not must:
103			return None
104		return mk_arr_index (expr, mk_word32_maybe (i))
105
106def num_fields (container, typ):
107	if container == typ:
108		return 1
109	elif container.kind == 'Array':
110		return container.num * num_fields (container.el_typ, typ)
111	elif container.kind == 'Struct':
112		struct = structs[container.name]
113		return sum ([num_fields (typ2, typ)
114			for (nm, typ2) in struct.field_list])
115	else:
116		return 0
117
118def get_const_global_acc_offset (expr, offs, typ):
119	if expr.kind == 'ConstGlobal':
120		return (expr, offs)
121	elif expr.is_op ('ArrayIndex'):
122		[expr2, offs2] = expr.vals
123		offs = mk_plus (offs, mk_times (offs2,
124			mk_word32 (num_fields (expr.typ, typ))))
125		return get_const_global_acc_offset (expr2, offs, typ)
126	elif expr.kind == 'Field':
127		struct = structs[expr.struct.typ.name]
128		offs2 = 0
129		for (nm, typ2) in struct.field_list:
130			if (nm, typ2) == expr.field:
131				offs = mk_plus (offs, mk_word32 (offs2))
132				return get_const_global_acc_offset (
133					expr.struct, offs, typ)
134			else:
135				offs2 += num_fields (typ2, typ)
136	else:
137		return None
138
139def compile_const_global_acc (expr):
140	if expr.kind == 'ConstGlobal' or (expr.is_op ('ArrayIndex')
141			and expr.vals[0].kind == 'ConstGlobal'):
142		return None
143	if expr.typ.kind != 'Word':
144		return None
145	r = get_const_global_acc_offset (expr, mk_word32 (0), expr.typ)
146	if r == None:
147		return None
148	(cg, offs) = r
149	return mk_arr_index (cg, offs)
150
151def compile_val_fields (expr, replaces):
152	if expr.typ.kind == 'Array':
153		res = []
154		for i in range (expr.typ.num):
155			acc = compile_array_acc (i, expr, replaces)
156			res.extend (compile_val_fields (acc, replaces))
157		return res
158	elif expr.typ.kind == 'Struct':
159		res = []
160		for (nm, typ2) in structs[expr.typ.name].field_list:
161			acc = compile_field_acc (nm, expr, replaces)
162			res.extend (compile_val_fields (acc, replaces))
163		return res
164	else:
165		return [compile_accs (replaces, expr)]
166
167def compile_val_fields_of_typ (expr, replaces, typ):
168	return [e for e in compile_val_fields (expr, replaces)
169		if e.typ == typ]
170
171# it helps in this compilation to know that the outermost expression we are
172# trying to fetch is always of basic type, never struct or array.
173
174# sort of fudged in the array indexing case here
175def compile_accs (replaces, expr):
176	r = compile_const_global_acc (expr)
177	if r:
178		return compile_accs (replaces, r)
179
180	if expr.kind == 'Field':
181		expr = compile_field_acc (expr.field[0], expr.struct, replaces)
182		return compile_accs (replaces, expr)
183	elif expr.is_op ('ArrayIndex'):
184		[arr, n] = expr.vals
185		expr2 = compile_array_acc (n, arr, replaces, False)
186		if expr2:
187			return compile_accs (replaces, expr2)
188		arr = compile_accs (replaces, arr)
189		n = compile_accs (replaces, n)
190		expr2 = compile_array_acc (n, arr, replaces, False)
191		if expr2:
192			return compile_accs (replaces, expr2)
193		else:
194			return mk_arr_index (arr, n)
195	elif (expr.is_op ('MemUpdate')
196			and expr.vals[2].is_op ('MemAcc')
197			and expr.vals[2].vals[0] == expr.vals[0]
198			and expr.vals[2].vals[1] == expr.vals[1]):
199		# null memory copy. probably created by ops below
200		return compile_accs (replaces, expr.vals[0])
201	elif (expr.is_op ('MemUpdate')
202			and expr.vals[2].kind == 'FieldUpd'):
203		[m, p, f_upd] = expr.vals
204		assert f_upd.typ.kind == 'Struct'
205		(typ, offs, _) = structs[f_upd.typ.name].fields[f_upd.field[0]]
206		assert f_upd.val.typ == typ
207		return compile_accs (replaces,
208			mk_memupd (mk_memupd (m, p, f_upd.struct),
209				mk_plus (p, mk_word32 (offs)), f_upd.val))
210	elif (expr.is_op ('MemUpdate')
211			and expr.vals[2].typ.kind == 'Struct'):
212		[m, p, s_val] = expr.vals
213		struct = structs[s_val.typ.name]
214		for (nm, (typ, offs, _)) in struct.fields.iteritems ():
215			f = compile_field_acc (nm, s_val, replaces)
216			assert f.typ == typ
217			m = mk_memupd (m, mk_plus (p, mk_word32 (offs)), f)
218		return compile_accs (replaces, m)
219	elif (expr.is_op ('MemUpdate')
220			and expr.vals[2].is_op ('ArrayUpdate')):
221		[m, p, arr_upd] = expr.vals
222		[arr, i, v] = arr_upd.vals
223		return compile_accs (replaces,
224			mk_memupd (mk_memupd (m, p, arr),
225				mk_arroffs (p, arr.typ, i), v))
226	elif (expr.is_op ('MemUpdate')
227				and expr.vals[2].typ.kind == 'Array'):
228		[m, p, arr] = expr.vals
229		n = arr.typ.num
230		typ = arr.typ.el_typ
231		for i in range (n):
232			offs = i * typ.size ()
233			assert offs == i or offs % 4 == 0
234			e = compile_array_acc (i, arr, replaces)
235			m = mk_memupd (m, mk_plus (p, mk_word32 (offs)), e)
236		return compile_accs (replaces, m)
237	elif expr.is_op ('Equals') \
238			and expr.vals[0].typ.kind in ['Struct', 'Array']:
239		[x, y] = expr.vals
240		assert x.typ == y.typ
241		xs = compile_val_fields (x, replaces)
242		ys = compile_val_fields (y, replaces)
243		eq = foldr1 (mk_and, map (mk_eq, xs, ys))
244		return compile_accs (replaces, eq)
245	elif expr.is_op ('PAlignValid'):
246		[typ, p] = expr.vals
247		p = compile_accs (replaces, p)
248		assert typ.kind == 'Type'
249		return logic.mk_align_valid_ineq (('Type', typ.val), p)
250	elif expr.kind == 'Op':
251		vals = [compile_accs (replaces, v) for v in expr.vals]
252		return syntax.adjust_op_vals (expr, vals)
253	elif expr.kind == 'Symbol':
254		return mk_word32 (symbols[expr.name][0])
255	else:
256		if expr.kind not in {'Var':True, 'ConstGlobal':True,
257				'Num':True, 'Invent':True, 'Type':True}:
258			print expr
259			assert not 'field acc compiled'
260		return expr
261
262def expand_arg_fields (replaces, args):
263	xs = []
264	for arg in args:
265		if arg.typ.kind == 'Struct':
266			ys = [compile_field_acc (nm, arg, replaces)
267				for (nm, _) in structs[arg.typ.name].field_list]
268			xs.extend (expand_arg_fields (replaces, ys))
269		elif arg.typ.kind == 'Array':
270			ys = [compile_array_acc (i, arg, replaces)
271				for i in range (arg.typ.num)]
272			xs.extend (expand_arg_fields (replaces, ys))
273		else:
274			xs.append (compile_accs (replaces, arg))
275	return xs
276
277def expand_lval_list (replaces, lvals):
278	xs = []
279	for (nm, typ) in lvals:
280		if nm in replaces:
281			xs.extend (expand_lval_list (replaces, [(v_nm, typ)
282				for (f_nm, v_nm, typ) in replaces[nm]]))
283		else:
284			assert typ.kind not in ['Struct', 'Array']
285			xs.append ((nm, typ))
286	return xs
287
288def mk_acc (idx, expr, replaces):
289	if logic.is_int (idx):
290		assert expr.typ.kind == 'Array'
291		return compile_array_acc (idx, expr, replaces)
292	else:
293		assert expr.typ.kind == 'Struct'
294		return compile_field_acc (idx, expr, replaces)
295
296def compile_upds (replaces, upds):
297	lvs = expand_lval_list (replaces, [lv for (lv, v) in upds])
298	vs = expand_arg_fields (replaces, [v for (lv, v) in upds])
299
300	assert [typ for (nm, typ) in lvs] == map (get_expr_typ, vs), (lvs, vs)
301
302	return [(lv, v) for (lv, v) in zip (lvs, vs)
303		if not v.is_var (lv)]
304
305def compile_struct_use (function):
306	trace ('Compiling in %s.' % function.name)
307	vs = get_vars (function)
308	max_node = max (function.nodes.keys () + [2])
309
310	visit_vs = vs.keys ()
311	replaces = {}
312	while visit_vs:
313		v = visit_vs.pop ()
314		typ = vs[v]
315		if typ.kind == 'Struct':
316			fields = structs[typ.name].field_list
317		elif typ.kind == 'Array':
318			fields = [(i, typ.el_typ) for i in range (typ.num)]
319		else:
320			continue
321		new_vs = [(nm, fresh_name ('%s.%s' % (v, nm), vs, f_typ), f_typ)
322			for (nm, f_typ) in fields]
323		replaces[v] = new_vs
324		visit_vs.extend ([v_nm for (_, v_nm, _) in new_vs])
325
326	for n in function.nodes:
327		node = function.nodes[n]
328		if node.kind == 'Basic':
329			node.upds = compile_upds (replaces, node.upds)
330		elif node.kind == 'Basic':
331			assert not node.lval[1].kind in ['Struct', 'Array']
332			node.val = compile_accs (replaces, node.val)
333		elif node.kind == 'Call':
334			node.args = expand_arg_fields (replaces, node.args)
335			node.rets = expand_lval_list (replaces, node.rets)
336		elif node.kind == 'Cond':
337			node.cond = compile_accs (replaces, node.cond)
338		else:
339			assert not 'node kind understood'
340
341	function.inputs = expand_lval_list (replaces, function.inputs)
342	function.outputs = expand_lval_list (replaces, function.outputs)
343	return len (replaces) == 0
344
345def check_compile (func):
346	for node in func.nodes.itervalues ():
347		vs = {}
348		get_node_vars (node, vs)
349		for (v_nm, typ) in vs.iteritems ():
350			if typ.kind == 'Struct':
351				print 'Failed to compile struct %s in %s' % (v_nm, func)
352				print node
353				assert not 'compiled'
354			if typ.kind == 'Array':
355				print 'Failed to compile array %s in %s' % (v_nm, func)
356				print node
357				assert not 'compiled'
358
359def subst_expr (expr):
360	if expr.kind == 'Symbol':
361		if expr.name in symbols:
362			return mk_word32 (symbols[expr.name][0])
363		else:
364			return None
365	elif expr.is_op ('PAlignValid'):
366		[typ, p] = expr.vals
367		assert typ.kind == 'Type'
368		return logic.mk_align_valid_ineq (('Type', typ.val), p)
369	elif expr.kind in ['Op', 'Var', 'Num', 'Type']:
370		return None
371	else:
372		assert not 'expression simple-substitutable', expr
373
374def substitute_simple (func):
375	from syntax import Node
376	for (n, node) in func.nodes.items ():
377		func.nodes[n] = node.subst_exprs (subst_expr,
378			ss = set (['Symbol', 'PAlignValid']))
379
380def nodes_symbols (nodes):
381	symbols_needed = set()
382	def visitor (expr):
383		if expr.kind == 'Symbol':
384			symbols_needed.add(expr.name)
385	for node in nodes:
386		node.visit (lambda l: (), visitor)
387	return symbols_needed
388
389def missing_symbols (functions):
390	symbols_needed = nodes_symbols ([node
391		for func in functions.itervalues ()
392		for node in func.nodes.itervalues ()])
393	trouble = symbols_needed - set (symbols)
394	if trouble:
395		print ('Symbols missing for substitution: %r' % trouble)
396	return trouble
397
398def compile_funcs (functions):
399	missing_symbols (functions)
400	for (f, func) in functions.iteritems ():
401		substitute_simple (func)
402		check_compile (func)
403
404def combine_duplicate_nodes (nodes):
405	orig_size = len (nodes)
406	node_renames = {}
407	progress = True
408	while progress:
409		progress = False
410		node_names = {}
411		for (n, node) in nodes.items ():
412			if node in node_names:
413				node_renames[n] = node_names[node]
414				del nodes[n]
415				progress = True
416			else:
417				node_names[node] = n
418
419		if not progress:
420			break
421
422		for (n, node) in nodes.items ():
423			nodes[n] = rename_node_conts (node, node_renames)
424
425	if len (nodes) < orig_size:
426		print 'Trimmed duplicates %d -> %d' % (orig_size, len (nodes))
427	return node_renames
428
429def rename_node_conts (node, renames):
430	new_conts = [renames.get (c, c) for c in node.get_conts ()]
431	return Node (node.kind, new_conts, node.get_args ())
432
433def recommended_rename (s):
434	bits = s.split ('.')
435	if len (bits) != 2:
436		return s
437	if all ([x in '0123456789' for x in bits[1]]):
438		return bits[0]
439	else:
440		return s
441
442def rename_vars (function):
443	preds = logic.compute_preds (function.nodes)
444	var_deps = logic.compute_var_deps (function.nodes,
445		lambda x: function.outputs, preds)
446
447	vs = set ()
448	dont_rename_vs = set ()
449	for n in var_deps:
450		rev_renames = {}
451		for (v, t) in var_deps[n]:
452			v2 = recommended_rename (v)
453			rev_renames.setdefault (v2, [])
454			rev_renames[v2].append ((v, t))
455			vs.add ((v, t))
456		for (v2, vlist) in rev_renames.iteritems ():
457			if len (vlist) > 1:
458				dont_rename_vs.update (vlist)
459
460	renames = dict ([(v, recommended_rename (v)) for (v, t) in vs
461		if (v, t) not in dont_rename_vs])
462
463	f = function
464	f.inputs = [(renames.get (v, v), t) for (v, t) in f.inputs]
465	f.outputs = [(renames.get (v, v), t) for (v, t) in f.outputs]
466	for n in f.nodes:
467		f.nodes[n] = syntax.copy_rename (f.nodes[n], (renames, {}))
468
469def rename_and_combine_function_duplicates (functions):
470	for (f, fun) in functions.iteritems ():
471		rename_vars (fun)
472		renames = combine_duplicate_nodes (fun.nodes)
473		fun.entry = renames.get (fun.entry, fun.entry)
474
475
476