1#
2# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3#
4# SPDX-License-Identifier: BSD-2-Clause
5#
6
7from syntax import (Expr, mk_var, Node, true_term, false_term,
8  fresh_name, word32T, word8T, mk_eq, mk_word32, builtinTs)
9import syntax
10
11from target_objects import functions, pairings, trace, printout
12import sys
13import logic
14from logic import azip
15
16class Abort(Exception):
17	pass
18
19last_problem = [None]
20
21class Problem:
22	def __init__ (self, pairing, name = None):
23		if name == None:
24			name = pairing.name
25		self.name = 'Problem (%s)' % name
26		self.pairing = pairing
27
28		self.nodes = {}
29		self.vs = {}
30		self.next_node_name = 1
31		self.preds = {}
32		self.loop_data = {}
33		self.node_tags = {}
34		self.node_tag_revs = {}
35		self.inline_scripts = {}
36		self.entries = []
37		self.outputs = {}
38		self.tarjan_order = []
39		self.loop_var_analysis_cache = {}
40
41		self.known_eqs = {}
42		self.cached_analysis = {}
43		self.hook_tag_hints = {}
44
45		last_problem[0] = self
46
47	def fail_msg (self):
48		return 'FAILED %s (size %05d)' % (self.name, len(self.nodes))
49
50	def alloc_node (self, tag, detail, loop_id = None, hint = None):
51		name = self.next_node_name
52		self.next_node_name = name + 1
53
54		self.node_tags[name] = (tag, detail)
55		self.node_tag_revs.setdefault ((tag, detail), [])
56		self.node_tag_revs[(tag, detail)].append (name)
57
58		if loop_id != None:
59			self.loop_data[name] = ('Mem', loop_id)
60
61		return name
62
63	def fresh_var (self, name, typ):
64		name = fresh_name (name, self.vs, typ)
65		return mk_var (name, typ)
66
67	def clone_function (self, fun, tag):
68		self.nodes = {}
69		self.vs = syntax.get_vars (fun)
70		for n in fun.reachable_nodes ():
71			self.nodes[n] = fun.nodes[n]
72			detail = (fun.name, n)
73			self.node_tags[n] = (tag, detail)
74			self.node_tag_revs.setdefault ((tag, detail), [])
75			self.node_tag_revs[(tag, detail)].append (n)
76		self.outputs[tag] = fun.outputs
77		self.entries = [(fun.entry, tag, fun.name, fun.inputs)]
78		self.next_node_name = max (self.nodes.keys () + [2]) + 1
79		self.inline_scripts[tag] = []
80
81	def add_function (self, fun, tag, node_renames, loop_id = None):
82		if not fun.entry:
83			printout ('Aborting %s: underspecified %s' % (
84				self.name, fun.name))
85			raise Abort ()
86		node_renames.setdefault('Ret', 'Ret')
87		node_renames.setdefault('Err', 'Err')
88		new_node_renames = {}
89		vs = syntax.get_vars (fun)
90		vs = dict ([(v, fresh_name (v, self.vs, vs[v])) for v in vs])
91		ns = fun.reachable_nodes ()
92		check_no_symbols ([fun.nodes[n] for n in ns])
93		for n in ns:
94			assert n not in node_renames
95			node_renames[n] = self.alloc_node (tag, (fun.name, n),
96				loop_id = loop_id, hint = n)
97			new_node_renames[n] = node_renames[n]
98		for n in ns:
99			self.nodes[node_renames[n]] = syntax.copy_rename (
100				fun.nodes[n], (vs, node_renames))
101
102		return (new_node_renames, vs)
103
104	def add_entry_function (self, fun, tag):
105		(ns, vs) = self.add_function (fun, tag, {})
106
107		entry = ns[fun.entry]
108		args = [(vs[v], typ) for (v, typ) in fun.inputs]
109		rets = [(vs[v], typ) for (v, typ) in fun.outputs]
110		self.entries.append((entry, tag, fun.name, args))
111		self.outputs[tag] = rets
112
113		self.inline_scripts[tag] = []
114
115		return (args, rets, entry)
116
117	def get_entry_details (self, tag):
118		[(e, t, fname, args)] = [(e, t, fname, args)
119			for (e, t, fname, args) in self.entries if t == tag]
120		return (e, fname, args)
121
122	def get_entry (self, tag):
123		(e, fname, args) = self.get_entry_details (tag)
124		return e
125
126	def tags (self):
127		return self.outputs.keys ()
128
129	def entry_exit_renames (self, tags = None):
130		"""computes the rename set of a function's formal parameters
131		to the actual input/output variable names at the various entry
132		and exit points"""
133		mk = lambda xs, ys: dict ([(x[0], y[0]) for (x, y) in
134			azip (xs, ys)])
135		renames = {}
136		if tags == None:
137			tags = self.tags ()
138		for tag in tags:
139			(_, fname, args) = self.get_entry_details (tag)
140			fun = functions[fname]
141			out = self.outputs[tag]
142			renames[tag + '_IN'] = mk (fun.inputs, args)
143			renames[tag + '_OUT'] = mk (fun.outputs, out)
144		return renames
145
146	def redirect_conts (self, reds):
147		for node in self.nodes.itervalues():
148			if node.kind == 'Cond':
149				node.left = reds.get(node.left, node.left)
150				node.right = reds.get(node.right, node.right)
151			else:
152				node.cont = reds.get(node.cont, node.cont)
153
154	def do_analysis (self):
155		self.cached_analysis.clear ()
156		self.compute_preds ()
157		self.do_loop_analysis ()
158
159	def mk_node_graph (self, node_subset = None):
160		if node_subset == None:
161			node_subset = self.nodes
162		return dict ([(n, [c for c in self.nodes[n].get_conts ()
163				if c in node_subset])
164			for n in node_subset])
165
166	def do_loop_analysis (self):
167		entries = [e for (e, tag, nm, args) in self.entries]
168		self.loop_data = {}
169
170		graph = self.mk_node_graph ()
171		comps = logic.tarjan (graph, entries)
172		self.tarjan_order = []
173
174		for (head, tail) in comps:
175			self.tarjan_order.append (head)
176			self.tarjan_order.extend (tail)
177			if not tail and head not in graph[head]:
178				continue
179			trace ('Loop (%d, %s)' % (head, tail))
180
181			loop_set = set (tail)
182			loop_set.add (head)
183
184			r = self.force_single_loop_return (head, loop_set)
185			if r != None:
186				tail.append (r)
187				loop_set.add (r)
188				self.tarjan_order.append (r)
189				self.compute_preds ()
190
191			self.loop_data[head] = ('Head', loop_set)
192			for t in tail:
193				self.loop_data[t] = ('Mem', head)
194
195		# put this in first-to-last order.
196		self.tarjan_order.reverse ()
197
198	def check_no_inner_loops (self):
199		for loop in self.loop_heads ():
200			check_no_inner_loop (self, loop)
201
202	def force_single_loop_return (self, head, loop_set):
203		rets = [n for n in self.preds[head] if n in loop_set]
204		if (len (rets) == 1 and rets[0] != head and
205				self.nodes[rets[0]].is_noop ()):
206			return None
207		r = self.alloc_node (self.node_tags[head][0],
208			'LoopReturn', loop_id = head)
209		self.nodes[r] = Node ('Basic', head, [])
210		for r2 in rets:
211			self.nodes[r2] = syntax.copy_rename (self.nodes[r2],
212				({}, {head: r}))
213		return r
214
215	def splittable_points (self, n):
216		"""splittable points are points which when removed, the loop
217		'splits' and ceases to be a loop.
218
219		equivalently, the set of splittable points is the intersection
220		of all sub-loops of the loop."""
221		head = self.loop_id (n)
222		assert head != None
223		k = ('Splittables', head)
224		if k in self.cached_analysis:
225			return self.cached_analysis[k]
226
227		# check if the head point is a split (the inner loop
228		# check does exactly that)
229		if has_inner_loop (self, head):
230			head = logic.get_one_loop_splittable (self,
231				self.loop_body (head))
232			if head == None:
233				return set ()
234
235		splits = self.get_loop_splittables (head)
236		self.cached_analysis[k] = splits
237		return splits
238
239	def get_loop_splittables (self, head):
240		loop_set = self.loop_body (head)
241		splittable = dict ([(n, False) for n in loop_set])
242		arc = [head]
243		n = head
244		while True:
245			ns = [n2 for n2 in self.nodes[n].get_conts ()
246				if n2 in loop_set]
247			ns2 = [x for x in ns if x == head or x not in arc]
248			#n = ns[0]
249			n = ns2[0]
250			arc.append (n)
251			splittable[n] = True
252			if n == head:
253				break
254		last_descs = {}
255		for i in range (len (arc)):
256			last_descs[arc[i]] = i
257		def last_desc (n):
258			if n in last_descs:
259				return last_descs[n]
260			n2s = [n2 for n2 in self.nodes[n].get_conts()
261				if n2 in loop_set]
262			last_descs[n] = None
263			for n2 in n2s:
264			  x = last_desc(n2)
265			  if last_descs[n] == None or x >= last_descs[n]:
266			    last_descs[n] = x
267			return last_descs[n]
268		for i in range (len (arc)):
269			max_arc = max ([last_desc (n)
270				for n in self.nodes[arc[i]].get_conts ()
271				if n in loop_set])
272			for j in range (i + 1, max_arc):
273				splittable[arc[j]] = False
274		return set ([n for n in splittable if splittable[n]])
275
276	def loop_heads (self):
277		return [n for n in self.loop_data
278			if self.loop_data[n][0] == 'Head']
279
280	def loop_id (self, n):
281		if n not in self.loop_data:
282			return None
283		elif self.loop_data[n][0] == 'Head':
284			return n
285		else:
286			assert self.loop_data[n][0] == 'Mem'
287			return self.loop_data[n][1]
288
289	def loop_body (self, n):
290		head = self.loop_id (n)
291		return self.loop_data[head][1]
292
293	def compute_preds (self):
294		self.preds = logic.compute_preds (self.nodes)
295
296	def var_dep_outputs (self, n):
297		return self.outputs[self.node_tags[n][0]]
298
299	def compute_var_dependencies (self):
300		if 'var_dependencies' in self.cached_analysis:
301			return self.cached_analysis['var_dependencies']
302		var_deps = logic.compute_var_deps (self.nodes,
303			self.var_dep_outputs, self.preds)
304		var_deps2 = dict ([(n, dict ([(v, None)
305				for v in var_deps.get (n, [])]))
306			for n in self.nodes])
307		self.cached_analysis['var_dependencies'] = var_deps2
308		return var_deps2
309
310	def get_loop_var_analysis (self, var_deps, n):
311		head = self.loop_id (n)
312		assert head, n
313		assert n in self.splittable_points (n)
314		loop_sort = tuple (sorted (self.loop_body (head)))
315		node_data = [(self.nodes[n2], sorted (self.preds[n]),
316				sorted (var_deps[n2].keys ()))
317			for n2 in loop_sort]
318		k = (n, loop_sort)
319		data = (node_data, n)
320		if k in self.loop_var_analysis_cache:
321			for (data2, va) in self.loop_var_analysis_cache[k]:
322				if data2 == data:
323					return va
324		va = logic.compute_loop_var_analysis (self, var_deps, n)
325		group = self.loop_var_analysis_cache.setdefault (k, [])
326		group.append ((data, va))
327		del group[:-10]
328		return va
329
330	def save_graph (self, fname):
331		cols = mk_graph_cols (self.node_tags)
332		save_graph (self.nodes, fname, cols = cols,
333			node_tags = self.node_tags)
334
335	def save_graph_summ (self, fname):
336		node_ids = {}
337		def is_triv (n):
338			if n not in self.nodes:
339				return False
340			if len (self.preds[n]) != 1:
341				return False
342			node = self.nodes[n]
343			if node.kind == 'Basic':
344				return (True, node.cont)
345			elif node.kind == 'Cond' and node.right == 'Err':
346				return (True, node.left)
347			else:
348				return False
349		for n in self.nodes:
350			if n in node_ids:
351				continue
352			ns = []
353			while is_triv (n):
354				ns.append (n)
355				n = is_triv (n)[1]
356			for n2 in ns:
357				node_ids[n2] = n
358		nodes = {}
359		for n in self.nodes:
360			if is_triv (n):
361				continue
362			nodes[n] = syntax.copy_rename (self.nodes[n],
363				({}, node_ids))
364		cols = mk_graph_cols (self.node_tags)
365		save_graph (nodes, fname, cols = cols,
366			node_tags = self.node_tags)
367
368	def serialise (self):
369		ss = ['Problem']
370		for (n, tag, fname, inputs) in self.entries:
371			xs = ['Entry', '%d' % n, tag, fname,
372				'%d' % len (inputs)]
373			for (nm, typ) in inputs:
374				xs.append (nm)
375				typ.serialise (xs)
376			xs.append ('%d' % len (self.outputs[tag]))
377			for (nm, typ) in self.outputs[tag]:
378				xs.append (nm)
379				typ.serialise (xs)
380			ss.append (' '.join (xs))
381		for n in self.nodes:
382			xs = ['%d' % n]
383			self.nodes[n].serialise (xs)
384			ss.append (' '.join (xs))
385		ss.append ('EndProblem')
386		return ss
387
388	def save_serialise (self, fname):
389		ss = self.serialise ()
390		f = open (fname, 'w')
391		for s in ss:
392			f.write (s + '\n')
393		f.close ()
394
395	def pad_merge_points (self):
396		self.compute_preds ()
397
398		arcs = [(pred, n) for n in self.preds
399			if len (self.preds[n]) > 1
400			if n in self.nodes
401			for pred in self.preds[n]
402			if (self.nodes[pred].kind != 'Basic'
403				or self.nodes[pred].upds != [])]
404
405		for (pred, n) in arcs:
406			(tag, _) = self.node_tags[pred]
407			name = self.alloc_node (tag, 'MergePadding')
408			self.nodes[name] = Node ('Basic', n, [])
409			self.nodes[pred] = syntax.copy_rename (self.nodes[pred],
410				({}, {n: name}))
411
412	def function_call_addrs (self):
413		return [(n, self.nodes[n].fname)
414			for n in self.nodes if self.nodes[n].kind == 'Call']
415
416	def function_calls (self):
417		return set ([fn for (n, fn) in self.function_call_addrs ()])
418
419	def get_extensions (self):
420		if 'extensions' in self.cached_analysis:
421			return self.cached_analysis['extensions']
422		extensions = set ()
423		for node in self.nodes.itervalues ():
424			extensions.update (syntax.get_extensions (node))
425		self.cached_analysis['extensions'] = extensions
426		return extensions
427
428	def replay_inline_script (self, tag, script):
429		for (detail, idx, fname) in script:
430			n = self.node_tag_revs[(tag, detail)][idx]
431			assert self.nodes[n].kind == 'Call', self.nodes[n]
432			assert self.nodes[n].fname == fname, self.nodes[n]
433			inline_at_point (self, n, do_analysis = False)
434		if script:
435			self.do_analysis ()
436
437	def is_reachable_from (self, source, target):
438		'''discover if graph addr "target" is reachable
439			from starting node "source"'''
440		k = ('is_reachable_from', source)
441		if k in self.cached_analysis:
442			reachable = self.cached_analysis[k]
443			if target in reachable:
444				return reachable[target]
445
446		reachable = {}
447		visit = [source]
448		while visit:
449			n = visit.pop ()
450			if n not in self.nodes:
451				continue
452			for n2 in self.nodes[n].get_conts ():
453				if n2 not in reachable:
454					reachable[n2] = True
455					visit.append (n2)
456		for n in list (self.nodes) + ['Ret', 'Err']:
457			if n not in reachable:
458				reachable[n] = False
459		self.cached_analysis[k] = reachable
460		return reachable[target]
461
462	def is_reachable_without (self, cutpoint, target):
463		'''discover if graph addr "target" is reachable
464			without visiting node "cutpoint"
465			(an oddity: cutpoint itself is considered reachable)'''
466		k = ('is_reachable_without', cutpoint)
467		if k in self.cached_analysis:
468			reachable = self.cached_analysis[k]
469			if target in reachable:
470				return reachable[target]
471
472		reachable = dict ([(self.get_entry (t), True)
473			for t in self.tags ()])
474		for n in self.tarjan_order + ['Ret', 'Err']:
475			if n in reachable:
476				continue
477			reachable[n] = bool ([pred for pred in self.preds[n]
478				if pred != cutpoint
479				if reachable.get (pred) == True])
480		self.cached_analysis[k] = reachable
481		return reachable[target]
482
483def deserialise (name, lines):
484	assert lines[0] == 'Problem', lines[0]
485	assert lines[-1] == 'EndProblem', lines[-1]
486	i = 1
487	# not easy to reconstruct pairing
488	p = Problem (pairing = None, name = name)
489	while lines[i].startswith ('Entry'):
490		bits = lines[i].split ()
491		en = int (bits[1])
492		tag = bits[2]
493		fname = bits[3]
494		(n, inputs) = syntax.parse_list (syntax.parse_lval, bits, 4)
495		(n, outputs) = syntax.parse_list (syntax.parse_lval, bits, n)
496		assert n == len (bits), (n, bits)
497		p.entries.append ((en, tag, fname, inputs))
498		p.outputs[tag] = outputs
499		i += 1
500	for i in range (i, len (lines) - 1):
501		bits = lines[i].split ()
502		n = int (bits[0])
503		node = syntax.parse_node (bits, 1)
504		p.nodes[n] = node
505	return p
506
507# trivia
508
509def check_no_symbols (nodes):
510	import pseudo_compile
511	symbs = pseudo_compile.nodes_symbols (nodes)
512	if not symbs:
513		return
514	printout ('Aborting %s: undefined symbols %s' % (self.name, symbs))
515	raise Abort ()
516
517# printing of problem graphs
518
519def sanitise_str (s):
520	return s.replace ('"', '_').replace ("'", "_").replace (' ', '')
521
522def graph_name (nodes, node_tags, n, prev=None):
523	if type (n) == str:
524		return 't_%s_%d' % (n, prev)
525	if n not in nodes:
526		return 'unknown_%d' % n
527	if n not in node_tags:
528		ident = '%d' % n
529	else:
530		(tag, details) = node_tags[n]
531		if len (details) > 1 and logic.is_int (details[1]):
532			ident = '%d_%s_%s_0x%x' % (n, tag,
533				details[0], details[1])
534		elif type (details) != str:
535			details = '_'.join (map (str, details))
536			ident = '%d_%s_%s' % (n, tag, details)
537		else:
538			ident = '%d_%s_%s' % (n, tag, details)
539	ident = sanitise_str (ident)
540	node = nodes[n]
541	if node.kind == 'Call':
542		return 'fcall_%s' % ident
543	if node.kind == 'Cond':
544		return ident
545	if node.kind == 'Basic':
546		return 'ass_%s' % ident
547	assert not 'node kind understood'
548
549def graph_node_tooltip (nodes, n):
550	if n == 'Err':
551		return 'Error point'
552	if n == 'Ret':
553		return 'Return point'
554	node = nodes[n]
555	if node.kind == 'Call':
556		return "%s: call to '%s'" % (n, sanitise_str (node.fname))
557	if node.kind == 'Cond':
558		return '%s: conditional node' % n
559	if node.kind == 'Basic':
560		var_names = [sanitise_str (x[0][0]) for x in node.upds]
561		return '%s: assignment to [%s]' % (n, ', '.join (var_names))
562	assert not 'node kind understood'
563
564def graph_edges (nodes, n):
565	node = nodes[n]
566	if node.is_noop ():
567		return [(node.get_conts () [0], 'N')]
568	elif node.kind == 'Cond':
569		return [(node.left, 'T'), (node.right, 'F')]
570	else:
571		return [(node.cont, 'C')]
572
573def get_graph_font (n, col):
574	font = 'fontname = "Arial", fontsize = 20, penwidth=3'
575	if col:
576		font = font + ', color=%s, fontcolor=%s' % (col, col)
577	return font
578
579def get_graph_loops (nodes):
580	graph = dict ([(n, [c for c in nodes[n].get_conts ()
581		if type (c) != str]) for n in nodes])
582	graph['ENTRY'] = list (nodes)
583	comps = logic.tarjan (graph, ['ENTRY'])
584	comp_ids = {}
585	for (head, tail) in comps:
586		comp_ids[head] = head
587		for n in tail:
588			comp_ids[n] = head
589	loops = set ([(n, n2) for n in graph for n2 in graph[n]
590		if comp_ids[n] == comp_ids[n2]])
591	return loops
592
593def make_graph (nodes, cols, node_tags = {}, entries = []):
594	graph = []
595	graph.append ('digraph foo {')
596
597	loops = get_graph_loops (nodes)
598
599	for n in nodes:
600		n_nm = graph_name (nodes, node_tags, n)
601		f = get_graph_font (n, cols.get (n))
602		graph.append ('%s [%s\n label="%s"\n tooltip="%s"];' % (n,
603			f, n_nm, graph_node_tooltip (nodes, n)))
604		for (c, l) in graph_edges (nodes, n):
605			if c in ['Ret', 'Err']:
606				c_nm = '%s_%s' % (c, n)
607				if c == 'Ret':
608					f2 = f + ', shape=doubleoctagon'
609				else:
610					f2 = f + ', shape=Mdiamond'
611				graph.append ('%s [label="%s", %s];'
612					% (c_nm, c, f2))
613			else:
614				c_nm = c
615			ft = f
616			if (n, c) in loops:
617				ft = f + ', penwidth=6'
618			graph.append ('%s -> %s [label=%s, %s];' % (
619				n, c_nm, l, ft))
620
621	for (i, (n, tag, inps)) in enumerate (entries):
622		f = get_graph_font (n, cols.get (n))
623		nm1 = tag + ' ENTRY_POINT'
624		nm2 = 'entry_point_%d' % i
625		graph.extend (['%s -> %s [%s];' % (nm2, n, f),
626			'%s [label = "%s", shape=none, %s];' % (nm2, nm1, f)])
627
628	graph.append ('}')
629	return graph
630
631def print_graph (nodes, cols = {}, entries = []):
632	for line in make_graph (nodes, cols, entries):
633		print line
634
635def save_graph (nodes, fname, cols = {}, entries = [], node_tags = {}):
636	f = open (fname, 'w')
637	for line in make_graph (nodes, cols = cols, node_tags = node_tags,
638			entries = entries):
639		f.write (line + '\n')
640	f.close ()
641
642def mk_graph_cols (node_tags):
643	known_cols = {'C': "forestgreen", 'ASM_adj': "darkblue",
644		'ASM': "darkorange"}
645	cols = {}
646	for n in node_tags:
647		if node_tags[n][0] in known_cols:
648			cols[n] = known_cols[node_tags[n][0]]
649	return cols
650
651def make_graph_with_eqs (p, invis = False):
652	if invis:
653		invis_s = ', style=invis'
654	else:
655		invis_s = ''
656	cols = mk_graph_cols (p.node_tags)
657	graph = make_graph (p.nodes, cols = cols)
658	graph.pop ()
659	for k in p.known_eqs:
660		if k == 'Hyps':
661			continue
662		(n_vc_x, tag_x) = k
663		nm1 = graph_name (p.nodes, p.node_tags, n_vc_x[0])
664		for (x, n_vc_y, tag_y, y, hyps) in p.known_eqs[k]:
665			nm2 = graph_name (p.nodes, p.node_tags, n_vc_y[0])
666			graph.extend ([('%s -> %s [ dir = back, color = blue, '
667				'penwidth = 3, weight = 0 %s ]')
668					% (nm2, nm1, invis_s)])
669	graph.append ('}')
670	return graph
671
672def save_graph_with_eqs (p, fname = 'diagram.dot', invis = False):
673	graph = make_graph_with_eqs (p, invis = invis)
674	f = open (fname, 'w')
675	for s in graph:
676		f.write (s + '\n')
677	f.close ()
678
679def get_problem_vars (p):
680	inout = set.union (* ([set(xs) for xs in p.outputs.itervalues ()]
681		+ [set (args) for (_, _, _, args) in p.entries]))
682
683	vs = dict(inout)
684	for node in p.nodes.itervalues():
685		syntax.get_node_vars(node, vs)
686	return vs
687
688def is_trivial_fun (fun):
689	for node in fun.nodes.itervalues ():
690		if node.is_noop ():
691			continue
692		if node.kind == 'Call':
693			return False
694		elif node.kind == 'Basic':
695			for (lv, v) in node.upds:
696				if v.kind not in ['Var', 'Num']:
697					return False
698		elif node.kind == 'Cond':
699			if node.cond.kind != 'Var' and node.cond not in [
700					true_term, false_term]:
701				return False
702	return True
703
704last_alt_nodes = [0]
705
706def avail_val (vs, typ):
707	for (nm, typ2) in vs:
708		if typ2 == typ:
709			return mk_var (nm, typ2)
710	return logic.default_val (typ)
711
712def inline_at_point (p, n, do_analysis = True):
713	node = p.nodes[n]
714	if node.kind != 'Call':
715		return
716
717	f_nm = node.fname
718	fun = functions[f_nm]
719	(tag, detail) = p.node_tags[n]
720	idx = p.node_tag_revs[(tag, detail)].index (n)
721	p.inline_scripts[tag].append ((detail, idx, f_nm))
722
723	trace ('Inlining %s into %s' % (f_nm, p.name))
724	if n in p.loop_data:
725		trace ('  inlining into loop %d!' % p.loop_id (n))
726
727	ex = p.alloc_node (tag, (f_nm, 'RetToCaller'))
728
729	(ns, vs) = p.add_function (fun, tag, {'Ret': ex})
730	en = ns[fun.entry]
731
732	inp_lvs = [(vs[v], typ) for (v, typ) in fun.inputs]
733	p.nodes[n] = Node ('Basic', en, azip (inp_lvs, node.args))
734
735	out_vs = [mk_var (vs[v], typ) for (v, typ) in fun.outputs]
736	p.nodes[ex] = Node ('Basic', node.cont, azip (node.rets, out_vs))
737
738	p.cached_analysis.clear ()
739
740	if do_analysis:
741		p.do_analysis ()
742
743	trace ('Problem size now %d' % len(p.nodes))
744	sys.stdin.flush ()
745
746	return ns.values ()
747
748def loop_body_inner_loops (p, head, loop_body):
749	loop_set_all = set (loop_body)
750	loop_set = loop_set_all - set ([head])
751	graph = dict([(n, [c for c in p.nodes[n].get_conts ()
752			if c in loop_set])
753		for n in loop_set_all])
754
755	comps = logic.tarjan (graph, [head])
756	assert sum ([1 + len (t) for (_, t) in comps]) == len (loop_set_all)
757	return [comp for comp in comps if comp[1]]
758
759def loop_inner_loops (p, head):
760	k = ('inner_loop_set', head)
761	if k in p.cached_analysis:
762		return p.cached_analysis[k]
763	res = loop_body_inner_loops (p, head, p.loop_body (head))
764	p.cached_analysis[k] = res
765	return res
766
767def loop_heads_including_inner (p):
768	heads = p.loop_heads ()
769	check = [(head, p.loop_body (head)) for head in heads]
770	while check:
771		(head, body) = check.pop ()
772		comps = loop_body_inner_loops (p, head, body)
773		heads.extend ([head for (head, _) in comps])
774		check.extend ([(head, [head] + list (body))
775			for (head, body) in comps])
776	return heads
777
778def check_no_inner_loop (p, head):
779	subs = loop_inner_loops (p, head)
780	if subs:
781		printout ('Aborting %s, complex loop' % p.name)
782		trace ('  sub-loops %s of loop at %s' % (subs, head))
783		for (h, _) in subs:
784			trace ('    head %d tagged %s' % (h, p.node_tags[h]))
785		raise Abort ()
786
787def has_inner_loop (p, head):
788	return bool (loop_inner_loops (p, head))
789
790def fun_has_inner_loop (f):
791	p = f.as_problem (Problem)
792	p.do_analysis ()
793	return bool ([head for head in p.loop_heads ()
794		if has_inner_loop (p, head)])
795
796def loop_var_analysis (p, head, tail):
797	# getting the set of variables that go round the loop
798	nodes = set (tail)
799	nodes.add (head)
800	used_vs = set ([])
801	created_vs_at = {}
802	visit = []
803
804	def process_node (n, created):
805		if p.nodes[n].is_noop ():
806			lvals = set ([])
807		else:
808			vs = syntax.get_node_rvals (p.nodes[n])
809			for rv in vs.iteritems ():
810				if rv not in created:
811					used_vs.add (rv)
812			lvals = set (p.nodes[n].get_lvals ())
813
814		created = set.union (created, lvals)
815		created_vs_at[n] = created
816
817		visit.extend (p.nodes[n].get_conts ())
818
819	process_node (head, set ([]))
820
821	while visit:
822		n = visit.pop ()
823		if (n not in nodes) or (n in created_vs_at):
824			continue
825		if not all ([pr in created_vs_at for pr in p.preds[n]]):
826			continue
827
828		pre_created = [created_vs_at[pr] for pr in p.preds[n]]
829		process_node (n, set.union (* pre_created))
830
831	final_pre_created = [created_vs_at[pr] for pr in p.preds[head]
832		if pr in nodes]
833	created = set.union (* final_pre_created)
834
835	loop_vs = set.intersection (created, used_vs)
836	trace ('Loop vars at head: %s' % loop_vs)
837
838	return loop_vs
839
840
841