1#
2# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3#
4# SPDX-License-Identifier: BSD-2-Clause
5#
6
7# pseudo-compiler for use of aggregate types in C-derived function code
8
9import syntax
10from syntax import structs, get_vars, get_expr_typ, get_node_vars, Expr, Node
11import logic
12
13
14(mk_var, mk_plus, mk_uminus, mk_minus, mk_times, mk_modulus, mk_bwand, mk_eq,
15mk_less_eq, mk_less, mk_implies, mk_and, mk_or, mk_not, mk_word32, mk_word8,
16mk_word32_maybe, mk_cast, mk_memacc, mk_memupd, mk_arr_index, mk_arroffs,
17mk_if, mk_meta_typ, mk_pvalid) = syntax.mks
18
19from syntax import word32T, word8T
20
21from syntax import fresh_name, foldr1
22
23from target_objects import symbols, trace
24
25def compile_field_acc (name, expr, replaces):
26	'''pseudo-compile access to field (named name) of expr'''
27	if expr.kind == 'StructCons':
28		return expr.vals[name]
29	elif expr.kind == 'FieldUpd':
30		if expr.field[0] == name:
31			return expr.val
32		else:
33			return compile_field_acc (name, expr.struct, replaces)
34	elif expr.kind == 'Var':
35		assert expr.name in replaces
36		[(v_nm, typ)] = [(v_nm, typ) for (f_nm, v_nm, typ)
37			in replaces[expr.name] if f_nm == name]
38		return mk_var (v_nm, typ)
39	elif expr.is_op ('MemAcc'):
40		assert expr.typ.kind == 'Struct'
41		(typ, offs, _) = structs[expr.typ.name].fields[name]
42		[m, p] = expr.vals
43		return mk_memacc (m, mk_plus (p, mk_word32 (offs)), typ)
44	elif expr.kind == 'Field':
45		expr2 = compile_field_acc (expr.field[0], expr.struct, replaces)
46		return compile_field_acc (name, expr2, replaces)
47	elif expr.is_op ('ArrayIndex'):
48		[arr, i] = expr.vals
49		expr2 = compile_array_acc (i, arr, replaces, False)
50		assert expr2, (arr, i)
51		return compile_field_acc (name, expr2, replaces)
52	else:
53		print expr
54		assert not 'field acc compilable'
55
56def compile_array_acc (i, expr, replaces, must = True):
57	'''pseudo-compile access to array element i of expr'''
58	if not logic.is_int (i) and i.kind == 'Num':
59		assert i.typ == word32T
60		i = i.val
61	if expr.kind == 'Array':
62		if logic.is_int (i):
63			return expr.vals[i]
64		else:
65			expr2 = expr.vals[-1]
66			for (j, v) in enumerate (expr.vals[:-1]):
67				expr2 = mk_if (mk_eq (i, mk_word32 (j)), v, expr2)
68			return expr2
69	elif expr.is_op ('ArrayUpdate'):
70		[arr, j, v] = expr.vals
71		if j.kind == 'Num' and logic.is_int (i):
72			if i == j.val:
73				return v
74			else:
75				return compile_array_acc (i, arr, replaces)
76		else:
77			return mk_if (mk_eq (j, mk_word32_maybe (i)), v,
78				compile_array_acc (i, arr, replaces))
79	elif expr.is_op ('MemAcc'):
80		[m, p] = expr.vals
81		return mk_memacc (m, mk_arroffs (p, expr.typ, i), expr.typ.el_typ)
82	elif expr.is_op ('IfThenElse'):
83		[cond, left, right] = expr.vals
84		return mk_if (cond, compile_array_acc (i, left, replaces),
85			compile_array_acc (i, right, replaces))
86	elif expr.kind == 'Var':
87		assert expr.name in replaces
88		if logic.is_int (i):
89			(_, v_nm, typ) = replaces[expr.name][i]
90			return mk_var (v_nm, typ)
91		else:
92			vs = [(mk_word32 (j), mk_var (v_nm, typ))
93				for (j, v_nm, typ)
94				in replaces[expr.name]]
95			expr2 = vs[0][1]
96			for (j, v) in vs[1:]:
97				expr2 = mk_if (mk_eq (i, j), v, expr2)
98			return expr2
99	else:
100		if not must:
101			return None
102		return mk_arr_index (expr, mk_word32_maybe (i))
103
104def num_fields (container, typ):
105	if container == typ:
106		return 1
107	elif container.kind == 'Array':
108		return container.num * num_fields (container.el_typ, typ)
109	elif container.kind == 'Struct':
110		struct = structs[container.name]
111		return sum ([num_fields (typ2, typ)
112			for (nm, typ2) in struct.field_list])
113	else:
114		return 0
115
116def get_const_global_acc_offset (expr, offs, typ):
117	if expr.kind == 'ConstGlobal':
118		return (expr, offs)
119	elif expr.is_op ('ArrayIndex'):
120		[expr2, offs2] = expr.vals
121		offs = mk_plus (offs, mk_times (offs2,
122			mk_word32 (num_fields (expr.typ, typ))))
123		return get_const_global_acc_offset (expr2, offs, typ)
124	elif expr.kind == 'Field':
125		struct = structs[expr.struct.typ.name]
126		offs2 = 0
127		for (nm, typ2) in struct.field_list:
128			if (nm, typ2) == expr.field:
129				offs = mk_plus (offs, mk_word32 (offs2))
130				return get_const_global_acc_offset (
131					expr.struct, offs, typ)
132			else:
133				offs2 += num_fields (typ2, typ)
134	else:
135		return None
136
137def compile_const_global_acc (expr):
138	if expr.kind == 'ConstGlobal' or (expr.is_op ('ArrayIndex')
139			and expr.vals[0].kind == 'ConstGlobal'):
140		return None
141	if expr.typ.kind != 'Word':
142		return None
143	r = get_const_global_acc_offset (expr, mk_word32 (0), expr.typ)
144	if r == None:
145		return None
146	(cg, offs) = r
147	return mk_arr_index (cg, offs)
148
149def compile_val_fields (expr, replaces):
150	if expr.typ.kind == 'Array':
151		res = []
152		for i in range (expr.typ.num):
153			acc = compile_array_acc (i, expr, replaces)
154			res.extend (compile_val_fields (acc, replaces))
155		return res
156	elif expr.typ.kind == 'Struct':
157		res = []
158		for (nm, typ2) in structs[expr.typ.name].field_list:
159			acc = compile_field_acc (nm, expr, replaces)
160			res.extend (compile_val_fields (acc, replaces))
161		return res
162	else:
163		return [compile_accs (replaces, expr)]
164
165def compile_val_fields_of_typ (expr, replaces, typ):
166	return [e for e in compile_val_fields (expr, replaces)
167		if e.typ == typ]
168
169# it helps in this compilation to know that the outermost expression we are
170# trying to fetch is always of basic type, never struct or array.
171
172# sort of fudged in the array indexing case here
173def compile_accs (replaces, expr):
174	r = compile_const_global_acc (expr)
175	if r:
176		return compile_accs (replaces, r)
177
178	if expr.kind == 'Field':
179		expr = compile_field_acc (expr.field[0], expr.struct, replaces)
180		return compile_accs (replaces, expr)
181	elif expr.is_op ('ArrayIndex'):
182		[arr, n] = expr.vals
183		expr2 = compile_array_acc (n, arr, replaces, False)
184		if expr2:
185			return compile_accs (replaces, expr2)
186		arr = compile_accs (replaces, arr)
187		n = compile_accs (replaces, n)
188		expr2 = compile_array_acc (n, arr, replaces, False)
189		if expr2:
190			return compile_accs (replaces, expr2)
191		else:
192			return mk_arr_index (arr, n)
193	elif (expr.is_op ('MemUpdate')
194			and expr.vals[2].is_op ('MemAcc')
195			and expr.vals[2].vals[0] == expr.vals[0]
196			and expr.vals[2].vals[1] == expr.vals[1]):
197		# null memory copy. probably created by ops below
198		return compile_accs (replaces, expr.vals[0])
199	elif (expr.is_op ('MemUpdate')
200			and expr.vals[2].kind == 'FieldUpd'):
201		[m, p, f_upd] = expr.vals
202		assert f_upd.typ.kind == 'Struct'
203		(typ, offs, _) = structs[f_upd.typ.name].fields[f_upd.field[0]]
204		assert f_upd.val.typ == typ
205		return compile_accs (replaces,
206			mk_memupd (mk_memupd (m, p, f_upd.struct),
207				mk_plus (p, mk_word32 (offs)), f_upd.val))
208	elif (expr.is_op ('MemUpdate')
209			and expr.vals[2].typ.kind == 'Struct'):
210		[m, p, s_val] = expr.vals
211		struct = structs[s_val.typ.name]
212		for (nm, (typ, offs, _)) in struct.fields.iteritems ():
213			f = compile_field_acc (nm, s_val, replaces)
214			assert f.typ == typ
215			m = mk_memupd (m, mk_plus (p, mk_word32 (offs)), f)
216		return compile_accs (replaces, m)
217	elif (expr.is_op ('MemUpdate')
218			and expr.vals[2].is_op ('ArrayUpdate')):
219		[m, p, arr_upd] = expr.vals
220		[arr, i, v] = arr_upd.vals
221		return compile_accs (replaces,
222			mk_memupd (mk_memupd (m, p, arr),
223				mk_arroffs (p, arr.typ, i), v))
224	elif (expr.is_op ('MemUpdate')
225				and expr.vals[2].typ.kind == 'Array'):
226		[m, p, arr] = expr.vals
227		n = arr.typ.num
228		typ = arr.typ.el_typ
229		for i in range (n):
230			offs = i * typ.size ()
231			assert offs == i or offs % 4 == 0
232			e = compile_array_acc (i, arr, replaces)
233			m = mk_memupd (m, mk_plus (p, mk_word32 (offs)), e)
234		return compile_accs (replaces, m)
235	elif expr.is_op ('Equals') \
236			and expr.vals[0].typ.kind in ['Struct', 'Array']:
237		[x, y] = expr.vals
238		assert x.typ == y.typ
239		xs = compile_val_fields (x, replaces)
240		ys = compile_val_fields (y, replaces)
241		eq = foldr1 (mk_and, map (mk_eq, xs, ys))
242		return compile_accs (replaces, eq)
243	elif expr.is_op ('PAlignValid'):
244		[typ, p] = expr.vals
245		p = compile_accs (replaces, p)
246		assert typ.kind == 'Type'
247		return logic.mk_align_valid_ineq (('Type', typ.val), p)
248	elif expr.kind == 'Op':
249		vals = [compile_accs (replaces, v) for v in expr.vals]
250		return syntax.adjust_op_vals (expr, vals)
251	elif expr.kind == 'Symbol':
252		return mk_word32 (symbols[expr.name][0])
253	else:
254		if expr.kind not in {'Var':True, 'ConstGlobal':True,
255				'Num':True, 'Invent':True, 'Type':True}:
256			print expr
257			assert not 'field acc compiled'
258		return expr
259
260def expand_arg_fields (replaces, args):
261	xs = []
262	for arg in args:
263		if arg.typ.kind == 'Struct':
264			ys = [compile_field_acc (nm, arg, replaces)
265				for (nm, _) in structs[arg.typ.name].field_list]
266			xs.extend (expand_arg_fields (replaces, ys))
267		elif arg.typ.kind == 'Array':
268			ys = [compile_array_acc (i, arg, replaces)
269				for i in range (arg.typ.num)]
270			xs.extend (expand_arg_fields (replaces, ys))
271		else:
272			xs.append (compile_accs (replaces, arg))
273	return xs
274
275def expand_lval_list (replaces, lvals):
276	xs = []
277	for (nm, typ) in lvals:
278		if nm in replaces:
279			xs.extend (expand_lval_list (replaces, [(v_nm, typ)
280				for (f_nm, v_nm, typ) in replaces[nm]]))
281		else:
282			assert typ.kind not in ['Struct', 'Array']
283			xs.append ((nm, typ))
284	return xs
285
286def mk_acc (idx, expr, replaces):
287	if logic.is_int (idx):
288		assert expr.typ.kind == 'Array'
289		return compile_array_acc (idx, expr, replaces)
290	else:
291		assert expr.typ.kind == 'Struct'
292		return compile_field_acc (idx, expr, replaces)
293
294def compile_upds (replaces, upds):
295	lvs = expand_lval_list (replaces, [lv for (lv, v) in upds])
296	vs = expand_arg_fields (replaces, [v for (lv, v) in upds])
297
298	assert [typ for (nm, typ) in lvs] == map (get_expr_typ, vs), (lvs, vs)
299
300	return [(lv, v) for (lv, v) in zip (lvs, vs)
301		if not v.is_var (lv)]
302
303def compile_struct_use (function):
304	trace ('Compiling in %s.' % function.name)
305	vs = get_vars (function)
306	max_node = max (function.nodes.keys () + [2])
307
308	visit_vs = vs.keys ()
309	replaces = {}
310	while visit_vs:
311		v = visit_vs.pop ()
312		typ = vs[v]
313		if typ.kind == 'Struct':
314			fields = structs[typ.name].field_list
315		elif typ.kind == 'Array':
316			fields = [(i, typ.el_typ) for i in range (typ.num)]
317		else:
318			continue
319		new_vs = [(nm, fresh_name ('%s.%s' % (v, nm), vs, f_typ), f_typ)
320			for (nm, f_typ) in fields]
321		replaces[v] = new_vs
322		visit_vs.extend ([v_nm for (_, v_nm, _) in new_vs])
323
324	for n in function.nodes:
325		node = function.nodes[n]
326		if node.kind == 'Basic':
327			node.upds = compile_upds (replaces, node.upds)
328		elif node.kind == 'Basic':
329			assert not node.lval[1].kind in ['Struct', 'Array']
330			node.val = compile_accs (replaces, node.val)
331		elif node.kind == 'Call':
332			node.args = expand_arg_fields (replaces, node.args)
333			node.rets = expand_lval_list (replaces, node.rets)
334		elif node.kind == 'Cond':
335			node.cond = compile_accs (replaces, node.cond)
336		else:
337			assert not 'node kind understood'
338
339	function.inputs = expand_lval_list (replaces, function.inputs)
340	function.outputs = expand_lval_list (replaces, function.outputs)
341	return len (replaces) == 0
342
343def check_compile (func):
344	for node in func.nodes.itervalues ():
345		vs = {}
346		get_node_vars (node, vs)
347		for (v_nm, typ) in vs.iteritems ():
348			if typ.kind == 'Struct':
349				print 'Failed to compile struct %s in %s' % (v_nm, func)
350				print node
351				assert not 'compiled'
352			if typ.kind == 'Array':
353				print 'Failed to compile array %s in %s' % (v_nm, func)
354				print node
355				assert not 'compiled'
356
357def subst_expr (expr):
358	if expr.kind == 'Symbol':
359		if expr.name in symbols:
360			return mk_word32 (symbols[expr.name][0])
361		else:
362			return None
363	elif expr.is_op ('PAlignValid'):
364		[typ, p] = expr.vals
365		assert typ.kind == 'Type'
366		return logic.mk_align_valid_ineq (('Type', typ.val), p)
367	elif expr.kind in ['Op', 'Var', 'Num', 'Type']:
368		return None
369	else:
370		assert not 'expression simple-substitutable', expr
371
372def substitute_simple (func):
373	from syntax import Node
374	for (n, node) in func.nodes.items ():
375		func.nodes[n] = node.subst_exprs (subst_expr,
376			ss = set (['Symbol', 'PAlignValid']))
377
378def nodes_symbols (nodes):
379	symbols_needed = set()
380	def visitor (expr):
381		if expr.kind == 'Symbol':
382			symbols_needed.add(expr.name)
383	for node in nodes:
384		node.visit (lambda l: (), visitor)
385	return symbols_needed
386
387def missing_symbols (functions):
388	symbols_needed = nodes_symbols ([node
389		for func in functions.itervalues ()
390		for node in func.nodes.itervalues ()])
391	trouble = symbols_needed - set (symbols)
392	if trouble:
393		print ('Symbols missing for substitution: %r' % trouble)
394	return trouble
395
396def compile_funcs (functions):
397	missing_symbols (functions)
398	for (f, func) in functions.iteritems ():
399		substitute_simple (func)
400		check_compile (func)
401
402def combine_duplicate_nodes (nodes):
403	orig_size = len (nodes)
404	node_renames = {}
405	progress = True
406	while progress:
407		progress = False
408		node_names = {}
409		for (n, node) in nodes.items ():
410			if node in node_names:
411				node_renames[n] = node_names[node]
412				del nodes[n]
413				progress = True
414			else:
415				node_names[node] = n
416
417		if not progress:
418			break
419
420		for (n, node) in nodes.items ():
421			nodes[n] = rename_node_conts (node, node_renames)
422
423	if len (nodes) < orig_size:
424		print 'Trimmed duplicates %d -> %d' % (orig_size, len (nodes))
425	return node_renames
426
427def rename_node_conts (node, renames):
428	new_conts = [renames.get (c, c) for c in node.get_conts ()]
429	return Node (node.kind, new_conts, node.get_args ())
430
431def recommended_rename (s):
432	bits = s.split ('.')
433	if len (bits) != 2:
434		return s
435	if all ([x in '0123456789' for x in bits[1]]):
436		return bits[0]
437	else:
438		return s
439
440def rename_vars (function):
441	preds = logic.compute_preds (function.nodes)
442	var_deps = logic.compute_var_deps (function.nodes,
443		lambda x: function.outputs, preds)
444
445	vs = set ()
446	dont_rename_vs = set ()
447	for n in var_deps:
448		rev_renames = {}
449		for (v, t) in var_deps[n]:
450			v2 = recommended_rename (v)
451			rev_renames.setdefault (v2, [])
452			rev_renames[v2].append ((v, t))
453			vs.add ((v, t))
454		for (v2, vlist) in rev_renames.iteritems ():
455			if len (vlist) > 1:
456				dont_rename_vs.update (vlist)
457
458	renames = dict ([(v, recommended_rename (v)) for (v, t) in vs
459		if (v, t) not in dont_rename_vs])
460
461	f = function
462	f.inputs = [(renames.get (v, v), t) for (v, t) in f.inputs]
463	f.outputs = [(renames.get (v, v), t) for (v, t) in f.outputs]
464	for n in f.nodes:
465		f.nodes[n] = syntax.copy_rename (f.nodes[n], (renames, {}))
466
467def rename_and_combine_function_duplicates (functions):
468	for (f, fun) in functions.iteritems ():
469		rename_vars (fun)
470		renames = combine_duplicate_nodes (fun.nodes)
471		fun.entry = renames.get (fun.entry, fun.entry)
472
473
474