1# * Copyright 2015, NICTA
2# *
3# * This software may be distributed and modified according to the terms of
4# * the BSD 2-Clause license. Note that NO WARRANTY is provided.
5# * See "LICENSE_BSD2.txt" for details.
6# *
7# * @TAG(NICTA_BSD)
8
9from solver import Solver, merge_envs_pcs, smt_expr, mk_smt_expr, to_smt_expr
10from syntax import (true_term, false_term, boolT, mk_and, mk_not, mk_implies,
11	builtinTs, word32T, word8T, foldr1, mk_eq, mk_plus, mk_word32, mk_var)
12import syntax
13import logic
14import solver
15from logic import azip
16
17from target_objects import functions, pairings, sections, trace, printout
18import target_objects
19import problem
20
21class VisitCount:
22	"""Used to represent a target number of visits to a split point.
23	Options include a number (0, 1, 2), a symbolic offset (i + 1, i + 2),
24	or a list of options."""
25	def __init__ (self, kind, value):
26		self.kind = kind
27		self.is_visit_count = True
28		if kind == 'Number':
29			self.n = value
30		elif kind == 'Offset':
31			self.n = value
32		elif kind == 'Options':
33			self.opts = tuple (value)
34			for opt in self.opts:
35				assert opt.kind in ['Number', 'Offset']
36		else:
37			assert not 'VisitCount type understood'
38
39	def __hash__ (self):
40		if self.kind == 'Options':
41			return hash (self.opts)
42		else:
43			return hash (self.kind) + self.n
44
45	def __eq__ (self, other):
46		if not other:
47			return False
48		if self.kind == 'Options':
49			return (other.kind == 'Options'
50				and self.opts == other.opts)
51		else:
52			return self.kind == other.kind and self.n == other.n
53
54	def __neq__ (self, other):
55		if not other:
56			return True
57		return not (self == other)
58
59	def __str__ (self):
60		if self.kind == 'Number':
61			return str (self.n)
62		elif self.kind == 'Offset':
63			return 'i+%s' % self.n
64		elif self.kind == 'Options':
65			return '_'.join (map (str, self.opts))
66
67	def __repr__ (self):
68		(ns, os) = self.get_opts ()
69		return 'vc_options (%r, %r)' % (ns, os)
70
71	def get_opts (self):
72		if self.kind == 'Options':
73			opts = self.opts
74		else:
75			opts = [self]
76		ns = [vc.n for vc in opts if vc.kind == 'Number']
77		os = [vc.n for vc in opts if vc.kind == 'Offset']
78		return (ns, os)
79
80	def serialise (self, ss):
81		ss.append ('VC')
82		(ns, os) = self.get_opts ()
83		ss.append ('%d' % len (ns))
84		ss.extend (['%d' % n for n in ns])
85		ss.append ('%d' % len (os))
86		ss.extend (['%d' % n for n in os])
87
88	def incr (self, incr):
89		if self.kind in ['Number', 'Offset']:
90			n = self.n + incr
91			if n < 0:
92				return None
93			return VisitCount (self.kind, n)
94		elif self.kind == 'Options':
95			opts = [vc.incr (incr) for vc in self.opts]
96			opts = [opt for opt in opts if opt]
97			if opts == []:
98				return None
99			return mk_vc_opts (opts)
100		else:
101			assert not 'VisitCount type understood'
102
103	def has_zero (self):
104		if self.kind == 'Options':
105			return bool ([vc for vc in self.opts
106				if vc.has_zero ()])
107		else:
108			return self.kind == 'Number' and self.n == 0
109
110def mk_vc_opts (opts):
111	if len (opts) == 1:
112		return opts[0]
113	else:
114		return VisitCount ('Options', opts)
115
116def vc_options (nums, offsets):
117	return mk_vc_opts (map (vc_num, nums) + map (vc_offs, offsets))
118
119def vc_num (n):
120	return VisitCount ('Number', n)
121
122def vc_upto (n):
123	return mk_vc_opts (map (vc_num, range (n)))
124
125def vc_offs (n):
126	return VisitCount ('Offset', n)
127
128def vc_offset_upto (n):
129	return mk_vc_opts (map (vc_offs, range (n)))
130
131def vc_double_range (n, m):
132	return mk_vc_opts (map (vc_num, range (n)) + map (vc_offs, range (m)))
133
134class InlineEvent(Exception):
135	pass
136
137class Hyp:
138	"""Used to represent a proposition about path conditions or data at
139	various points in execution."""
140
141	def __init__ (self, kind, arg1, arg2, induct = None):
142		self.kind = kind
143		if kind == 'PCImp':
144			self.pcs = [arg1, arg2]
145		elif kind == 'Eq':
146			self.vals = [arg1, arg2]
147			self.induct = induct
148		elif kind == 'EqIfAt':
149			self.vals = [arg1, arg2]
150			self.induct = induct
151		else:
152			assert not 'hyp kind understood'
153
154	def __repr__ (self):
155		if self.kind == 'PCImp':
156			vals = map (repr, self.pcs)
157		elif self.kind in ['Eq', 'EqIfAt']:
158			vals = map (repr, self.vals)
159			if self.induct:
160				vals += [repr (self.induct)]
161		else:
162			assert not 'hyp kind understood'
163		return 'Hyp (%r, %s)' % (self.kind, ', '.join (vals))
164
165	def hyp_tuple (self):
166		if self.kind == 'PCImp':
167			return ('PCImp', self.pcs[0], self.pcs[1])
168		elif self.kind in ['Eq', 'EqIfAt']:
169			return (self.kind, self.vals[0],
170				self.vals[1], self.induct)
171		else:
172			assert not 'hyp kind understood'
173
174	def __hash__ (self):
175		return hash (self.hyp_tuple ())
176
177	def __ne__ (self, other):
178		return not other or not (self == other)
179
180	def __cmp__ (self, other):
181		return cmp (self.hyp_tuple (), other.hyp_tuple ())
182
183	def visits (self):
184		if self.kind == 'PCImp':
185			return [vis for vis in self.pcs
186				if vis[0] != 'Bool']
187		elif self.kind in ['Eq', 'EqIfAt']:
188			return [vis for (_, vis) in self.vals]
189		else:
190			assert not 'hyp kind understood'
191
192	def get_vals (self):
193		if self.kind == 'PCImp':
194			return []
195		else:
196			return [val for (val, _) in self.vals]
197
198	def serialise_visit (self, (n, restrs), ss):
199		ss.append ('%s' % n)
200		ss.append ('%d' % len (restrs))
201		for (n2, vc) in restrs:
202			ss.append ('%d' % n2)
203			vc.serialise (ss)
204
205	def serialise_pc (self, pc, ss):
206		if pc[0] == 'Bool' and pc[1] == true_term:
207			ss.append ('True')
208		elif pc[0] == 'Bool' and pc[1] == false_term:
209			ss.append ('False')
210		else:
211			ss.append ('PC')
212			serialise_visit (pc[0], ss)
213			ss.append (pc[1])
214
215	def serialise_hyp (self, ss):
216		if self.kind == 'PCImp':
217			(visit1, visit2) = self.pcs
218			ss.append ('PCImp')
219			self.serialise_pc (visit1, ss)
220			self.serialise_pc (visit2, ss)
221		elif self.kind in ['Eq', 'EqIfAt']:
222			assert len (self.vals) == 2
223			ss.extend (self.kind)
224			for (exp, visit) in self.vals:
225				exp.serialise (ss)
226				self.serialise_visit (visit, ss)
227			if induct:
228				ss.append ('%d' % induct[0])
229				ss.append ('%d' % induct[1])
230			else:
231				ss.extend (['None', 'None'])
232		else:
233			assert not 'hyp kind understood'
234
235	def interpret (self, rep):
236		if self.kind == 'PCImp':
237			((visit1, tag1), (visit2, tag2)) = self.pcs
238			if visit1 == 'Bool':
239				pc1 = tag1
240			else:
241				pc1 = rep.get_pc (visit1, tag = tag1)
242			if visit2 == 'Bool':
243				pc2 = tag2
244			else:
245				pc2 = rep.get_pc (visit2, tag = tag2)
246			return mk_implies (pc1, pc2)
247		elif self.kind in ['Eq', 'EqIfAt']:
248			[(x, xvis), (y, yvis)] = self.vals
249			if self.induct:
250				v = rep.get_induct_var (self.induct)
251				x = subst_induct (x, v)
252				y = subst_induct (y, v)
253			x_pc_env = rep.get_node_pc_env (xvis[0], tag = xvis[1])
254			y_pc_env = rep.get_node_pc_env (yvis[0], tag = yvis[1])
255			if x_pc_env == None or y_pc_env == None:
256				if self.kind == 'EqIfAt':
257					return syntax.true_term
258				else:
259					return syntax.false_term
260			((_, xenv), (_, yenv)) = (x_pc_env, y_pc_env)
261			eq = inst_eq_with_envs ((x, xenv), (y, yenv), rep.solv)
262			if self.kind == 'EqIfAt':
263				x_pc = rep.get_pc (xvis[0], tag = xvis[1])
264				y_pc = rep.get_pc (yvis[0], tag = yvis[1])
265				return syntax.mk_n_implies ([x_pc, y_pc], eq)
266			else:
267				return eq
268		else:
269			assert not 'hypothesis type understood'
270
271def check_vis_is_vis (((n, vc), tag)):
272	assert vc[:0] == (), vc
273
274def eq_hyp (lhs, rhs, induct = None, use_if_at = False):
275	check_vis_is_vis (lhs[1])
276	check_vis_is_vis (rhs[1])
277	kind = 'Eq'
278	if use_if_at:
279		kind = 'EqIfAt'
280	return Hyp (kind, lhs, rhs, induct = induct)
281
282def true_if_at_hyp (expr, vis, induct = None):
283	check_vis_is_vis (vis)
284	return Hyp ('EqIfAt', (expr, vis), (true_term, vis),
285		induct = induct)
286
287def pc_true_hyp (vis):
288	check_vis_is_vis (vis)
289	return Hyp ('PCImp', ('Bool', true_term), vis)
290
291def pc_false_hyp (vis):
292	check_vis_is_vis (vis)
293	return Hyp ('PCImp', vis, ('Bool', false_term))
294
295def pc_triv_hyp (vis):
296	check_vis_is_vis (vis)
297	return Hyp ('PCImp', vis, vis)
298
299class GraphSlice:
300	"""Used to represent a slice of potential execution in a graph where
301	looping is limited to certain specific examples. For instance, we
302	might say that execution through node n will be represented only
303	by visits 0, 1, 2, 3, i, and i + 1 (for a symbolic value i). The
304	variable state at visits 4 and i + 2 will be calculated but no
305	further execution will be done."""
306
307	def __init__ (self, p, solv, inliner = None, fast = False):
308		self.p = p
309		self.solv = solv
310		self.inp_envs = {}
311		self.mem_calls = {}
312		self.add_input_envs ()
313
314		self.node_pc_envs = {}
315		self.node_pc_env_order = []
316		self.arc_pc_envs = {}
317		self.inliner = inliner
318		self.funcs = {}
319		self.pc_env_requests = set ()
320		self.fast = fast
321		self.induct_var_env = {}
322		self.contractions = {}
323
324		self.local_defs_unsat = False
325		self.use_known_eqs = True
326
327		self.avail_hyps = set ()
328		self.used_hyps = set ()
329
330	def add_input_envs (self):
331		for (entry, _, _, args) in self.p.entries:
332			self.inp_envs[entry] = mk_inp_env (entry, args, self)
333
334	def get_reachable (self, split, n):
335		return self.p.is_reachable_from (split, n)
336
337	class TooGeneral (Exception):
338		def __init__ (self, split):
339			self.split = split
340
341	def get_tag_vcount (self, (n, vcount), tag):
342		if tag == None:
343			tag = self.p.node_tags[n][0]
344		vcount_r = [(split, count, self.get_reachable (split, n))
345			for (split, count) in vcount
346			if self.p.node_tags[split][0] == tag]
347		for (split, count, r) in vcount_r:
348			if not r and not count.has_zero ():
349				return (tag, None)
350			assert count.is_visit_count
351		vcount = [(s, c) for (s, c, r) in vcount_r if r]
352		vcount = tuple (sorted (vcount))
353
354		loop_id = self.p.loop_id (n)
355		if loop_id != None:
356			for (split, visits) in vcount:
357				if (self.p.loop_id (split) == loop_id
358						and visits.kind == 'Options'):
359					raise self.TooGeneral (split)
360
361		return (tag, vcount)
362
363	def get_node_pc_env (self, (n, vcount), tag = None, request = True):
364		tag, vcount = self.get_tag_vcount ((n, vcount), tag)
365		if vcount == None:
366			return None
367
368		if (tag, n, vcount) in self.node_pc_envs:
369			return self.node_pc_envs[(tag, n, vcount)]
370
371		if request:
372			self.pc_env_requests.add (((n, vcount), tag))
373
374		self.warm_pc_env_cache ((n, vcount), tag)
375
376		pc_env = self.get_node_pc_env_raw ((n, vcount), tag)
377		if pc_env:
378			pc_env = self.apply_known_eqs_pc_env ((n, vcount),
379				tag, pc_env)
380
381		assert not (tag, n, vcount) in self.node_pc_envs
382		self.node_pc_envs[(tag, n, vcount)] = pc_env
383		if pc_env:
384			self.node_pc_env_order.append ((tag, n, vcount))
385
386		return pc_env
387
388	def warm_pc_env_cache (self, n_vc, tag):
389		'this is to avoid recursion limits and spot bugs'
390		prev_chain = []
391		for i in range (5000):
392			prevs = self.prevs (n_vc)
393			try:
394				prevs = [p for p in prevs
395					if (tag, p[0], p[1])
396						not in self.node_pc_envs
397					if self.get_tag_vcount (p, None)
398						== (tag, n_vc[1])]
399			except self.TooGeneral:
400				break
401			if not prevs:
402				break
403			n_vc = prevs[0]
404			prev_chain.append(n_vc)
405		if not (len (prev_chain) < 5000):
406			printout ([n for (n, vc) in prev_chain])
407			assert len (prev_chain) < 5000, (prev_chain[:10],
408				prev_chain[-10:])
409
410		prev_chain.reverse ()
411		for n_vc in prev_chain:
412			self.get_node_pc_env (n_vc, tag, request = False)
413
414	def get_loop_pc_env (self, split, vcount):
415		vcount2 = dict (vcount)
416		vcount2[split] = vc_num (0)
417		vcount2 = tuple (sorted (vcount2.items ()))
418		prev_pc_env = self.get_node_pc_env ((split, vcount2))
419		if prev_pc_env == None:
420			return None
421		(_, prev_env) = prev_pc_env
422		mem_calls = self.scan_mem_calls (prev_env)
423		mem_calls = self.add_loop_mem_calls (split, mem_calls)
424		def av (nm, typ, mem_name = None):
425			nm2 = '%s_loop_at_%s' % (nm, split)
426			return self.add_var (nm2, typ,
427				mem_name = mem_name, mem_calls = mem_calls)
428		env = {}
429		consts = set ()
430		for (nm, typ) in prev_env:
431			check_const = self.fast or (typ in
432				[builtinTs['HTD'], builtinTs['Dom']])
433			if check_const and self.is_synt_const (nm, typ, split):
434				env[(nm, typ)] = prev_env[(nm, typ)]
435				consts.add ((nm, typ))
436			else:
437				env[(nm, typ)] = av (nm + '_after', typ,
438					('Loop', prev_env[(nm, typ)]))
439		for (nm, typ) in prev_env:
440			if (nm, typ) in consts:
441				continue
442			z = self.var_rep_request ((nm, typ), 'Loop',
443				(split, vcount), env)
444			if z:
445				env[(nm, typ)] = z
446
447		pc = mk_smt_expr (av ('pc_of', boolT), boolT)
448		if self.fast:
449			imp = syntax.mk_implies (pc, prev_pc_env[0])
450			self.solv.assert_fact (imp, prev_env,
451				unsat_tag = ('LoopPCImp', split))
452
453		return (pc, env)
454
455	def is_synt_const (self, nm, typ, split):
456		"""check if a variable at a split point is a syntactic constant
457		which is always unmodified by the loop.
458		we allow cases where a variable is renamed and renamed back
459		during the loop (this often happens because of inlining).
460		the check is done by depth-first-search backward through the
461		graph looking for a source of a variant value."""
462		loop = self.p.loop_id (split)
463		if problem.has_inner_loop (self.p, split):
464			return False
465		loop_set = set (self.p.loop_body (split))
466
467		orig_nm = nm
468		safe = set ([(orig_nm, split)])
469		first_step = True
470		visit = []
471		count = 0
472		while first_step or visit:
473			if first_step:
474				(nm, n) = (orig_nm, split)
475				first_step = False
476			else:
477				(nm, n) = visit.pop ()
478				if (nm, n) in safe:
479					continue
480				elif n == split:
481					return False
482			new_nm = nm
483			node = self.p.nodes[n]
484			if node.kind == 'Call':
485				if (nm, typ) not in node.rets:
486					pass
487				elif self.fast_const_ret (n, nm, typ):
488					pass
489				else:
490					return False
491			elif node.kind == 'Basic':
492				upds = [arg for (lv, arg) in node.upds
493					if lv == (nm, typ)]
494				if [v for v in upds if v.kind != 'Var']:
495					return False
496				if upds:
497					new_nm = upds[0].name
498			preds = [(new_nm, n2) for n2 in self.p.preds[n]
499				if n2 in loop_set]
500			unknowns = [p for p in preds if p not in safe]
501			if unknowns:
502				visit.extend ([(nm, n)] + unknowns)
503			else:
504				safe.add ((nm, n))
505			count += 1
506			if count % 100000 == 0:
507				trace ('is_synt_const: %d iterations' % count)
508				trace ('visit length %d' % len (visit))
509				trace ('visit tail %s' % visit[-20:])
510		return True
511
512	def fast_const_ret (self, n, nm, typ):
513		"""determine if we can heuristically consider this return
514		value to be the same as an input. this is known for some
515		function returns, e.g. memory.
516		this is important for heuristic "fast" analysis."""
517		if not self.fast:
518			return False
519		node = self.p.nodes[n]
520		assert node.kind == 'Call'
521		for hook in target_objects.hooks ('rep_unsafe_const_ret'):
522			if hook (node, nm, typ):
523				return True
524		return False
525
526	def get_node_pc_env_raw (self, (n, vcount), tag):
527		if n in self.inp_envs:
528			return (true_term, self.inp_envs[n])
529
530		for (split, count) in vcount:
531			if split == n and count == vc_offs (0):
532				return self.get_loop_pc_env (split, vcount)
533
534		pc_envs = [pc_env for n_prev in self.p.preds[n]
535			if self.p.node_tags[n_prev][0] == tag
536			for pc_env in self.get_arc_pc_envs (n_prev,
537				(n, vcount))]
538
539		pc_envs = [pc_env for pc_env in pc_envs if pc_env]
540		if pc_envs == []:
541			return None
542
543		if n == 'Err':
544			# we'll never care about variable values here
545			# and there are sometimes a LOT of arcs to Err
546			# so we save a lot of merge effort
547			pc_envs = [(to_smt_expr (pc, env, self.solv), {})
548				for (pc, env) in pc_envs]
549
550		(pc, env, large) = merge_envs_pcs (pc_envs, self.solv)
551
552		if pc.kind != 'SMTExpr':
553			name = self.path_cond_name ((n, vcount), tag)
554			name = self.solv.add_def (name, pc, env)
555			pc = mk_smt_expr (name, boolT)
556
557		for (nm, typ) in env:
558			if len (env[(nm, typ)]) > 80:
559				env[(nm, typ)] = self.contract (nm, (n, vcount),
560					env[(nm, typ)], typ)
561
562		return (pc, env)
563
564	def contract (self, name, n_vc, val, typ):
565		if val in self.contractions:
566			return self.contractions[val]
567
568		name = self.local_name_before (name, n_vc)
569		name = self.solv.add_def (name, mk_smt_expr (val, typ), {})
570
571		self.contractions[val] = name
572		return name
573
574	def get_arc_pc_envs (self, n, n_vc2):
575		try:
576			prevs = [n_vc for n_vc in self.prevs (n_vc2)
577				if n_vc[0] == n]
578			assert len (prevs) <= 1
579			return [self.get_arc_pc_env (n_vc, n_vc2)
580				for n_vc in prevs]
581		except self.TooGeneral, e:
582			# consider specialisations of the target
583			specs = self.specialise (n_vc2, e.split)
584			specs = [(n_vc2[0], spec) for spec in specs]
585			return [pc_env for spec in specs
586				for pc_env in self.get_arc_pc_envs (n, spec)]
587
588	def get_arc_pc_env (self, (n, vcount), n2):
589		tag, vcount = self.get_tag_vcount ((n, vcount), None)
590
591		if vcount == None:
592			return None
593
594		assert self.is_cont ((n, vcount), n2), ((n, vcount),
595			n2, self.p.nodes[n].get_conts ())
596
597		if (n, vcount) in self.arc_pc_envs:
598			return self.arc_pc_envs[(n, vcount)].get (n2[0])
599
600		if self.get_node_pc_env ((n, vcount), request = False) == None:
601			return None
602
603		arcs = self.emit_node ((n, vcount))
604		self.post_emit_node_hooks ((n, vcount))
605		arcs = dict ([(cont, (pc, env)) for (cont, pc, env) in arcs])
606
607		self.arc_pc_envs[(n, vcount)] = arcs
608		return arcs.get (n2[0])
609
610	def add_local_def (self, n, vname, name, val, env):
611		if self.local_defs_unsat:
612			smt_name = self.solv.add_var (name, val.typ)
613			eq = mk_eq (mk_smt_expr (smt_name, val.typ), val)
614			self.solv.assert_fact (eq, env, unsat_tag
615				= ('Def', n, vname))
616		else:
617			smt_name = self.solv.add_def (name, val, env)
618		return smt_name
619
620	def add_var (self, name, typ, mem_name = None, mem_calls = None):
621		r = self.solv.add_var_restr (name, typ, mem_name = mem_name)
622		if typ == syntax.builtinTs['Mem']:
623			r_x = solver.parse_s_expression (r)
624			self.mem_calls[r_x] = mem_calls
625		return r
626
627	def var_rep_request (self, (nm, typ), kind, n_vc, env):
628		assert type (n_vc[0]) != str
629		for hook in target_objects.hooks ('problem_var_rep'):
630			z = hook (self.p, (nm, typ), kind, n_vc[0])
631			if z == None:
632				continue
633			if z[0] == 'SplitMem':
634				assert typ == builtinTs['Mem']
635				(_, addr) = z
636				addr = smt_expr (addr, env, self.solv)
637				name = '%s_for_%s' % (nm,
638					self.node_count_name (n_vc))
639				return self.solv.add_split_mem_var (addr, name,
640					typ, mem_name = 'SplitMemNonsense')
641			else:
642				assert z == None
643
644	def emit_node (self, n):
645		(pc, env) = self.get_node_pc_env (n, request = False)
646		tag = self.p.node_tags[n[0]][0]
647		app_eqs = self.apply_known_eqs_tm (n, tag)
648		# node = logic.simplify_node_elementary (self.p.nodes[n[0]])
649		# whether to ignore unreachable Cond arcs seems to be a huge
650		# dilemma. if we ignore them, some reachable sites become
651		# unreachable and we can't interpret all hyps
652		# if we don't ignore them, the variable set disagrees with
653		# var_deps and so the abstracted loop pc/env may not be
654		# sufficient and we get EnvMiss again. I don't really know
655		# what to do about this corner case.
656		node = self.p.nodes[n[0]]
657		env = dict (env)
658
659		if node.kind == 'Call':
660			self.try_inline (n[0], pc, env)
661
662		if pc == false_term:
663			return [(c, false_term, {}) for c in node.get_conts()]
664		elif node.kind == 'Cond' and node.left == node.right:
665			return [(node.left, pc, env)]
666		elif node.kind == 'Cond' and node.cond == true_term:
667			return [(node.left, pc, env),
668				(node.right, false_term, env)]
669		elif node.kind == 'Basic':
670			upds = []
671			for (lv, v) in node.upds:
672				if v.kind == 'Var':
673					upds.append ((lv, env[(v.name, v.typ)]))
674				else:
675					name = self.local_name (lv[0], n)
676					v = app_eqs (v)
677					vname = self.add_local_def (n,
678						('Var', lv), name, v, env)
679					upds.append ((lv, vname))
680			for (lv, v) in upds:
681				env[lv] = v
682			return [(node.cont, pc, env)]
683		elif node.kind == 'Cond':
684			name = self.cond_name (n)
685			cond = self.p.fresh_var (name, boolT)
686			env[(cond.name, boolT)] = self.add_local_def (n,
687				'Cond', name, app_eqs (node.cond), env)
688			lpc = mk_and (cond, pc)
689			rpc = mk_and (mk_not (cond), pc)
690			return [(node.left, lpc, env), (node.right, rpc, env)]
691		elif node.kind == 'Call':
692			nm = self.success_name (node.fname, n)
693			success = self.solv.add_var (nm, boolT)
694			success = mk_smt_expr (success, boolT)
695			fun = functions[node.fname]
696			ins = dict ([((x, typ), smt_expr (app_eqs (arg), env, self.solv))
697				for ((x, typ), arg) in azip (fun.inputs, node.args)])
698			mem_name = None
699			for (x, typ) in reversed (fun.inputs):
700				if typ == builtinTs['Mem']:
701					inp_mem = ins[(x, typ)]
702					mem_name = (node.fname, inp_mem)
703			mem_calls = self.scan_mem_calls (ins)
704			mem_calls = self.add_mem_call (node.fname, mem_calls)
705			outs = {}
706			for ((x, typ), (y, typ2)) in azip (node.rets, fun.outputs):
707				assert typ2 == typ
708				if self.fast_const_ret (n[0], x, typ):
709					outs[(y, typ2)] = env [(x, typ)]
710					continue
711				name = self.local_name (x, n)
712				env[(x, typ)] = self.add_var (name, typ,
713					mem_name = mem_name,
714					mem_calls = mem_calls)
715				outs[(y, typ2)] = env[(x, typ)]
716			for ((x, typ), (y, _)) in azip (node.rets, fun.outputs):
717				z = self.var_rep_request ((x, typ),
718					'Call', n, env)
719				if z != None:
720					env[(x, typ)] = z
721					outs[(y, typ)] = z
722			self.add_func (node.fname, ins, outs, success, n)
723			return [(node.cont, pc, env)]
724		else:
725			assert not 'node kind understood'
726
727	def post_emit_node_hooks (self, (n, vcount)):
728		for hook in target_objects.hooks ('post_emit_node'):
729			hook (self, (n, vcount))
730
731	def fetch_known_eqs (self, n_vc, tag):
732		if not self.use_known_eqs:
733			return None
734		eqs = self.p.known_eqs.get ((n_vc, tag))
735		if eqs == None:
736			return None
737		avail = []
738		for (x, n_vc_y, tag_y, y, hyps) in eqs:
739			if hyps <= self.avail_hyps:
740				(_, env) = self.get_node_pc_env (n_vc_y, tag_y)
741				avail.append ((x, smt_expr (y, env, self.solv)))
742				self.used_hyps.update (hyps)
743		if avail:
744			return avail
745		return None
746
747	def apply_known_eqs_pc_env (self, n_vc, tag, (pc, env)):
748		eqs = self.fetch_known_eqs (n_vc, tag)
749		if eqs == None:
750			return (pc, env)
751		env = dict (env)
752		for (x, sx) in eqs:
753			if x.kind == 'Var':
754				cur_rhs = env[x.name]
755				for y in env:
756					if env[y] == cur_rhs:
757						trace ('substituted %s at %s.' % (y, n_vc))
758						env[y] = sx
759		return (pc, env)
760
761	def apply_known_eqs_tm (self, n_vc, tag):
762		eqs = self.fetch_known_eqs (n_vc, tag)
763		if eqs == None:
764			return lambda x: x
765		eqs = dict ([(x, mk_smt_expr (sexpr, x.typ))
766			for (x, sexpr) in eqs])
767		return lambda tm: logic.recursive_term_subst (eqs, tm)
768
769	def rebuild (self, solv = None):
770		requests = self.pc_env_requests
771
772		self.node_pc_env_order = []
773		self.node_pc_envs = {}
774		self.arc_pc_envs = {}
775		self.funcs = {}
776		self.pc_env_requests = set ()
777		self.induct_var_env = {}
778		self.contractions = {}
779
780		if not solv:
781			solv = Solver (produce_unsat_cores
782				= self.local_defs_unsat)
783		self.solv = solv
784
785		self.add_input_envs ()
786
787		self.used_hyps = set ()
788		run_requests (self, requests)
789
790	def add_func (self, name, inputs, outputs, success, n_vc):
791		assert n_vc not in self.funcs
792		self.funcs[n_vc] = (inputs, outputs, success)
793		for pair in pairings.get (name, []):
794			self.funcs.setdefault (pair.name, [])
795			group = self.funcs[pair.name]
796			for n_vc2 in group:
797				if self.get_func_pairing (n_vc, n_vc2):
798					self.add_func_assert (n_vc, n_vc2)
799			group.append (n_vc)
800
801	def get_func (self, n_vc, tag = None):
802		"""returns (input_env, output_env, success_var) for
803		function call at given n_vc."""
804		tag, vc = self.get_tag_vcount (n_vc, tag)
805		n_vc = (n_vc[0], vc)
806		assert self.p.nodes[n_vc[0]].kind == 'Call'
807
808		if n_vc not in self.funcs:
809			# try to ensure n_vc has been emitted
810			cont = self.get_cont (n_vc)
811			self.get_node_pc_env (cont, tag = tag)
812
813		return self.funcs[n_vc]
814
815	def get_func_pairing_nocheck (self, n_vc, n_vc2):
816		fnames = [self.p.nodes[n_vc[0]].fname,
817			self.p.nodes[n_vc2[0]].fname]
818		pairs = [pair for pair in pairings[list (fnames)[0]]
819			if set (pair.funs.values ()) == set (fnames)]
820		if not pairs:
821			return None
822		[pair] = pairs
823		if pair.funs[pair.tags[0]] == fnames[0]:
824			return (pair, n_vc, n_vc2)
825		else:
826			return (pair, n_vc2, n_vc)
827
828	def get_func_pairing (self, n_vc, n_vc2):
829		res = self.get_func_pairing_nocheck (n_vc, n_vc2)
830		if not res:
831			return res
832		(pair, l_n_vc, r_n_vc) = res
833		(lin, _, _) = self.funcs[l_n_vc]
834		(rin, _, _) = self.funcs[r_n_vc]
835		l_mem_calls = self.scan_mem_calls (lin)
836		r_mem_calls = self.scan_mem_calls (rin)
837		tags = pair.tags
838		(c, s) = mem_calls_compatible (tags, l_mem_calls, r_mem_calls)
839		if not c:
840			trace ('skipped emitting func pairing %s -> %s'
841				% (l_n_vc, r_n_vc))
842			trace ('  ' + s)
843			return None
844		return res
845
846	def get_func_assert (self, n_vc, n_vc2):
847		(pair, l_n_vc, r_n_vc) = self.get_func_pairing (n_vc, n_vc2)
848		(ltag, rtag) = pair.tags
849		(inp_eqs, out_eqs) = pair.eqs
850		(lin, lout, lsucc) = self.funcs[l_n_vc]
851		(rin, rout, rsucc) = self.funcs[r_n_vc]
852		lpc = self.get_pc (l_n_vc)
853		rpc = self.get_pc (r_n_vc)
854		envs = {ltag + '_IN': lin, rtag + '_IN': rin,
855			ltag + '_OUT': lout, rtag + '_OUT': rout}
856		inp_eqs = inst_eqs (inp_eqs, envs, self.solv)
857		out_eqs = inst_eqs (out_eqs, envs, self.solv)
858		succ_imp = mk_implies (rsucc, lsucc)
859
860		return mk_implies (foldr1 (mk_and, inp_eqs + [rpc]),
861			foldr1 (mk_and, out_eqs + [succ_imp]))
862
863	def add_func_assert (self, n_vc, n_vc2):
864		imp = self.get_func_assert (n_vc, n_vc2)
865		imp = logic.weaken_assert (imp)
866		if self.local_defs_unsat:
867			self.solv.assert_fact (imp, {}, unsat_tag = ('FunEq',
868				ln, rn))
869		else:
870			self.solv.assert_fact (imp, {})
871
872	def node_count_name (self, (n, vcount)):
873		name = str (n)
874		bits = [str (n)] + ['%s=%s' % (split, count)
875			for (split, count) in vcount]
876		return '_'.join (bits)
877
878	def get_mem_calls (self, mem_sexpr):
879		mem_sexpr = solver.parse_s_expression (mem_sexpr)
880		return self.get_mem_calls_sexpr (mem_sexpr)
881
882	def get_mem_calls_sexpr (self, mem_sexpr):
883		stores = set (['store-word32', 'store-word8', 'store-word64'])
884		if mem_sexpr in self.mem_calls:
885			return self.mem_calls[mem_sexpr]
886		elif len (mem_sexpr) == 4 and mem_sexpr[0] in stores:
887			return self.get_mem_calls_sexpr (mem_sexpr[1])
888		elif mem_sexpr[:1] == ('ite', ):
889			(_, _, x, y) = mem_sexpr
890			x_calls = self.get_mem_calls_sexpr (x)
891			y_calls = self.get_mem_calls_sexpr (y)
892			return merge_mem_calls (x_calls, y_calls)
893		elif mem_sexpr in self.solv.defs:
894			mem_sexpr = self.solv.defs[mem_sexpr]
895			return self.get_mem_calls_sexpr (mem_sexpr)
896		assert not "mem_calls fallthrough", mem_sexpr
897
898	def scan_mem_calls (self, env):
899		mem_vs = [env[(nm, typ)]
900			for (nm, typ) in env
901			if typ == syntax.builtinTs['Mem']]
902		mem_calls = [self.get_mem_calls (v)
903			for v in mem_vs if v[0] != 'SplitMem']
904		if mem_calls:
905			return foldr1 (merge_mem_calls, mem_calls)
906		else:
907			return None
908
909	def add_mem_call (self, fname, mem_calls):
910		if mem_calls == None:
911			return None
912		mem_calls = dict (mem_calls)
913		(min_calls, max_calls) = mem_calls.get (fname, (0, 0))
914		if max_calls == None:
915			mem_calls[fname] = (min_calls + 1, None)
916		else:
917			mem_calls[fname] = (min_calls + 1, max_calls + 1)
918		return mem_calls
919
920	def add_loop_mem_calls (self, split, mem_calls):
921		if mem_calls == None:
922			return None
923		fnames = set ([self.p.nodes[n].fname
924			for n in self.p.loop_body (split)
925			if self.p.nodes[n].kind == 'Call'])
926		if not fnames:
927			return mem_calls
928		mem_calls = dict (mem_calls)
929		for fname in fnames:
930			if fname not in mem_calls:
931				mem_calls[fname] = (0, None)
932			else:
933				(min_calls, max_calls) = mem_calls[fname]
934				mem_calls[fname] = (min_calls, None)
935		return mem_calls
936
937	# note these names are designed to be unique by suffix
938	# (so that smt names are independent of order of requests)
939	def local_name (self, s, n_vc):
940		return '%s_after_%s' % (s, self.node_count_name (n_vc))
941
942	def local_name_before (self, s, n_vc):
943		return '%s_v_at_%s' % (s, self.node_count_name (n_vc))
944
945	def cond_name (self, n_vc):
946		return 'cond_at_%s' % self.node_count_name (n_vc)
947
948	def path_cond_name (self, n_vc, tag):
949		return 'path_cond_to_%s_%s' % (
950			self.node_count_name (n_vc), tag)
951
952	def success_name (self, fname, n_vc):
953		bits = fname.split ('.')
954		nms = ['_'.join (bits[i:]) for i in range (len (bits))
955			if bits[i:][0].isalpha ()]
956		if nms:
957			nm = nms[-1]
958		else:
959			nm = 'fun'
960		return '%s_success_at_%s' % (nm, self.node_count_name (n_vc))
961
962	def try_inline (self, n, pc, env):
963		if not self.inliner:
964			return False
965
966		inline = self.inliner ((self.p, n))
967		if not inline:
968			return False
969
970		# make sure this node is reachable before inlining
971		if self.solv.test_hyp (mk_not (pc), env):
972			trace ('Skipped inlining at %d.' % n)
973			return False
974
975		trace ('Inlining at %d.' % n)
976		inline ()
977		raise InlineEvent ()
978
979	def incr (self, vcount, n, incr):
980		vcount2 = dict (vcount)
981		vcount2[n] = vcount2[n].incr (incr)
982		if vcount2[n] == None:
983			return None
984		return tuple (sorted (vcount2.items ()))
985
986	def get_cont (self, (n, vcount)):
987		[c] = self.p.nodes[n].get_conts ()
988		vcount2 = dict (vcount)
989		if n in vcount2:
990			vcount = self.incr (vcount, n, 1)
991		cont = (c, vcount)
992		assert self.is_cont ((n, vcount), cont)
993		return cont
994
995	def is_cont (self, (n, vcount), (n2, vcount2)):
996		if n2 not in self.p.nodes[n].get_conts ():
997			trace ('Not a graph cont.')
998			return False
999
1000		vcount_d = dict (vcount)
1001		vcount_d2 = dict (vcount2)
1002		if n in vcount_d2:
1003			if n in vcount_d:
1004				assert vcount_d[n].kind != 'Options'
1005			vcount_d2[n] = vcount_d2[n].incr (-1)
1006
1007		if not vcount_d <= vcount_d2:
1008			trace ('Restrictions not subset.')
1009			return False
1010
1011		for (split, count) in vcount_d2.iteritems ():
1012			if split in vcount_d:
1013				continue
1014			if self.get_reachable (split, n):
1015				return False
1016			if not count.has_zero ():
1017				return False
1018
1019		return True
1020
1021	def prevs (self, (n, vcount)):
1022		prevs = []
1023		vcount_d = dict (vcount)
1024		for p in self.p.preds[n]:
1025			if p in vcount_d:
1026				vcount2 = self.incr (vcount, p, -1)
1027				if vcount2 == None:
1028					continue
1029				prevs.append ((p, vcount2))
1030			else:
1031				prevs.append ((p, vcount))
1032		return prevs
1033
1034	def specialise (self, (n, vcount), split):
1035		vcount = dict (vcount)
1036		assert vcount[split].kind == 'Options'
1037		specs = []
1038		for n in vcount[split].opts:
1039			v = dict (vcount)
1040			v[split] = n
1041			specs.append (tuple (sorted (v.items ())))
1042		return specs
1043
1044	def get_pc (self, (n, vcount), tag = None):
1045		pc_env = self.get_node_pc_env ((n, vcount), tag = tag)
1046		if pc_env == None:
1047			trace ('Warning: unreachable n_vc, tag: %s, %s' % ((n, vcount), tag))
1048			return false_term
1049		(pc, env) = pc_env
1050		return to_smt_expr (pc, env, self.solv)
1051
1052	def to_smt_expr (self, expr, (n, vcount), tag = None):
1053		pc_env = self.get_node_pc_env ((n, vcount), tag = tag)
1054		(pc, env) = pc_env
1055		return to_smt_expr (expr, env, self.solv)
1056
1057	def get_induct_var (self, (n1, n2)):
1058		if (n1, n2) not in self.induct_var_env:
1059			vname = self.solv.add_var ('induct_i_%d_%d' % (n1, n2),
1060				word32T)
1061			self.induct_var_env[(n1, n2)] = vname
1062			self.pc_env_requests.add (((n1, n2), 'InductVar'))
1063		else:
1064			vname = self.induct_var_env[(n1, n2)]
1065		return mk_smt_expr (vname, word32T)
1066
1067	def interpret_hyp (self, hyp):
1068		return hyp.interpret (self)
1069
1070	def interpret_hyp_imps (self, hyps, concl):
1071		hyps = map (self.interpret_hyp, hyps)
1072		return logic.strengthen_hyp (syntax.mk_n_implies (hyps, concl))
1073
1074	def test_hyp_whyps (self, hyp, hyps, cache = None, fast = False,
1075			model = None):
1076		self.avail_hyps = set (hyps)
1077		if not self.used_hyps <= self.avail_hyps:
1078			self.rebuild ()
1079
1080		last_test[0] = (hyp, hyps, list (self.pc_env_requests))
1081
1082		expr = self.interpret_hyp_imps (hyps, hyp)
1083
1084		trace ('Testing hyp whyps', push = 1)
1085		trace ('requests = %s' % self.pc_env_requests)
1086
1087		expr_s = smt_expr (expr, {}, self.solv)
1088		if cache and expr_s in cache:
1089			trace ('Cached: %s' % cache[expr_s])
1090			return cache[expr_s]
1091		if fast:
1092			trace ('(not in cache)')
1093			return None
1094
1095		self.solv.add_pvalid_dom_assertions ()
1096
1097		(result, _, _) = self.solv.parallel_test_hyps ([(None, expr)],
1098			{}, model = model)
1099		trace ('Result: %s' % result, push = -1)
1100		if cache != None:
1101			cache[expr_s] = result
1102		if not result:
1103			last_failed_test[0] = last_test[0]
1104		return result
1105
1106	def test_hyp_imp (self, hyps, hyp, model = None):
1107		return self.test_hyp_whyps (self.interpret_hyp (hyp), hyps,
1108			model = model)
1109
1110	def test_hyp_imps (self, imps):
1111		last_hyp_imps[0] = imps
1112		if imps == []:
1113			return (True, None)
1114		interp_imps = list (enumerate ([self.interpret_hyp_imps (hyps,
1115				self.interpret_hyp (hyp))
1116			for (hyps, hyp) in imps]))
1117		reqs = list (self.pc_env_requests)
1118		last_test[0] = (self.interpret_hyp (hyp), hyps, reqs)
1119		self.solv.add_pvalid_dom_assertions ()
1120		result = self.solv.parallel_test_hyps (interp_imps, {})
1121		assert result[0] in [True, False], result
1122		if result[0] == False:
1123			(hyps, hyp) = imps[result[1]]
1124			last_test[0] = (self.interpret_hyp (hyp), hyps, reqs)
1125			last_failed_test[0] = last_test[0]
1126		return result
1127
1128	def replay_requests (self, reqs):
1129		for ((n, vc), tag) in reqs:
1130			self.get_node_pc_env ((n, vc), tag = tag)
1131
1132last_test = [0]
1133last_failed_test = [0]
1134last_hyp_imps = [0]
1135
1136def to_smt_expr_under_op (expr, env, solv):
1137	if expr.kind == 'Op':
1138		vals = [to_smt_expr (v, env, solv) for v in expr.vals]
1139		return syntax.adjust_op_vals (expr, vals)
1140	else:
1141		return to_smt_expr (expr, env, solv)
1142
1143def inst_eq_with_envs ((x, env1), (y, env2), solv):
1144	x = to_smt_expr_under_op (x, env1, solv)
1145	y = to_smt_expr_under_op (y, env2, solv)
1146	if x.typ == syntax.builtinTs['RelWrapper']:
1147		return logic.apply_rel_wrapper (x, y)
1148	else:
1149		return mk_eq (x, y)
1150
1151def inst_eqs (eqs, envs, solv):
1152	return [inst_eq_with_envs ((x, envs[x_addr]), (y, envs[y_addr]), solv)
1153		for ((x, x_addr), (y, y_addr)) in eqs]
1154
1155def subst_induct (expr, induct_var):
1156	substs = {('%n', word32T): induct_var}
1157	return logic.var_subst (expr, substs, must_subst = False)
1158
1159printed_hyps = {}
1160def print_hyps (hyps):
1161	hyps = tuple (hyps)
1162	if hyps in printed_hyps:
1163		trace ('hyps = %s' % printed_hyps[hyps])
1164	else:
1165		hname = 'hyp_set_%d' % (len (printed_hyps) + 1)
1166		trace ('%s = %s' % (hname, list (hyps)))
1167		printed_hyps[hname] = hyps
1168		trace ('hyps = %s' % hname)
1169
1170def merge_mem_calls (mem_calls_x, mem_calls_y):
1171	if mem_calls_x == mem_calls_y:
1172		return mem_calls_x
1173	mem_calls = {}
1174	for fname in set (mem_calls_x.keys () + mem_calls_y.keys ()):
1175		(min_x, max_x) = mem_calls_x.get (fname, (0, 0))
1176		(min_y, max_y) = mem_calls_y.get (fname, (0, 0))
1177		if None in [max_x, max_y]:
1178			max_v = None
1179		else:
1180			max_v = max (max_x, max_y)
1181		mem_calls[fname] = (min (min_x, min_y), max_v)
1182	return mem_calls
1183
1184def mem_calls_compatible (tags, l_mem_calls, r_mem_calls):
1185	if l_mem_calls == None or r_mem_calls == None:
1186		return (True, None)
1187	r_cast_calls = {}
1188	for (fname, calls) in l_mem_calls.iteritems ():
1189		pairs = [pair for pair in pairings[fname]
1190			if pair.tags == tags]
1191		if not pairs:
1192			return (None, 'no pairing for %s' % fname)
1193		assert len (pairs) <= 1, pairs
1194		[pair] = pairs
1195		r_fun = pair.funs[tags[1]]
1196		if not [nm for (nm, typ) in functions[r_fun].outputs
1197				if typ == syntax.builtinTs['Mem']]:
1198			continue
1199		r_cast_calls[pair.funs[tags[1]]] = calls
1200	for fname in set (r_cast_calls.keys () + r_mem_calls.keys ()):
1201		r_cast = r_cast_calls.get (fname, (0, 0))
1202		r_actual = r_mem_calls.get (fname, (0, 0))
1203		s = 'mismatch in calls to %s and pairs, %s / %s' % (fname,
1204			r_cast, r_actual)
1205		if r_cast[1] != None and r_cast[1] < r_actual[0]:
1206			return (None, s)
1207		if r_actual[1] != None and r_actual[1] < r_cast[0]:
1208			return (None, s)
1209	return (True, None)
1210
1211def mk_inp_env (n, args, rep):
1212	trace ('rep_graph setting up input env at %d' % n, push = 1)
1213	inp_env = {}
1214
1215	for (v_nm, typ) in args:
1216		inp_env[(v_nm, typ)] = rep.add_var (v_nm + '_init', typ,
1217			mem_name = 'Init', mem_calls = {})
1218	for (v_nm, typ) in args:
1219		z = rep.var_rep_request ((v_nm, typ), 'Init', (n, ()), inp_env)
1220		if z:
1221			inp_env[(v_nm, typ)] = z
1222
1223	trace ('done setting up input env at %d' % n, push = -1)
1224	return inp_env
1225
1226def mk_graph_slice (p, inliner = None, fast = False, mk_solver = Solver):
1227	trace ('rep_graph setting up solver', push = 1)
1228	solv = mk_solver ()
1229	trace ('rep_graph setting up solver', push = -1)
1230	return GraphSlice (p, solv, inliner, fast = fast)
1231
1232def run_requests (rep, requests):
1233	for (n_vc, tag) in requests:
1234		if tag == 'InductVar':
1235			rep.get_induct_var (n_vc)
1236		else:
1237			rep.get_pc (n_vc, tag = tag)
1238	rep.solv.add_pvalid_dom_assertions ()
1239
1240import re
1241paren_w_re = re.compile (r"(\(|\)|\w+)")
1242
1243def mk_function_link_hyps (p, call_vis, tag, adjust_eq_seq = None):
1244	(entry, _, args) = p.get_entry_details (tag)
1245	((call_site, restrs), call_tag) = call_vis
1246	assert p.nodes[call_site].kind == 'Call'
1247	entry_vis = ((entry, ()), p.node_tags[entry][0])
1248
1249	args = [syntax.mk_var (nm, typ) for (nm, typ) in args]
1250
1251	pc = pc_true_hyp (call_vis)
1252	eq_seq = logic.azip (p.nodes[call_site].args, args)
1253	if adjust_eq_seq:
1254		eq_seq = adjust_eq_seq (eq_seq)
1255	hyps = [pc] + [eq_hyp ((x, call_vis), (y, entry_vis))
1256		for (x, y) in eq_seq
1257		if x.typ.kind == 'Word' or x.typ == syntax.builtinTs['Mem']
1258			or x.typ.kind == 'WordArray']
1259
1260	return hyps
1261
1262