1#
2# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3#
4# SPDX-License-Identifier: BSD-2-Clause
5#
6
7from solver import Solver, merge_envs_pcs, smt_expr, mk_smt_expr, to_smt_expr
8from syntax import (true_term, false_term, boolT, mk_and, mk_not, mk_implies,
9	builtinTs, word32T, word8T, foldr1, mk_eq, mk_plus, mk_word32, mk_var)
10import syntax
11import logic
12import solver
13from logic import azip
14
15from target_objects import functions, pairings, sections, trace, printout
16import target_objects
17import problem
18
19class VisitCount:
20	"""Used to represent a target number of visits to a split point.
21	Options include a number (0, 1, 2), a symbolic offset (i + 1, i + 2),
22	or a list of options."""
23	def __init__ (self, kind, value):
24		self.kind = kind
25		self.is_visit_count = True
26		if kind == 'Number':
27			self.n = value
28		elif kind == 'Offset':
29			self.n = value
30		elif kind == 'Options':
31			self.opts = tuple (value)
32			for opt in self.opts:
33				assert opt.kind in ['Number', 'Offset']
34		else:
35			assert not 'VisitCount type understood'
36
37	def __hash__ (self):
38		if self.kind == 'Options':
39			return hash (self.opts)
40		else:
41			return hash (self.kind) + self.n
42
43	def __eq__ (self, other):
44		if not other:
45			return False
46		if self.kind == 'Options':
47			return (other.kind == 'Options'
48				and self.opts == other.opts)
49		else:
50			return self.kind == other.kind and self.n == other.n
51
52	def __neq__ (self, other):
53		if not other:
54			return True
55		return not (self == other)
56
57	def __str__ (self):
58		if self.kind == 'Number':
59			return str (self.n)
60		elif self.kind == 'Offset':
61			return 'i+%s' % self.n
62		elif self.kind == 'Options':
63			return '_'.join (map (str, self.opts))
64
65	def __repr__ (self):
66		(ns, os) = self.get_opts ()
67		return 'vc_options (%r, %r)' % (ns, os)
68
69	def get_opts (self):
70		if self.kind == 'Options':
71			opts = self.opts
72		else:
73			opts = [self]
74		ns = [vc.n for vc in opts if vc.kind == 'Number']
75		os = [vc.n for vc in opts if vc.kind == 'Offset']
76		return (ns, os)
77
78	def serialise (self, ss):
79		ss.append ('VC')
80		(ns, os) = self.get_opts ()
81		ss.append ('%d' % len (ns))
82		ss.extend (['%d' % n for n in ns])
83		ss.append ('%d' % len (os))
84		ss.extend (['%d' % n for n in os])
85
86	def incr (self, incr):
87		if self.kind in ['Number', 'Offset']:
88			n = self.n + incr
89			if n < 0:
90				return None
91			return VisitCount (self.kind, n)
92		elif self.kind == 'Options':
93			opts = [vc.incr (incr) for vc in self.opts]
94			opts = [opt for opt in opts if opt]
95			if opts == []:
96				return None
97			return mk_vc_opts (opts)
98		else:
99			assert not 'VisitCount type understood'
100
101	def has_zero (self):
102		if self.kind == 'Options':
103			return bool ([vc for vc in self.opts
104				if vc.has_zero ()])
105		else:
106			return self.kind == 'Number' and self.n == 0
107
108def mk_vc_opts (opts):
109	if len (opts) == 1:
110		return opts[0]
111	else:
112		return VisitCount ('Options', opts)
113
114def vc_options (nums, offsets):
115	return mk_vc_opts (map (vc_num, nums) + map (vc_offs, offsets))
116
117def vc_num (n):
118	return VisitCount ('Number', n)
119
120def vc_upto (n):
121	return mk_vc_opts (map (vc_num, range (n)))
122
123def vc_offs (n):
124	return VisitCount ('Offset', n)
125
126def vc_offset_upto (n):
127	return mk_vc_opts (map (vc_offs, range (n)))
128
129def vc_double_range (n, m):
130	return mk_vc_opts (map (vc_num, range (n)) + map (vc_offs, range (m)))
131
132class InlineEvent(Exception):
133	pass
134
135class Hyp:
136	"""Used to represent a proposition about path conditions or data at
137	various points in execution."""
138
139	def __init__ (self, kind, arg1, arg2, induct = None):
140		self.kind = kind
141		if kind == 'PCImp':
142			self.pcs = [arg1, arg2]
143		elif kind == 'Eq':
144			self.vals = [arg1, arg2]
145			self.induct = induct
146		elif kind == 'EqIfAt':
147			self.vals = [arg1, arg2]
148			self.induct = induct
149		else:
150			assert not 'hyp kind understood'
151
152	def __repr__ (self):
153		if self.kind == 'PCImp':
154			vals = map (repr, self.pcs)
155		elif self.kind in ['Eq', 'EqIfAt']:
156			vals = map (repr, self.vals)
157			if self.induct:
158				vals += [repr (self.induct)]
159		else:
160			assert not 'hyp kind understood'
161		return 'Hyp (%r, %s)' % (self.kind, ', '.join (vals))
162
163	def hyp_tuple (self):
164		if self.kind == 'PCImp':
165			return ('PCImp', self.pcs[0], self.pcs[1])
166		elif self.kind in ['Eq', 'EqIfAt']:
167			return (self.kind, self.vals[0],
168				self.vals[1], self.induct)
169		else:
170			assert not 'hyp kind understood'
171
172	def __hash__ (self):
173		return hash (self.hyp_tuple ())
174
175	def __ne__ (self, other):
176		return not other or not (self == other)
177
178	def __cmp__ (self, other):
179		return cmp (self.hyp_tuple (), other.hyp_tuple ())
180
181	def visits (self):
182		if self.kind == 'PCImp':
183			return [vis for vis in self.pcs
184				if vis[0] != 'Bool']
185		elif self.kind in ['Eq', 'EqIfAt']:
186			return [vis for (_, vis) in self.vals]
187		else:
188			assert not 'hyp kind understood'
189
190	def get_vals (self):
191		if self.kind == 'PCImp':
192			return []
193		else:
194			return [val for (val, _) in self.vals]
195
196	def serialise_visit (self, (n, restrs), ss):
197		ss.append ('%s' % n)
198		ss.append ('%d' % len (restrs))
199		for (n2, vc) in restrs:
200			ss.append ('%d' % n2)
201			vc.serialise (ss)
202
203	def serialise_pc (self, pc, ss):
204		if pc[0] == 'Bool' and pc[1] == true_term:
205			ss.append ('True')
206		elif pc[0] == 'Bool' and pc[1] == false_term:
207			ss.append ('False')
208		else:
209			ss.append ('PC')
210			serialise_visit (pc[0], ss)
211			ss.append (pc[1])
212
213	def serialise_hyp (self, ss):
214		if self.kind == 'PCImp':
215			(visit1, visit2) = self.pcs
216			ss.append ('PCImp')
217			self.serialise_pc (visit1, ss)
218			self.serialise_pc (visit2, ss)
219		elif self.kind in ['Eq', 'EqIfAt']:
220			assert len (self.vals) == 2
221			ss.extend (self.kind)
222			for (exp, visit) in self.vals:
223				exp.serialise (ss)
224				self.serialise_visit (visit, ss)
225			if induct:
226				ss.append ('%d' % induct[0])
227				ss.append ('%d' % induct[1])
228			else:
229				ss.extend (['None', 'None'])
230		else:
231			assert not 'hyp kind understood'
232
233	def interpret (self, rep):
234		if self.kind == 'PCImp':
235			((visit1, tag1), (visit2, tag2)) = self.pcs
236			if visit1 == 'Bool':
237				pc1 = tag1
238			else:
239				pc1 = rep.get_pc (visit1, tag = tag1)
240			if visit2 == 'Bool':
241				pc2 = tag2
242			else:
243				pc2 = rep.get_pc (visit2, tag = tag2)
244			return mk_implies (pc1, pc2)
245		elif self.kind in ['Eq', 'EqIfAt']:
246			[(x, xvis), (y, yvis)] = self.vals
247			if self.induct:
248				v = rep.get_induct_var (self.induct)
249				x = subst_induct (x, v)
250				y = subst_induct (y, v)
251			x_pc_env = rep.get_node_pc_env (xvis[0], tag = xvis[1])
252			y_pc_env = rep.get_node_pc_env (yvis[0], tag = yvis[1])
253			if x_pc_env == None or y_pc_env == None:
254				if self.kind == 'EqIfAt':
255					return syntax.true_term
256				else:
257					return syntax.false_term
258			((_, xenv), (_, yenv)) = (x_pc_env, y_pc_env)
259			eq = inst_eq_with_envs ((x, xenv), (y, yenv), rep.solv)
260			if self.kind == 'EqIfAt':
261				x_pc = rep.get_pc (xvis[0], tag = xvis[1])
262				y_pc = rep.get_pc (yvis[0], tag = yvis[1])
263				return syntax.mk_n_implies ([x_pc, y_pc], eq)
264			else:
265				return eq
266		else:
267			assert not 'hypothesis type understood'
268
269def check_vis_is_vis (((n, vc), tag)):
270	assert vc[:0] == (), vc
271
272def eq_hyp (lhs, rhs, induct = None, use_if_at = False):
273	check_vis_is_vis (lhs[1])
274	check_vis_is_vis (rhs[1])
275	kind = 'Eq'
276	if use_if_at:
277		kind = 'EqIfAt'
278	return Hyp (kind, lhs, rhs, induct = induct)
279
280def true_if_at_hyp (expr, vis, induct = None):
281	check_vis_is_vis (vis)
282	return Hyp ('EqIfAt', (expr, vis), (true_term, vis),
283		induct = induct)
284
285def pc_true_hyp (vis):
286	check_vis_is_vis (vis)
287	return Hyp ('PCImp', ('Bool', true_term), vis)
288
289def pc_false_hyp (vis):
290	check_vis_is_vis (vis)
291	return Hyp ('PCImp', vis, ('Bool', false_term))
292
293def pc_triv_hyp (vis):
294	check_vis_is_vis (vis)
295	return Hyp ('PCImp', vis, vis)
296
297class GraphSlice:
298	"""Used to represent a slice of potential execution in a graph where
299	looping is limited to certain specific examples. For instance, we
300	might say that execution through node n will be represented only
301	by visits 0, 1, 2, 3, i, and i + 1 (for a symbolic value i). The
302	variable state at visits 4 and i + 2 will be calculated but no
303	further execution will be done."""
304
305	def __init__ (self, p, solv, inliner = None, fast = False):
306		self.p = p
307		self.solv = solv
308		self.inp_envs = {}
309		self.mem_calls = {}
310		self.add_input_envs ()
311
312		self.node_pc_envs = {}
313		self.node_pc_env_order = []
314		self.arc_pc_envs = {}
315		self.inliner = inliner
316		self.funcs = {}
317		self.pc_env_requests = set ()
318		self.fast = fast
319		self.induct_var_env = {}
320		self.contractions = {}
321
322		self.local_defs_unsat = False
323		self.use_known_eqs = True
324
325		self.avail_hyps = set ()
326		self.used_hyps = set ()
327
328	def add_input_envs (self):
329		for (entry, _, _, args) in self.p.entries:
330			self.inp_envs[entry] = mk_inp_env (entry, args, self)
331
332	def get_reachable (self, split, n):
333		return self.p.is_reachable_from (split, n)
334
335	class TooGeneral (Exception):
336		def __init__ (self, split):
337			self.split = split
338
339	def get_tag_vcount (self, (n, vcount), tag):
340		if tag == None:
341			tag = self.p.node_tags[n][0]
342		vcount_r = [(split, count, self.get_reachable (split, n))
343			for (split, count) in vcount
344			if self.p.node_tags[split][0] == tag]
345		for (split, count, r) in vcount_r:
346			if not r and not count.has_zero ():
347				return (tag, None)
348			assert count.is_visit_count
349		vcount = [(s, c) for (s, c, r) in vcount_r if r]
350		vcount = tuple (sorted (vcount))
351
352		loop_id = self.p.loop_id (n)
353		if loop_id != None:
354			for (split, visits) in vcount:
355				if (self.p.loop_id (split) == loop_id
356						and visits.kind == 'Options'):
357					raise self.TooGeneral (split)
358
359		return (tag, vcount)
360
361	def get_node_pc_env (self, (n, vcount), tag = None, request = True):
362		tag, vcount = self.get_tag_vcount ((n, vcount), tag)
363		if vcount == None:
364			return None
365
366		if (tag, n, vcount) in self.node_pc_envs:
367			return self.node_pc_envs[(tag, n, vcount)]
368
369		if request:
370			self.pc_env_requests.add (((n, vcount), tag))
371
372		self.warm_pc_env_cache ((n, vcount), tag)
373
374		pc_env = self.get_node_pc_env_raw ((n, vcount), tag)
375		if pc_env:
376			pc_env = self.apply_known_eqs_pc_env ((n, vcount),
377				tag, pc_env)
378
379		assert not (tag, n, vcount) in self.node_pc_envs
380		self.node_pc_envs[(tag, n, vcount)] = pc_env
381		if pc_env:
382			self.node_pc_env_order.append ((tag, n, vcount))
383
384		return pc_env
385
386	def warm_pc_env_cache (self, n_vc, tag):
387		'this is to avoid recursion limits and spot bugs'
388		prev_chain = []
389		for i in range (5000):
390			prevs = self.prevs (n_vc)
391			try:
392				prevs = [p for p in prevs
393					if (tag, p[0], p[1])
394						not in self.node_pc_envs
395					if self.get_tag_vcount (p, None)
396						== (tag, n_vc[1])]
397			except self.TooGeneral:
398				break
399			if not prevs:
400				break
401			n_vc = prevs[0]
402			prev_chain.append(n_vc)
403		if not (len (prev_chain) < 5000):
404			printout ([n for (n, vc) in prev_chain])
405			assert len (prev_chain) < 5000, (prev_chain[:10],
406				prev_chain[-10:])
407
408		prev_chain.reverse ()
409		for n_vc in prev_chain:
410			self.get_node_pc_env (n_vc, tag, request = False)
411
412	def get_loop_pc_env (self, split, vcount):
413		vcount2 = dict (vcount)
414		vcount2[split] = vc_num (0)
415		vcount2 = tuple (sorted (vcount2.items ()))
416		prev_pc_env = self.get_node_pc_env ((split, vcount2))
417		if prev_pc_env == None:
418			return None
419		(_, prev_env) = prev_pc_env
420		mem_calls = self.scan_mem_calls (prev_env)
421		mem_calls = self.add_loop_mem_calls (split, mem_calls)
422		def av (nm, typ, mem_name = None):
423			nm2 = '%s_loop_at_%s' % (nm, split)
424			return self.add_var (nm2, typ,
425				mem_name = mem_name, mem_calls = mem_calls)
426		env = {}
427		consts = set ()
428		for (nm, typ) in prev_env:
429			check_const = self.fast or (typ in
430				[builtinTs['HTD'], builtinTs['Dom']])
431			if check_const and self.is_synt_const (nm, typ, split):
432				env[(nm, typ)] = prev_env[(nm, typ)]
433				consts.add ((nm, typ))
434			else:
435				env[(nm, typ)] = av (nm + '_after', typ,
436					('Loop', prev_env[(nm, typ)]))
437		for (nm, typ) in prev_env:
438			if (nm, typ) in consts:
439				continue
440			z = self.var_rep_request ((nm, typ), 'Loop',
441				(split, vcount), env)
442			if z:
443				env[(nm, typ)] = z
444
445		pc = mk_smt_expr (av ('pc_of', boolT), boolT)
446		if self.fast:
447			imp = syntax.mk_implies (pc, prev_pc_env[0])
448			self.solv.assert_fact (imp, prev_env,
449				unsat_tag = ('LoopPCImp', split))
450
451		return (pc, env)
452
453	def is_synt_const (self, nm, typ, split):
454		"""check if a variable at a split point is a syntactic constant
455		which is always unmodified by the loop.
456		we allow cases where a variable is renamed and renamed back
457		during the loop (this often happens because of inlining).
458		the check is done by depth-first-search backward through the
459		graph looking for a source of a variant value."""
460		loop = self.p.loop_id (split)
461		if problem.has_inner_loop (self.p, split):
462			return False
463		loop_set = set (self.p.loop_body (split))
464
465		orig_nm = nm
466		safe = set ([(orig_nm, split)])
467		first_step = True
468		visit = []
469		count = 0
470		while first_step or visit:
471			if first_step:
472				(nm, n) = (orig_nm, split)
473				first_step = False
474			else:
475				(nm, n) = visit.pop ()
476				if (nm, n) in safe:
477					continue
478				elif n == split:
479					return False
480			new_nm = nm
481			node = self.p.nodes[n]
482			if node.kind == 'Call':
483				if (nm, typ) not in node.rets:
484					pass
485				elif self.fast_const_ret (n, nm, typ):
486					pass
487				else:
488					return False
489			elif node.kind == 'Basic':
490				upds = [arg for (lv, arg) in node.upds
491					if lv == (nm, typ)]
492				if [v for v in upds if v.kind != 'Var']:
493					return False
494				if upds:
495					new_nm = upds[0].name
496			preds = [(new_nm, n2) for n2 in self.p.preds[n]
497				if n2 in loop_set]
498			unknowns = [p for p in preds if p not in safe]
499			if unknowns:
500				visit.extend ([(nm, n)] + unknowns)
501			else:
502				safe.add ((nm, n))
503			count += 1
504			if count % 100000 == 0:
505				trace ('is_synt_const: %d iterations' % count)
506				trace ('visit length %d' % len (visit))
507				trace ('visit tail %s' % visit[-20:])
508		return True
509
510	def fast_const_ret (self, n, nm, typ):
511		"""determine if we can heuristically consider this return
512		value to be the same as an input. this is known for some
513		function returns, e.g. memory.
514		this is important for heuristic "fast" analysis."""
515		if not self.fast:
516			return False
517		node = self.p.nodes[n]
518		assert node.kind == 'Call'
519		for hook in target_objects.hooks ('rep_unsafe_const_ret'):
520			if hook (node, nm, typ):
521				return True
522		return False
523
524	def get_node_pc_env_raw (self, (n, vcount), tag):
525		if n in self.inp_envs:
526			return (true_term, self.inp_envs[n])
527
528		for (split, count) in vcount:
529			if split == n and count == vc_offs (0):
530				return self.get_loop_pc_env (split, vcount)
531
532		pc_envs = [pc_env for n_prev in self.p.preds[n]
533			if self.p.node_tags[n_prev][0] == tag
534			for pc_env in self.get_arc_pc_envs (n_prev,
535				(n, vcount))]
536
537		pc_envs = [pc_env for pc_env in pc_envs if pc_env]
538		if pc_envs == []:
539			return None
540
541		if n == 'Err':
542			# we'll never care about variable values here
543			# and there are sometimes a LOT of arcs to Err
544			# so we save a lot of merge effort
545			pc_envs = [(to_smt_expr (pc, env, self.solv), {})
546				for (pc, env) in pc_envs]
547
548		(pc, env, large) = merge_envs_pcs (pc_envs, self.solv)
549
550		if pc.kind != 'SMTExpr':
551			name = self.path_cond_name ((n, vcount), tag)
552			name = self.solv.add_def (name, pc, env)
553			pc = mk_smt_expr (name, boolT)
554
555		for (nm, typ) in env:
556			if len (env[(nm, typ)]) > 80:
557				env[(nm, typ)] = self.contract (nm, (n, vcount),
558					env[(nm, typ)], typ)
559
560		return (pc, env)
561
562	def contract (self, name, n_vc, val, typ):
563		if val in self.contractions:
564			return self.contractions[val]
565
566		name = self.local_name_before (name, n_vc)
567		name = self.solv.add_def (name, mk_smt_expr (val, typ), {})
568
569		self.contractions[val] = name
570		return name
571
572	def get_arc_pc_envs (self, n, n_vc2):
573		try:
574			prevs = [n_vc for n_vc in self.prevs (n_vc2)
575				if n_vc[0] == n]
576			assert len (prevs) <= 1
577			return [self.get_arc_pc_env (n_vc, n_vc2)
578				for n_vc in prevs]
579		except self.TooGeneral, e:
580			# consider specialisations of the target
581			specs = self.specialise (n_vc2, e.split)
582			specs = [(n_vc2[0], spec) for spec in specs]
583			return [pc_env for spec in specs
584				for pc_env in self.get_arc_pc_envs (n, spec)]
585
586	def get_arc_pc_env (self, (n, vcount), n2):
587		tag, vcount = self.get_tag_vcount ((n, vcount), None)
588
589		if vcount == None:
590			return None
591
592		assert self.is_cont ((n, vcount), n2), ((n, vcount),
593			n2, self.p.nodes[n].get_conts ())
594
595		if (n, vcount) in self.arc_pc_envs:
596			return self.arc_pc_envs[(n, vcount)].get (n2[0])
597
598		if self.get_node_pc_env ((n, vcount), request = False) == None:
599			return None
600
601		arcs = self.emit_node ((n, vcount))
602		self.post_emit_node_hooks ((n, vcount))
603		arcs = dict ([(cont, (pc, env)) for (cont, pc, env) in arcs])
604
605		self.arc_pc_envs[(n, vcount)] = arcs
606		return arcs.get (n2[0])
607
608	def add_local_def (self, n, vname, name, val, env):
609		if self.local_defs_unsat:
610			smt_name = self.solv.add_var (name, val.typ)
611			eq = mk_eq (mk_smt_expr (smt_name, val.typ), val)
612			self.solv.assert_fact (eq, env, unsat_tag
613				= ('Def', n, vname))
614		else:
615			smt_name = self.solv.add_def (name, val, env)
616		return smt_name
617
618	def add_var (self, name, typ, mem_name = None, mem_calls = None):
619		r = self.solv.add_var_restr (name, typ, mem_name = mem_name)
620		if typ == syntax.builtinTs['Mem']:
621			r_x = solver.parse_s_expression (r)
622			self.mem_calls[r_x] = mem_calls
623		return r
624
625	def var_rep_request (self, (nm, typ), kind, n_vc, env):
626		assert type (n_vc[0]) != str
627		for hook in target_objects.hooks ('problem_var_rep'):
628			z = hook (self.p, (nm, typ), kind, n_vc[0])
629			if z == None:
630				continue
631			if z[0] == 'SplitMem':
632				assert typ == builtinTs['Mem']
633				(_, addr) = z
634				addr = smt_expr (addr, env, self.solv)
635				name = '%s_for_%s' % (nm,
636					self.node_count_name (n_vc))
637				return self.solv.add_split_mem_var (addr, name,
638					typ, mem_name = 'SplitMemNonsense')
639			else:
640				assert z == None
641
642	def emit_node (self, n):
643		(pc, env) = self.get_node_pc_env (n, request = False)
644		tag = self.p.node_tags[n[0]][0]
645		app_eqs = self.apply_known_eqs_tm (n, tag)
646		# node = logic.simplify_node_elementary (self.p.nodes[n[0]])
647		# whether to ignore unreachable Cond arcs seems to be a huge
648		# dilemma. if we ignore them, some reachable sites become
649		# unreachable and we can't interpret all hyps
650		# if we don't ignore them, the variable set disagrees with
651		# var_deps and so the abstracted loop pc/env may not be
652		# sufficient and we get EnvMiss again. I don't really know
653		# what to do about this corner case.
654		node = self.p.nodes[n[0]]
655		env = dict (env)
656
657		if node.kind == 'Call':
658			self.try_inline (n[0], pc, env)
659
660		if pc == false_term:
661			return [(c, false_term, {}) for c in node.get_conts()]
662		elif node.kind == 'Cond' and node.left == node.right:
663			return [(node.left, pc, env)]
664		elif node.kind == 'Cond' and node.cond == true_term:
665			return [(node.left, pc, env),
666				(node.right, false_term, env)]
667		elif node.kind == 'Basic':
668			upds = []
669			for (lv, v) in node.upds:
670				if v.kind == 'Var':
671					upds.append ((lv, env[(v.name, v.typ)]))
672				else:
673					name = self.local_name (lv[0], n)
674					v = app_eqs (v)
675					vname = self.add_local_def (n,
676						('Var', lv), name, v, env)
677					upds.append ((lv, vname))
678			for (lv, v) in upds:
679				env[lv] = v
680			return [(node.cont, pc, env)]
681		elif node.kind == 'Cond':
682			name = self.cond_name (n)
683			cond = self.p.fresh_var (name, boolT)
684			env[(cond.name, boolT)] = self.add_local_def (n,
685				'Cond', name, app_eqs (node.cond), env)
686			lpc = mk_and (cond, pc)
687			rpc = mk_and (mk_not (cond), pc)
688			return [(node.left, lpc, env), (node.right, rpc, env)]
689		elif node.kind == 'Call':
690			nm = self.success_name (node.fname, n)
691			success = self.solv.add_var (nm, boolT)
692			success = mk_smt_expr (success, boolT)
693			fun = functions[node.fname]
694			ins = dict ([((x, typ), smt_expr (app_eqs (arg), env, self.solv))
695				for ((x, typ), arg) in azip (fun.inputs, node.args)])
696			mem_name = None
697			for (x, typ) in reversed (fun.inputs):
698				if typ == builtinTs['Mem']:
699					inp_mem = ins[(x, typ)]
700					mem_name = (node.fname, inp_mem)
701			mem_calls = self.scan_mem_calls (ins)
702			mem_calls = self.add_mem_call (node.fname, mem_calls)
703			outs = {}
704			for ((x, typ), (y, typ2)) in azip (node.rets, fun.outputs):
705				assert typ2 == typ
706				if self.fast_const_ret (n[0], x, typ):
707					outs[(y, typ2)] = env [(x, typ)]
708					continue
709				name = self.local_name (x, n)
710				env[(x, typ)] = self.add_var (name, typ,
711					mem_name = mem_name,
712					mem_calls = mem_calls)
713				outs[(y, typ2)] = env[(x, typ)]
714			for ((x, typ), (y, _)) in azip (node.rets, fun.outputs):
715				z = self.var_rep_request ((x, typ),
716					'Call', n, env)
717				if z != None:
718					env[(x, typ)] = z
719					outs[(y, typ)] = z
720			self.add_func (node.fname, ins, outs, success, n)
721			return [(node.cont, pc, env)]
722		else:
723			assert not 'node kind understood'
724
725	def post_emit_node_hooks (self, (n, vcount)):
726		for hook in target_objects.hooks ('post_emit_node'):
727			hook (self, (n, vcount))
728
729	def fetch_known_eqs (self, n_vc, tag):
730		if not self.use_known_eqs:
731			return None
732		eqs = self.p.known_eqs.get ((n_vc, tag))
733		if eqs == None:
734			return None
735		avail = []
736		for (x, n_vc_y, tag_y, y, hyps) in eqs:
737			if hyps <= self.avail_hyps:
738				(_, env) = self.get_node_pc_env (n_vc_y, tag_y)
739				avail.append ((x, smt_expr (y, env, self.solv)))
740				self.used_hyps.update (hyps)
741		if avail:
742			return avail
743		return None
744
745	def apply_known_eqs_pc_env (self, n_vc, tag, (pc, env)):
746		eqs = self.fetch_known_eqs (n_vc, tag)
747		if eqs == None:
748			return (pc, env)
749		env = dict (env)
750		for (x, sx) in eqs:
751			if x.kind == 'Var':
752				cur_rhs = env[x.name]
753				for y in env:
754					if env[y] == cur_rhs:
755						trace ('substituted %s at %s.' % (y, n_vc))
756						env[y] = sx
757		return (pc, env)
758
759	def apply_known_eqs_tm (self, n_vc, tag):
760		eqs = self.fetch_known_eqs (n_vc, tag)
761		if eqs == None:
762			return lambda x: x
763		eqs = dict ([(x, mk_smt_expr (sexpr, x.typ))
764			for (x, sexpr) in eqs])
765		return lambda tm: logic.recursive_term_subst (eqs, tm)
766
767	def rebuild (self, solv = None):
768		requests = self.pc_env_requests
769
770		self.node_pc_env_order = []
771		self.node_pc_envs = {}
772		self.arc_pc_envs = {}
773		self.funcs = {}
774		self.pc_env_requests = set ()
775		self.induct_var_env = {}
776		self.contractions = {}
777
778		if not solv:
779			solv = Solver (produce_unsat_cores
780				= self.local_defs_unsat)
781		self.solv = solv
782
783		self.add_input_envs ()
784
785		self.used_hyps = set ()
786		run_requests (self, requests)
787
788	def add_func (self, name, inputs, outputs, success, n_vc):
789		assert n_vc not in self.funcs
790		self.funcs[n_vc] = (inputs, outputs, success)
791		for pair in pairings.get (name, []):
792			self.funcs.setdefault (pair.name, [])
793			group = self.funcs[pair.name]
794			for n_vc2 in group:
795				if self.get_func_pairing (n_vc, n_vc2):
796					self.add_func_assert (n_vc, n_vc2)
797			group.append (n_vc)
798
799	def get_func (self, n_vc, tag = None):
800		"""returns (input_env, output_env, success_var) for
801		function call at given n_vc."""
802		tag, vc = self.get_tag_vcount (n_vc, tag)
803		n_vc = (n_vc[0], vc)
804		assert self.p.nodes[n_vc[0]].kind == 'Call'
805
806		if n_vc not in self.funcs:
807			# try to ensure n_vc has been emitted
808			cont = self.get_cont (n_vc)
809			self.get_node_pc_env (cont, tag = tag)
810
811		return self.funcs[n_vc]
812
813	def get_func_pairing_nocheck (self, n_vc, n_vc2):
814		fnames = [self.p.nodes[n_vc[0]].fname,
815			self.p.nodes[n_vc2[0]].fname]
816		pairs = [pair for pair in pairings[list (fnames)[0]]
817			if set (pair.funs.values ()) == set (fnames)]
818		if not pairs:
819			return None
820		[pair] = pairs
821		if pair.funs[pair.tags[0]] == fnames[0]:
822			return (pair, n_vc, n_vc2)
823		else:
824			return (pair, n_vc2, n_vc)
825
826	def get_func_pairing (self, n_vc, n_vc2):
827		res = self.get_func_pairing_nocheck (n_vc, n_vc2)
828		if not res:
829			return res
830		(pair, l_n_vc, r_n_vc) = res
831		(lin, _, _) = self.funcs[l_n_vc]
832		(rin, _, _) = self.funcs[r_n_vc]
833		l_mem_calls = self.scan_mem_calls (lin)
834		r_mem_calls = self.scan_mem_calls (rin)
835		tags = pair.tags
836		(c, s) = mem_calls_compatible (tags, l_mem_calls, r_mem_calls)
837		if not c:
838			trace ('skipped emitting func pairing %s -> %s'
839				% (l_n_vc, r_n_vc))
840			trace ('  ' + s)
841			return None
842		return res
843
844	def get_func_assert (self, n_vc, n_vc2):
845		(pair, l_n_vc, r_n_vc) = self.get_func_pairing (n_vc, n_vc2)
846		(ltag, rtag) = pair.tags
847		(inp_eqs, out_eqs) = pair.eqs
848		(lin, lout, lsucc) = self.funcs[l_n_vc]
849		(rin, rout, rsucc) = self.funcs[r_n_vc]
850		lpc = self.get_pc (l_n_vc)
851		rpc = self.get_pc (r_n_vc)
852		envs = {ltag + '_IN': lin, rtag + '_IN': rin,
853			ltag + '_OUT': lout, rtag + '_OUT': rout}
854		inp_eqs = inst_eqs (inp_eqs, envs, self.solv)
855		out_eqs = inst_eqs (out_eqs, envs, self.solv)
856		succ_imp = mk_implies (rsucc, lsucc)
857
858		return mk_implies (foldr1 (mk_and, inp_eqs + [rpc]),
859			foldr1 (mk_and, out_eqs + [succ_imp]))
860
861	def add_func_assert (self, n_vc, n_vc2):
862		imp = self.get_func_assert (n_vc, n_vc2)
863		imp = logic.weaken_assert (imp)
864		if self.local_defs_unsat:
865			self.solv.assert_fact (imp, {}, unsat_tag = ('FunEq',
866				ln, rn))
867		else:
868			self.solv.assert_fact (imp, {})
869
870	def node_count_name (self, (n, vcount)):
871		name = str (n)
872		bits = [str (n)] + ['%s=%s' % (split, count)
873			for (split, count) in vcount]
874		return '_'.join (bits)
875
876	def get_mem_calls (self, mem_sexpr):
877		mem_sexpr = solver.parse_s_expression (mem_sexpr)
878		return self.get_mem_calls_sexpr (mem_sexpr)
879
880	def get_mem_calls_sexpr (self, mem_sexpr):
881		stores = set (['store-word32', 'store-word8', 'store-word64'])
882		if mem_sexpr in self.mem_calls:
883			return self.mem_calls[mem_sexpr]
884		elif len (mem_sexpr) == 4 and mem_sexpr[0] in stores:
885			return self.get_mem_calls_sexpr (mem_sexpr[1])
886		elif mem_sexpr[:1] == ('ite', ):
887			(_, _, x, y) = mem_sexpr
888			x_calls = self.get_mem_calls_sexpr (x)
889			y_calls = self.get_mem_calls_sexpr (y)
890			return merge_mem_calls (x_calls, y_calls)
891		elif mem_sexpr in self.solv.defs:
892			mem_sexpr = self.solv.defs[mem_sexpr]
893			return self.get_mem_calls_sexpr (mem_sexpr)
894		assert not "mem_calls fallthrough", mem_sexpr
895
896	def scan_mem_calls (self, env):
897		mem_vs = [env[(nm, typ)]
898			for (nm, typ) in env
899			if typ == syntax.builtinTs['Mem']]
900		mem_calls = [self.get_mem_calls (v)
901			for v in mem_vs if v[0] != 'SplitMem']
902		if mem_calls:
903			return foldr1 (merge_mem_calls, mem_calls)
904		else:
905			return None
906
907	def add_mem_call (self, fname, mem_calls):
908		if mem_calls == None:
909			return None
910		mem_calls = dict (mem_calls)
911		(min_calls, max_calls) = mem_calls.get (fname, (0, 0))
912		if max_calls == None:
913			mem_calls[fname] = (min_calls + 1, None)
914		else:
915			mem_calls[fname] = (min_calls + 1, max_calls + 1)
916		return mem_calls
917
918	def add_loop_mem_calls (self, split, mem_calls):
919		if mem_calls == None:
920			return None
921		fnames = set ([self.p.nodes[n].fname
922			for n in self.p.loop_body (split)
923			if self.p.nodes[n].kind == 'Call'])
924		if not fnames:
925			return mem_calls
926		mem_calls = dict (mem_calls)
927		for fname in fnames:
928			if fname not in mem_calls:
929				mem_calls[fname] = (0, None)
930			else:
931				(min_calls, max_calls) = mem_calls[fname]
932				mem_calls[fname] = (min_calls, None)
933		return mem_calls
934
935	# note these names are designed to be unique by suffix
936	# (so that smt names are independent of order of requests)
937	def local_name (self, s, n_vc):
938		return '%s_after_%s' % (s, self.node_count_name (n_vc))
939
940	def local_name_before (self, s, n_vc):
941		return '%s_v_at_%s' % (s, self.node_count_name (n_vc))
942
943	def cond_name (self, n_vc):
944		return 'cond_at_%s' % self.node_count_name (n_vc)
945
946	def path_cond_name (self, n_vc, tag):
947		return 'path_cond_to_%s_%s' % (
948			self.node_count_name (n_vc), tag)
949
950	def success_name (self, fname, n_vc):
951		bits = fname.split ('.')
952		nms = ['_'.join (bits[i:]) for i in range (len (bits))
953			if bits[i:][0].isalpha ()]
954		if nms:
955			nm = nms[-1]
956		else:
957			nm = 'fun'
958		return '%s_success_at_%s' % (nm, self.node_count_name (n_vc))
959
960	def try_inline (self, n, pc, env):
961		if not self.inliner:
962			return False
963
964		inline = self.inliner ((self.p, n))
965		if not inline:
966			return False
967
968		# make sure this node is reachable before inlining
969		if self.solv.test_hyp (mk_not (pc), env):
970			trace ('Skipped inlining at %d.' % n)
971			return False
972
973		trace ('Inlining at %d.' % n)
974		inline ()
975		raise InlineEvent ()
976
977	def incr (self, vcount, n, incr):
978		vcount2 = dict (vcount)
979		vcount2[n] = vcount2[n].incr (incr)
980		if vcount2[n] == None:
981			return None
982		return tuple (sorted (vcount2.items ()))
983
984	def get_cont (self, (n, vcount)):
985		[c] = self.p.nodes[n].get_conts ()
986		vcount2 = dict (vcount)
987		if n in vcount2:
988			vcount = self.incr (vcount, n, 1)
989		cont = (c, vcount)
990		assert self.is_cont ((n, vcount), cont)
991		return cont
992
993	def is_cont (self, (n, vcount), (n2, vcount2)):
994		if n2 not in self.p.nodes[n].get_conts ():
995			trace ('Not a graph cont.')
996			return False
997
998		vcount_d = dict (vcount)
999		vcount_d2 = dict (vcount2)
1000		if n in vcount_d2:
1001			if n in vcount_d:
1002				assert vcount_d[n].kind != 'Options'
1003			vcount_d2[n] = vcount_d2[n].incr (-1)
1004
1005		if not vcount_d <= vcount_d2:
1006			trace ('Restrictions not subset.')
1007			return False
1008
1009		for (split, count) in vcount_d2.iteritems ():
1010			if split in vcount_d:
1011				continue
1012			if self.get_reachable (split, n):
1013				return False
1014			if not count.has_zero ():
1015				return False
1016
1017		return True
1018
1019	def prevs (self, (n, vcount)):
1020		prevs = []
1021		vcount_d = dict (vcount)
1022		for p in self.p.preds[n]:
1023			if p in vcount_d:
1024				vcount2 = self.incr (vcount, p, -1)
1025				if vcount2 == None:
1026					continue
1027				prevs.append ((p, vcount2))
1028			else:
1029				prevs.append ((p, vcount))
1030		return prevs
1031
1032	def specialise (self, (n, vcount), split):
1033		vcount = dict (vcount)
1034		assert vcount[split].kind == 'Options'
1035		specs = []
1036		for n in vcount[split].opts:
1037			v = dict (vcount)
1038			v[split] = n
1039			specs.append (tuple (sorted (v.items ())))
1040		return specs
1041
1042	def get_pc (self, (n, vcount), tag = None):
1043		pc_env = self.get_node_pc_env ((n, vcount), tag = tag)
1044		if pc_env == None:
1045			trace ('Warning: unreachable n_vc, tag: %s, %s' % ((n, vcount), tag))
1046			return false_term
1047		(pc, env) = pc_env
1048		return to_smt_expr (pc, env, self.solv)
1049
1050	def to_smt_expr (self, expr, (n, vcount), tag = None):
1051		pc_env = self.get_node_pc_env ((n, vcount), tag = tag)
1052		(pc, env) = pc_env
1053		return to_smt_expr (expr, env, self.solv)
1054
1055	def get_induct_var (self, (n1, n2)):
1056		if (n1, n2) not in self.induct_var_env:
1057			vname = self.solv.add_var ('induct_i_%d_%d' % (n1, n2),
1058				word32T)
1059			self.induct_var_env[(n1, n2)] = vname
1060			self.pc_env_requests.add (((n1, n2), 'InductVar'))
1061		else:
1062			vname = self.induct_var_env[(n1, n2)]
1063		return mk_smt_expr (vname, word32T)
1064
1065	def interpret_hyp (self, hyp):
1066		return hyp.interpret (self)
1067
1068	def interpret_hyp_imps (self, hyps, concl):
1069		hyps = map (self.interpret_hyp, hyps)
1070		return logic.strengthen_hyp (syntax.mk_n_implies (hyps, concl))
1071
1072	def test_hyp_whyps (self, hyp, hyps, cache = None, fast = False,
1073			model = None):
1074		self.avail_hyps = set (hyps)
1075		if not self.used_hyps <= self.avail_hyps:
1076			self.rebuild ()
1077
1078		last_test[0] = (hyp, hyps, list (self.pc_env_requests))
1079
1080		expr = self.interpret_hyp_imps (hyps, hyp)
1081
1082		trace ('Testing hyp whyps', push = 1)
1083		trace ('requests = %s' % self.pc_env_requests)
1084
1085		expr_s = smt_expr (expr, {}, self.solv)
1086		if cache and expr_s in cache:
1087			trace ('Cached: %s' % cache[expr_s])
1088			return cache[expr_s]
1089		if fast:
1090			trace ('(not in cache)')
1091			return None
1092
1093		self.solv.add_pvalid_dom_assertions ()
1094
1095		(result, _, _) = self.solv.parallel_test_hyps ([(None, expr)],
1096			{}, model = model)
1097		trace ('Result: %s' % result, push = -1)
1098		if cache != None:
1099			cache[expr_s] = result
1100		if not result:
1101			last_failed_test[0] = last_test[0]
1102		return result
1103
1104	def test_hyp_imp (self, hyps, hyp, model = None):
1105		return self.test_hyp_whyps (self.interpret_hyp (hyp), hyps,
1106			model = model)
1107
1108	def test_hyp_imps (self, imps):
1109		last_hyp_imps[0] = imps
1110		if imps == []:
1111			return (True, None)
1112		interp_imps = list (enumerate ([self.interpret_hyp_imps (hyps,
1113				self.interpret_hyp (hyp))
1114			for (hyps, hyp) in imps]))
1115		reqs = list (self.pc_env_requests)
1116		last_test[0] = (self.interpret_hyp (hyp), hyps, reqs)
1117		self.solv.add_pvalid_dom_assertions ()
1118		result = self.solv.parallel_test_hyps (interp_imps, {})
1119		assert result[0] in [True, False], result
1120		if result[0] == False:
1121			(hyps, hyp) = imps[result[1]]
1122			last_test[0] = (self.interpret_hyp (hyp), hyps, reqs)
1123			last_failed_test[0] = last_test[0]
1124		return result
1125
1126	def replay_requests (self, reqs):
1127		for ((n, vc), tag) in reqs:
1128			self.get_node_pc_env ((n, vc), tag = tag)
1129
1130last_test = [0]
1131last_failed_test = [0]
1132last_hyp_imps = [0]
1133
1134def to_smt_expr_under_op (expr, env, solv):
1135	if expr.kind == 'Op':
1136		vals = [to_smt_expr (v, env, solv) for v in expr.vals]
1137		return syntax.adjust_op_vals (expr, vals)
1138	else:
1139		return to_smt_expr (expr, env, solv)
1140
1141def inst_eq_with_envs ((x, env1), (y, env2), solv):
1142	x = to_smt_expr_under_op (x, env1, solv)
1143	y = to_smt_expr_under_op (y, env2, solv)
1144	if x.typ == syntax.builtinTs['RelWrapper']:
1145		return logic.apply_rel_wrapper (x, y)
1146	else:
1147		return mk_eq (x, y)
1148
1149def inst_eqs (eqs, envs, solv):
1150	return [inst_eq_with_envs ((x, envs[x_addr]), (y, envs[y_addr]), solv)
1151		for ((x, x_addr), (y, y_addr)) in eqs]
1152
1153def subst_induct (expr, induct_var):
1154	substs = {('%n', word32T): induct_var}
1155	return logic.var_subst (expr, substs, must_subst = False)
1156
1157printed_hyps = {}
1158def print_hyps (hyps):
1159	hyps = tuple (hyps)
1160	if hyps in printed_hyps:
1161		trace ('hyps = %s' % printed_hyps[hyps])
1162	else:
1163		hname = 'hyp_set_%d' % (len (printed_hyps) + 1)
1164		trace ('%s = %s' % (hname, list (hyps)))
1165		printed_hyps[hname] = hyps
1166		trace ('hyps = %s' % hname)
1167
1168def merge_mem_calls (mem_calls_x, mem_calls_y):
1169	if mem_calls_x == mem_calls_y:
1170		return mem_calls_x
1171	mem_calls = {}
1172	for fname in set (mem_calls_x.keys () + mem_calls_y.keys ()):
1173		(min_x, max_x) = mem_calls_x.get (fname, (0, 0))
1174		(min_y, max_y) = mem_calls_y.get (fname, (0, 0))
1175		if None in [max_x, max_y]:
1176			max_v = None
1177		else:
1178			max_v = max (max_x, max_y)
1179		mem_calls[fname] = (min (min_x, min_y), max_v)
1180	return mem_calls
1181
1182def mem_calls_compatible (tags, l_mem_calls, r_mem_calls):
1183	if l_mem_calls == None or r_mem_calls == None:
1184		return (True, None)
1185	r_cast_calls = {}
1186	for (fname, calls) in l_mem_calls.iteritems ():
1187		pairs = [pair for pair in pairings[fname]
1188			if pair.tags == tags]
1189		if not pairs:
1190			return (None, 'no pairing for %s' % fname)
1191		assert len (pairs) <= 1, pairs
1192		[pair] = pairs
1193		r_fun = pair.funs[tags[1]]
1194		if not [nm for (nm, typ) in functions[r_fun].outputs
1195				if typ == syntax.builtinTs['Mem']]:
1196			continue
1197		r_cast_calls[pair.funs[tags[1]]] = calls
1198	for fname in set (r_cast_calls.keys () + r_mem_calls.keys ()):
1199		r_cast = r_cast_calls.get (fname, (0, 0))
1200		r_actual = r_mem_calls.get (fname, (0, 0))
1201		s = 'mismatch in calls to %s and pairs, %s / %s' % (fname,
1202			r_cast, r_actual)
1203		if r_cast[1] != None and r_cast[1] < r_actual[0]:
1204			return (None, s)
1205		if r_actual[1] != None and r_actual[1] < r_cast[0]:
1206			return (None, s)
1207	return (True, None)
1208
1209def mk_inp_env (n, args, rep):
1210	trace ('rep_graph setting up input env at %d' % n, push = 1)
1211	inp_env = {}
1212
1213	for (v_nm, typ) in args:
1214		inp_env[(v_nm, typ)] = rep.add_var (v_nm + '_init', typ,
1215			mem_name = 'Init', mem_calls = {})
1216	for (v_nm, typ) in args:
1217		z = rep.var_rep_request ((v_nm, typ), 'Init', (n, ()), inp_env)
1218		if z:
1219			inp_env[(v_nm, typ)] = z
1220
1221	trace ('done setting up input env at %d' % n, push = -1)
1222	return inp_env
1223
1224def mk_graph_slice (p, inliner = None, fast = False, mk_solver = Solver):
1225	trace ('rep_graph setting up solver', push = 1)
1226	solv = mk_solver ()
1227	trace ('rep_graph setting up solver', push = -1)
1228	return GraphSlice (p, solv, inliner, fast = fast)
1229
1230def run_requests (rep, requests):
1231	for (n_vc, tag) in requests:
1232		if tag == 'InductVar':
1233			rep.get_induct_var (n_vc)
1234		else:
1235			rep.get_pc (n_vc, tag = tag)
1236	rep.solv.add_pvalid_dom_assertions ()
1237
1238import re
1239paren_w_re = re.compile (r"(\(|\)|\w+)")
1240
1241def mk_function_link_hyps (p, call_vis, tag, adjust_eq_seq = None):
1242	(entry, _, args) = p.get_entry_details (tag)
1243	((call_site, restrs), call_tag) = call_vis
1244	assert p.nodes[call_site].kind == 'Call'
1245	entry_vis = ((entry, ()), p.node_tags[entry][0])
1246
1247	args = [syntax.mk_var (nm, typ) for (nm, typ) in args]
1248
1249	pc = pc_true_hyp (call_vis)
1250	eq_seq = logic.azip (p.nodes[call_site].args, args)
1251	if adjust_eq_seq:
1252		eq_seq = adjust_eq_seq (eq_seq)
1253	hyps = [pc] + [eq_hyp ((x, call_vis), (y, entry_vis))
1254		for (x, y) in eq_seq
1255		if x.typ.kind == 'Word' or x.typ == syntax.builtinTs['Mem']
1256			or x.typ.kind == 'WordArray']
1257
1258	return hyps
1259
1260